# EGG Progress

In [14]:
import egg.core as core
from torch.nn import functional as F
from graph.dataset import FamilyGraphDataset
from agents import Sender, Receiver
from options import Options
from dataloaders import create_data_loaders

options = Options()

A Dataset is created by using the graph generating functions from "graph". The output is similar to the datasets used in PyTorch Geometric. A single graph Data object also has corresponding labels, which is a tensor of [num_nodes x 0] where one position is 1 (e.g. tensor([0, 0, 0, 1, 0])) for 5 nodes with target node 4. This is needed by the sender because I want it to send a description of the target node. It is later also needed in the loss function.

In [16]:
dataset = FamilyGraphDataset(root='/Users/meeslindeman/Library/Mobile Documents/com~apple~CloudDocs/Thesis/Code/data', number_of_graphs=100, generations=3)
graph = dataset[0]

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print()
print(graph)
print('=============================================================')
# Gather some statistics about the first graph.
print(f'Number of nodes: {graph.num_nodes}')
print(f'Number of edges: {graph.num_edges}')
print(f'Labels: {graph.labels}')


Dataset: FamilyGraphDataset(100):
Number of graphs: 100
Number of features: 2

Data(x=[9, 2], edge_index=[2, 22], edge_attr=[22], labels=[9])
Number of nodes: 9
Number of edges: 22
Labels: tensor([0, 1, 0, 0, 0, 0, 0, 0, 0])


Processing...
Done!


Batching is done with dataloaders from dataloaders.py. They split the dataset iby 0.8/0.2 for training/testing. Batch size 3 therefore outputs 3 batches, totaling 8 graphs.

In [20]:
train_loader, test_loader = create_data_loaders(dataset, batch_size=3)

total_batches = len(train_loader)
print("Total Batches:", total_batches)

for batch in train_loader:
    print(f'Number of graphs in the current batch: {batch.num_graphs}')
    print(batch)
    break

Total Batches: 27
Number of graphs in the current batch: 3
DataBatch(x=[18, 2], edge_index=[2, 46], edge_attr=[46], labels=[18], batch=[18], ptr=[4])


The sender and receiver are instiated, embedding size is used for the size of the graph embedding. The final Linear layer will give output in hidden size as is expected by the EGG wrappers (see agents.py).

In [21]:
sender = Sender(embedding_size=options.embedding_size, hidden_size=options.hidden_size, temperature=options.temp) 
print(sender)
receiver = Receiver(embedding_size=options.embedding_size, hidden_size=options.hidden_size) 
print(receiver)

Sender(
  (conv1): GATv2Conv(2, 32, heads=2)
  (conv2): GATv2Conv(-1, 32, heads=2)
  (fc): Linear(in_features=64, out_features=16, bias=True)
)
Receiver(
  (conv1): GATv2Conv(2, 32, heads=2)
  (conv2): GATv2Conv(-1, 32, heads=2)
  (fc): Linear(in_features=16, out_features=64, bias=True)
)


Then we wrap the agents:

In [22]:
sender_wrapped = core.RnnSenderGS(sender, options.vocab_size, options.embedding_size, options.hidden_size, max_len=5, temperature=1.0, cell=options.sender_cell)
sender_gs = core.GumbelSoftmaxWrapper(sender, temperature=1.0)
print(sender_wrapped)

receiver_wrapped = core.RnnReceiverGS(receiver, options.vocab_size, options.embedding_size, options.hidden_size, cell=options.sender_cell)
receiver_gs = core.SymbolReceiverWrapper(receiver, vocab_size=options.vocab_size, agent_input_size=options.hidden_size)
print(receiver_wrapped)

RnnSenderGS(
  (agent): Sender(
    (conv1): GATv2Conv(2, 32, heads=2)
    (conv2): GATv2Conv(-1, 32, heads=2)
    (fc): Linear(in_features=64, out_features=16, bias=True)
  )
  (hidden_to_output): Linear(in_features=16, out_features=20, bias=True)
  (embedding): Linear(in_features=20, out_features=32, bias=True)
  (cell): GRUCell(32, 16)
)
RnnReceiverGS(
  (agent): Receiver(
    (conv1): GATv2Conv(2, 32, heads=2)
    (conv2): GATv2Conv(-1, 32, heads=2)
    (fc): Linear(in_features=16, out_features=64, bias=True)
  )
  (cell): GRUCell(32, 16)
  (embedding): Linear(in_features=20, out_features=32, bias=True)
)


The output of the receiver is of shape (nodes x max_len+1 x num_classes). Num_classes is 2 since I want the receiver to ouput probabilities for each class (target node or no target node). I do this because the loss function that is eventually used requires receiver output and labels to calculate the loss. Previously I had the receiver output one probabilty per node but this didn't work.

In [24]:
# Sender produces a message
sender_output = sender_wrapped(graph)
#print("Sender's message:", sender_output)
print("Sender's shape:", sender_output.shape) # batch size x max_len+1 x vocab size

# Receiver tries to identify the target node
receiver_output = receiver_wrapped(sender_output, graph)
#print("Receiver's output:", receiver_output)
print("Receiver's shape:", receiver_output.shape) # nodes x max_len+1 x num_classes

Sender's shape: torch.Size([1, 6, 20])
Receiver's shape: torch.Size([9, 6, 2])


As you can see in the loss function, we need the receiver output and labels to calculate the loss. The interaction, however, gives problems as you can see later on.

In [25]:
def loss_nll(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input):
    """
    NLL loss - differentiable and can be used with both GS and Reinforce
    """
    nll = F.nll_loss(receiver_output, labels, reduction="none")
    acc = (labels == receiver_output.argmax(dim=1)).float().mean()
    return nll, {"acc": acc}

game = core.SenderReceiverRnnGS(sender_wrapped, receiver_wrapped, loss_nll)

loss, interaction = game(sender_input=graph, labels=graph.labels, receiver_input=graph, aux_input=None)
print(loss)
print("====================================")
print(interaction)

tensor(-0.1111, grad_fn=<MeanBackward0>)
Interaction(sender_input=Data(x=[9, 2], edge_index=[2, 22], edge_attr=[22], labels=[9]), receiver_input=Data(x=[9, 2], edge_index=[2, 22], edge_attr=[22], labels=[9]), labels=tensor([0, 1, 0, 0, 0, 0, 0, 0, 0]), aux_input=None, message=tensor([[[9.5213e-03, 7.4491e-03, 2.2523e-03, 6.8643e-02, 7.9756e-03,
          5.2132e-03, 5.1704e-02, 2.5034e-04, 8.4826e-02, 3.3761e-03,
          5.6289e-03, 4.5174e-03, 6.8774e-04, 3.9150e-02, 6.3565e-02,
          5.2750e-03, 1.4432e-03, 2.5779e-04, 3.9017e-01, 2.4809e-01],
         [5.2824e-03, 4.1476e-02, 7.9926e-03, 2.7547e-01, 3.0722e-02,
          2.6026e-03, 1.5556e-01, 3.4094e-03, 5.9376e-02, 9.7753e-02,
          6.4214e-02, 9.1295e-02, 3.9879e-04, 7.0249e-03, 1.7666e-02,
          5.0102e-03, 1.1797e-03, 5.7303e-03, 3.4935e-02, 9.2902e-02],
         [4.3903e-02, 2.2493e-02, 1.7936e-03, 3.8288e-02, 1.9507e-02,
          8.2004e-03, 7.8642e-02, 7.0019e-03, 2.2240e-02, 1.5049e-02,
          1.6740e-02,

I had to modify the Batch class that is used in the egg.Trainer function because it used to unwrap an expected batch with tensors, which Phong showed in his method. I modified it to handle the DataBatch object by Torch Geometric differently. Now it unwraps the graphs and the labels. I need the receiver to handle the same graph (we spoke about masking earlier but this gave so many problem that I held it off for now), so the sender input and receiver input are the same.

In [26]:
from typing import Any, Dict, Optional

import torch

from egg.core.util import move_to

from torch_geometric.data import Batch


class CustomBatch:
    def __init__(
        self,
        data_batch: Batch,
        labels: Optional[torch.Tensor] = None,
        receiver_input: Optional[Any] = None,
        aux_input: Optional[Dict[Any, Any]] = None,

    ):
        self.data_batch = data_batch
        for i in range(len(self.data_batch)):
            self.sender_input = data_batch.get_example(i)
            self.labels = data_batch.labels[:self.sender_input.num_nodes]
            self.receiver_input = data_batch.get_example(i)
            self.aux_input = None

    def __getitem__(self, idx):
        """
        >>> b = Batch(torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3]), {})
        >>> b[0]
        tensor([1.])
        >>> b[1]
        tensor([2.])
        >>> b[2]
        tensor([3.])
        >>> b[3]
        {}
        >>> b[6]
        Traceback (most recent call last):
            ...
        IndexError: Trying to access a wrong index in the batch
        """
        if idx == 0:
            return self.sender_input
        elif idx == 1:
            return self.labels
        elif idx == 2:
            return self.receiver_input
        elif idx == 3:
            return self.aux_input
        else:
            raise IndexError("Trying to access a wrong index in the batch")

    def __iter__(self):
        """
        >>> _ = torch.manual_seed(111)
        >>> sender_input = torch.rand(2, 2)
        >>> labels = torch.rand(2, 2)
        >>> batch = Batch(sender_input, labels)
        >>> it = batch.__iter__()
        >>> it_sender_input = next(it)
        >>> torch.allclose(sender_input, it_sender_input)
        True
        >>> it_labels = next(it)
        >>> torch.allclose(labels, it_labels)
        True
        """
        return iter(
            [self.sender_input, self.labels, self.receiver_input, self.aux_input]
        )

    def to(self, device: torch.device):
        """Method to move all (nested) tensors of the batch to a specific device.
        This operation doest not change the original batch element and returns a new Batch instance.
        """
        self.sender_input = move_to(self.sender_input, device)
        self.labels = move_to(self.labels, device)
        self.receiver_input = move_to(self.receiver_input, device)
        self.aux_input = move_to(self.aux_input, device)
        return self

In [27]:
print("Batch type as it is passed to the CustomBatch class:")
print("========================================")
print(batch)
print()

print("Batch information after it is passed to the CustomBatch class:")
print("========================================")
class_output = CustomBatch(batch)
print("Sender input:", class_output.sender_input)
print("Labels:", class_output.labels)
print("Receiver input:", class_output.receiver_input)
print("Aux input:", class_output.aux_input)
print()

print("Output of the game (as it is called in the train function):")
print("========================================")
loss, interaction = game(*class_output)
print(loss)
print(interaction)

Batch type as it is passed to the CustomBatch class:
DataBatch(x=[18, 2], edge_index=[2, 46], edge_attr=[46], labels=[18], batch=[18], ptr=[4])

Batch information after it is passed to the CustomBatch class:
Sender input: Data(x=[5, 2], edge_index=[2, 12], edge_attr=[12], labels=[5])
Labels: tensor([0, 0, 1, 0, 0])
Receiver input: Data(x=[5, 2], edge_index=[2, 12], edge_attr=[12], labels=[5])
Aux input: None

Output of the game (as it is called in the train function):
tensor(-0.2000, grad_fn=<MeanBackward0>)
Interaction(sender_input=Data(x=[5, 2], edge_index=[2, 12], edge_attr=[12], labels=[5]), receiver_input=Data(x=[5, 2], edge_index=[2, 12], edge_attr=[12], labels=[5]), labels=tensor([0, 0, 1, 0, 0]), aux_input=None, message=tensor([[[2.0549e-02, 1.1745e-02, 3.7538e-04, 4.5080e-01, 5.9699e-03,
          2.2580e-03, 8.0914e-02, 2.4243e-04, 6.1061e-03, 9.8548e-03,
          7.7322e-03, 7.4713e-04, 1.0265e-04, 3.5962e-02, 2.3592e-01,
          2.2852e-02, 6.2252e-04, 1.9337e-03, 9.3854

The Trainer function also has to be modified to pass the batch type through the right Batch class, which is now CustomBatch.

In [28]:
# From EGG:

import os
import pathlib
from typing import List, Optional

try:
    # requires python >= 3.7
    from contextlib import nullcontext
except ImportError:
    # not exactly the same, but will do for our purposes
    from contextlib import suppress as nullcontext

import torch
from torch.utils.data import DataLoader

from egg.core.batch import Batch
from egg.core.callbacks import (
    Callback,
    Checkpoint,
    CheckpointSaver,
    ConsoleLogger,
    TensorboardLogger,
)
from egg.core.distributed import get_preemptive_checkpoint_dir
from egg.core.interaction import Interaction
from egg.core.util import get_opts, move_to

try:
    from torch.cuda.amp import GradScaler, autocast
except ImportError:
    pass

class CustomTrainer:
    """
    Implements the training logic. Some common configuration (checkpointing frequency, path, validation frequency)
    is done by checking util.common_opts that is set via the CL.
    """

    def __init__(
        self,
        game: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_data: DataLoader,
        optimizer_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        validation_data: Optional[DataLoader] = None,
        device: torch.device = None,
        callbacks: Optional[List[Callback]] = None,
        grad_norm: float = None,
        aggregate_interaction_logs: bool = True,
    ):
        """
        :param game: A nn.Module that implements forward(); it is expected that forward returns a tuple of (loss, d),
            where loss is differentiable loss to be minimized and d is a dictionary (potentially empty) with auxiliary
            metrics that would be aggregated and reported
        :param optimizer: An instance of torch.optim.Optimizer
        :param optimizer_scheduler: An optimizer scheduler to adjust lr throughout training
        :param train_data: A DataLoader for the training set
        :param validation_data: A DataLoader for the validation set (can be None)
        :param device: A torch.device on which to tensors should be stored
        :param callbacks: A list of egg.core.Callback objects that can encapsulate monitoring or checkpointing
        """
        self.game = game
        self.optimizer = optimizer
        self.optimizer_scheduler = optimizer_scheduler
        self.train_data = train_data
        self.validation_data = validation_data
        common_opts = get_opts()
        self.validation_freq = common_opts.validation_freq
        self.device = common_opts.device if device is None else device

        self.should_stop = False
        self.start_epoch = 0  # Can be overwritten by checkpoint loader
        self.callbacks = callbacks if callbacks else []
        self.grad_norm = grad_norm
        self.aggregate_interaction_logs = aggregate_interaction_logs

        self.update_freq = common_opts.update_freq

        if common_opts.load_from_checkpoint is not None:
            print(
                f"# Initializing model, trainer, and optimizer from {common_opts.load_from_checkpoint}"
            )
            self.load_from_checkpoint(common_opts.load_from_checkpoint)

        self.distributed_context = common_opts.distributed_context
        if self.distributed_context.is_distributed:
            print("# Distributed context: ", self.distributed_context)

        if self.distributed_context.is_leader and not any(
            isinstance(x, CheckpointSaver) for x in self.callbacks
        ):
            if common_opts.preemptable:
                assert (
                    common_opts.checkpoint_dir
                ), "checkpointing directory has to be specified"
                d = get_preemptive_checkpoint_dir(common_opts.checkpoint_dir)
                self.checkpoint_path = d
                self.load_from_latest(d)
            else:
                self.checkpoint_path = (
                    None
                    if common_opts.checkpoint_dir is None
                    else pathlib.Path(common_opts.checkpoint_dir)
                )

            if self.checkpoint_path:
                checkpointer = CheckpointSaver(
                    checkpoint_path=self.checkpoint_path,
                    checkpoint_freq=common_opts.checkpoint_freq,
                )
                self.callbacks.append(checkpointer)

        if self.distributed_context.is_leader and common_opts.tensorboard:
            assert (
                common_opts.tensorboard_dir
            ), "tensorboard directory has to be specified"
            tensorboard_logger = TensorboardLogger()
            self.callbacks.append(tensorboard_logger)

        if self.callbacks is None:
            self.callbacks = [
                ConsoleLogger(print_train_loss=False, as_json=False),
            ]

        if self.distributed_context.is_distributed:
            device_id = self.distributed_context.local_rank
            torch.cuda.set_device(device_id)
            self.game.to(device_id)

            # NB: here we are doing something that is a bit shady:
            # 1/ optimizer was created outside of the Trainer instance, so we don't really know
            #    what parameters it optimizes. If it holds something what is not within the Game instance
            #    then it will not participate in distributed training
            # 2/ if optimizer only holds a subset of Game parameters, it works, but somewhat non-documentedly.
            #    In fact, optimizer would hold parameters of non-DistributedDataParallel version of the Game. The
            #    forward/backward calls, however, would happen on the DistributedDataParallel wrapper.
            #    This wrapper would sync gradients of the underlying tensors - which are the ones that optimizer
            #    holds itself.  As a result it seems to work, but only because DDP doesn't take any tensor ownership.

            self.game = torch.nn.parallel.DistributedDataParallel(
                self.game,
                device_ids=[device_id],
                output_device=device_id,
                find_unused_parameters=True,
            )
            self.optimizer.state = move_to(self.optimizer.state, device_id)

        else:
            self.game.to(self.device)
            # NB: some optimizers pre-allocate buffers before actually doing any steps
            # since model is placed on GPU within Trainer, this leads to having optimizer's state and model parameters
            # on different devices. Here, we protect from that by moving optimizer's internal state to the proper device
            self.optimizer.state = move_to(self.optimizer.state, self.device)

        if common_opts.fp16:
            self.scaler = GradScaler()
        else:
            self.scaler = None

    def eval(self, data=None):
        mean_loss = 0.0
        interactions = []
        n_batches = 0
        validation_data = self.validation_data if data is None else data
        self.game.eval()
        with torch.no_grad():
            for batch in validation_data:
                if not isinstance(batch, Batch):
                    batch = CustomBatch(batch) # MODIFIED
                batch = batch.to(self.device)
                optimized_loss, interaction = self.game(*batch)
                if (
                    self.distributed_context.is_distributed
                    and self.aggregate_interaction_logs
                ):
                    interaction = Interaction.gather_distributed_interactions(
                        interaction
                    )
                interaction = interaction.to("cpu")
                mean_loss += optimized_loss

                for callback in self.callbacks:
                    callback.on_batch_end(
                        interaction, optimized_loss, n_batches, is_training=False
                    )

                interactions.append(interaction)
                n_batches += 1

        mean_loss /= n_batches
        full_interaction = Interaction.from_iterable(interactions)

        return mean_loss.item(), full_interaction

    def train_epoch(self):
        mean_loss = 0
        n_batches = 0
        interactions = []

        self.game.train()

        self.optimizer.zero_grad()

        for batch_id, batch in enumerate(self.train_data):
            if not isinstance(batch, Batch):
                batch = CustomBatch(batch) # MODIFIED
            batch = batch.to(self.device)


            context = autocast() if self.scaler else nullcontext()
            with context:
                optimized_loss, interaction = self.game(*batch)

                if self.update_freq > 1:
                    # throughout EGG, we minimize _mean_ loss, not sum
                    # hence, we need to account for that when aggregating grads
                    optimized_loss = optimized_loss / self.update_freq

            if self.scaler:
                self.scaler.scale(optimized_loss).backward()
            else:
                optimized_loss.backward()

            if batch_id % self.update_freq == self.update_freq - 1:
                if self.scaler:
                    self.scaler.unscale_(self.optimizer)

                if self.grad_norm:
                    torch.nn.utils.clip_grad_norm_(
                        self.game.parameters(), self.grad_norm
                    )
                if self.scaler:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()

                self.optimizer.zero_grad()

            n_batches += 1
            mean_loss += optimized_loss.detach()
            if (
                self.distributed_context.is_distributed
                and self.aggregate_interaction_logs
            ):
                interaction = Interaction.gather_distributed_interactions(interaction)
            interaction = interaction.to("cpu")

            for callback in self.callbacks:
                callback.on_batch_end(interaction, optimized_loss, batch_id)

            interactions.append(interaction)

        if self.optimizer_scheduler:
            self.optimizer_scheduler.step()

        mean_loss /= n_batches
        full_interaction = Interaction.from_iterable(interactions)
        return mean_loss.item(), full_interaction

    def train(self, n_epochs):
        for callback in self.callbacks:
            callback.on_train_begin(self)

        for epoch in range(self.start_epoch, n_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch + 1)

            train_loss, train_interaction = self.train_epoch()

            for callback in self.callbacks:
                callback.on_epoch_end(train_loss, train_interaction, epoch + 1)

            validation_loss = validation_interaction = None
            if (
                self.validation_data is not None
                and self.validation_freq > 0
                and (epoch + 1) % self.validation_freq == 0
            ):
                for callback in self.callbacks:
                    callback.on_validation_begin(epoch + 1)
                validation_loss, validation_interaction = self.eval()

                for callback in self.callbacks:
                    callback.on_validation_end(
                        validation_loss, validation_interaction, epoch + 1
                    )

            if self.should_stop:
                for callback in self.callbacks:
                    callback.on_early_stopping(
                        train_loss,
                        train_interaction,
                        epoch + 1,
                        validation_loss,
                        validation_interaction,
                    )
                break

        for callback in self.callbacks:
            callback.on_train_end()

    def load(self, checkpoint: Checkpoint):
        self.game.load_state_dict(checkpoint.model_state_dict)
        self.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
        if checkpoint.optimizer_scheduler_state_dict:
            self.optimizer_scheduler.load_state_dict(
                checkpoint.optimizer_scheduler_state_dict
            )
        self.start_epoch = checkpoint.epoch

    def load_from_checkpoint(self, path):
        """
        Loads the game, agents, and optimizer state from a file
        :param path: Path to the file
        """
        print(f"# loading trainer state from {path}")
        checkpoint = torch.load(path)
        self.load(checkpoint)

    def load_from_latest(self, path):
        latest_file, latest_time = None, None

        for file in path.glob("*.tar"):
            creation_time = os.stat(file).st_ctime
            if latest_time is None or creation_time > latest_time:
                latest_file, latest_time = file, creation_time

        if latest_file is not None:
            self.load_from_checkpoint(latest_file)

In [29]:
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import RelaxedOneHotCategorical

from egg.core.interaction import LoggingStrategy

class CustomSenderReceiverRnnGS(nn.Module):
    """
    This class implements the Sender/Receiver game mechanics for the Sender/Receiver game with variable-length
    communication messages and Gumber-Softmax relaxation of the channel. The vocabulary term with id `0` is assumed
    to the end-of-sequence symbol. It is assumed that communication is stopped either after all the message is processed
    or when the end-of-sequence symbol is met.

    >>> class Sender(nn.Module):
    ...     def __init__(self):
    ...         super().__init__()
    ...         self.fc = nn.Linear(10, 5)
    ...     def forward(self, x, _input=None, aux_input=None):
    ...         return self.fc(x)
    >>> sender = Sender()
    >>> sender = RnnSenderGS(sender, vocab_size=2, embed_dim=3, hidden_size=5, max_len=3, temperature=5.0, cell='gru')
    >>> class Receiver(nn.Module):
    ...     def __init__(self):
    ...         super().__init__()
    ...         self.fc = nn.Linear(7, 10)
    ...     def forward(self, x, _input=None, aux_input=None):
    ...         return self.fc(x)
    >>> receiver = RnnReceiverGS(Receiver(), vocab_size=2, embed_dim=4, hidden_size=7, cell='rnn')
    >>> def loss(sender_input, _message, _receiver_input, receiver_output, labels, aux_input):
    ...     return (sender_input - receiver_output).pow(2.0).mean(dim=1), {'aux': torch.zeros(sender_input.size(0))}
    >>> game = SenderReceiverRnnGS(sender, receiver, loss)
    >>> loss, interaction = game(torch.ones((3, 10)), None, None)  # batch of 3 10d vectors
    >>> interaction.aux['aux'].detach()
    tensor([0., 0., 0.])
    >>> loss.item() > 0
    True
    """

    def __init__(
        self,
        sender,
        receiver,
        loss,
        length_cost=0.0,
        train_logging_strategy: Optional[LoggingStrategy] = None,
        test_logging_strategy: Optional[LoggingStrategy] = None,
    ):
        """
        :param sender: sender agent
        :param receiver: receiver agent
        :param loss:  the optimized loss that accepts
            sender_input: input of Sender
            message: the is sent by Sender
            receiver_input: input of Receiver from the dataset
            receiver_output: output of Receiver
            labels: labels assigned to Sender's input data
          and outputs a tuple of (1) a loss tensor of shape (batch size, 1) (2) the dict with auxiliary information
          of the same shape. The loss will be minimized during training, and the auxiliary information aggregated over
          all batches in the dataset.
        :param length_cost: the penalty applied to Sender for each symbol produced
        :param train_logging_strategy, test_logging_strategy: specify what parts of interactions to persist for
            later analysis in the callbacks.

        """
        super(CustomSenderReceiverRnnGS, self).__init__()
        self.sender = sender
        self.receiver = receiver
        self.loss = loss
        self.length_cost = length_cost
        self.train_logging_strategy = (
            LoggingStrategy()
            if train_logging_strategy is None
            else train_logging_strategy
        )
        self.test_logging_strategy = (
            LoggingStrategy()
            if test_logging_strategy is None
            else test_logging_strategy
        )

    def forward(self, sender_input, labels, receiver_input=None, aux_input=None):
        message = self.sender(sender_input, aux_input)
        receiver_output = self.receiver(message, receiver_input, aux_input)

        loss = 0
        not_eosed_before = torch.ones(receiver_output.size(0)).to(
            receiver_output.device
        )
        expected_length = 0.0

        aux_info = {}
        z = 0.0
        for step in range(receiver_output.size(1)):
            step_loss, step_aux = self.loss(
                sender_input,
                message[:, step, ...],
                receiver_input,
                receiver_output[:, step, ...],
                labels,
                aux_input,
            )
            eos_mask = message[:, step, 0]  # always eos == 0

            add_mask = eos_mask * not_eosed_before
            z += add_mask
            loss += step_loss * add_mask + self.length_cost * (1.0 + step) * add_mask
            expected_length += add_mask.detach() * (1.0 + step)

            for name, value in step_aux.items():
                aux_info[name] = value * add_mask + aux_info.get(name, 0.0)

            not_eosed_before = not_eosed_before * (1.0 - eos_mask)

        # the remainder of the probability mass
        loss += (
            step_loss * not_eosed_before
            + self.length_cost * (step + 1.0) * not_eosed_before
        )
        expected_length += (step + 1) * not_eosed_before

        z += not_eosed_before
        assert z.allclose(
            torch.ones_like(z)
        ), f"lost probability mass, {z.min()}, {z.max()}"

        for name, value in step_aux.items():
            aux_info[name] = value * not_eosed_before + aux_info.get(name, 0.0)

        aux_info["length"] = expected_length

        logging_strategy = (
            self.train_logging_strategy if self.training else self.test_logging_strategy
        )
        interaction = logging_strategy.filtered_interaction(
            sender_input=torch.zeros(sender_input.num_nodes), # MODIFIED
            receiver_input=torch.zeros(sender_input.num_nodes), # MODIFIED
            labels=labels,
            aux_input=aux_input,
            receiver_output=receiver_output.detach(),
            message=message.detach(),
            message_length=expected_length.detach(),
            aux=aux_info,
        )

        return loss.mean(), interaction

In [30]:
game = CustomSenderReceiverRnnGS(sender_wrapped, receiver_wrapped, loss_nll)

loss, interaction = game(sender_input=graph, labels=graph.labels, receiver_input=graph, aux_input=None)
print("====================================")
print(interaction)

Interaction(sender_input=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), receiver_input=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), labels=tensor([0, 1, 0, 0, 0, 0, 0, 0, 0]), aux_input=None, message=tensor([[[3.2741e-03, 5.3805e-03, 5.6546e-04, 1.4919e-02, 1.2929e-02,
          8.1527e-02, 4.0610e-01, 1.3966e-04, 9.5604e-03, 2.7534e-02,
          1.9643e-02, 1.5255e-04, 1.5031e-04, 1.1153e-03, 6.6448e-02,
          4.3851e-03, 8.4110e-04, 2.4364e-04, 6.6930e-02, 2.7816e-01],
         [2.4063e-03, 2.7710e-03, 2.4723e-04, 8.2089e-01, 3.8277e-03,
          1.8013e-03, 8.5198e-02, 8.3164e-03, 6.7180e-03, 1.0799e-02,
          1.7178e-02, 8.2034e-04, 1.3892e-04, 1.0117e-03, 6.8539e-03,
          2.0015e-03, 3.8317e-04, 9.2605e-04, 1.8848e-02, 8.8655e-03],
         [8.0806e-03, 1.2795e-02, 3.2145e-03, 1.7831e-02, 2.6261e-02,
          1.5558e-01, 5.2083e-02, 1.8758e-02, 1.5623e-02, 1.4591e-02,
          8.9768e-03, 5.2383e-02, 1.0774e-02, 4.9124e-02, 2.7331e-01,
          8.1389e-03, 1.4472

So I have modified the Batch, Trainer and Game classes from EGG which is not ideal. However I keep getting stuck when I want to apply Phong's method. This is mainly because I can't get my agents to handle batches. They are disigned to do one graph at a time since I calculate the message and therefore the labels per graph.

In [31]:
opts = core.init(params=['--random_seed=7', 
                         '--lr=1e-3',   
                         f'--batch_size={options.batch_size}',
                         '--optimizer=adam'])

optimizer = torch.optim.Adam(game.parameters())

trainer = CustomTrainer(
    game=game, 
    optimizer=optimizer, 
    train_data=train_loader,
    validation_data=test_loader, 
    callbacks=[core.ConsoleLogger(as_json=True, print_train_loss=True)]
)

trainer.train(n_epochs=30)

{"loss": -0.15072019398212433, "acc": 0.8541667461395264, "length": 5.5017266273498535, "mode": "train", "epoch": 1}
{"loss": -0.13650795817375183, "acc": 0.8545454740524292, "length": 6.0, "mode": "test", "epoch": 1}
{"loss": -0.14927984774112701, "acc": 0.8564102053642273, "length": 5.320054531097412, "mode": "train", "epoch": 2}
{"loss": -0.13650791347026825, "acc": 0.8545454740524292, "length": 6.0, "mode": "test", "epoch": 2}
{"loss": -0.13827162981033325, "acc": 0.8413461446762085, "length": 5.404200553894043, "mode": "train", "epoch": 3}
{"loss": -0.13650794327259064, "acc": 0.8545454740524292, "length": 6.0, "mode": "test", "epoch": 3}
{"loss": -0.1443415880203247, "acc": 0.8557214140892029, "length": 5.441940784454346, "mode": "train", "epoch": 4}
{"loss": -0.13650794327259064, "acc": 0.8545454740524292, "length": 6.0, "mode": "test", "epoch": 4}
{"loss": -0.15709877014160156, "acc": 0.8709677457809448, "length": 5.48160982131958, "mode": "train", "epoch": 5}
{"loss": -0.13650