In [15]:
import torch
import torch.nn as nn
from vector_quantize_pytorch import VectorQuantize

vq_layer = VectorQuantize(
    dim=16,                # dimension of the input features
    codebook_size=32,      # number of possible discrete codes
    decay=0.8              # decay for the exponential moving average in the codebook update
)

class Sender(nn.Module):
    def __init__(self, input_dim=10, output_dim=16):
        super(Sender, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
    
class Receiver(nn.Module):
    def __init__(self, input_dim=16, output_dim=1):
        super(Receiver, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

sender = Sender()
receiver = Receiver()
input_data = torch.randn((1, 10))  # random input to the sender

sender_output = sender(input_data)
print("Sender Output (Continuous):", sender_output)

quantized_output, indices, _ = vq_layer(sender_output)
print("Quantized Output:", quantized_output)
print("Quantization Indices:", indices)

receiver_output = receiver(quantized_output)
print("Receiver Output:", receiver_output)


Sender Output (Continuous): tensor([[ 0.6104, -0.1569, -0.1290, -0.6602, -0.7849,  0.3486, -0.7926, -0.1663,
          0.0035,  0.2201,  1.1152, -0.6269, -0.0760, -1.2679,  0.3927, -0.2035]],
       grad_fn=<AddmmBackward0>)
Quantized Output: tensor([[ 0.0984,  0.0992,  0.0761, -0.0454, -0.0139, -0.0485, -0.0839,  0.0201,
          0.0010,  0.0679,  0.0735, -0.0461, -0.0019, -0.0817,  0.0678,  0.0102]],
       grad_fn=<ViewBackward0>)
Quantization Indices: tensor([30])
Receiver Output: tensor([[0.0918]], grad_fn=<AddmmBackward0>)


In [13]:
from options import Options
from archs.run_series import run_experiment

import torch
import torch.nn as nn
import torch.nn.functional as F
from archs.network import GAT, Transform

from archs.distractors import select_distractors

code_book = 512

class SenderRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, temperature):
        super(SenderRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.hidden_size = hidden_size
        self.temp = temperature
          
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads) 
        self.fc = nn.Linear(2 * embedding_size, hidden_size) 

        self.vq_layer = VectorQuantize(
            dim = hidden_size,
            codebook_size = code_book,
            decay = 0.8
        )

    def forward(self, x, _aux_input):
        data = _aux_input

        batch_ptr, target_node_idx, ego_idx = data.ptr, data.target_node_idx, data.ego_node_idx

        h = self.layer(data)

        adjusted_ego_idx = ego_idx + batch_ptr[:-1]
        adjusted_target_node_idx = target_node_idx + batch_ptr[:-1]
  
        target_embedding = torch.cat((h[adjusted_target_node_idx], h[adjusted_ego_idx]), dim=1) 

        output = self.fc(target_embedding)   

        quantized_output, indices, _ = self.vq_layer(output)

        return quantized_output # batch_size x hidden_size

class ReceiverRel(nn.Module):
    def __init__(self, num_node_features, embedding_size, heads, layer, hidden_size, distractors):
        super(ReceiverRel, self).__init__()
        self.num_node_features = num_node_features
        self.heads = heads
        self.distractors = distractors
        
        self.layer = Transform(self.num_node_features, embedding_size, heads) if layer == 'transform' else GAT(self.num_node_features, embedding_size, heads)
        self.fc = nn.Linear(hidden_size, embedding_size)

    def forward(self, message, _input, _aux_input):
        data = _aux_input
        h = self.layer(data)

        indices, _ = select_distractors(
            data, 
            self.distractors if not getattr(data, 'evaluation', False) else len(data.target_node) - 1,
            evaluation=getattr(data, 'evaluation', False)
        )

        embeddings = h[indices]

        batch_size = data.num_graphs
        num_candidates = embeddings.size(0) // batch_size

        embeddings = embeddings.view(batch_size, num_candidates, -1)
        message = self.fc(message)
        message = message.unsqueeze(2)  

        dot_products = torch.bmm(embeddings, message).squeeze(-1)  
        log_probabilities = F.log_softmax(dot_products, dim=1)

        # add small random noise
        log_probabilities = log_probabilities + 1e-10 * torch.randn_like(log_probabilities)
        
        return log_probabilities

In [14]:
options_input = Options(n_epochs=10, distractors=4, prune_graph=True)
results = run_experiment(options_input, f'results/{options_input.need_probs}')

Dataset: data/uniform
epoch=1, mode=train, loss=1.6087253093719482, acc=0.2029999941587448
epoch=1, mode=test, loss=1.6153666973114014, acc=0.19900000095367432
epoch=2, mode=train, loss=1.4476107358932495, acc=0.32419997453689575
epoch=2, mode=test, loss=1.5324656963348389, acc=0.3399999737739563
epoch=3, mode=train, loss=1.2681119441986084, acc=0.3837999999523163
epoch=3, mode=test, loss=1.3369066715240479, acc=0.39100003242492676
epoch=4, mode=train, loss=1.193585991859436, acc=0.40539997816085815
epoch=4, mode=test, loss=1.3519017696380615, acc=0.3619999885559082
epoch=5, mode=train, loss=1.1216416358947754, acc=0.4431999921798706, eval_acc=0.03125, complexity=0, information_loss=0
epoch=5, mode=test, loss=1.425414800643921, acc=0.3720000088214874
epoch=6, mode=train, loss=1.0616799592971802, acc=0.47540000081062317
epoch=6, mode=test, loss=0.9986928701400757, acc=0.5010000467300415
epoch=7, mode=train, loss=0.9840986728668213, acc=0.5248000025749207
epoch=7, mode=test, loss=0.99428