In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import egg.core as core

import random

from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
from torch_geometric.data import Batch, Dataset, Data
from torch_geometric.data.data import BaseData
from torch_geometric.nn import GATv2Conv
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset

  from .autonotebook import tqdm as notebook_tqdm


## Dataset generation

In [2]:
from dataclasses import dataclass

"""
Set the options here. I used max_len=1 for now to have a better view when printing the messages. Otherwise they become very long.
game_size doesn't matter for this set-up, I use it in the other graph game.
For better readablity, I set the embedding size to 10 (this influences training of course).
"""

@dataclass
class Options:
    # Agents
    embedding_size: int = 10 # Default: 50
    hidden_size: int = 20 # Default: 20
    sender_cell: str = 'rnn' # 'rnn', 'gru', 'lstm'
    max_len: int = 1 # Default: 1
    gs_tau: int = 1.0 # Default: 1.0

    # Training
    n_epochs: int = 10
    game_size: int = 20 # Default: 4
    vocab_size: int = 100 # Default: 100
    batch_size: int = 1 # Default: 16
    training_mode: str = 'gs'  # 'rf' for Reinforce or 'gs' for Gumbel-Softmax

opts = Options()

In [3]:
class FamilyMember:
    """
    Represents a family member with gender, age, spouse, and children.
    """

    def __init__(self, gender, age):
        self.gender = gender
        self.age = age
        self.spouse = None
        self.children = []

    def create_spouse(self):
        """
        Creates a spouse for the family member based on their gender and age.
        Returns the created spouse.
        """
        spouse_gender = 'f' if self.gender == 'm' else 'm'
        spouse_age = random.randint(max(18, self.age - 5), min(self.age + 5, 100))
        spouse = FamilyMember(spouse_gender, spouse_age)
        self.spouse = spouse
        spouse.spouse = self
        return spouse

    def create_children(self, max_children=4):
        """
        Creates children for the family member and their spouse.
        The number of children is randomly determined.
        """
        children_count = 2 #random.randint(1, max_children) # Change this to adjust the size of the graph.
        youngest_parent_age = min(self.age, self.spouse.age)
        for _ in range(children_count):
            child_gender = random.choice(['m', 'f'])
            max_child_age = max(0, youngest_parent_age - 20)
            min_child_age = max(0, youngest_parent_age - 30)
            child_age = random.randint(min_child_age, max_child_age)
            child = FamilyMember(child_gender, child_age)
            self.children.append(child)
            self.spouse.children.append(child)

def create_family_tree(generations):
    age_range = (80,100)
    root_age = random.randint(*age_range)
    root_member = FamilyMember('m', root_age)

    spouse = root_member.create_spouse()
    current_generation = [(root_member, spouse)]
    
    all_members = {0: root_member, 1: spouse}
    next_index = 2  # Start indexing from 2 as 0 and 1 are already used

    for gen in range(1, generations):
        next_generation = []
        for parent1, parent2 in current_generation:
            parent1.create_children()
            for child in parent1.children:
                all_members[next_index] = child
                next_index += 1
                if gen < generations - 1:
                    spouse = child.create_spouse()
                    all_members[next_index] = spouse
                    next_generation.append((child, spouse))
                    next_index += 1
        current_generation = next_generation

    return all_members

def create_data_object(all_members):
    # Convert genders to a binary representation and collect node features
    gender_to_binary = {'m': 0, 'f': 1}
    x = [[gender_to_binary[member.gender], member.age] for index, member in all_members.items()]

    # Prepare edge_index and edge_attr
    edge_index = []
    edge_attr = []

    for index, member in all_members.items():
        if member.spouse:
            spouse_index = list(all_members.keys())[list(all_members.values()).index(member.spouse)]
            # Add edges for spouses in both directions with the 'married' attribute
            edge_index.append([index, spouse_index])
            edge_index.append([spouse_index, index])
            edge_attr.extend([0, 0])  # 0 for 'married'

        for child in member.children:
            child_index = list(all_members.keys())[list(all_members.values()).index(child)]
            # Add edges from children to this member with the 'childOf' attribute
            edge_index.append([child_index, index])
            edge_attr.append(1)  # 1 for 'childOf'

    # Convert to PyTorch tensors
    x = torch.tensor(x, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.long)

    # Create the data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

In [4]:
import os

class FamilyGraphDataset(Dataset):
    """
    Dataset class for generating family graph data.

    Args:
        root (str): Root directory path.
        number_of_graphs (int): Number of graphs to generate.
        generations (int): Number of generations in each family tree.

    Returns:
        Data(x=[8, 2], edge_index=[2, 20], edge_attr=[20], labels=[8])
    """
    def __init__(self, root, number_of_graphs, generations, transform=None, pre_transform=None):
        self.number_of_graphs = number_of_graphs
        self.generations = generations
        super(FamilyGraphDataset, self).__init__(root, transform, pre_transform)
        self.data = None
        self.process()

    @property
    def processed_file_names(self):
        return ['family_graphs.pt']

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]
    
    def generate_labels(self, num_nodes):
        target_node_idx = random.randint(0, num_nodes - 1)
        return target_node_idx
    
    def process(self):
        if not os.path.isfile(self.processed_paths[0]):
            self.data = []
            for _ in range(self.number_of_graphs):
                family_tree = create_family_tree(self.generations)
                graph_data = create_data_object(family_tree)

                # Generate random labels for each node
                target_node_idx = self.generate_labels(graph_data.num_nodes)

                # Store the labels as an attribute of the graph_data
                graph_data.target_node_idx = target_node_idx

                self.data.append(graph_data)

            torch.save(self.data, self.processed_paths[0])
        else:
            self.data = torch.load(self.processed_paths[0])

You can set the number of graphs and generations here. In the FamilyMember class you can set the number of children. For now, to make graphs of 4 nodes the number of children is 2 and generations is 2. Meaning two parents and two children = 4.  

In [5]:
dataset = FamilyGraphDataset(root='.../data', number_of_graphs=100, generations=2)
print(dataset[0])

Data(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=0)


In [6]:
total_nodes = 0
for i in range(0, len(dataset)):
    total_nodes += dataset[i].x.shape[0]  # shape[0] gives the number of nodes in each graph

average_nodes = total_nodes / len(dataset)
print("Average number of nodes:", average_nodes)

Average number of nodes: 4.0


## Agents

Here are the agents. I realised later that I did the embedding layers seperatley for each agent. Meaning the sender embeds the graph, then the receiver embeds the graph on its own. This could possibly lead to the receiver having very different embedding of the graph and thus having more difficuly understanding the message and finding the target node. So now the GAT layer is outside the agents.

In [7]:
class GAT(torch.nn.Module):
    def __init__(self, num_node_features, embedding_size):
        super().__init__()
        self.conv1 = GATv2Conv(num_node_features, embedding_size, edge_dim=1, heads=2, concat=True)
        self.conv2 = GATv2Conv(-1, embedding_size, edge_dim=1, heads=2, concat=True)

    def forward(self, data):
        h, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        h = self.conv1(x=h, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.relu(h)

        h = self.conv2(x=h, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.relu(h)
        return h

In [8]:
class Sender(nn.Module):
    def __init__(self, embedding_size, hidden_size, temperature):
        super(Sender, self).__init__()
        self.temp = temperature
        self.hidden_size = hidden_size

        self.gat = GAT(2, embedding_size) # 2 is num_node features, a little ugly to do it like this but it doesn't change 
        self.fc = nn.Linear((embedding_size * 2), hidden_size) # because I use 2 heads and concatenate them, embedding_size * 2

    def forward(self, x, _aux_input):
        data = _aux_input

        target_node_idx = data.target_node_idx

        h = self.gat(data)

        target_embedding = h[target_node_idx]           

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size)

class Receiver(nn.Module):
    def __init__(self, embedding_size, hidden_size):
        super(Receiver, self).__init__()

        self.gat = GAT(2, embedding_size)
        self.fc = nn.Linear(hidden_size, (embedding_size * 2))

    def forward(self, message, _input, _aux_input):
        data = _aux_input

        h = self.gat(data)   

        message_embedding = self.fc(message)                 

        dot_products = torch.matmul(h, message_embedding.t()).t()   

        probabilities = F.log_softmax(dot_products, dim=1)                      

        return probabilities

## Ouput for a single graph

You can see the graph and its corresponding target node below, indexing starts at 0, so if the target node is 0 it takes the 1st node of the graph. The range is set in the dataset building function above:
`target_node_idx = random.randint(0, num_nodes - 1)`. The graph is passed to both the sender and receiver, with the receiver also getting the sender message of course.

In [9]:
sender = Sender(embedding_size=opts.embedding_size, hidden_size=opts.hidden_size, temperature=opts.gs_tau) 
receiver = Receiver(embedding_size=opts.embedding_size, hidden_size=opts.hidden_size) 

graph = dataset[0]
print(graph)

# Sender produces a message
sender_output = sender(None, graph)
print("Sender's message:", sender_output)
print("Sender's shape: ", sender_output.shape) # (n graphs, hidden size)

# Receiver tries to identify the target node
receiver_output = receiver(sender_output, None, graph)
print("Receiver's output:", receiver_output) # returns log probs for each node
print("Receiver's shape: ", receiver_output.shape) # (n graphs, n nodes)

Data(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=0)
Sender's message: tensor([[ 4.9285,  0.3794,  0.9473,  2.3869, -0.5996,  4.6859,  3.9219,  3.5665,
         -4.3103,  0.3931, -1.7749,  5.4911,  7.6793, -5.1749,  4.6028, -3.2139,
         10.2020, -5.5374, -2.0164, -8.2498]], grad_fn=<ViewBackward0>)
Sender's shape:  torch.Size([1, 20])
Receiver's output: tensor([[-4.7771e+01, -4.7566e+01, -2.0946e-03, -6.1694e+00]],
       grad_fn=<LogSoftmaxBackward0>)
Receiver's shape:  torch.Size([1, 4])


In [10]:
"""
The following is for printing the embeddings without having to cope with it printing every epoch later.
"""

class SenderPrint(nn.Module):
    def __init__(self, embedding_size, hidden_size, temperature):
        super(SenderPrint, self).__init__()
        self.temp = temperature
        self.hidden_size = hidden_size

        self.gcn = GAT(2, embedding_size) # 2 is num_node features, a little ugly to do it like this but it doesn't change 
        self.fc = nn.Linear((embedding_size * 2), hidden_size) # because I use 2 heads and concatenate them, embedding_size * 2

    def forward(self, x, _aux_input):
        data = _aux_input

        target_node_idx = data.target_node_idx

        h = self.gcn(data)

        target_embedding = h[target_node_idx]           

        output = self.fc(target_embedding)                           

        return output.view(-1, self.hidden_size), h, target_embedding

In [11]:
sender_print = SenderPrint(embedding_size=opts.embedding_size, hidden_size=opts.hidden_size, temperature=opts.gs_tau)

sender_output, h, target_emb = sender_print(None, graph)

print("Input: ", graph)
print("h:", h)
print("target_embedding:", target_emb)

Input:  Data(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=0)
h: tensor([[23.4326,  0.0000,  0.0000,  0.0000, 36.9734,  1.6236,  0.0000,  6.3429,
          0.0000,  8.4839, 13.2722,  0.0000, 36.4619,  0.0000,  4.6447, 20.2840,
         25.0810,  1.2073,  0.0000,  8.3897],
        [23.4102,  0.0000,  0.0000,  0.0000, 36.9806,  1.6271,  0.0000,  6.2997,
          0.0000,  8.5159, 13.2796,  0.0000, 36.4762,  0.0000,  4.6396, 20.2842,
         25.0921,  1.2070,  0.0000,  8.3960],
        [19.2740,  0.0000,  0.0000,  0.0000, 30.0949,  3.3817,  0.0000,  5.4792,
          0.0000,  7.7906, 11.4679,  0.0000, 33.1418,  0.0000,  4.5721, 18.7955,
         23.6619,  0.7205,  0.0000,  8.3582],
        [20.4764,  0.0000,  0.0000,  0.0000, 31.9581,  3.5842,  0.0000,  5.8217,
          0.0000,  8.2490, 12.1901,  0.0000, 35.2062,  0.0000,  4.8491, 19.9479,
         25.1093,  0.7920,  0.0000,  8.8734]], grad_fn=<ReluBackward0>)
target_embedding: tensor([23.4326,  0.0000,  0.0000,  0.0000, 3

So the embedding h is size (nodes x embedding size). The input is printed above, you can check if the target embedding taken to the linear function is indeed the right index from the entire embedding.

In [12]:
sender_gs = core.RnnSenderGS(sender, opts.vocab_size, opts.embedding_size, opts.hidden_size, max_len=opts.max_len, temperature=opts.gs_tau, cell=opts.sender_cell)
receiver_gs = core.RnnReceiverGS(receiver, opts.vocab_size, opts.embedding_size, opts.hidden_size, cell=opts.sender_cell)

# Sender produces a message
sender_output = sender_gs(None, graph)
# print("Sender's message:", sender_output)
print("Sender's shape:", sender_output.shape) # (n graphs, max_len+1, vocab size)

# Receiver tries to identify the target node
receiver_output = receiver_gs(sender_output, None, graph)
# print("Receiver's output:", receiver_output)
print("Receiver's shape:", receiver_output.shape) # (n graphs, max_len+1, n nodes)

Sender's shape: torch.Size([1, 2, 100])
Receiver's shape: torch.Size([1, 2, 4])


In [13]:
print(graph)

def loss_nll(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input):
    """
    NLL loss - differentiable and can be used with both GS and Reinforce
    """
    print(labels)
    labels = torch.tensor([labels]).long()
    nll = F.nll_loss(receiver_output, labels, reduction="none")
    acc = (labels == receiver_output.argmax(dim=1)).float().mean()
    return nll, {"acc": acc}

game = core.SenderReceiverRnnGS(sender_gs, receiver_gs, loss_nll)

loss, interaction = game(sender_input=None, labels=graph.target_node_idx, receiver_input=None, aux_input=graph)
print(loss)
print("====================================")
# print(interaction)

Data(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=0)
0
0
tensor(1.1499, grad_fn=<MeanBackward0>)


This was an overview of the output for a single graph. Because the labels are in the same object as the graph data they never get mixed up (I believe). 

## Batching

The following is when batching the graphs. Like I explained I need a size of 1 because otherwise indexing gets messed up. I will try to explain what happens in that case later. For now we use a size of 1.

In [14]:
"""
This is a modified version of your dataloader.
"""

class Collater:
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        self.game_size = game_size
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        elem = batch[0]
        if isinstance(elem, BaseData):
            batch = batch[:((len(batch) // self.game_size) * self.game_size)]  # we throw away the last batch_size % game_size
            batch = Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
            # we return a tuple (sender_input, labels, receiver_input, aux_input)
            # we use aux_input to store minibatch of graphs
            return (
                torch.zeros(len(batch) // self.game_size, 1),  # we don't need sender_input --> create a fake one
                batch.target_node_idx,  # the target is aways the first graph among game_size graphs
                None,  # we don't care about receiver_input
                batch  # this is a compact data for batch_size graphs 
            )

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

    def collate_fn(self, batch: List[Any]) -> Any:
        if isinstance(self.dataset, OnDiskDataset):
            return self(self.dataset.multi_get(batch))
        return self(batch)


class DataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        self.game_size = game_size
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        self.collator = Collater(game_size, dataset, follow_batch, exclude_keys)

        if isinstance(dataset, OnDiskDataset):
            dataset = range(len(dataset))

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=self.collator.collate_fn,
            **kwargs,
        )

In [15]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and validation sets
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Print the lengths of the training and validation sets
print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))


Training set length: 80
Validation set length: 20


In [16]:
"""
game_size = 1 because we don't need distractor graphs. The nodes are distractors.
"""

train_loader = DataLoader(game_size=1, dataset=train_data, batch_size=opts.batch_size, shuffle=True)
val_loader = DataLoader(game_size=1, dataset=val_data, batch_size=opts.batch_size, shuffle=True)

## Outputs for batch

The batch contains some irrelevant data. The second element is the label (or target node), this time returned as tensor. If it were a batch of more graphs it could look like this: tensor([3, 1, 2, 2]) for example. Meaning 4 graphs in the batch, each element defining the target node in an individual graph. The final element is the input data for the sender and receiver classes.

In [17]:
batch = next(iter(train_loader))
print(batch)
print("====================================")
print(*batch, sep="\n")

(tensor([[0.]]), tensor([1]), None, DataBatch(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=[1], batch=[4], ptr=[2]))
tensor([[0.]])
tensor([1])
None
DataBatch(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=[1], batch=[4], ptr=[2])


In [18]:
# Sender produces a message
sender_output = sender(None, batch[3])
#print("Sender's message:", sender_output)
print("Sender's shape: ", sender_output.shape)

# Receiver tries to identify the target node
receiver_output = receiver(sender_output, None, batch[3])
#print("Receiver's output:", receiver_output)
print("Receiver's shape: ", receiver_output.shape)

Sender's shape:  torch.Size([1, 20])
Receiver's shape:  torch.Size([1, 4])


In [19]:
# Sender produces a message
sender_output = sender_gs(None, batch[3])
#print("Sender's message:", sender_output)
print("Sender's shape:", sender_output.shape) # batch size x max_len+1 x vocab size

# Receiver tries to identify the target node
receiver_output = receiver_gs(sender_output, None, batch[3])
#print("Receiver's output:", receiver_output)
print("Receiver's output shape:", receiver_output.shape)

Sender's shape: torch.Size([1, 2, 100])
Receiver's output shape: torch.Size([1, 2, 4])


The output dimensions are same as when we used a single graph of course, since we have 1 graph in the batch. If you were to increase batch size the dimensions would change from 1 to 4 for example.

In [20]:
def loss_nll(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input):
    """
    NLL loss - differentiable and can be used with both GS and Reinforce
    """
    nll = F.nll_loss(receiver_output, labels, reduction="none")
    acc = (labels == receiver_output.argmax(dim=1)).float().mean()
    return nll, {"acc": acc}

game = core.SenderReceiverRnnGS(sender_gs, receiver_gs, loss_nll)

loss, interaction = game(*batch)
print(loss)
print("====================================")
#print(interaction)

tensor(1.1475, grad_fn=<MeanBackward0>)


In [21]:
game = core.SenderReceiverRnnGS(sender_gs, receiver_gs, loss_nll)

opts = core.init(params=['--random_seed=7', 
                         '--lr=1e-2',   
                         f'--batch_size={opts.batch_size}',
                         '--optimizer=adam'])

optimizer = torch.optim.Adam(game.parameters())

trainer = core.Trainer(
    game=game, 
    optimizer=optimizer, 
    train_data=train_loader,
    validation_data=val_loader, 
    callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True)]
)

trainer.train(n_epochs=30)

{"loss": 1.4233074188232422, "acc": 0.2751663327217102, "length": 1.9960514307022095, "mode": "train", "epoch": 1}
{"loss": 1.4135353565216064, "acc": 0.30000001192092896, "length": 2.0, "mode": "test", "epoch": 1}
{"loss": 1.4137192964553833, "acc": 0.26232555508613586, "length": 1.9924099445343018, "mode": "train", "epoch": 2}
{"loss": 1.386225938796997, "acc": 0.3499999940395355, "length": 2.0, "mode": "test", "epoch": 2}
{"loss": 1.3950414657592773, "acc": 0.20050778985023499, "length": 1.9913307428359985, "mode": "train", "epoch": 3}
{"loss": 1.3741767406463623, "acc": 0.44999998807907104, "length": 2.0, "mode": "test", "epoch": 3}
{"loss": 1.3842742443084717, "acc": 0.29994508624076843, "length": 1.9954662322998047, "mode": "train", "epoch": 4}
{"loss": 1.3663982152938843, "acc": 0.4000000059604645, "length": 2.0, "mode": "test", "epoch": 4}
{"loss": 1.3679454326629639, "acc": 0.2534807324409485, "length": 1.9879167079925537, "mode": "train", "epoch": 5}
{"loss": 1.23709225654602

## Notes

- Training improved when I put the GAT layer outside of the agents, but sometimes it gets to 0.6 acc and sometimes 0.35 so it's still a little weird. Also it goes down after a few epochs...
- Batching is still not possible for sizes larger than 1. Which is not a very big problem but can become one if I need to run more epochs and bigger datasets.
- Phong set `aux_input[k] = None` in interactions.py in order to be able to run with graphs, I did the same but I also set `receiver_output = None`. This is because when I used graphs of different sizes I kept getting errors, the output has a different size depending on the amount of nodes, and it kept expecting a tensor with dimensions of the first output. This is just for logging so I don't think it matters much.

I will try to explain my batching problem below.

In [22]:
train_loader = DataLoader(game_size=1, dataset=train_data, batch_size=2, shuffle=True)
val_loader = DataLoader(game_size=1, dataset=val_data, batch_size=2, shuffle=True)

batch = next(iter(train_loader))
print(batch)
print("====================================")
print(*batch, sep="\n")

(tensor([[0.],
        [0.]]), tensor([3, 3]), None, DataBatch(x=[8, 2], edge_index=[2, 16], edge_attr=[16], target_node_idx=[2], batch=[8], ptr=[3]))
tensor([[0.],
        [0.]])
tensor([3, 3])
None
DataBatch(x=[8, 2], edge_index=[2, 16], edge_attr=[16], target_node_idx=[2], batch=[8], ptr=[3])


The batch now contains 2 graphs, meaning 4x2=8 nodes. The labels are a tensor resulting from concatenating the specific target nodes of an individual graph. 

In [23]:
sender_print = SenderPrint(embedding_size=10, hidden_size=20, temperature=1.0)

sender_output, h, target_emb = sender_print(None, batch[3])

print("Input: ", graph)
print("h:", h)
print("target_embedding:", target_emb)

Input:  Data(x=[4, 2], edge_index=[2, 8], edge_attr=[8], target_node_idx=0)
h: tensor([[3.4949e+00, 1.0353e+01, 3.4148e+01, 0.0000e+00, 0.0000e+00, 3.1618e+01,
         2.0880e+01, 4.9061e+00, 0.0000e+00, 0.0000e+00, 7.5956e+00, 6.9070e+00,
         0.0000e+00, 2.1370e+00, 1.2696e+01, 1.4778e+01, 2.1560e-01, 1.2968e+00,
         0.0000e+00, 0.0000e+00],
        [3.4949e+00, 1.0353e+01, 3.4148e+01, 0.0000e+00, 0.0000e+00, 3.1618e+01,
         2.0880e+01, 4.9061e+00, 0.0000e+00, 0.0000e+00, 7.6422e+00, 6.9264e+00,
         0.0000e+00, 2.1913e+00, 1.2723e+01, 1.4853e+01, 2.0809e-01, 1.3289e+00,
         0.0000e+00, 0.0000e+00],
        [3.6427e+00, 1.0457e+01, 3.4259e+01, 0.0000e+00, 0.0000e+00, 3.1818e+01,
         2.0967e+01, 4.9298e+00, 0.0000e+00, 0.0000e+00, 5.3782e+00, 4.7993e+00,
         0.0000e+00, 1.0397e+00, 1.0988e+01, 1.1264e+01, 3.0407e-01, 0.0000e+00,
         5.0659e-01, 0.0000e+00],
        [3.3749e+00, 1.0269e+01, 3.4058e+01, 0.0000e+00, 0.0000e+00, 3.1455e+01,
         

As you can see the output becomes (8 x embedding_size) because now there are 8 nodes. If you look at the indexing, for me the labels were 0 and 1, it takes the 0th and 1st element of the entire embedding. Meaning 2 nodes are passed from the first graph, and no nodes from the second graph. I need some way of knowing when a graph ends, which is easier when it is always the same size, but if I want more variation it doesn't work. I can't figure out how to solve this. Scatter would unwrap the batch but like I said it returns a mean or sum, which leaves out the possibility of identifyig the target node.

In [24]:
# Sender produces a message
sender_output = sender(None, batch[3])
#print("Sender's message:", sender_output)
print("Sender's shape: ", sender_output.shape)

# Receiver tries to identify the target node
receiver_output = receiver(sender_output, None, batch[3])
#print("Receiver's output:", receiver_output)
print("Receiver's shape: ", receiver_output.shape)

Sender's shape:  torch.Size([2, 20])
Receiver's shape:  torch.Size([2, 8])
