In [4]:
import torch
import torch.nn as nn
from vector_quantize_pytorch import VectorQuantize

vq_layer = VectorQuantize(
    dim=16,                # dimension of the input features
    codebook_size=32,      # number of possible discrete codes
    decay=0.8              # decay for the exponential moving average in the codebook update
)

class Sender(nn.Module):
    def __init__(self, input_dim=10, output_dim=16):
        super(Sender, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
    
class Receiver(nn.Module):
    def __init__(self, input_dim=16, output_dim=1):
        super(Receiver, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

sender = Sender()
receiver = Receiver()
input_data = torch.randn((1, 10))  # random input to the sender

sender_output = sender(input_data)
print("Sender Output (Continuous):", sender_output)

quantized_output, indices, commit_loss = vq_layer(sender_output)
print("Quantized Output:", quantized_output)
print("Quantization Indices:", indices)
print("Test:", commit_loss)

receiver_output = receiver(quantized_output)
print("Receiver Output:", receiver_output)


Sender Output (Continuous): tensor([[ 0.2297, -0.0911, -0.9183,  0.9135,  0.6772, -1.4203, -0.7435, -0.0276,
         -0.1114, -0.2450, -0.7967, -0.2640, -1.5369, -0.3758,  0.2145,  1.5033]],
       grad_fn=<AddmmBackward0>)
Quantized Output: tensor([[ 0.0656, -0.0100, -0.0638,  0.1032, -0.0405, -0.0616, -0.0976,  0.0048,
          0.0480,  0.0240,  0.0214, -0.0821, -0.1011,  0.0019,  0.0987,  0.0625]],
       grad_fn=<ViewBackward0>)
Quantization Indices: tensor([22])
Test: tensor([0.5808], grad_fn=<AddBackward0>)
Receiver Output: tensor([[-0.1836]], grad_fn=<AddmmBackward0>)


In [2]:
from options import Options
from archs.run_series import run_experiment

import torch
import torch.nn as nn
import torch.nn.functional as F
from archs.network import GAT, Transform

from archs.distractors import select_distractors

code_book = 512

class SenderRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, temperature):
        super(SenderRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature
          
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(2 * embedding_size, hidden_size) 

        self.vq_layer = VectorQuantize(
            dim = hidden_size,
            codebook_size = code_book,
            decay = 0.8
        )

    def forward(self, x, _aux_input):
        data = _aux_input

        batch_ptr, target_node_idx, ego_idx = data.ptr, data.target_node_idx, data.ego_node_idx

        h = self.layer(data)

        adjusted_ego_idx = ego_idx + batch_ptr[:-1]
        adjusted_target_node_idx = target_node_idx + batch_ptr[:-1]
  
        target_embedding = torch.cat((h[adjusted_target_node_idx], h[adjusted_ego_idx]), dim=1) 

        output = self.fc(target_embedding)   

        quantized_output, indices, _ = self.vq_layer(output)

        return quantized_output # batch_size x hidden_size

class ReceiverRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, distractors):
        super(ReceiverRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.distractors = distractors
        
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads)
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input
        h = self.layer(data)

        indices, _ = select_distractors(
            data, 
            self.distractors if not getattr(data, 'evaluation', False) else len(data.target_node) - 1,
            evaluation=getattr(data, 'evaluation', False)
        )

        embeddings = h[indices]

        batch_size = data.num_graphs
        num_candidates = embeddings.size(0) // batch_size

        embeddings = embeddings.view(batch_size, num_candidates, -1)
        message = self.fc(message)
        message = message.unsqueeze(2)  

        dot_products = torch.bmm(embeddings, message).squeeze(-1)  
        log_probabilities = F.log_softmax(dot_products, dim=1)

        # add small random noise
        log_probabilities = log_probabilities + 1e-10 * torch.randn_like(log_probabilities)
        
        return log_probabilities

In [3]:
options_input = Options(n_epochs=10, distractors=4, prune_graph=True)
results = run_experiment(options_input, f'results/{options_input.need_probs}')

Dataset: data/uniform


TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple