# Template for variable length Signalling Game

This template is based on the [MNIST autoencoder tutorial](https://github.com/facebookresearch/EGG/blob/master/tutorials/EGG%20walkthrough%20with%20a%20MNIST%20autoencoder.ipynb) and [signal game implementation](https://github.com/facebookresearch/EGG/blob/master/egg/zoo/signal_game) provided by the [EGG library](https://github.com/facebookresearch/EGG).

Some code is provided by Mathieu Bartels and Liselore Borel Rinkes at the UvA.

Make sure you have a directory `SignalGame` in your Drive!

In [2]:
from torchvision import datasets, transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
from tqdm import tqdm
import os
import pickle
import matplotlib.pyplot as plt
import random
import numpy as np
import scipy.spatial.distance as distance
import scipy.stats
import scipy
import egg.core as core
import egg.zoo as zoo

## Configuration
Make sure to define some important configuration parameters in a convenient place, such as the number of images.

In [3]:
import types
import json

# For convenience and reproducibility, we set some EGG-level command line arguments here
opts = core.init(params=['--random_seed=7', # will initialize numpy, torch, and python RNGs
                         '--lr=1e-3', # sets the learning rate for the selected optimizer 
                         '--batch_size=64',
                         '--vocab_size=100',
                         '--max_len=10',
                         '--n_epochs=15',
                         '--tensorboard',
                         ]) 

_args_dict = {
    "architecture" : {
        "embed_size"      : 64,
        "hidden_sender"   : 200,
        "hidden_receiver" : 200,
        "cell_type"       : 'gru',
    },
    "game" : {
        "num_classes"     : 100, # defined by CIFAR-100
        "game_size"       : 2,
        # OTHER
        "sender_has_distractor" : False,
    },
    "training" : {
        "temperature"     : 1,
    },
}

args = json.loads(json.dumps(_args_dict), object_hook=lambda item: types.SimpleNamespace(**item))

# TODO: other configurations?

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Vision
The vision model is used in this template to create embeddings for the input images. The code below imports a vision module for the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). However, you can also choose for making the visual unit part of the agent architecture and update/train it during the game (as in the MNIST tutorial).

In [4]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Vision(nn.Module):

    def __init__(self, block, layers, num_classes=100):
        super(Vision, self).__init__()
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x

    def classify(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
 
# load pre-trained parameters
num_classes = 100
restnet_location = "https://github.com/chenyaofo/CIFAR-pretrained-models/releases/download/resnet/cifar100-resnet56-2f147f26.pth"
vision = Vision(BasicBlock, [9, 9, 9], num_classes=num_classes).to(device)
vision.load_state_dict(model_zoo.load_url(restnet_location))

vision.eval()

Vision(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=Fals

## Data
Implement a custom dataset to be used in the Signalling Game, building on top of a Dataset of your choice (e.g. CIFAR-100).
*You could use the below code to train a featuriser for CIFAR-100 or implement your own!*

In [5]:
#kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
#kwargs = {'num_workers': 1, 'pin_memory': True} if device==torch.device("cuda") else {}

transform = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023]
    ),
])

train_set = datasets.CIFAR100('./data', train=True, download=True, transform=transform)

test_set = datasets.CIFAR100('./data', train=False, transform=transform)

test_set_originals = datasets.CIFAR100('./data', train=False, transform=transforms.ToTensor())

checking_test_loader = torch.utils.data.DataLoader(test_set, shuffle=False,
                                         batch_size=opts.batch_size, num_workers=2)

# Set to True if you want to evaluate the vision_model
eval_vision_model = False

if eval_vision_model:
  mean_loss, mean_acc, n_batches = 0, 0, 0
  for batch_idx, (data, target) in enumerate(checking_test_loader):
      with torch.no_grad():
        data, target = data.to(device), target.to(device)
        output = vision.classify(data)
        loss = F.cross_entropy(output, target)

        mean_loss += loss.mean().item()
        mean_acc += (target == output.argmax(dim=1)).float().mean().item()
        n_batches += 1
      
  print(f' mean loss: {mean_loss / n_batches}, mean acc: {mean_acc / n_batches}')

Files already downloaded and verified


In [6]:
def get_random_indices(excluded, range, amount):
  indices = random.sample(range, amount)
  while excluded in indices:
    indices = random.sample(range, amount)
  return indices

class SignalGameDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, game_size, vision, embedding_size, dataset_name, device, sender_has_distractor=True, classes=None):
        self.dataset = dataset
        self.game_size = game_size
        self.vision = vision
        self.dataset_name = dataset_name
        self.embedding_size = embedding_size
        self.embeddings = self.pre_process_image_embeddings(16)
        self.sender_has_distractor = sender_has_distractor
        self.device = device

    def pre_process_image_embeddings(self, batch_size):
        if os.path.isfile(f"image_embeddings_{self.dataset_name}.pkl"):
            return torch.load( open(f"image_embeddings_{self.dataset_name}.pkl", "rb" ) )
        trainloader = torch.utils.data.DataLoader(self.dataset, shuffle=False,
                                          batch_size=batch_size, num_workers=2)
        
        image_embeddings = torch.zeros((len(self.dataset), self.embedding_size))
        labels = torch.zeros(len(self.dataset))
        for i, (x, y) in enumerate(tqdm(trainloader)):
          x = x.to(device)
          with torch.no_grad():
            embedding = self.vision(x).cpu()
          image_embeddings[i*batch_size:(i+1) * batch_size, :] = embedding
          labels[i*batch_size:(i+1) * batch_size] = y
        
        torch.save(image_embeddings, open(f"image_embeddings_{self.dataset_name}.pkl", "wb" ))
        return image_embeddings

    def get_item_info(self, index):
        image, classlabel = self.dataset[index]
        return image, classlabel

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self, item):
        dataset = self.embeddings
        game_size = self.game_size
        target_image = dataset[item]

        indices = get_random_indices(item, range(self.__len__()), game_size-1)
        images = [target_image] + [dataset[indice] for indice in indices]

        sender_images = torch.stack(images, dim=0)

        perm = torch.randperm(game_size)
        receiver_imgs = sender_images[perm]
        target = torch.argmin(perm)
        
        if not self.sender_has_distractor:
            sender_images = target_image

        return sender_images, target, receiver_imgs

trainset = SignalGameDataset(train_set, args.game.game_size, vision, args.architecture.embed_size, 'train',
                             device,sender_has_distractor=args.game.sender_has_distractor)
trainloader = torch.utils.data.DataLoader(trainset, shuffle=True,
                                          batch_size=opts.batch_size, num_workers=0) #was 2

testset = SignalGameDataset(test_set, args.game.game_size, vision, args.architecture.embed_size, 'test',
                            device,sender_has_distractor=args.game.sender_has_distractor)
testloader = torch.utils.data.DataLoader(testset, shuffle=False,
                                         batch_size=opts.batch_size, num_workers=0) # was 2

## Agent design
We can use the Rnn Wrappers from EGG to implement a recurrent neural agent on top of your Sender and Receiver design. Here the Gumbel Softmax (GS) is used, but you can also use a [Reinforce Wrapper](https://github.com/facebookresearch/EGG/blob/master/egg/core/reinforce_wrappers.py).

See [RnnSenderGS and RnnReceiverGS](https://github.com/facebookresearch/EGG/blob/master/egg/core/gs_wrappers.py) in the documentation for what happens under the hood.

### Sender

In [7]:
class Sender(nn.Module):
    def __init__(self, embed_size, game_size, hidden_sender):
      super(Sender, self).__init__()
      self.embed_size = embed_size
      self.game_size = game_size
      
      self.lin4 = nn.Linear(embed_size * game_size, hidden_sender, bias=True)
        
    def forward(self, imgs):
      # imgs shape sender torch.Size([batch, embed_size])
      out = self.lin4(imgs.view(-1, self.embed_size * self.game_size)).tanh()
      # out shape sender torch.Size([64, 200])
      return out

if args.game.sender_has_distractor:
    sender = Sender(args.architecture.embed_size, args.game.game_size, args.architecture.hidden_sender)
else:
    sender = Sender(args.architecture.embed_size, 1, args.architecture.hidden_sender)

sender = core.RnnSenderGS(sender, opts.vocab_size, args.architecture.embed_size, args.architecture.hidden_sender, cell=args.architecture.cell_type,
                        max_len=opts.max_len, temperature=args.training.temperature, straight_through=True)

### Receiver


In [8]:
class Receiver(nn.Module):
  def __init__(self, hidden_receiver, embed_size):
    super(Receiver, self).__init__()
    self.embedding_size = embed_size

    self.fc1 = nn.Linear(embed_size, hidden_receiver)

  def forward(self, message, imgs):
    # imgs shape  torch.Size([64, 2, 64])
    # torch.Size([batch_size, game_size, hidden_size])
    embedded_input = self.fc1(imgs).tanh()
    
    # message shape torch.Size([batch, hidden])
    # torch.Size([batch_size, game_size, 1])
    energies = torch.matmul(embedded_input, torch.unsqueeze(message, dim=-1))

    return energies.squeeze()

receiver = Receiver(args.architecture.hidden_receiver, args.architecture.embed_size)
receiver = core.RnnReceiverGS(receiver, opts.vocab_size, args.architecture.embed_size,
                    args.architecture.hidden_receiver, cell=args.architecture.cell_type)

# Loss function

In [9]:
import torch.nn.functional as F

def loss(_sender_input,  _message, _receiver_input, receiver_output, _labels):
    acc = (receiver_output.argmax(dim=1) == _labels).detach().float()
    loss = F.cross_entropy(receiver_output, _labels, reduction="none")
    # print('Loss: ', loss.mean().cpu().item(), 'Acc: ', acc.mean().cpu().item())
    return loss, {'acc': acc}

## Game setup and training
Use `core.SenderReceiverRnnGS` for creating a variable message length game, set the optimizer and other options, and start training the agents.

In [10]:
%load_ext tensorboard

In [11]:
model_prefix = f"multi_target_max_len_{opts.max_len}"
models_path = "/content/gdrive/My Drive/SignalGame/models"

checkpointer = core.callbacks.CheckpointSaver(checkpoint_path=models_path, checkpoint_freq=0, prefix=model_prefix)
early_stopper = core.early_stopping.EarlyStopperAccuracy(threshold = 0.1)
logger = core.callbacks.ConsoleLogger(print_train_loss=True)
temperature_update = core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)

#callbacks = [checkpointer, early_stopper, logger, temperature_update]
checkpointer_class = core.callbacks.CheckpointSaver(checkpoint_path=models_path, checkpoint_freq=0, prefix=model_prefix+"_class")


In [12]:
game = core.SenderReceiverRnnGS(sender, receiver, loss)
optimizer = torch.optim.Adam(game.parameters())

callbacks = [core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1),
             core.ConsoleLogger(as_json=True, print_train_loss=True),
             core.TensorboardLogger(),
             core.EarlyStopperAccuracy(0.97)]

trainer = core.Trainer(
    game=game, optimizer=optimizer, train_data=trainloader,
    validation_data=testloader, callbacks=callbacks)

In [13]:
trainer.train(n_epochs=opts.n_epochs)

{"loss": 0.3063197135925293, "acc": 0.8298199772834778, "length": 7.104320049285889, "mode": "train", "epoch": 1}
{"loss": 0.1404149830341339, "acc": 0.9340000152587891, "length": 10.236499786376953, "mode": "test", "epoch": 1}
{"loss": 0.08692803978919983, "acc": 0.9606999754905701, "length": 9.655579566955566, "mode": "train", "epoch": 2}
{"loss": 0.07190242409706116, "acc": 0.9686999917030334, "length": 10.068300247192383, "mode": "test", "epoch": 2}
{"loss": 0.05964502692222595, "acc": 0.9736199975013733, "length": 10.071399688720703, "mode": "train", "epoch": 3}
{"loss": 0.04568062350153923, "acc": 0.9797000288963318, "length": 9.040800094604492, "mode": "test", "epoch": 3}


In [14]:
%tensorboard --logdir ./runs

ERROR: Could not find `tensorboard`. Please ensure that your PATH
contains an executable `tensorboard` program, or explicitly specify
the path to a TensorBoard binary by setting the `TENSORBOARD_BINARY`
environment variable.

## Evaluation of the emergent languages
Next up, start analysing the languages and other behaviour you are interested in.

In [1]:
from egg.zoo.objects_game.util import dump_sender_receiver
# Check out https://github.com/facebookresearch/EGG/blob/aba2489e78f0b6e202bf2c6138fde41bf9056cb5/egg/zoo/objects_game/util.py

sender_inputs, messages, receiver_inputs, receiver_outputs, labels = dump_sender_receiver(game, testloader, True, variable_length=True, device=device)
messages = [message.cpu().numpy() for message in messages]
for m in messages:
  print(str(m))

NameError: name 'game' is not defined

### Zero-shot evaluation
Evaluate the messages on a held out test set.

In [16]:
raise NotImplementedError

NotImplementedError: 