# Template for variable length Signalling Game

This template is based on the [MNIST autoencoder tutorial](https://github.com/facebookresearch/EGG/blob/master/tutorials/EGG%20walkthrough%20with%20a%20MNIST%20autoencoder.ipynb) and [signal game implementation](https://github.com/facebookresearch/EGG/blob/master/egg/zoo/signal_game) provided by the [EGG library](https://github.com/facebookresearch/EGG).

Some code is provided by Mathieu Bartels and Liselore Borel Rinkes at the UvA.

Make sure you have a directory `SignalGame` in your Drive!

In [None]:
# if you are running this notebook via Google Colab, you have to install EGG first
!pip install git+https://github.com/facebookresearch/EGG.git

# Setup connection to Google Drive
from google.colab import drive
drive.mount('/content/drive/')
%cd "/content/drive/My Drive/SignalGame"
!mkdir -p 'models'
!ls


# also you'll need to change the runtime to GPU (Runtime -> Change runtime type -> Hardware Accelator -> GPU)

In [None]:
from torchvision import datasets, transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
from tqdm import tqdm
import os
import pickle
import matplotlib.pyplot as plt
import random
import numpy as np
import scipy.spatial.distance as distance
import scipy.stats
import scipy
import egg.core as core
import egg.zoo as zoo

## Configuration
Make sure to define some important configuration parameters in a convenient place, such as the number of images.

In [None]:
import types
import json

# For convenience and reproducibility, we set some EGG-level command line arguments here
opts = core.init(params=['--random_seed=7', # will initialize numpy, torch, and python RNGs
                         '--lr=1e-3', # sets the learning rate for the selected optimizer 
                         '--batch_size=64',
                         '--vocab_size=100',
                         '--max_len=10',
                         '--n_epochs=15',
                         '--tensorboard',
                         ]) 

# Other configurations that are not part of the above command line arguments we define separately. 
# Feel free to use a different format.
_args_dict = {
    "architecture" : {
        "embed_size"      : 64,
        "hidden_sender"   : 200,
        "hidden_receiver" : 200,
        "cell_type"       : 'gru',
    },
    "game" : {
        "num_imgs"        : 2, # number of images the game is played with
    },
    "training" : {
        "temperature"     : 1,
        "decay"           : 0.9,
        "early_stop_accuracy" : 0.97,
    },
}

# A trick for having a hierarchical argument namespace from the above dict
args = json.loads(json.dumps(_args_dict), object_hook=lambda item: types.SimpleNamespace(**item))

print("Cell type of the agents:", args.architecture.cell_type)

# TODO: other configurations?
raise NotImplementedError

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Vision
In this template we use the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) for the images and use a fixed [pre-trained vision module](https://github.com/chenyaofo/CIFAR-pretrained-models/).
The vision module will encode the images in the dataset before the agents get to see these.
You can also choose to make the vision module part of the agents and update the parameters during playing the game (as in the MNIST tutorial).

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Vision(nn.Module):

    def __init__(self, block, layers, num_classes=100):
        super(Vision, self).__init__()
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x

    def classify(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
 
# load pre-trained parameters
num_classes = 100
restnet_location = "https://github.com/chenyaofo/CIFAR-pretrained-models/releases/download/resnet/cifar100-resnet56-2f147f26.pth"
vision = Vision(BasicBlock, [9, 9, 9], num_classes=num_classes).to(device)
vision.load_state_dict(model_zoo.load_url(restnet_location))

vision.eval()

## Data
Implement a custom dataset to be used in the Signalling Game, building on top of a Dataset of your choice.
*You could use the below code to use a featuriser for CIFAR-100 or implement your own!*

In [None]:
transform = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023]
    ),
])

cifar_train_set = datasets.CIFAR100('./data', train=True, download=True, transform=transform)

cifar_test_set = datasets.CIFAR100('./data', train=False, transform=transform)

In [None]:
class SignalGameDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, num_imgs, vision, classes=None):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, item):
        raise NotImplementedError
        return sender_imgs, target, receiver_imgs

trainset = SignalGameDataset(cifar_train_set, args.game.num_imgs, vision)
trainloader = torch.utils.data.DataLoader(trainset, shuffle=True,
                                          batch_size=opts.batch_size, num_workers=2)

testset = SignalGameDataset(cifar_test_set, args.game.num_imgs, vision)
testloader = torch.utils.data.DataLoader(testset, shuffle=False,
                                         batch_size=opts.batch_size, num_workers=2)

## Agent design
We can use the Rnn Wrappers from EGG to implement a recurrent neural agent on top of your Sender and Receiver design. Here the Gumbel Softmax (GS) is used, but you can also use a [Reinforce Wrapper](https://github.com/facebookresearch/EGG/blob/master/egg/core/reinforce_wrappers.py).

See [RnnSenderGS and RnnReceiverGS](https://github.com/facebookresearch/EGG/blob/master/egg/core/gs_wrappers.py) in the documentation for what happens under the hood.

### Sender

In [None]:
class Sender(nn.Module):
    def __init__(self, embed_size, num_imgs, hidden_sender):
      super(Sender, self).__init__()
      raise NotImplementedError
        
    def forward(self, imgs):
      raise NotImplementedError

sender = Sender(args.architecture.embed_size, args.game.num_imgs, args.architecture.hidden_sender)

sender = core.RnnSenderGS(sender, opts.vocab_size, args.architecture.embed_size, args.architecture.hidden_sender, cell=args.architecture.cell_type,
                        max_len=opts.max_len, temperature=args.training.temperature, straight_through=True)

### Receiver


In [None]:
class Receiver(nn.Module):
    def __init__(self, hidden_size, embed_size):
        super(Receiver, self).__init__()
        raise NotImplementedError

    def forward(self, hidden_state, imgs):
        raise NotImplementedError

receiver = Receiver(args.architecture.hidden_receiver, args.architecture.embed_size)

receiver = core.RnnReceiverGS(receiver, opts.vocab_size, args.architecture.embed_size,
                    args.architecture.hidden_receiver, cell=args.architecture.cell_type)

## Loss function

In [None]:
def loss(_sender_input,  _message, _receiver_input, receiver_output, _labels):
    raise NotImplementedError
    return loss, {'acc': accuracy}

## Game setup and training
Use `core.SenderReceiverRnnGS` for creating a variable message length game, set the optimizer and other options, and start training the agents.

In [None]:
%load_ext tensorboard

In [None]:
model_prefix = f"maxlen_{opts.max_len}" # Example
models_path = "/content/drive/My Drive/SignalGame/models" # location where we store trained models

checkpointer = core.callbacks.CheckpointSaver(checkpoint_path=models_path, checkpoint_freq=0, prefix=model_prefix)

In [None]:
game = core.SenderReceiverRnnGS(sender, receiver, loss)
optimizer = torch.optim.Adam(game.parameters())

callbacks = [core.TemperatureUpdater(agent=game.sender, decay=args.training.decay, minimum=0.1),
             core.ConsoleLogger(as_json=True, print_train_loss=True),
             core.TensorboardLogger(),
             core.EarlyStopperAccuracy(args.training.early_stop_accuray),
             checkpointer]

trainer = core.Trainer(
    game=game, optimizer=optimizer, train_data=trainloader,
    validation_data=testloader, callbacks=callbacks)

In [None]:
trainer.train(n_epochs=opts.n_epochs)

In [None]:
%tensorboard --logdir ./runs

## Evaluation of the emergent languages
Next up, start analysing the languages and other behaviour you are interested in.