In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import egg.core as core
from options import Options
opts = Options()

from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
from torch_geometric.data import Batch, Dataset, Data
from torch_geometric.data.data import BaseData
from torch_geometric.nn import GATv2Conv
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import torch
import random
from torch_geometric.data import Dataset
from graph.build import create_family_tree, create_data_object

class FamilyGraphDataset(Dataset):
    """
    Dataset class for generating family graph data.

    Args:
        root (str): Root directory path.
        number_of_graphs (int): Number of graphs to generate.
        generations (int): Number of generations in each family tree.

    Returns:
        Data(x=[8, 2], edge_index=[2, 20], edge_attr=[20], labels=[8])
    """
    def __init__(self, root, number_of_graphs, generations, transform=None, pre_transform=None):
        self.number_of_graphs = number_of_graphs
        self.generations = generations
        super(FamilyGraphDataset, self).__init__(root, transform, pre_transform)
        self.data = None
        self.process()

    @property
    def processed_file_names(self):
        return ['family_graphs.pt']

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]
    
    def generate_labels(self, num_nodes):
        target_node_idx = random.randint(0, num_nodes - 1)
        return target_node_idx
    
    def process(self):
        if not os.path.isfile(self.processed_paths[0]):
            self.data = []
            for _ in range(self.number_of_graphs):
                family_tree = create_family_tree(self.generations)
                graph_data = create_data_object(family_tree)

                # Generate random labels for each node
                target_node_idx = self.generate_labels(graph_data.num_nodes)

                # Store the labels as an attribute of the graph_data
                graph_data.target_node_idx = target_node_idx

                self.data.append(graph_data)

            torch.save(self.data, self.processed_paths[0])
        else:
            self.data = torch.load(self.processed_paths[0])

In [3]:
dataset = FamilyGraphDataset(root='/Users/meeslindeman/Library/Mobile Documents/com~apple~CloudDocs/Thesis/Code/data', number_of_graphs=10000, generations=3)
print(dataset[0])
print(dataset[1])

Data(x=[18, 2], edge_index=[2, 42], edge_attr=[42], target_node_idx=3)
Data(x=[5, 2], edge_index=[2, 12], edge_attr=[12], target_node_idx=2)


In [4]:
from torch_geometric.utils import scatter

class GAT(torch.nn.Module):
    def __init__(self, num_node_features, embedding_size):
        super().__init__()
        self.conv1 = GATv2Conv(num_node_features, embedding_size, edge_dim=1, heads=1, concat=True)
        self.conv2 = GATv2Conv(-1, embedding_size, edge_dim=1, heads=1, concat=True)

    def forward(self, data):
        h, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        h = self.conv1(x=h, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.relu(h)

        h = self.conv2(x=h, edge_index=edge_index, edge_attr=edge_attr)     
        h = F.relu(h) 

        x = F.dropout(h, training=self.training)
        x = scatter(x, data.batch, dim=0, reduce='sum')  # size: data.num_graphs * n_hidden
        return x

In [5]:
class Sender(nn.Module):
    def __init__(self, num_node_features, n_hidden, game_size, temperature):
        super(Sender, self).__init__()
        self.temp = temperature
        self.game_size = game_size
        self.gcn = GAT(num_node_features, n_hidden)

        self.fc1 = nn.Linear(n_hidden, n_hidden)

    def forward(self, x, _aux_input):
        data = _aux_input
        assert data.num_graphs % self.game_size == 0
        x = self.gcn(data)[::self.game_size]  # we just need the target graph, hence we take only the first graph of every game_size graphs
        x = self.fc1(x)
        return x  # size: n_games * n_hidden  (note: n_games = batch_size // game_size)

class Receiver(nn.Module):
    def __init__(self, num_node_features, n_hidden, game_size):
        super().__init__()
        self.game_size = game_size
        self.gcn = GAT(num_node_features, n_hidden)
        
        self.fc1 = nn.Linear(n_hidden, n_hidden)

    def forward(self, x, _input, _aux_input):
        # x is the tensor of shape n_games * n_hidden -- each row is an embedding decoded from the message sent by the sender
        cands = self.gcn(_aux_input)  # graph embeddings for all batch_size graphs; size: batch_size * n_hidden
        cands = cands.view(cands.shape[0] // self.game_size, self.game_size, -1)  # size: n_games * game_size * n_hidden 
        dots = torch.matmul(cands, torch.unsqueeze(x, dim=-1))  # size: n_games * game_size * 1
        return dots.squeeze()  # size: n_games * game_size: each row is a list of scores for a game (each score tells how good the corresponding candidate is) 

In [6]:
class Collater:
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        self.game_size = game_size
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        elem = batch[0]
        if isinstance(elem, BaseData):
            batch = batch[:((len(batch) // self.game_size) * self.game_size)]  # we throw away the last batch_size % game_size
            batch = Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
            # we return a tuple (sender_input, labels, receiver_input, aux_input)
            # we use aux_input to store minibatch of graphs
            return (
                torch.zeros(len(batch) // self.game_size, 1),  # we don't need sender_input --> create a fake one
                torch.zeros(len(batch) // self.game_size).long(),  # the target is aways the first graph among game_size graphs
                None,  # we don't care about receiver_input
                batch  # this is a compact data for batch_size graphs 
            )

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

    def collate_fn(self, batch: List[Any]) -> Any:
        if isinstance(self.dataset, OnDiskDataset):
            return self(self.dataset.multi_get(batch))
        return self(batch)


class DataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        self.game_size = game_size
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        self.collator = Collater(game_size, dataset, follow_batch, exclude_keys)

        if isinstance(dataset, OnDiskDataset):
            dataset = range(len(dataset))

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=self.collator.collate_fn,
            **kwargs,
        )

In [7]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and validation sets
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Print the lengths of the training and validation sets
print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))

train_loader = DataLoader(game_size=opts.game_size, dataset=train_data, batch_size=opts.batch_size, shuffle=True)
val_loader = DataLoader(game_size=opts.game_size, dataset=val_data, batch_size=opts.batch_size, shuffle=True)

Training set length: 8000
Validation set length: 2000


In [8]:
def loss(
    _sender_input,
    _message,
    _receiver_input,
    receiver_output,
    labels,
    _aux_input,
):
    acc = (receiver_output.argmax(dim=1) == labels).detach().float()
    loss = F.cross_entropy(receiver_output, labels, reduction="none")
    return loss, {"acc": acc}

In [9]:
# we create the two agents
receiver = Receiver(dataset[0].num_node_features, n_hidden=opts.hidden_size, game_size=opts.game_size)
sender = Sender(dataset[0].num_node_features, n_hidden=opts.hidden_size, game_size=opts.game_size, temperature=opts.gs_tau)

sender_w = core.RnnSenderReinforce(
    sender,
    vocab_size=opts.vocab_size,
    embed_dim=opts.embedding_size,
    hidden_size=opts.hidden_size,
    cell=opts.sender_cell,
    max_len=opts.max_len
)
receiver_w = core.RnnReceiverDeterministic(
    receiver,
    vocab_size=opts.vocab_size,
    embed_dim=opts.embedding_size,
    hidden_size=opts.hidden_size,
    cell=opts.sender_cell
)

game = core.SenderReceiverRnnReinforce(sender_w, receiver_w, loss, receiver_entropy_coeff=0)

opts = core.init(params=['--random_seed=7', 
                         '--lr=1e-2',   
                         f'--batch_size={opts.batch_size}',
                         '--optimizer=adam'])

optimizer = torch.optim.Adam(game.parameters())

trainer = core.Trainer(
    game=game, 
    optimizer=optimizer, 
    train_data=train_loader,
    validation_data=val_loader, 
    callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True)]
)

In [10]:
trainer.train(n_epochs=10)

{"loss": 250.5278778076172, "acc": 0.04749999940395355, "sender_entropy": 2.1901023387908936, "receiver_entropy": 0.0, "length": 1.9850000143051147, "mode": "train", "epoch": 1}
{"loss": 204.7561492919922, "acc": 0.07999999821186066, "sender_entropy": 2.150277614593506, "receiver_entropy": 0.0, "length": 2.0, "mode": "test", "epoch": 1}
{"loss": 185.22689819335938, "acc": 0.08250000327825546, "sender_entropy": 2.143399715423584, "receiver_entropy": 0.0, "length": 1.9774999618530273, "mode": "train", "epoch": 2}
{"loss": 100.58894348144531, "acc": 0.019999999552965164, "sender_entropy": 2.1419339179992676, "receiver_entropy": 0.0, "length": 2.0, "mode": "test", "epoch": 2}
{"loss": 122.75272369384766, "acc": 0.17749999463558197, "sender_entropy": 2.162497043609619, "receiver_entropy": 0.0, "length": 1.9824999570846558, "mode": "train", "epoch": 3}
{"loss": 76.38700866699219, "acc": 0.05000000074505806, "sender_entropy": 2.159968137741089, "receiver_entropy": 0.0, "length": 2.0, "mode": 