# Game
We create a simple game in which 
- There are two agents (sender and receiver)
- In each game, the sender is given a target graph, the receiver is given the target graph and some distractor graphs
- The sender generates a message describing the whole target graph, and the receiver needs to point out which graph is the target (target graph vs distractor graphs)

In this game, our graphs are from the `MNIST` dataset, which is for graph classification. (See [here](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.GNNBenchmarkDataset.html))

We start with importing libs

In [1]:
import argparse

import torch
import torch.nn.functional as F
import egg.core as core
import torch.nn as nn
from torch_geometric.datasets import GNNBenchmarkDataset
from torch_geometric.nn import GCNConv
from torch_geometric.utils import scatter

from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset

  from .autonotebook import tqdm as notebook_tqdm


# Data
Each game has a target graph and a number of distractor graphs. In total there are `game_size` graphs. We stack `game_size` graphs as bellow: 
- the first graph is always the target (note: as long as we don't make use of position, it doesn't matter where the target graph is located in the stack)
- the next `game_size - 1` graphs are distractors

A minibatch is a list of `batch_size` graphs in which the first `game_size` graphs are for the first game, the second `game_size` graphs are for the second game, and so on. Therefore, a minibatch contains graphs for `batch_size // game_size` games. (Note: we will ignore the last `batch_size % game_size` graphs). 

We will make use of the mini-batching mechanism of pytorch_geometric [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#mini-batches). The below is two classes for data loading. They are modification of the classes from pytorch_geometric (see the comments for the differences).

In [15]:
class Collater:
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        self.game_size = game_size
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        # Check if the batch is smaller than game_size
        if len(batch) < self.game_size:
            return print("Test")  # Or return an appropriate placeholder
        
        elem = batch[0]
        if isinstance(elem, BaseData):
            batch = batch[:((len(batch) // self.game_size) * self.game_size)]  # we throw away the last batch_size % game_size
            batch = Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
            # we return a tuple (sender_input, labels, receiver_input, aux_input)
            # we use aux_input to store minibatch of graphs
            return (
                torch.zeros(len(batch) // self.game_size, 1),  # we don't need sender_input --> create a fake one
                torch.zeros(len(batch) // self.game_size).long(),  # the target is aways the first graph among game_size graphs
                None,  # we don't care about receiver_input
                batch  # this is a compact data for batch_size graphs 
            )

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

    def collate_fn(self, batch: List[Any]) -> Any:
        if isinstance(self.dataset, OnDiskDataset):
            return self(self.dataset.multi_get(batch))
        return self(batch)


class DataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        game_size: int,  # the number of graphs for a game
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        self.game_size = game_size
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        self.collator = Collater(game_size, dataset, follow_batch, exclude_keys)

        if isinstance(dataset, OnDiskDataset):
            dataset = range(len(dataset))

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=self.collator.collate_fn,
            **kwargs,
        )

# Agents 
Now we declare agent classes. Firstly, we need to build graph neural networks. Here for simplicity we use one graph-conv layer (like [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#learning-methods-on-graphs)). The class computes a global graph embedding for each graph.

In [3]:
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, n_hidden):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, n_hidden)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # check number of graphs via: data.num_graphs
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = scatter(x, data.batch, dim=0, reduce='mean')  # size: data.num_graphs * n_hidden
        return x

Note the line `x = scatter(x, data.batch, dim=0, reduce='mean')` -- which is to calculate graph embeddings: the graph embedding is the mean of the node embeddings. See [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#mini-batches) for why using `scatter`.

The two agents (sender and receiver) each has one GCN. Because we have to follow the notations of EGG, we will make use of `aux_input` to send a minibatch of graphs to them.

In [4]:
class Sender(nn.Module):
    def __init__(self, num_node_features, n_hidden, game_size):
        super().__init__()
        self.game_size = game_size
        self.gcn = GCN(num_node_features, n_hidden)
        self.fc1 = nn.Linear(n_hidden, n_hidden)

    def forward(self, x, _aux_input):
        # _aux_input is a minibatch of n_games x game_size graphs
        # we don't care about x
        data = _aux_input
        assert data.num_graphs % self.game_size == 0
        x = self.gcn(data)[::self.game_size]  # we just need the target graph, hence we take only the first graph of every game_size graphs
        return self.fc1(x)  # size: n_games * n_hidden  (note: n_games = batch_size // game_size)


class Receiver(nn.Module):
    def __init__(self, num_node_features, n_hidden, game_size):
        super().__init__()
        self.game_size = game_size
        self.gcn = GCN(num_node_features, n_hidden)
        self.fc1 = nn.Linear(n_hidden, n_hidden)

    def forward(self, x, _input, _aux_input):
        # x is the tensor of shape n_games * n_hidden -- each row is an embedding decoded from the message sent by the sender
        cands = self.gcn(_aux_input)  # graph embeddings for all batch_size graphs; size: batch_size * n_hidden
        cands = cands.view(cands.shape[0] // self.game_size, self.game_size, -1)  # size: n_games * game_size * n_hidden 
        dots = torch.matmul(cands, torch.unsqueeze(x, dim=-1))  # size: n_games * game_size * 1
        print(dots.squeeze().shape)
        return dots.squeeze()  # size: n_games * game_size: each row is a list of scores for a game (each score tells how good the corresponding candidate is) 

In [5]:
from torch_geometric.data import Dataset
from graph.build import create_family_tree, create_data_object
import random
import os

class FamilyGraphDataset(Dataset):
    """
    Dataset class for generating family graph data.

    Args:
        root (str): Root directory path.
        number_of_graphs (int): Number of graphs to generate.
        generations (int): Number of generations in each family tree.

    Returns:
        Data(x=[8, 2], edge_index=[2, 20], edge_attr=[20], labels=[8])
    """
    def __init__(self, root, number_of_graphs, generations, transform=None, pre_transform=None):
        self.number_of_graphs = number_of_graphs
        self.generations = generations
        super(FamilyGraphDataset, self).__init__(root, transform, pre_transform)
        self.data = None
        self.process()

    @property
    def processed_file_names(self):
        return ['family_graphs.pt']

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]
    
    def generate_labels(self, num_nodes):
        target_node_idx = random.randint(0, num_nodes - 1)
        return target_node_idx
    
    def process(self):
        if not os.path.isfile(self.processed_paths[0]):
            self.data = []
            for _ in range(self.number_of_graphs):
                family_tree = create_family_tree(self.generations)
                graph_data = create_data_object(family_tree)

                # Generate random labels for each node
                target_node_idx = self.generate_labels(graph_data.num_nodes)

                # Store the labels as an attribute of the graph_data
                graph_data.target_node_idx = target_node_idx

                self.data.append(graph_data)

            torch.save(self.data, self.processed_paths[0])
        else:
            self.data = torch.load(self.processed_paths[0])

# Training 
We reuse code from [here](https://github.com/facebookresearch/EGG/blob/main/egg/zoo/basic_games/README.md#discrimination-game) from EGG. 

In [16]:
from sklearn.model_selection import train_test_split

def get_params(params):
    parser = argparse.ArgumentParser()

    # arguments concerning the training method
    parser.add_argument(
        "--mode",
        type=str,
        default="rf",
        help="Selects whether Reinforce or Gumbel-Softmax relaxation is used for training {rf, gs} (default: rf)",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="GS temperature for the sender, only relevant in Gumbel-Softmax (gs) mode (default: 1.0)",
    )
    # arguments concerning the agent architectures
    parser.add_argument(
        "--sender_cell",
        type=str,
        default="rnn",
        help="Type of the cell used for Sender {rnn, gru, lstm} (default: rnn)",
    )
    parser.add_argument(
        "--receiver_cell",
        type=str,
        default="rnn",
        help="Type of the cell used for Receiver {rnn, gru, lstm} (default: rnn)",
    )
    parser.add_argument(
        "--sender_hidden",
        type=int,
        default=10,
        help="Size of the hidden layer of Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_hidden",
        type=int,
        default=10,
        help="Size of the hidden layer of Receiver (default: 10)",
    )
    parser.add_argument(
        "--sender_embedding",
        type=int,
        default=10,
        help="Output dimensionality of the layer that embeds symbols produced at previous step in Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_embedding",
        type=int,
        default=10,
        help="Output dimensionality of the layer that embeds the message symbols for Receiver (default: 10)",
    )
    # arguments controlling the script output
    parser.add_argument(
        "--print_validation_events",
        default=False,
        action="store_true",
        help="If this flag is passed, at the end of training the script prints the input validation data, the corresponding messages produced by the Sender, and the output probabilities produced by the Receiver (default: do not print)",
    )
    parser.add_argument(
        "--game_size",
        type=int,
        default=4,
        help="The number of graphs in a game (including a target and distractors) (default: 4)",
    )
    args = core.init(parser, params)
    return args


def main(params):
    opts = get_params(params)
    print(opts, flush=True)
    game_size = opts.game_size

    # we care about the communication success: the accuracy that the receiver can distinguish the target from distractors
    def loss(
        _sender_input,
        _message,
        _receiver_input,
        receiver_output,
        labels,
        _aux_input,
    ):
        acc = (receiver_output.argmax(dim=1) == labels).detach().float()
        loss = F.cross_entropy(receiver_output, labels, reduction="none")
        return loss, {"acc": acc}

    dataset = FamilyGraphDataset(root='/Users/meeslindeman/Library/Mobile Documents/com~apple~CloudDocs/Thesis/Code/data', number_of_graphs=2048, generations=3)

    # Split the dataset into training and validation sets
    train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

    train_loader = DataLoader(game_size=opts.game_size, dataset=train_data, batch_size=opts.batch_size, shuffle=True)
    val_loader = DataLoader(game_size=opts.game_size, dataset=val_data, batch_size=opts.batch_size, shuffle=True)
    
    # we create dataset and dataloader
    # dataname = 'MNIST'
    # root = f'/tmp/{dataname}'
    # train_dataset = GNNBenchmarkDataset(root=root, name=dataname, split='train')
    # train_loader = DataLoader(game_size, train_dataset, batch_size=opts.batch_size, shuffle=True)
    # val_dataset = GNNBenchmarkDataset(root=root, name=dataname, split='val')
    # val_loader = DataLoader(game_size, val_dataset, batch_size=opts.batch_size, shuffle=True)

    # we create the two agents
    receiver = Receiver(dataset[0].num_node_features, n_hidden=opts.receiver_hidden, game_size=game_size)
    sender = Sender(dataset[0].num_node_features, n_hidden=opts.sender_hidden, game_size=game_size)

    sender = core.RnnSenderReinforce(
        sender,
        vocab_size=opts.vocab_size,
        embed_dim=opts.sender_embedding,
        hidden_size=opts.sender_hidden,
        cell=opts.sender_cell,
        max_len=opts.max_len,
    )
    receiver = core.RnnReceiverDeterministic(
        receiver,
        vocab_size=opts.vocab_size,
        embed_dim=opts.receiver_embedding,
        hidden_size=opts.receiver_hidden,
        cell=opts.receiver_cell,
    )
    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        receiver_entropy_coeff=0,
    )
    callbacks = []

    optimizer = core.build_optimizer(game.parameters())
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=val_loader,
        callbacks=callbacks
        + [core.ConsoleLogger(print_train_loss=True, as_json=True)],
    )

    trainer.train(n_epochs=opts.n_epochs)

Before running the training, we need to modify the EGG's code a bit. Open `[your_dir]/site-packages/egg/core/interaction.py`, change line 209 to `aux_input[k] = None`. 

Let's run the game. Notice that the accuracy on the val set increases. (You should keep in mind that the accuracy is *not* for graph classification. Instead it is for how well the receiver can distinguish the target from distractors.)

In [17]:
main(["--n_epochs", "1"])

Namespace(mode='rf', temperature=1.0, sender_cell='rnn', receiver_cell='rnn', sender_hidden=10, receiver_hidden=10, sender_embedding=10, receiver_embedding=10, print_validation_events=False, game_size=4, random_seed=821793608, checkpoint_dir=None, preemptable=False, checkpoint_freq=0, validation_freq=1, n_epochs=1, load_from_checkpoint=None, no_cuda=True, batch_size=32, optimizer='adam', lr=0.01, update_freq=1, vocab_size=10, max_len=1, tensorboard=False, tensorboard_dir='runs/', distributed_port=18363, fp16=False, cuda=False, device=device(type='cpu'), distributed_context=DistributedContext(is_distributed=False, rank=0, local_rank=0, world_size=1, mode='none'))
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.Size([8, 4])
torch.

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)