In [2]:
import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)

In [None]:
from options import Options
from archs.run_series import run_series

import torch
import torch.nn as nn
import torch.nn.functional as F
from archs.network import GAT, Transform

from archs.distractors import select_distractors, select_fixed_distractors

class SenderRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, temperature):
        super(SenderRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature
          
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(2 * embedding_size, hidden_size) 

    def forward(self, x, _aux_input):
        data = _aux_input

        batch_ptr, target_node_idx, ego_idx = data.ptr, data.target_node_idx, data.ego_node_idx

        h = self.layer(data)

        adjusted_ego_idx = ego_idx + batch_ptr[:-1]
        adjusted_target_node_idx = target_node_idx + batch_ptr[:-1]
  
        target_embedding = torch.cat((h[adjusted_target_node_idx], h[adjusted_ego_idx]), dim=1) 

        output = self.fc(target_embedding)   

        return output # batch_size x hidden_size

class ReceiverRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, distractors):
        super(ReceiverRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.distractors = distractors
        
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads)
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input
        h = self.layer(data)

        print(data.ego_node_idx)
        print(data.ego_node)
        # indices, _ = select_distractors(
        #     data, 
        #     self.distractors if not getattr(data, 'evaluation', False) else len(data.target_node) - 1,
        #     evaluation=getattr(data, 'evaluation', False)
        # )

        indices, _ = select_fixed_distractors(
            data, 
            self.distractors if not getattr(data, 'evaluation', False) else len(data.target_node) - 1,
            evaluation=getattr(data, 'evaluation', False)
        )

        embeddings = h[indices]

        batch_size = data.num_graphs
        num_candidates = embeddings.size(0) // batch_size

        embeddings = embeddings.view(batch_size, num_candidates, -1)
        message = self.fc(message)
        message = message.unsqueeze(2)  

        dot_products = torch.bmm(embeddings, message).squeeze(-1)  
        log_probabilities = F.log_softmax(dot_products, dim=1)

        # add small random noise
        log_probabilities = log_probabilities + 1e-10 * torch.randn_like(log_probabilities)
        
        return log_probabilities