**Project description** :
In the following notebook, we investigate to what extent languages induced by a two-player (sender to receiver) communication game share the linguistic properties of natural languages. More precisely, we study the role of the communication channel between the sender and the receiver.

**Notebook permanent link** : 
If you wish to access this notebook in the future, please ue this link : https://colab.research.google.com/drive/1lNgQnARZxYwPOxq_3LvjDSsjro-y2MCR?usp=sharing

## 🚨 Instructions 🚨



```
In order to run this notebook properly on your Google Drive, please make sure to follow the following steps :
    1. First, make sure that you have enabled access to a GPU hardware accelerator. In order to do so, go to : 
      'Runtime' > 'Change runtime type' > 'Hardware accelerator' -> 'GPU'
    2. Then, before running 

```

Please note that if you wish to directly see the results of the experiments implemented in the notebook below, please consult this project's [Weights & Biases dashboard](https://wandb.ai/volut3s/nlp-emergent-languages?workspace=user-volut3s). Feel free to consult our [Github repository](https://github.com/excitingtimes/encom) for a complete overview of the code of the experiments.

*Note* : Run the following code in order to prevent you from being connected from Google Colab's VM :
```
function KeepClicking(){
console.log("Clicking");
document.querySelector("colab-connect-button").click()
}
setInterval(KeepClicking,60000)
```




### Main folders of the project

In [None]:
PROJECT_DIR = "drive/My Drive/Projects/nlp_emergent_languages/"
CHECKPOINTS_DIR = PROJECT_DIR + 'checkpoints/'
INTERACTIONS_DIR = PROJECT_DIR + 'interactions/' # Apparently, EGG does store interactions one level below the level provided 
DATASETS_DIR = PROJECT_DIR + 'datasets/'
PRETRAINED_MODELS_DIR = '/content/' + PROJECT_DIR + 'pretrained_models/'
FINETUNED_MODELS_DIR = '/content/' + PROJECT_DIR + 'finetuned_models/'

DATASET_IMAGENET_DIR = 'imagenet/imagenet/'
DATASET_TINY_IMAGENET_DIR = 'tiny-imagenet/tiny-imagenet-200/'

### Dynamically define a checkpoint if you want to continue an experiment or evaluate a pretrained agents pair

In [None]:
# CHECKPOINT_DIR = PROJECT_DIR + 'runs/'
# CHECKPOINT_PATH = None # CHECKPOINT_DIR + 'your_checkpoint_folder_here/'

### Wandb-related parameters

In [None]:
WANDB_API_KEY = "your-own-api-key"
WANDB_PROJECT = "nlp-emergent-languages"
WANDB_ENTITY = "volut3s"
WANDB_NOTES = "We assess the robustness and generalization / compositionality capabilities of emergent languages \
in a two-agent signaling game under channel noisyness constriants"
WANDB_EXPERIMENT_NAME="No channel constraints"
WANDB_EXPERIMENT_GROUP="vision-model-pretraining"  # 'reconstruction-game', 'discrimination-game', 'vision-model-pretraining', 'ablation-study'

#### Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#### Utility function for memory management :

In [None]:
def free_memory(
    model, 
    *args,
):
    # Now, we free the available memory in order to launch other training experiments.
    print("Freeing memory ...")
    torch.cuda.empty_cache()

    for x in args:
        x = None
        del x
    print("Successfully freed some memory !")

---
# 0. 📚 Installing useful dependencies 📚
---

In [None]:
# EGG : Emergence of lanGuage in Games environment (see https://github.com/facebookresearch/EGG for more information)
!pip install --quiet git+https://github.com/facebookresearch/EGG.git
!pip install --quiet torchvision
!pip install --quiet wandb
!pip install --quiet pytorch_lightning
!pip install --quiet h5py

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2


import egg.core as core

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision as tv
from torchvision import models
from torchvision import datasets
from torchvision import transforms as T

import wandb
import matplotlib.pyplot as plt
import random
import numpy as np
import seaborn as sns
import plotly.express as px

from PIL import ImageFilter
from operator import itemgetter

import copy
import time
import os
import re
import csv
from pylab import rcParams
from functools import partial
from itertools import product
from typing import Tuple, Optional, Union
from collections import OrderedDict
from torch.distributions import Categorical

c = copy.deepcopy # To copy neural modules without tying the weights (independent copies)

def cycle(
    loader,
): # Iterates over a dataloader
    while True:
        for data in loader:
            yield data

rcParams['figure.figsize'] = 5, 10

# For convenince and reproducibility, we set some EGG-level command line arguments here
opts = core.init(params=['--random_seed=7',  # Will initialize numpy, torch, and python RNGs
                         '--lr=1e-3',        # Sets the learning rate for the selected optimizer 
                         '--batch_size=32',
                         '--optimizer=adam', # 'sgd', 'adagrad', 'adam'
                         # '--fp16',
                         '--update_freq=1',  # Updates learnable weights every x
                         f'--checkpoint_dir="{PROJECT_DIR}"',
                         '--checkpoint_freq=10',
                         '--validation_freq=5',
                         # f'--load_from_checkpoint="{CHECKPOINT_PATH}"',
                         # '--tensorboard',
                         ])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Random **seed initialization** (for *reproducibility*) :

In [None]:
SET_RANDOM_SEED = False

if SET_RANDOM_SEED:
    hashed_sentence = 'emergent languages are very cool, yes indeed !'
    get_seed = lambda s: hash(s) % (2**32 - 1)
    SEED = get_seed(hashed_sentence)

    # Setting the random seeds of Numpy, PyTorch and Random
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    random.seed(SEED)

    def seed_worker(
        worker_id,
    ):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)

    # Initialize the following parameters in the dataloader :
    worker_init_fn = seed_worker
    generator = g

    # Avoid the following insofar as possible for efficiency purposes :
    torch.use_deterministic_algorithms(False)

In [None]:
def wandb_connect():
    wandb_conx = wandb.login(key = WANDB_API_KEY)
    print(f"Connected to Wandb online interface : {wandb_conx}")

wandb_connect()

---
#1. 👁 Vision feature extractor 👁
---

## A. 🎦 Pretraining our own vision module (*proof of principle, do not use this technique further down*)

### Defining the vision module architecture (custom example)

In [None]:
class Vision(nn.Module):
    def __init__(
        self,
    ):
        super(Vision, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)

    def forward(
        self, 
        x,
    ):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        return x

class PretrainNet(nn.Module):
    def __init__(
        self, 
        vision_module, 
        dim_vision_out=500,
    ):
        super(PretrainNet, self).__init__()
        self.vision_module = vision_module
        self.dim_vision_out = dim_vision_out
        self.fc = nn.Linear(self.dim_vision_out, 10)
        
    def forward(
        self, 
        x,
    ):
        x = self.vision_module(x)
        x = self.fc(F.leaky_relu(x))

        return x

### Defining a data augmentation strategy

In [None]:
# Defining forms of data augmentation in order to discourage overfitting of the induced language on the particular dataset
transform_train = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
transform_val_test = T.ToTensor()

### Instantiating training and test dataloaders :

In [None]:
def generate_dataloaders_from_remote(
    dataset, 
    transform_train, 
    transform_val_test, 
    name='mnist', 
    num_workers=1, 
    pin_memory=True, 
    val_test_ratio=(0.75, 0.25),
):
    kwargs = {'num_workers': num_workers, # In order to parallelize dataset loading : interesting :) 
            'pin_memory': pin_memory} if torch.cuda.is_available() else {}

    batch_size = opts.batch_size

    if name == 'places-365':
        train_dataset = dataset(DATASETS_DIR + name
                                + '/', split='train-standard', download=True, transform=transform_train)
    elif name == 'svhn':
        train_dataset = dataset(DATASETS_DIR + name + '/', split='train', download=True, transform=transform_train)
    elif name == 'inaturalist':
        train_dataset = dataset(DATASETS_DIR + name + '/', version='2021_train', download=True, transform=transform_train)
    elif name == 'fake-data':
        train_dataset = dataset(size=1_000, image_size=(3, 26, 26), num_classes=10, transform=transform_train)
    elif name == 'caltech-101':
        full_dataset = dataset(DATASETS_DIR + name + '/', target_type='category', download=True, transform=transform_train)
    elif name == 'caltech-256':
        full_dataset = dataset(DATASETS_DIR + name + '/', download=True, transform=transform_train)
    else:
        train_dataset = dataset(DATASETS_DIR + name + '/', train=True, download=True, transform=transform_train)

    if name in ['caltech-101', 'caltech-256']:
        len_train = int(0.85 * len(full_dataset))
        len_val_test = len(full_dataset) - len_train
        print("Caltech dataset : len_train={}, len_val_test={}".format(len_train, len_val_test))
        train_dataset, __test_dataset__ = torch.utils.data.random_split(full_dataset, lengths=[len_train, len_val_test])

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, 
                                            shuffle=True, 
                                            # worker_init_fn=worker_init_fn,
                                            # generator=generator,
                                            **kwargs)
    
    len_train = len(train_dataset)
    print("Loading dataset : [{}]".format(name))
    print("\nNumber of training samples : {}".format(len_train))

    if name =='places-365':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', split='val', download=True, transform=transform_val_test)
    elif name =='svhn':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', split='test', download=True, transform=transform_val_test)
    elif name == 'inaturalist':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', version='2021_valid', download=True, transform=transform_val_test)
    elif name == 'fake-data':
        __test_dataset__ = dataset(size=100, image_size=(3, 26, 26), num_classes=10, transform=transform_val_test)
    elif name in ['caltech-101', 'caltech-256']:
        pass
    else:
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', train=False, download=True, transform=transform_val_test)

    len_val = int(val_test_ratio[0] * len(__test_dataset__))
    len_test = len(__test_dataset__) - len_val
    print("Number of validation samples : {}".format(len_val))
    print("Number of test samples : {}".format(len_test))

    val_dataset, test_dataset = torch.utils.data.random_split(__test_dataset__, [len_val, len_test])

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size, 
                                             shuffle=True, 
                                             # worker_init_fn=worker_init_fn,
                                             # generator=generator,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size, 
                                              shuffle=True, 
                                              # worker_init_fn=worker_init_fn,
                                              # generator=generator,
                                              **kwargs)

    print("\nBatch size : {}".format(batch_size))
    print("\nLength of training dataloader (in batches) : {}".format(len(train_loader)))
    print("Length of validation dataloader (in batches) : {}".format(len(val_loader)))
    print("Length of test dataloader (in batches) : {}".format(len(test_loader)))

    split_dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
    }

    return split_dataloaders

### Training our custom vision model :

In [None]:
# Defining the optimization routine of the model
vision = Vision()
class_prediction = PretrainNet(vision) #  note that we pass vision - which we want to pretrain
optimizer = core.build_optimizer(class_prediction.parameters()) #  uses command-line parameters we passed to core.init
class_prediction = class_prediction.to(device)

In [None]:
transform = T.ToTensor()
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
transform_train = T.ToTensor()

batch_size = opts.batch_size # set via the CL arguments above
train_loader = torch.utils.data.DataLoader(
        datasets.KMNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=batch_size, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
        datasets.KMNIST('./data', train=False, transform=transform),
           batch_size=batch_size, shuffle=False, **kwargs)

In [None]:
for epoch in range(10):
    mean_loss, n_batches = 0, 0
    for batch_idx, (data, target) in enumerate(train_loader):
        # print(data.shape, target.shape)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = class_prediction(data)
        # print(output.shape)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        
        mean_loss += loss.mean().item()
        n_batches += 1
        
    print(f'Train Epoch: {epoch}, mean loss: {mean_loss / n_batches}')

## B. 🌏 Import a SOTA pretrained model (*preferred technique*)

### [*Optional, only run once*] --- Run the code below in order to expand *tiny-imagenet.zip* ---

In [None]:
# !pwd
# !ls drive/My\ Drive/Projects/nlp_emergent_languages/datasets/
# !unzip drive/My\ Drive/Projects/nlp_emergent_languages/datasets/tiny-imagenet -d drive/My\ Drive/Projects/nlp_emergent_languages/datasets/tiny-imagenet

### Data augmentation for the visual domain :

In [None]:
class GaussianBlur():
    def __init__(
        self, 
        sigma=[0.1, 2.0],
    ):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

class TransformsAugment():
    def __init__(
        self, 
        size, 
        multi_channel=True,
    ):
        print("Transforms Augment")
        s = 1
        color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        transformations = [
            T.RandomResizedCrop(size=size),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
            T.RandomHorizontalFlip(),  # with 0.5 probability
            T.ToTensor(),
        ]
        # We "pseudo-colorize" the image by broadcasting to the three dimensions
        if not multi_channel:
            # Solution : number 1 : we simply broadcast the pixel information over all three channels,
            # but the main problem is that this is suboptimal
            # transformations.append(T.Lambda(lambda x: x[0:1, :, :]))
            transformations.append(T.Lambda(lambda x: x.repeat(3,1,1)))

            print("Not multi-channel")

        """
        transformations.append(
            T.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )
        )
        """
        self.transform = T.Compose(transformations)

    def __call__(self, x):
        x_trans = self.transform(x)
        return x_trans

### Defining a data augmentation strategy :

In [None]:
# transform_train = TransformsAugment(size=256, imagenet=True)
# transform_val_test = T.ToTensor()
# transform_train = T.ToTensor()
# transform_val_test = T.ToTensor()

### Instantiating training and test dataloaders :

In [None]:
def generate_dataloaders_from_local(
    path, 
    transform_train, 
    transform_val_test, 
    num_workers=2, 
    pin_memory=True, 
    val_test_ratio=(0.75, 0.25),
):
    kwargs = {'num_workers': num_workers, # In order to parallelize dataset loading : interesting :) 
            'pin_memory': pin_memory} if torch.cuda.is_available() else {}

    batch_size = opts.batch_size

    train_dataset = datasets.ImageFolder(root=path + 'train/', transform=transform_train)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, 
                                            shuffle=True, 
                                            # worker_init_fn=worker_init_fn,
                                            # generator=generator,
                                            **kwargs)

    __test_dataset__ = datasets.ImageFolder(path + 'val/', transform=transform_val_test)

    len_val = int(val_test_ratio[0] * len(__test_dataset__))
    len_test = len(__test_dataset__) - len_val
    print("Number of validation samples : {}".format(len_val))
    print("Number of test samples : {}".format(len_test))

    val_dataset, test_dataset = torch.utils.data.random_split(__test_dataset__, [len_val, len_test])

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size, 
                                             shuffle=True, 
                                             # worker_init_fn=worker_init_fn,
                                             # generator=generator,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size, 
                                              shuffle=True, 
                                              # worker_init_fn=worker_init_fn,
                                              # generator=generator,
                                              **kwargs)

    print("\nLength of training dataloader : {}".format(len(train_loader)))
    print("Length of validation dataloader : {}".format(len(val_loader)))
    print("Length of test dataloader : {}".format(len(test_loader)))

    split_dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
    }

    return split_dataloaders

### Downloading a few natural images datasets

In [None]:
split_dataloaders = {
    'train': None,
    'val': None,
    'test': None,
}

image_datasets = {
    'cifar-10': split_dataloaders,
    'cifar-100': split_dataloaders,
    'mnist': split_dataloaders,
    'tiny-imagenet': split_dataloaders,
    'fashion-mnist': split_dataloaders,
    'q-mnist': split_dataloaders,
    'k-mnist': split_dataloaders,
    'svhn': split_dataloaders,
    'caltech-101': split_dataloaders,
}

In [None]:
image_datasets.update({
    'imagenet': split_dataloaders,
    'places-365': split_dataloaders,
    'inaturalist': split_dataloaders,
    'fake-data': split_dataloaders,
})

#### 🏞 CIFAR-10 & CIFAR-100 🏞

In [None]:
image_datasets['cifar-10'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.CIFAR10, 
    transform_train, 
    transform_val_test, 
    name='cifar-10',
)
image_datasets['cifar-100'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.CIFAR100, 
    transform_train, 
    transform_val_test, 
    name='cifar-100',
)

#### 🔢 MNIST 🔢

In [None]:
image_datasets['mnist'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.MNIST,
    transform_train, 
    transform_val_test, 
    name='mnist',
)

#### 🪐 ImageNet & TinyImageNet 🪐 *(we retrieve both of them from a local repository)*

In [None]:
"""
image_datasets['imagenet'] = lambda transform_train, transform_val_test: generate_dataloaders_from_local(
    DATASETS_DIR + DATASET_IMAGENET_DIR,
    transform_train, 
    transform_val_test,
)
image_datasets['tiny-imagenet'] = lambda transform_train, transform_val_test: generate_dataloaders_from_local(
    DATASETS_DIR + DATASET_TINY_IMAGENET_DIR,
    transform_train, 
    transform_val_test,
)
"""

#### 🥋 FashionMNIST 🥋

In [None]:
image_datasets['fashion-mnist'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.FashionMNIST,
    transform_train, 
    transform_val_test,
    name='fashion-mnist',
)

#### 🌉 Places365 🌉 

In [None]:
"""
image_datasets['places-365'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.Places365, 
    transform_train, 
    transform_val_test, 
    name='places-365',
)
"""

#### 🌿 iNaturalist 🌿

In [None]:
"""
image_datasets['inaturalist'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.INaturalist, 
    transform_train, 
    transform_val_test, 
    name='inaturalist',
)
"""

#### 🤡 FakeData 🤡

In [None]:
"""
image_datasets['fake-data'] = lambda: generate_dataloaders_from_remote(
    datasets.FakeData, 
    transform_train, 
    transform_val_test, 
    name='fake-data',
)
"""

#### 🧾 QMNIST 🧾

In [None]:
image_datasets['q-mnist'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.QMNIST, 
    transform_train, 
    transform_val_test, 
    name='q-mnist',
)

#### 🔖 KMNIST 🔖

In [None]:
image_datasets['k-mnist'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.KMNIST, 
    transform_train, 
    transform_val_test, 
    name='k-mnist',
)

#### 🗻 SVHN 🗻

In [None]:
image_datasets['svhn'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.SVHN,
    transform_train, 
    transform_val_test, 
    name='svhn',
)

#### 🪑 Caltech-101 🪑

In [None]:
"""
image_datasets['caltech-101'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.Caltech101, 
    transform_train, 
    transform_val_test, 
    name='caltech-101',
)
"""

#### 💺 Caltech-256 💺

In [None]:
"""
image_datasets['caltech-256'] = lambda transform_train, transform_val_test: generate_dataloaders_from_remote(
    datasets.Caltech101, 
    transform_train, 
    transform_val_test, 
    name='caltech-256',
)
"""

#### Standard image sizes for all datasets

In [None]:
sizes = {
    'cifar-10': 32,
    'cifar-100': 32,
    'mnist': 28,
    'imagenet': 256,
    'tiny-imagenet': 64,
    'fashion-mnist': 28,
    'places-365': 256,
    'inaturalist': 256, # the dimension to resize the images to, but they may be actually higher-def (up to 2,048 px)
    'fake-data': 224,
    'q-mnist': 28,
    'k-mnist': 28,
    'svhn': 32,
    'caltech-101': 64, # there are actually sizes for each image - therefore we round up to the average value
    'caltech-256': 64, # same
}
n_channels = {
    'cifar-10': 3,
    'cifar-100': 3,
    'mnist': 1,
    'imagenet': 3,
    'tiny-imagenet': 3,
    'fashion-mnist': 1,
    'places-365': 3,
    'inaturalist': 3,
    'fake-data': 3,
    'q-mnist': 1,
    'k-mnist': 1,
    'svhn': 1,
    'caltech-101': 3,
    'caltech-256': 3,
}

### Downloading a few pretrained vision models

In [None]:
resnets = {
        "resnet-18": (lambda: tv.models.resnet18(pretrained=True, progress=True)),
        "resnet-34": (lambda: tv.models.resnet34(pretrained=True, progress=True)),
        "resnet-50": (lambda: tv.models.resnet50(pretrained=True, progress=True)),
        "resnet-101": (lambda: tv.models.resnet101(pretrained=True, progress=True)),
        "resnet-152": (lambda: tv.models.resnet152(pretrained=True, progress=True)),
    }

efficientnets = {
    "efficientnet-b0": (lambda: tv.models.efficientnet_b0(pretrained=True, progress=True)),
    "efficientnet-b1": (lambda: tv.models.efficientnet_b1(pretrained=True, progress=True)),
    "efficientnet-b2": (lambda: tv.models.efficientnet_b2(pretrained=True, progress=True)),
    "efficientnet-b3": (lambda: tv.models.efficientnet_b3(pretrained=True, progress=True)),
    "efficientnet-b4": (lambda: tv.models.efficientnet_b4(pretrained=True, progress=True)),
    "efficientnet-b5": (lambda: tv.models.efficientnet_b5(pretrained=True, progress=True)),
    "efficientnet-b6": (lambda: tv.models.efficientnet_b6(pretrained=True, progress=True)),
    "efficientnet-b7": (lambda: tv.models.efficientnet_b7(pretrained=True, progress=True)),
}

def get_vision_module(version: str):
    """
    Loads ResNet & EfficientNet encoders from torchvision along with the final features number
    """

    os.environ['TORCH_HOME'] = PRETRAINED_MODELS_DIR # Used in order to specify where to save the pretrained models, so as not to load them again in the future

    if (version not in resnets) and (version not in efficientnets):
        raise KeyError(f"{version} is not a valid ResNet / EfficientNet version")

    models_library = {**resnets, **efficientnets}

    model = models_library[version]()
    
    features_dim = model.fc.out_features

    return model, features_dim

try:
    get_vision_module('downloading')
except:
    print("Pretrained vision models successfully downloaded")

### Instantiating our model

In [None]:
"""
model_name = "resnet-18"

# vision, dim_vision_out = get_vision_module(model_name)
dim_vision_out = 500
class_prediction = PretrainNet(vision, dim_vision_out=dim_vision_out)
optimizer = core.build_optimizer(class_prediction.parameters()) #  uses command-line parameters we passed to core.init
class_prediction = class_prediction.to(device)
"""

## C. 🖌 Finetuning our vision models on the specified dataset

Please refer to this specific notebook for further details : https://colab.research.google.com/drive/12m-SOfFKWQzys5h-z6p8caJ8-6Esgu1e?usp=sharing

---
# 2. 🗺 Defining the **agents**' internal models
---

## A. 🗣 **Sender** agent

#### Helper function

In [None]:
intertwine = lambda a, b: [x for pair in zip(a, b) for x in pair]
print(intertwine([1, 3, 5], [2, 4]))

def get_mlp(
    input_size=400, 
    output_size=400, 
    hidden_sizes=[],
):
    n_h = len(hidden_sizes)

    linears = [nn.Linear(input_size, hidden_sizes[0])]
    for i in range(n_h):
        if i == n_h - 1:
            linears.append(nn.Linear(hidden_sizes[i], output_size))
        else:
            linears.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
    activations = [nn.ReLU() for _ in range(n_h)]

    mix = intertwine(linears, activations) + [linears[-1]]
    keys = [f"linear_{i//2}" if (i%2 == 0) else f"relu_{i//2}" for i in range(len(mix))]

    mlp = nn.Sequential(OrderedDict(zip(keys, mix)))

    return mlp

In [None]:
class Sender(nn.Module):
    def __init__(self, vision, input_size=400, output_size=400, hidden_sizes=[]):
        super(Sender, self).__init__()

        n_h = len(hidden_sizes)
        if n_h >= 1:
            self.fc = get_mlp(input_size, output_size, hidden_sizes)
        else:
            self.fc = nn.Linear(input_size, output_size)

        self.vision = vision
        
    def forward(self, x, aux_input=None):
        with torch.no_grad():
            x = self.vision(x)
        x = self.fc(x)

        return x

__sender__ = Sender(vision, input_size=400, output_size=400)
__sender__ = Sender(vision, input_size=400, output_size=400, hidden_sizes=[600, 800, 600])
del __sender__

## B. 👂 **Receiver** agent

In [None]:
class Receiver(nn.Module):
    def __init__(
        self, 
        input_size=400, 
        output_size=784, 
        hidden_sizes=[],
    ):
        super(Receiver, self).__init__()

        n_h = len(hidden_sizes)
        if n_h >= 1:
            self.fc = get_mlp(input_size, output_size, hidden_sizes)
        else:
            self.fc = nn.Linear(input_size, output_size)

    def forward(
        self, 
        channel_input, 
        receiver_input=None, 
        aux_input=None,
    ):
        x = self.fc(channel_input)

        return torch.sigmoid(x)

__receiver__ = Receiver(input_size=400, output_size=784)
__receiver__ = Receiver(input_size=400, output_size=784, hidden_sizes=[600, 800, 600])
del __receiver__

---
# 3. 🗺 Defining the game **mechanics** and the **environment** 🗺
---

## A. 🎮 Game wrapper

#### Accuracy loss  (*non-differentiable*) : **Discrimination Game**


In [None]:
def loss_accuracy_discrimination(
    _sender_input, 
    _message, 
    _receiver_input, 
    receiver_output, 
    labels, 
    _aux_input,
):
    """
    Accuracy loss - non-differetiable hence cannot be used with GS
    """
    acc = (labels == receiver_output).float()
    return -acc, {"acc": acc}

#### Negative log-likelihood loss (*differentiable*) : **Discrimination Game**

In [None]:
def loss_nll_discrimination(
    _sender_input, 
    _message, 
    _receiver_input, 
    receiver_output, 
    labels, 
    _aux_input,
):
    """
    NLL loss - differentiable and can be used with both GS and Reinforce
    """
    nll = F.nll_loss(receiver_output, labels, reduction="none")
    acc = (labels == receiver_output.argmax(dim=1)).float().mean()
    return nll, {"acc": acc}

#### BCE loss (*differentiable*) : **Reconstruction Game**

In [None]:
def loss_bce_reconstruction(
    sender_input, 
    _message, 
    _receiver_input, 
    receiver_output, 
    _labels, 
    _aux_input=None,
):
    global NUM_CHANNELS

    # reconstruction_dim = 784

    if True:
        if NUM_CHANNELS == 1:
            reconstruction_dim = int(sender_input.shape[2] ** 2)
            loss = F.binary_cross_entropy(receiver_output, sender_input[:, 0:1, :, :].view(-1, reconstruction_dim), reduction='none').mean(dim=1)
        else:
            reconstruction_dim = int(NUM_CHANNELS * (sender_input.shape[2] ** 2))
            loss = F.binary_cross_entropy(receiver_output, sender_input.view(-1, reconstruction_dim), reduction='none').mean(dim=1)

    # loss = F.binary_cross_entropy(receiver_output, sender_input.view(-1, reconstruction_dim), reduction='none').mean(dim=1)
    
    return loss, {}

## B. 🏅 Communication **channel** and **reward** function

We propose to model realistic communication between two agents over a channel in two ways. First, we propose to model the **confusion between similar phonemes** (*embedded* perspective on language production) through various cost functions, more or less faithful to linguistic models and which we compare through our experiments :


1.   **Uniform** : 
2.   **1-dimensional similarity** : 
3.   **Chromatic similarity** : 
4.   **2-dimensional similarity** : 
5.   **Phoneme-accurate similarity** : 
6.   **Keyboard-accurate similarity** : 
7.   **Tree-hierarchical similarity** : 



Second, we propose to model the stochasticity of the communication channel through random corruption operators transforming the message sent by the **sender** agent.

#### *Sender* & *Receiver*'s **vocabulary** :

In [None]:
law_exp = lambda mu, x: (1.0 / mu) * np.exp(-mu / x)

class Vocabulary():
    """
    """
    def __init__(
        self,
        vocab_model='uniform',
        vocab_size=26,
        vocab_loss_c=1.0,
        vocab_loss_mute=10.0,
        vocab_mu_dist=1.0,
        vocab_temperature=0.5,
    ):
        self.model = vocab_model
        self.temperature = vocab_temperature
        self.size = vocab_size
        # Note : the first symbol of index 0 is used to denote the EOS token. It is not taken into account
        # when computing the cost function across the communication channel

        self.loss_c = vocab_loss_c
        self.loss_mute = vocab_loss_mute
        self.mu_dist = vocab_mu_dist
        
    def __compute_distance__(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # x_in  : bs x max_len
        # x_out : bs x (max_len - 1) --> x_out[i, j] corresponds to distance between x_in[i, j] and x_in[i, j + 1]

        x_dist = torch.zeros_like(x_in)
        if (self.model == 'uniform'):
            x_dist = torch.ones_like(x_in)
        elif (self.model == '1d'):
            for i in range(x_in.shape[0]):
                if x_in[i, 0] == 0:
                    x_dist[i, 0] = self.loss_mute
                for j in range(x_in.shape[1] - 1):
                    if x_in[i, j + 1] == 0:
                        break
                    d = np.abs(x_in[i, j] - x_in[i, j + 1])
                    x_dist[i, j] = law_exp(self.mu_dist, d)
                    
        elif (self.model == 'chromatic'):
            pass

        elif (self.model == '2d'):
            size_x = int(np.sqrt(self.size))
            size_y = self.size // size_x
            for i in range(x_in.shape[0]):
                if x_in[i, 0] == 0:
                    x_dist[i, 0] = self.loss_mute
                for j in range(x_in.shape[1] - 1):
                    if x_in[i, j + 1] == 0:
                        break
                    coor_x = (x_in[i, j] % size_x, x_in[i, j + 1] % size_x)
                    coor_y = (x_in[i, j] // size_x, x_in[i, j + 1] // size_y)

                    d_x_2 = (coor_x[0] - coor_x[1])**2
                    d_y_2 = (coor_y[0] - coor_y[1])**2

                    d = np.sqrt(d_x_2 + d_y_2)

                    x_dist[i, j] = law_exp(self.mu_dist, d)

        elif (self.model == 'phoneme_accurate'):
            pass

        elif (self.model == 'keyboard_accurate'):
            pass

        elif (self.model == 'tree_hierarchical'):
            pass

        return x_dist
        
    def compute_cost(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # x_dist : bs x max_len
        # loss   : bs

        x_dist = self.__compute_distance__(x_in)

        # We sum-reduce along the batch dimension
        sender_loss = torch.einsum('ij->i', x_dist)
        return sender_loss, x_dist

vocabulary = Vocabulary(vocab_model='uniform', vocab_size=10)
__x_in__ = torch.randint(0, 25, (4, 10,))
__loss__, _ = vocabulary.compute_cost(__x_in__)
# print(__loss__)

#### Helper function to visualize the induced language in a more readable format, across an entire batch along with labels :

In [None]:
def make_readable_communication(
        x, 
        string,
        bs,
        max_len,
        lower=True,
    ):
        shift_voc = 97 if lower else 65
        
        for i in range(bs):
            string.append('')
            for j in range(max_len):
                if x[i, j] == 0:
                    string[i] += '.'
                    break
                else:
                    string[i] += chr(x[i, j] + shift_voc - 1)

In [None]:
def visualize_batch_communication(
    x_in, 
    x_out=None, 
    label_in=None, 
    label_out=None, 
    lower=True,
):
    bs, max_len_in = x_in.shape
    string_in, string_out = [], []
    
    make_readable_communication(x_in, string_in, bs, max_len_in, lower)
    if (x_out is not None):
        max_len_out = x_out.shape[1]
        make_readable_communication(x_out, string_out, bs, max_len_out, lower)

    for i in range(bs):
        if (label_in is not None) and (label_out is not None):  
            ground_truth_comparison = f'[I = {label_in[i]} / O = {label_out[i]}]'
        else:
            ground_truth_comparison = ''
        if (x_out is not None):
            received_message = '      -> ‖CHANNEL‖ ->      ⟥' + string_out[i].ljust(max_len_out) + '⟤      -> R      '
        else:
            received_message = ''
        print(f'[{i+1}/{bs}]'.ljust(10) + 'S ->      ' + 
                '⟥' + string_in[i].ljust(max_len_in) + '⟤' + 
                received_message + 
                ground_truth_comparison)

visualize_batch_communication(__x_in__, __x_in__)

#### **Noisy** communication **channel** and **cost function** :

In [None]:
class Channel():
    def __init__(
        self,
        max_len=3,
        channel_top_k=10,
        channel_temperature=0.8,
        vocabulary=Vocabulary(),
        p_random_insertion=0.05,
        p_random_deletion=0.05,
        p_random_permutation=0.05,
        p_random_corruption=0.05,
        corruption_function=['insertion', 'deletion', 'permutatio', 'corruption'],
    ):
        self.max_len = max_len

        # Corruption sampling parameters
        self.top_k = channel_top_k
        self.temperature = channel_temperature
        self.vocabulary = vocabulary

        self.p_random_insertion = p_random_insertion
        self.p_random_deletion = p_random_deletion
        self.p_random_permutation = p_random_permutation
        self.p_random_corruption = p_random_corruption

        self.corruption_function = corruption_function

    def __random_insert__(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # Randomly inserts tokens in the input sentence
        # x_in : bs x max_len

        if self.p_random_insertion == 0.:
            return x_in

        if (self.vocabulary.model == 'uniform'):
            x_out = torch.zeros_like(x_in, device=x_in.device)

            x_0 = torch.zeros((x_in.shape[0], 1)).to(x_in.device)
            
            for b in range(x_in.shape[0]):
                i, j = 0, 0
                while (j < x_in.shape[1]) and (i < x_in.shape[1]):
                    if random.uniform(0, 1) < self.p_random_insertion:
                        total_supplement = x_out.shape[1] - x_in.shape[1]
                        current_supplement = j - i
                        if (current_supplement == total_supplement) or (i + 1 == x_out.shape[1]):
                            x_out = torch.cat((x_out, x_0,), dim=1)
                        x_out[b, i] = random.randint(1, self.vocabulary.size)
                        i += 1
                    else:
                        x_out[b, i] = x_in[b, j]
                        i += 1
                        j += 1

            

        elif (self.vocabulary.model == '1d'):
            pass
        
        elif (self.vocabulary.model == 'chromatic'):
            pass

        elif (self.vocabulary.model == '2d'):
            pass

        elif (self.vocabulary.model == 'phoneme_accurate'):
            pass

        elif (self.vocabulary.model == 'keyboard_accurate'):
            pass

        elif (self.vocabulary.model == 'tree_hierarchical'):
            pass

        return x_out

    def __random_delete__(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # Randomly deletes tokens in the input sentence
        # x_in : bs x max_len

        if self.p_random_deletion == 0.:
            return x_in

        if (self.vocabulary.model == 'uniform'):
            x_out = torch.zeros_like(x_in, device=x_in.device)
            
            for b in range(x_in.shape[0]):
                i, j = 0, 0
                while j < x_in.shape[1]:
                    if x_in[b, i] == 0:
                        break
                    if random.uniform(0, 1) < self.p_random_deletion:
                        j += 1
                    if j < x_in.shape[1] - 1:
                        x_out[b, i] = x_in[b, j]
                    i += 1
                    j += 1

        elif (self.vocabulary.model == '1d'):
            pass

        elif (self.vocabulary.model == 'chromatic'):
            pass

        elif (self.vocabulary.model == '2d'):
            pass

        elif (self.vocabulary.model == 'phoneme_accurate'):
            pass

        elif (self.vocabulary.model == 'keyboard_accurate'):
            pass

        elif (self.vocabulary.model == 'tree_hierarchical'):
            pass

        return x_out

    def __random_permute__(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # Randomly permutes arbitrary pairs of tokens in the input sentence, according to the vocabulary distance function defined
        # in the Vocabulary class instance
        # x_in : bs x max_len

        if self.p_random_permutation == 0.:
            return x_in

        if (self.vocabulary.model == 'uniform'):
            x_out = copy.deepcopy(x_in).to(x_in.device)
            
            for b in range(x_out.shape[0]):
                for i in range(x_out.shape[1]):
                    try:
                        index_eos = (-x_out[b, :]).argmax(1)
                    except:
                        index_eos = x_out.shape[1]

                    if random.uniform(0, 1) < self.p_random_permutation:
                        # SHOULD IT BE -1 or not -1 ????? CHECK THAT !
                        swap_index = random.randint(0, index_eos - 1)
                        # print(x_out[b, i], x_out[b, swap_index])
                        x_out[b, i], x_out[b, swap_index] = x_out[b, swap_index].item(), x_out[b, i].item()
                        # print(x_out[b, i], x_out[b, swap_index])

        elif (self.vocabulary.model == '1d'):
            pass
        
        elif (self.vocabulary.model == 'chromatic'):
            pass

        elif (self.vocabulary.model == '2d'):
            pass

        elif (self.vocabulary.model == 'phoneme_accurate'):
            pass

        elif (self.vocabulary.model == 'keyboard_accurate'):
            pass

        elif (self.vocabulary.model == 'tree_hierarchical'):
            pass

        return x_out

    def __random_corrupt__(
        self,
        x_in: torch.Tensor,
    ) -> torch.Tensor:
        # Randomly replaces input tokens by arbitrary tokens sampled either uniformly from the input vocabulary or according to
        # the vocabulary distance function
        # x_in : bs x max_len

        if self.p_random_corruption == 0.:
            return x_in

        if (self.vocabulary.model == 'uniform'):
            x_out = copy.deepcopy(x_in).to(x_in.device)
            
            for (b, i) in product(range(x_in.shape[0]), range(x_in.shape[1])):
                if random.uniform(0, 1) < self.p_random_corruption:
                    x_out[b, i] = random.randint(1, self.vocabulary.size)

        elif (self.vocabulary.model == '1d'):
            pass
        
        elif (self.vocabulary.model == 'chromatic'):
            pass

        elif (self.vocabulary.model == '2d'):
            pass

        elif (self.vocabulary.model == 'phoneme_accurate'):
            pass

        elif (self.vocabulary.model == 'keyboard_accurate'):
            pass

        elif (self.vocabulary.model == 'tree_hierarchical'):
            pass
        
        return x_out

    def forward(
        self,
        x_in: torch.Tensor,
        order=['insertion', 'deletion', 'permutation', 'corruption'],
    ) -> torch.Tensor:
        methods = {
            'insertion': self.__random_insert__,
            'deletion': self.__random_delete__,
            'permutation': self.__random_permute__,
            'corruption': self.__random_corrupt__,
        }

        # loss, _ = self.vocabulary.compute_cost(x_in)

        for op in order:
            x_in = methods[op](x_in)

        x_0 = torch.zeros((x_in.shape[0], 1)).to(x_in.device)
        x_out = torch.cat((x_in, x_0,), dim=1)

        return x_in #, loss

channel = Channel(
    p_random_insertion=0.05,
    p_random_deletion=0.05,
    p_random_permutation=0.05,
    p_random_corruption=0.05,
)

__x_out_insert__ = channel.__random_insert__(__x_in__)
print("\nVisualizing the effect of channel noisyness wrt. random insertions :")
visualize_batch_communication(__x_in__, __x_out_insert__)

__x_out_delete__ = channel.__random_delete__(__x_in__)
print("\n\nVisualizing the effect of channel noisyness wrt. random deletions :")
visualize_batch_communication(__x_in__, __x_out_delete__)

__x_out_permute__ = channel.__random_permute__(__x_in__)
print("\n\nVisualizing the effect of channel noisyness wrt. random permutations :")
visualize_batch_communication(__x_in__, __x_out_permute__)

__x_out_corrupt__ = channel.__random_corrupt__(__x_in__)
print("\n\nVisualizing the effect of channel noisyness wrt. random token-level corruptions :")
visualize_batch_communication(__x_in__, __x_out_corrupt__)

__x_out_full__ = channel.forward(
    __x_in__,
    order=['insertion', 'deletion', 'permutation', 'corruption'],
)
print("\n\nVisualizing the effect of full channel corruption (insertion -> deletion -> permutation -> corruption in that order) :")
visualize_batch_communication(__x_in__, __x_out_full__)

---
# 4. 🏋 Training routine 🏋
---

## A. 🔦 Defining the hyperparameters

#### **Vision model** related parameters

In [None]:
mp = {
    'model_name': 'resnet-50',
    'dataset': 'mnist', # 'tiny-imagenet', etc...
}

#### **General-purpose** parameters

In [None]:
gp = {
    'random_seed': 42,
    'num_epochs': 5,
}

#### **Sender** parameters

In [None]:
sp = {
    's_hidden_sizes': [400],

    's_hidden_size': 20, 
    's_emb_size': 10,
    's_cell': 'gru', # 'lstm', 'rnn', 'gru'
    's_entropy_coeff': 0.015,
    's_num_layers': 1,
    's_lr': 1e-3, # as a rule of thumb, should be lower than the learning rate of the rp

    's_lr_scheduler_mult': 2,
    's_lr_scheduler_eta_min': 0.01,
}

#### **Vocabulary**-level parameters

In [None]:
vp = {
    'vocab_size': 10,
    'vocab_model': 'uniform',
    'vocab_loss_c': None,
    'vocab_loss_mute': None,
    'vocab_mu_dist': None,
    'vocab_temperature': None,
}

#### **Channel**-level parameters

In [None]:
cp = {
    'max_len': 5,
    'channel_top_k': None,
    'channel_temperature': None,
    'p_random_insertion': 0.05,
    'p_random_deletion': 0.05,
    'p_random_permutation': 0.05,
    'p_random_corruption': 0.05,
    'corruption_function': ['insertion', 'deletion', 'permutation', 'corruption'],
}

#### **Receiver** parameters

In [None]:
rp = {
    'r_hidden_sizes': [400],

    'r_hidden_size': 20, 
    'r_emb_size': 10,
    'r_cell': 'gru', # 'lstm', 'rnn', 'gru'
    'r_entropy_coeff': 0.0,
    'r_num_layers': 1,
    'r_lr': 1e-2,

    'r_lr_scheduler_mult': 3,
    'r_lr_scheduler_eta_min': 0.01,
}

#### **Optimization**-related parameters

In [None]:
op = {
    'n_epochs': 5,
    'grad_norm': 0,
    'temperature_decay': 0.75,
    'temperature_minimum': 0.01,
    'temperature_update_freq': 1,
}

#### **Task**-related parameters

In [None]:
tp = {
    'task_loss': ('reconstruction', loss_bce_reconstruction, 'bce'),
    # ('discrimination', loss_nll_discrimination, 'nll')
    # ('discrimination', loss_accuracy_discrimination, 'accuracy')
    # ('discrimination', loss_bce_reconstruction, 'bce')
}

#### **Callbacks**-related parameters

In [None]:
cap = {
    'checkpoint_freq': 1,
    'max_checkpoints': 5,
    'early_stopping_field': 'acc',
    'early_stopping_threshold': 0.2,
    'all_distances_topsim': True,
}

In [None]:
# We will later push the experiment's hyperparameters to Weights & Biases in order to keep track of them
params = {**mp, **gp, **sp, **vp, **cp, **rp, **op, **tp, **cap}

In [None]:
if params['task_loss'][0] == 'reconstruction':
    WANDB_EXPERIMENT_GROUP = 'reconstruction-game'
elif params['task_loss'][0] == 'discrimination':
    WANDB_EXPERIMENT_GROUP = 'discrimination-game'
else:
    WANDB_EXPERIMENT_GROUP = 'ablation-study'

## B. ⏭ Instantiating the agents

#### 🔍 Utility function to load the best checkpoint of a vision module finetuned in an other notebook

In [None]:
def retrieve_finetuned_vision_model(
    model_name, 
    dataset, 
    add_name=None,
):
    if add_name is None:
        NAME_RUN = f'Vision model pretraining : M[{model_name}], D[{dataset}]'
        NAME_CHECKPOINT = f'model={model_name}_dataset={dataset}'
    else:
        NAME_RUN = f'Vision model pretraining : M[{model_name}], D[{dataset}], D[{add_name}]'
        NAME_CHECKPOINT = f'model={model_name}_dataset={dataset}_cond={model_name}'

    DIR_CHECKPOINT = FINETUNED_MODELS_DIR + NAME_CHECKPOINT + '/'

    print(f'Loading the model {model_name} finetuned on the dataset {dataset} ...')

    def find_best_epoch(
        ckpt_folder,
    ):
        """
        Find the highest epoch in the Test Tube file structure.
        :param ckpt_folder: dir where the checkpoints are being saved.
        :return: Integer of the highest epoch reached by the checkpoints.
        """
        checkpoint_files = os.listdir(ckpt_folder)  # list of strings
        accuracies = [re.search('val_accuracy=(.*).pt', filename) for filename in checkpoint_files]
        checkpoint_files = list(filter(lambda x: x[1] is not None, list(zip(checkpoint_files, accuracies))))
        checkpoint_files = map(lambda x : (x[0], x[1].group(1),), checkpoint_files)
        best_checkpoint_filename, best_accuracy = max(checkpoint_files, key=itemgetter(1))
        print(f'Found the following checkpoint : {best_checkpoint_filename} with the following validation accuracy : {best_accuracy}')

        return best_checkpoint_filename

    best_checkpoint_filepath = find_best_epoch(DIR_CHECKPOINT)
    full_checkpoint_filepath = DIR_CHECKPOINT + best_checkpoint_filepath

    # Retrieving the general vision model architecture (without the finetuned weights)
    __model_backbone__, dim_vision_out = get_vision_module(version=model_name)

    # Loading the finetuned weights of the vision model into the backbone architecture
    if torch.cuda.is_available():
        __model_backbone__.load_state_dict(torch.load(
            full_checkpoint_filepath,
        ))
    else:
        __model_backbone__.load_state_dict(torch.load(
            full_checkpoint_filepath,
            map_location=torch.device('cpu'),
        ))

    # __model_backbone__.fc = nn.Identity()

    print(f'Successfully loaded the model {model_name} finetuned on the dataset {dataset} ...')

    return __model_backbone__, dim_vision_out

#### Instantiating the general architecture of the vision model's optimizer + prediction layer

In [None]:
# model, dim_vision_out = get_vision_module(version=params['model_name'])

# model, dim_vision_out = retrieve_finetuned_vision_model(params['model_name'], params['dataset'], add_name=None)

In [None]:
dim_vision_out = 500

In [None]:
# !ls "/content/drive/My Drive/Projects/nlp_emergent_languages/finetuned_models/"

#### Creating the stochastic image transformation operators (*data augmentation*)

In [None]:
# multi_channel = not params['dataset'] in ['mnist', 'k-mnist', 'q-mnist', 'fashion-mnist']
# transform_train = TransformsAugment(size=sizes[params['dataset']], multi_channel=multi_channel)
# transform_val_test = TransformsAugment(size=sizes[params['dataset']], multi_channel=multi_channel)
# transform_train = T.Compose([T.ToTensor(), T.Lambda(lambda x: x.repeat(3,1,1))])
# transform_val_test = T.Compose([T.ToTensor(), T.Lambda(lambda x: x.repeat(3,1,1))])
transform_train, transform_val_test = T.ToTensor(), T.ToTensor()

#### Loading the dataset on which the model was finetuned (and which will be used in order to train the agents on a reconstruction task)

In [None]:
split = image_datasets[params['dataset']](
    transform_train, 
    transform_val_test,
)
train_loader, val_loader, test_loader = split.values()
NUM_CHANNELS = n_channels[params['dataset']]

#### Creating the **vocabulary model** and the **noisy communication channel**

In [None]:
VOCAB = Vocabulary(
    vocab_model=params['vocab_model'],
    vocab_size=params['vocab_size'],
    vocab_loss_c=params['vocab_loss_c'],
    vocab_loss_mute=params['vocab_loss_mute'],
    vocab_mu_dist=params['vocab_mu_dist'],
    vocab_temperature=params['vocab_temperature'],
)

CHANNEL = Channel(
    max_len=params['max_len'],
    channel_top_k=params['channel_top_k'],
    channel_temperature=params['channel_temperature'],
    vocabulary=VOCAB,
    p_random_insertion=params['p_random_insertion'],
    p_random_deletion=params['p_random_deletion'],
    p_random_permutation=params['p_random_permutation'],
    p_random_corruption=params['p_random_corruption'],
    corruption_function=params['corruption_function'],   
)

#### Instantiating a new channel communication wrapper taking into account our custom modification of the **Vocabulary** and **Channel**

In [None]:
import egg

In [None]:
import math
from collections import defaultdict
from egg.core.baselines import Baseline, MeanBaseline
from typing import Callable
from egg.core.baselines import Baseline, MeanBaseline
from egg.core.interaction import LoggingStrategy
from egg.core.rnn import RnnEncoder
from egg.core.transformer import TransformerDecoder, TransformerEncoder
from egg.core.util import find_lengths

class SenderReceiverRnnReinforce(nn.Module):
    """
    Implements Sender/Receiver game with training done via Reinforce. Both agents are supposed to
    return 3-tuples of (output, log-prob of the output, entropy).
    The game implementation is responsible for handling the end-of-sequence term, so that the optimized loss
    corresponds either to the position of the eos term (assumed to be 0) or the end of sequence.
    Sender and Receiver can be obtained by applying the corresponding wrappers.
    `SenderReceiverRnnReinforce` also applies the mean baseline to the loss function to reduce
    the variance of the gradient estimate.
    >>> class Sender(nn.Module):
    ...     def __init__(self):
    ...         super().__init__()
    ...         self.fc = nn.Linear(3, 10)
    ...     def forward(self, rnn_output, _input=None, _aux_input=None):
    ...         return self.fc(rnn_output)
    >>> sender = Sender()
    >>> sender = RnnSenderReinforce(sender, vocab_size=15, embed_dim=5, hidden_size=10, max_len=10, cell='lstm')
    >>> class Receiver(nn.Module):
    ...     def __init__(self):
    ...         super().__init__()
    ...         self.fc = nn.Linear(5, 3)
    ...     def forward(self, rnn_output, _input=None, _aux_input=None):
    ...         return self.fc(rnn_output)
    >>> receiver = RnnReceiverDeterministic(Receiver(), vocab_size=15, embed_dim=10, hidden_size=5)
    >>> def loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input):
    ...     loss = F.mse_loss(sender_input, receiver_output, reduction='none').mean(dim=1)
    ...     aux = {'aux': torch.ones(sender_input.size(0))}
    ...     return loss, aux
    >>> game = SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0,
    ...                                   length_cost=1e-2)
    >>> input = torch.zeros((5, 3)).normal_()
    >>> optimized_loss, interaction = game(input, labels=None, aux_input=None)
    >>> sorted(list(interaction.aux.keys()))  # returns debug info such as entropies of the agents, message length etc
    ['aux', 'length', 'receiver_entropy', 'sender_entropy']
    >>> interaction.aux['aux'], interaction.aux['aux'].sum()
    (tensor([1., 1., 1., 1., 1.]), tensor(5.))
    """

    def __init__(
        self,
        sender: nn.Module,
        receiver: nn.Module,
        loss: Callable,
        sender_entropy_coeff: float = 0.0,
        receiver_entropy_coeff: float = 0.0,
        length_cost: float = 0.0,
        baseline_type: Baseline = MeanBaseline,
        train_logging_strategy: LoggingStrategy = None,
        test_logging_strategy: LoggingStrategy = None,
    ):
        """
        :param sender: sender agent
        :param receiver: receiver agent
        :param loss:  the optimized loss that accepts
            sender_input: input of Sender
            message: the is sent by Sender
            receiver_input: input of Receiver from the dataset
            receiver_output: output of Receiver
            labels: labels assigned to Sender's input data
          and outputs a tuple of (1) a loss tensor of shape (batch size, 1) (2) the dict with auxiliary information
          of the same shape. The loss will be minimized during training, and the auxiliary information aggregated over
          all batches in the dataset.
        :param sender_entropy_coeff: entropy regularization coeff for sender
        :param receiver_entropy_coeff: entropy regularization coeff for receiver
        :param length_cost: the penalty applied to Sender for each symbol produced
        :param baseline_type: Callable, returns a baseline instance (eg a class specializing core.baselines.Baseline)
        :param train_logging_strategy, test_logging_strategy: specify what parts of interactions to persist for
            later analysis in callbacks
        """
        super(SenderReceiverRnnReinforce, self).__init__()
        self.sender = sender
        self.receiver = receiver
        self.loss = loss

        self.mechanics = CommunicationRnnReinforce(
            sender_entropy_coeff,
            receiver_entropy_coeff,
            length_cost,
            baseline_type,
            train_logging_strategy,
            test_logging_strategy,
        )

    def forward(self, sender_input, labels, receiver_input=None, aux_input=None):
        return self.mechanics(
            self.sender,
            self.receiver,
            self.loss,
            sender_input,
            labels,
            receiver_input,
            aux_input,
        )


class CommunicationRnnReinforce(nn.Module):
    def __init__(
        self,
        sender_entropy_coeff: float,
        receiver_entropy_coeff: float,
        length_cost: float = 0.0,
        baseline_type: Baseline = MeanBaseline,
        train_logging_strategy: LoggingStrategy = None,
        test_logging_strategy: LoggingStrategy = None,
    ):
        """
        :param sender_entropy_coeff: entropy regularization coeff for sender
        :param receiver_entropy_coeff: entropy regularization coeff for receiver
        :param length_cost: the penalty applied to Sender for each symbol produced
        :param baseline_type: Callable, returns a baseline instance (eg a class specializing core.baselines.Baseline)
        :param train_logging_strategy, test_logging_strategy: specify what parts of interactions to persist for
            later analysis in callbacks
        """

        global CHANNEL

        super().__init__()

        self.sender_entropy_coeff = sender_entropy_coeff
        self.receiver_entropy_coeff = receiver_entropy_coeff
        self.length_cost = length_cost

        self.baselines = defaultdict(baseline_type)
        self.train_logging_strategy = (
            LoggingStrategy()
            if train_logging_strategy is None
            else train_logging_strategy
        )
        self.test_logging_strategy = (
            LoggingStrategy()
            if test_logging_strategy is None
            else test_logging_strategy
        )

        self.said = False


    def forward(
        self,
        sender,
        receiver,
        loss,
        sender_input,
        labels,
        receiver_input=None,
        aux_input=None,
    ):
        global CHANNEL

        if not self.said:
            print("==============================================")
            print("PROOF THAT CHANNEL HYPERPARAMETERS ARE UPDATED")
            print(f"P Random Insertion : {CHANNEL.p_random_insertion}")
            print(f"P Random Deletion : {CHANNEL.p_random_deletion}")
            print(f"P Random Permutation : {CHANNEL.p_random_permutation}")
            print(f"P Random Corruption : {CHANNEL.p_random_corruption}")
            print(f"Corruption function : {CHANNEL.corruption_function}")
            print(f"Max Len : {CHANNEL.max_len}")
            print("==============================================")
            self.said = True
        
        message, log_prob_s, entropy_s = sender(sender_input, aux_input)

        # print(message)
        # print(message.shape)
        # print(interaction.message)
        # print(message, message.device, message.shape, message.type())
        new_message = CHANNEL.forward(message).long()
        # print(new_message, new_message.device, new_message.shape, new_message.type())

        message_length = find_lengths(message)
        receiver_output, log_prob_r, entropy_r = receiver(
            message, receiver_input, aux_input, message_length
        )

        loss, aux_info = loss(
            sender_input, message, receiver_input, receiver_output, labels, aux_input
        )

        # the entropy of the outputs of S before and including the eos symbol - as we don't care about what's after
        effective_entropy_s = torch.zeros_like(entropy_r)

        # the log prob of the choices made by S before and including the eos symbol - again, we don't
        # care about the rest
        effective_log_prob_s = torch.zeros_like(log_prob_r)

        for i in range(message.size(1)):
            not_eosed = (i < message_length).float()
            effective_entropy_s += entropy_s[:, i] * not_eosed
            effective_log_prob_s += log_prob_s[:, i] * not_eosed
        effective_entropy_s = effective_entropy_s / message_length.float()

        weighted_entropy = (
            effective_entropy_s.mean() * self.sender_entropy_coeff
            + entropy_r.mean() * self.receiver_entropy_coeff
        )

        log_prob = effective_log_prob_s + log_prob_r

        length_loss = message_length.float() * self.length_cost

        policy_length_loss = (
            (length_loss - self.baselines["length"].predict(length_loss))
            * effective_log_prob_s
        ).mean()
        policy_loss = (
            (loss.detach() - self.baselines["loss"].predict(loss.detach())) * log_prob
        ).mean()

        optimized_loss = policy_length_loss + policy_loss - weighted_entropy
        # if the receiver is deterministic/differentiable, we apply the actual loss
        optimized_loss += loss.mean()

        if self.training:
            self.baselines["loss"].update(loss)
            self.baselines["length"].update(length_loss)

        aux_info["sender_entropy"] = entropy_s.detach()
        aux_info["receiver_entropy"] = entropy_r.detach()
        aux_info["length"] = message_length.float()  # will be averaged

        logging_strategy = (
            self.train_logging_strategy if self.training else self.test_logging_strategy
        )
        interaction = logging_strategy.filtered_interaction(
            sender_input=sender_input,
            labels=labels,
            receiver_input=receiver_input,
            aux_input=aux_input,
            message=message.detach(),
            receiver_output=receiver_output.detach(),
            message_length=message_length,
            aux=aux_info,
        )

        return optimized_loss, interaction

#### Instantiating the agent's modules building up on the vision module defined above

In [None]:
# Sender / receiver models responsible respectively for producing / processing the initial / final RNN hidden state
sender = Sender(
    vision, 
    input_size=dim_vision_out, 
    output_size=params['s_hidden_size'], 
    hidden_sizes=params['s_hidden_sizes'],
)
receiver = Receiver(
    input_size=params['r_hidden_size'], 
    output_size=int(sizes[params['dataset']] ** 2),
    hidden_sizes=params['r_hidden_sizes'],
)

# Wrapping the sender / receiver models into a RNN module in order to handle variable-length messages
sender_rnn = core.RnnSenderReinforce(
    agent=sender, 
    vocab_size=params['vocab_size'], 
    embed_dim=params['s_emb_size'], 
    hidden_size=params['s_hidden_size'], 
    cell=params['s_cell'], 
    max_len=params['max_len'], 
    num_layers=params['s_num_layers'],
)
receiver_rnn = core.RnnReceiverDeterministic(
    agent=receiver, 
    vocab_size=params['vocab_size'], 
    embed_dim=params['r_emb_size'], 
    hidden_size=params['r_hidden_size'], 
    cell=params['r_cell'], 
    num_layers=params['r_num_layers'],
)

# Initializing the game wrapper module
game_rnn = SenderReceiverRnnReinforce(sender_rnn, receiver_rnn, params['task_loss'][1], sender_entropy_coeff=params['s_entropy_coeff'], receiver_entropy_coeff=params['r_entropy_coeff'])

## C. ⚓ Helper functions and callbacks

### Optimization

In [None]:
if False:
    temperature_updater_callback = core.TemperatureUpdater(
        agent=sender, 
        decay=params['temperature_decay'], 
        minimum=params['temperature_minimum'],
        update_frequency=params['temperature_update_freq'],
    )

early_stopping_callback = core.EarlyStopperAccuracy(
    threshold=params['early_stopping_threshold'],
    field_name=params['early_stopping_field'],
    validation=True,
)

### Checkpointing

In [None]:
NAME_CHECKPOINT = f"game={params['task_loss'][0]}_model={params['model_name']}_dataset={params['dataset']}_"
SUB_DIRECTORY = WANDB_EXPERIMENT_GROUP + '/'
CHECKPOINT_DIR = CHECKPOINTS_DIR + SUB_DIRECTORY + NAME_CHECKPOINT + '/'

INTERACTION_DIR = INTERACTIONS_DIR + WANDB_EXPERIMENT_GROUP + '/' + NAME_CHECKPOINT + '/'

In [None]:
checkpoint_saver_callback = core.CheckpointSaver(
    checkpoint_path=CHECKPOINT_DIR,
    checkpoint_freq=params['checkpoint_freq'],
    prefix=NAME_CHECKPOINT,
    max_checkpoints=params['max_checkpoints'],
)

interaction_saver_callback = core.InteractionSaver(
    train_epochs=list(range(1, params['num_epochs'] + 1, 2)),
    test_epochs=list(range(2, params['num_epochs'] + 1, 2)),
    checkpoint_dir=INTERACTION_DIR,
)

### Display and ML experiment management

In [None]:
def wandb_connect():
    wandb_conx = wandb.login(key = WANDB_API_KEY)
    print(f"Connected to Wandb online interface : {wandb_conx}")

LOG_WANDB = True

if LOG_WANDB:
    wandb_connect()

    WANDB_RUN = wandb.init(
        project=WANDB_PROJECT, 
        entity=WANDB_ENTITY,
        notes=WANDB_NOTES,
        name=WANDB_EXPERIMENT_NAME,
        group=WANDB_EXPERIMENT_GROUP,
        save_code=True,
    )

    WANDB_RUN_ID = WANDB_RUN.id

In [None]:
if LOG_WANDB:
    wandb_logger = core.callbacks.WandbLogger(
        opts=params,
        project=WANDB_PROJECT,
        run_id=WANDB_RUN_ID,
    )

    console_logger_callback = core.callbacks.ConsoleLogger(
        as_json=False, 
        print_train_loss=True,
    )

    progress_bar_callback = core.callbacks.ProgressBarLogger(
        n_epochs=params['n_epochs'], 
        use_info_table=False,
    )

### Linguistic analysis metrics

In [None]:
distances = ['edit', 'cosine', 'hamming', 'jaccard', 'euclidean']
if params['all_distances_topsim']:
    distances_topsim = product(distances, distances)
else:
    distances_topsim = [('hamming', 'edit')]

In [None]:
disentanglement_metric_callback = core.Disent(
    is_gumbel=False,
    compute_posdis=True,
    compute_bosdis=True,
    vocab_size=params['vocab_size'],
    print_train=True,
    print_test=True,
)

topographic_similarity_metric_callbacks = [core.TopographicSimilarity(
    sender_input_distance_fn=input_dist,
    message_distance_fn=message_dist,
    compute_topsim_train_set=True,
    compute_topsim_test_set=True,
    is_gumbel=False,) for (input_dist, message_dist) in distances_topsim]

message_entropy_metric_callback = core.MessageEntropy(
    print_train=True,
    is_gumbel=False,
)

### Gathering active callbacks

In [None]:
callbacks=[
           # temperature_updater_callback,
           early_stopping_callback,
           
           checkpoint_saver_callback,
           interaction_saver_callback,
           
           # wandb_logger,
           # console_logger_callback,
           # progress_bar_callback,

           # disentanglement_metric_callback,
           message_entropy_metric_callback,
          ]
          
if LOG_WANDB:
    callbacks.append(wandb_logger)
    callbacks.append(console_logger_callback)
    callbacks.append(progress_bar_callback)
# callbacks += topographic_similarity_metric_callbacks

## D. 🚠 Defining the training routine and optimizers

In [None]:
optimizer = torch.optim.Adam([
        {'params': game_rnn.sender.parameters(), # Sender-side optimization hyperparameters
         'lr': params['s_lr'],
         # 'lr_scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=params['s_lr_scheduler_mult'], eta_min=params['s_lr_scheduler_eta_min'], last_epoch=-1),
         },
        {'params': game_rnn.receiver.parameters(), # Receiver-side optimization hyperparameters
         'lr': params['r_lr'],
         # 'lr_scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=params['r_lr_scheduler_mult'], eta_min=params['r_lr_scheduler_eta_min'], last_epoch=-1),
         },
    ])

trainer = core.Trainer(
    game=game_rnn, 
    optimizer=optimizer, 
    train_data=train_loader, 
    validation_data=val_loader,
    # grad_norm=params['grad_norm'],
    device=device,
    callbacks=callbacks,
)

In [None]:
# ---------------------------
# ---------
# ---
# -
# trainer.train(
#     n_epochs=5,
# )
# core.close()
# -
# ---
# ---------
# ---------------------------

In [None]:
wandb.finish()

## E. 📺 Instantiating a test dataset and visualization functions

In [None]:
def init_vocab_channel(
    noisy=False,
):
    global VOCAB, CHANNEL

    print(f"Initializing the vocabulary model and the communication channel ... noisy=[{noisy}]")

    if noisy:
        VOCAB = Vocabulary(
            vocab_model=params['vocab_model'],
            vocab_size=params['vocab_size'],
            vocab_loss_c=params['vocab_loss_c'],
            vocab_loss_mute=params['vocab_loss_mute'],
            vocab_mu_dist=params['vocab_mu_dist'],
            vocab_temperature=params['vocab_temperature'],
        )

        CHANNEL = Channel(
            max_len=params['max_len'],
            channel_top_k=params['channel_top_k'],
            channel_temperature=params['channel_temperature'],
            vocabulary=VOCAB,
            p_random_insertion=params['p_random_insertion'],
            p_random_deletion=params['p_random_deletion'],
            p_random_permutation=params['p_random_permutation'],
            p_random_corruption=params['p_random_corruption'],
            corruption_function=params['corruption_function'],   
        )

    else:
        VOCAB = Vocabulary(
            vocab_model='uniform',
            vocab_size=params['vocab_size'],
            vocab_loss_c=None,
            vocab_loss_mute=None,
            vocab_mu_dist=None,
            vocab_temperature=None,
        )

        CHANNEL = CHANNEL = Channel(
            max_len=params['max_len'],
            channel_top_k=None,
            channel_temperature=None,
            vocabulary=VOCAB,
            p_random_insertion=0.,
            p_random_deletion=0.,
            p_random_permutation=0.,
            p_random_corruption=0.,
            corruption_function=0.,   
        )

In [None]:
def plot(game, 
         dataloader, 
         name_split='train', 
         noisy=False, 
         ablation=None,
    ):
    global INTERACTIONS_DIR

    init_vocab_channel(noisy)
    
    print(f"Currently generating reconstructed (i.e. autoencoded) images for all class representatives of the following split : [{name_split}]")
    x = next(iter(dataloader))

    if ablation is None:
        filename = INTERACTION_DIR + f'{name_split}_split_noisy_{noisy}_reconstruction'
    else:
        filename = ablation + f'{name_split}_split_noisy_{noisy}_reconstruction'
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    test_inputs = []
    for z in range(10):
        print((x[1] == z).nonzero().shape)
        try:
            index = (x[1] == z).nonzero()[0, 0]
            img = x[0][index]
            test_inputs.append(img.unsqueeze(0))
        except:
            pass

    test_inputs = torch.cat(test_inputs)

    by_class_dataset = [[test_inputs, None]]

    interaction = core.dump_interactions(game, by_class_dataset, False, variable_length=True)

    for z in range(len(test_inputs)):
        src = interaction.sender_input[z].squeeze(0)[:, :]
        print(src.shape)
        dst = interaction.receiver_output[z].view(src.shape[1], src.shape[1])
        # we'll plot two images side-by-side: the original (left) and the reconstruction
        # print(dst.shape)
        image = torch.cat([src[:, :], dst], dim=1).cpu().numpy()

        plt.title(f"Input: digit {z}, channel message {interaction.message[z]}")
        plt.imshow(image[:, :], cmap='gray')
        plt.savefig(filename + '_image_{z}.png')
        plt.show()

In [None]:
# plot(game_rnn, train_loader, name_split='train', noisy=False, ablation=None)
# plot(game_rnn, train_loader, name_split='train', noisy=True, ablation=None)

In [None]:
# plot(game_rnn, val_loader, name_split='val', noisy=False, ablation=None)
# plot(game_rnn, val_loader, name_split='val', noisy=True, ablation=None)

In [None]:
# plot(game_rnn, test_loader, name_split='test', noisy=False, ablation=None)
# plot(game_rnn, test_loader, name_split='test', noisy=True, ablation=None)

In [None]:
def generate_interactions(
    game, 
    dataloader, 
    name_split='train', 
    noisy=False, 
    ablation=None,
    until=None
):
    global INTERACTIONS_DIR

    init_vocab_channel(noisy)

    def dump_to_csv(data, name_data='messages_tokenized'):
        if ablation is None:
            filename = INTERACTION_DIR + f'{name_split}_split_noisy_{noisy}.csv'
        else:
            filename = ablation + f'{name_split}_split_noisy_{noisy}.csv'
        print(filename)
        os.makedirs(os.path.dirname(filename), exist_ok=True) # We first create the required directory to dump the interactions
        # if it does not exist yet
        file = open(filename, 'w+', newline ='')
        with file:    
            write = csv.writer(file)
            write.writerows(data)

    print(f"Currently generating batches of textual interactions for the following split : [{name_split}]")

    __messages__ = []

    if (until is None) or (until > len(dataloader)):
        until = len(dataloader)

    for i, x in list(enumerate(iter(dataloader)))[:until]:
        if i % 10 == 0:
            print(f"Loading batch n°{i}/{len(dataloader)} ...")
        interaction = core.dump_interactions(game, [[x[0], None]], False, variable_length=True)

        __messages__.append(interaction.message)

    print("Successfully generated batches of textual interactions")

    __messages__ = torch.cat(__messages__)[:, :-1]

    messages_tokenized = []
    messages_char = []

    for i in range(__messages__.shape[0]):
        messages_tokenized.append(__messages__[i, :].tolist())

    if params['vocab_size'] <= 26:
        __messages_char_shifted__ = []
        make_readable_communication(__messages__, __messages_char_shifted__, __messages__.shape[0], __messages__.shape[1], lower=True)
        for i, message in enumerate(__messages_char_shifted__):
            messages_char.append(list(message))

        res = {
            'messages_tokenized': messages_tokenized,
            'messages_char': messages_char,
        }
    else:
        print("We are sorry, but the vocabulary size is too big to map the arbitrary characeters to natural characters (for further post-processing)")
        res = {
            'messages_tokenized': messages_tokenized,
        }

    for (k, v) in res.items():
        print(f"Successfully saved {k} to a CSV file ...")
        dump_to_csv(data=v, name_data=k)

    return res

In [None]:
# res = generate_interactions(game_rnn, train_loader, name_split='train', noisy=False, ablation=None)
# res = generate_interactions(game_rnn, train_loader, name_split='train', noisy=True, ablation=None, until=100)

In [None]:
# res = generate_interactions(game_rnn, val_loader, name_split='val', noisy=False, ablation=None)
# res = generate_interactions(game_rnn, val_loader, name_split='val', noisy=True, ablation=None, until=100)

In [None]:
# res = generate_interactions(game_rnn, test_loader, name_split='test', noisy=False, ablation=None)
# res = generate_interactions(game_rnn, test_loader, name_split='test', noisy=True, ablation=None, until=100)

In [None]:
def test2by2(
    game, 
    dataloader, 
    name_split='train', 
    noisy=False, 
    ablation=None,
):
    global INTERACTIONS_DIR

    init_vocab_channel(noisy)

    print(f"Currently generating 2x2 interaction matrices (i.e. testing all possible length-2 messages) for the following split : [{name_split}]")

    if ablation is None:
        filename = INTERACTION_DIR + f'{name_split}_split_noisy_{noisy}_2by2'
    else:
        filename = ablation + f'{name_split}_split_noisy_{noisy}_2by2'
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    if params['vocab_size'] <= 15:
        f, ax = plt.subplots(params['vocab_size'], params['vocab_size'], sharex=True, sharey=True)

        for x in range(params['vocab_size']):
            for y in range(params['vocab_size']):
                    
                t = torch.zeros((1, 2)).to(device).long()
                t[0, 0] = x
                t[0, 1] = y

                with torch.no_grad():
                    sample = game_rnn.receiver(t)[0].float().cpu()
                    sample = sample[0, :].view(28, 28)
                    ax[x][y].imshow(sample, cmap='gray')
                    
                    if y == 0:
                        ax[x][y].set_ylabel(f'x={x}')
                    if x == 0:
                        ax[x][y].set_title(f'y={y}')
                    
                    ax[x][y].set_yticklabels([])
                    ax[x][y].set_xticklabels([])

        plt.show()
        if ablation is not None:
            plt.savefig(filename + '.png')

    else:
        print("We are sorry, but the vocabulary size is too big to perform a matrix study of interactions")

In [None]:
# test2by2(game_rnn, train_loader, name_split='train', noisy=False, ablation=None)
# test2by2(game_rnn, train_loader, name_split='train', noisy=True, ablation=None)

In [None]:
# test2by2(game_rnn, val_loader, name_split='val', noisy=False, ablation=None)
# test2by2(game_rnn, val_loader, name_split='val', noisy=True, ablation=None)

In [None]:
# test2by2(game_rnn, test_loader, name_split='test', noisy=False, ablation=None)
# test2by2(game_rnn, test_loader, name_split='test', noisy=True, ablation=None)

## F. ⛵ Training the agents ! 

In [None]:
# ---------------------------
# ---------
# ---
# -
# trainer.train(
#     n_epochs=5,
# )
# core.close()
# -
# ---
# ---------
# ---------------------------

## G. 🏃 *Optional : running hyperparameter optimization on the game parameters*

#### Note : in order to perform HPO on the communication game studied here, use the NEST module provided along with the EGG framework.
*(Since it is cumbersome to use in Jupyter Notebooks, we prefer in the following section to define our own HPO routine in order to allow execution on Google Colab Cloud GPU resources)*

In [None]:
"""
%write example.json
{
  "vocab_size": [10, 15],
  "n_epoch": [15],
  "random_seed": [0, 1, 2, 3],
  "batch_size": [256]
}
"""

## H. 🕸 Ablation study : Running multiple experiments on **vision models**, **datasets** and **reward functions**

Throughout the ablation study carried out below, we use the following default hyperparameters (in addition to the optimization and architectural hyperparameters whose value by default is defined above) :

1. Default *model* : **ResNet-50** (for the *ResNet* family) and  **EfficientNet-B3** (for the *EfficientNet* family)

2. Default *dataset* : **CIFAR-100** (for the *natural images* family) and **MNIST** (for the *numbers* family)

3. Default *reward function* : **Sequnence length cost**

#### *Helper function in order to carry out the ablation on a given parameter*

In [None]:
#@title Decide whether to test ablation on a single batch or to run the full ablation study { run: "auto", vertical-output: true, form-width: "65%", display-mode: "code" }

option = "Full ablation" #@param ["Single batch", "Full ablation"]

if option == "Single batch":
    TEST_ONE_BATCH = True
    LOG_TO_WANDB = False
else:
    TEST_ONE_BATCH = False
    LOG_TO_WANDB = True

In [None]:
# Use "raise StopExecution" in order to interrupt a cell
class StopExecution(Exception):
    def _render_traceback_(self):
        pass

ablation = {}

WANDB_EXPERIMENT_GROUP = 'ablation-study'

In [None]:
def run_ablation_study(params, parameter='model_name', model=None, add_name=None, **kwargs):
    # We copy the default hyperparameters
    params = copy.deepcopy(params)
    params.update(kwargs)

    values = ablation[parameter]

    print(f"\nPerforming an ablation study on the following parameter : [{parameter}]\n")
    print(f"The following experiments will be carried out :")
    for (i, v) in enumerate(values):
        print(f"- [{i + 1}/{len(values)}] {parameter}='{v}'")

    def single_run(parameter, value, dataset):
        global NUM_CHANNELS, VOCAB, CHANNEL, INTERACTIONS_DIR

        print(params)

        NUM_CHANNELS = n_channels[params['dataset']]

        value = params[parameter]
        model_name = params['model_name']
        dataset = params['dataset']

        """
        if not (parameter in ['model_name', 'dataset']):
            if add_name is None:
                NAME_RUN = f'Ablation study : P[{parameter}], V[{value}], M[{model_name}], D[{dataset}]'
                NAME_CHECKPOINT = f'param={parameter}_value={value}_model={model_name}_dataset={dataset}'
            else:
                NAME_RUN = f'Ablation study : P[{parameter}], V[{value}], M[{model_name}], D[{dataset}], A[{add_name}]'
                NAME_CHECKPOINT = f'param={parameter}_value={value}_model={model_name}_dataset={dataset}_cond={add_name}'
        else:
            if add_name is None:
                NAME_RUN = f'Ablation study : M[{model_name}], D[{dataset}]'
                NAME_CHECKPOINT = f'model={model_name}_dataset={dataset}'
            else:
                NAME_RUN = f'Ablation study : M[{model_name}], D[{dataset}], A[{add_name}]'
                NAME_CHECKPOINT = f'model={model_name}_dataset={dataset}_cond={add_name}'
        """

        if not (parameter in ['model_name', 'dataset']):
            if add_name is None:
                NAME_RUN = f'Ablation study : P[{parameter}], V[{value}]'
                NAME_CHECKPOINT = f'param={parameter}_value={value}'
            else:
                NAME_RUN = f'Ablation study : P[{parameter}], V[{value}], A[{add_name}]'
                NAME_CHECKPOINT = f'param={parameter}_value={value}_cond={add_name}'

        SUB_DIRECTORY = WANDB_EXPERIMENT_GROUP + '/'
        DIR_CHECKPOINT = CHECKPOINTS_DIR + SUB_DIRECTORY + NAME_CHECKPOINT + '/'
        WANDB_EXPERIMENT_NAME = NAME_RUN

        print(f"NAME_RUN={NAME_RUN}")
        print(f"NAME_CHECKPOINT={NAME_CHECKPOINT}")
        print(f"DIR_CHECKPOINT={DIR_CHECKPOINT}")

        # -------------------------------------------------------- #
        # 1. Instantiating the Model using the finetuned version. #
        # ------------------------------------------------------ #
        """
        model, dim_vision_out = retrieve_finetuned_vision_model(
            params['model_name'], 
            params['dataset'], 
            add_name=add_name)
        
        model = model.to(device) # Moving the dataset to the GPU, if it is available
        """

        # ------------------------------- #
        # 2. Instantiating the Datasets. #
        # ----------------------------- #
        # a. Creating the appropriatge data augmentation scheme (image-level transform operators)
        multi_channel = not params['dataset'] in ['mnist', 'k-mnist', 'q-mnist', 'fashion-mnist']
        transform_train = TransformsAugment(size=sizes[params['dataset']], multi_channel=multi_channel)
        transform_val_test = TransformsAugment(size=sizes[params['dataset']], multi_channel=multi_channel)

        # b. Loading the right dataset split
        """
        split = image_datasets[params['dataset']](
            transform_train, 
            transform_val_test,
        )
        train_loader, val_loader, test_loader = split.values()
        """

        # ----------------------------- #
        # 3. Instantiating the Agents. #
        # --------------------------- #
        # Sender / receiver models responsible respectively for producing / processing the initial / final RNN hidden state
        sender = Sender(
            vision, 
            input_size=dim_vision_out, 
            output_size=params['s_hidden_size'], 
            hidden_sizes=params['s_hidden_sizes'],
        )
        receiver = Receiver(
            input_size=params['r_hidden_size'],
            output_size=int(sizes[params['dataset']] ** 2),
            hidden_sizes=params['s_hidden_sizes'],
        )

        # Wrapping the sender / receiver models into a RNN module in order to handle variable-length messages
        sender_rnn = core.RnnSenderReinforce(
            agent=sender, 
            vocab_size=params['vocab_size'], 
            embed_dim=params['s_emb_size'], 
            hidden_size=params['s_hidden_size'], 
            cell=params['s_cell'], 
            max_len=params['max_len'], 
            num_layers=params['s_num_layers'],
        )
        receiver_rnn = core.RnnReceiverDeterministic(
            agent=receiver, 
            vocab_size=params['vocab_size'], 
            embed_dim=params['r_emb_size'], 
            hidden_size=params['r_hidden_size'], 
            cell=params['r_cell'], 
            num_layers=params['r_num_layers'],
        )

        # --------------------------------- #
        # 4. Instantiating the Vocabulary. #
        # ------------------------------- #
        VOCAB = Vocabulary(
            vocab_model=params['vocab_model'],
            vocab_size=params['vocab_size'],
            vocab_loss_c=params['vocab_loss_c'],
            vocab_loss_mute=params['vocab_loss_mute'],
            vocab_mu_dist=params['vocab_mu_dist'],
            vocab_temperature=params['vocab_temperature'],
        )



        # -------------------------------------------------- #
        # 5. Instantiating the Noisy Communication Channel. #
        # ------------------------------------------------ #
        CHANNEL = Channel(
            max_len=params['max_len'],
            channel_top_k=params['channel_top_k'],
            channel_temperature=params['channel_temperature'],
            vocabulary=VOCAB,
            p_random_insertion=params['p_random_insertion'],
            p_random_deletion=params['p_random_deletion'],
            p_random_permutation=params['p_random_permutation'],
            p_random_corruption=params['p_random_corruption'],
            corruption_function=params['corruption_function'],   
        )

        # --------------------------------------- #
        # 6. Instantiating the Game Environment. #
        # ------------------------------------- #
        game_rnn = SenderReceiverRnnReinforce(
            sender_rnn, 
            receiver_rnn, 
            loss=params['task_loss'][1], 
            sender_entropy_coeff=params['s_entropy_coeff'], 
            receiver_entropy_coeff=params['r_entropy_coeff'],
        )

        # ----------------------------- #
        # 7. Instantiating a WandB run #
        # --------------------------- #
        WANDB_RUN = wandb.init(
            project=WANDB_PROJECT, 
            entity=WANDB_ENTITY,
            notes=WANDB_NOTES,
            name=WANDB_EXPERIMENT_NAME,
            group=WANDB_EXPERIMENT_GROUP,
            save_code=True,
        )
        WANDB_RUN_ID = WANDB_RUN.id

        # -------------------------------- #
        # 7. Instantiating the Callbacks. #
        # ------------------------------ #
        checkpoint_saver_callback = core.CheckpointSaver(
            checkpoint_path=DIR_CHECKPOINT,
            checkpoint_freq=params['checkpoint_freq'],
            prefix=NAME_CHECKPOINT,
            max_checkpoints=params['max_checkpoints'],
        )
        interaction_saver_callback = core.InteractionSaver(
            train_epochs=list(range(1, params['num_epochs'] + 1, 2)),
            test_epochs=list(range(2, params['num_epochs'] + 1, 2)),
            checkpoint_dir=INTERACTION_DIR,
        )
        callbacks=[
           # early_stopping_callback, # Optimization-related callbakcs
           
           checkpoint_saver_callback, # Checkpointing callbacks
           interaction_saver_callback,
           
           wandb_logger, # ML experiment logging and management callbacks
           console_logger_callback,
           progress_bar_callback,

           # disentanglement_metric_callback, # Linguistic analysis of the interactions
           message_entropy_metric_callback,
        ]
          
        # callbacks += topographic_similarity_metric_callbacks

        # -------------------------------- #
        # 8. Instantiating the Optimizer. #
        # ------------------------------ #
        optimizer = torch.optim.Adam([
            {'params': game_rnn.sender.parameters(), # Sender-side optimization hyperparameters
             'lr': params['s_lr'],
            },
            {'params': game_rnn.receiver.parameters(), # Receiver-side optimization hyperparameters
             'lr': params['r_lr'],
            },
        ])

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer, 
            T_0=1, 
            T_mult=params['s_lr_scheduler_mult'], 
            eta_min=params['s_lr_scheduler_eta_min'], 
            last_epoch=-1,
        ),

        # ------------------------------ #
        # 9. Instantiating the Trainer. #
        # ---------------------------- #
        trainer = core.Trainer(
            game=game_rnn, 
            optimizer=optimizer, # optimizer 
            train_data=train_loader, 
            validation_data=val_loader,
            # grad_norm=params['grad_norm'],
            device=device,
            callbacks=callbacks,
        )

        # ----------------------------------------------------- #
        # 10. Instantiating the Predictions logging functions. #
        # --------------------------------------------------- #

        # --------------------------- #
        # 11. Training the agents !. #
        # ------------------------- #
        trainer.train(
            n_epochs=params['num_epochs'],
        )
        core.close()

        # ---------------------------------------------- #
        # 12. Post-hoc linguistic analysis subroutines. #
        # -------------------------------------------- #

        ablation = INTERACTIONS_DIR + 'ablation-study' + '/' + WANDB_EXPERIMENT_NAME + '/'

        print(ablation)



        # A. Evaluating the final performance of the communicating agents on the training set
        init_vocab_channel(noisy=True)
        train_loss_noisy, _ = trainer.eval(data=train_loader)
        print(f"Train loss noisy : {train_loss_noisy}")
        init_vocab_channel(noisy=False)
        train_loss_clean, _ = trainer.eval(data=train_loader)
        print(f"Train loss clean : {train_loss_clean}")

        wandb.log({
            'train_loss_noisy': train_loss_noisy,
            'train_loss_clean': train_loss_clean,
        })

        plot(game_rnn, train_loader, name_split='train', noisy=False, ablation=ablation)
        plot(game_rnn, train_loader, name_split='train', noisy=True, ablation=ablation)  

        res = generate_interactions(game_rnn, train_loader, name_split='train', noisy=False, ablation=ablation)
        res = generate_interactions(game_rnn, train_loader, name_split='train', noisy=True, ablation=ablation, until=200)  

        test2by2(game_rnn, train_loader, name_split='train', noisy=False, ablation=ablation)
        test2by2(game_rnn, train_loader, name_split='train', noisy=True, ablation=ablation)



        # B. Evaluating the final performance of the communicating agents on the validation set
        init_vocab_channel(noisy=True)
        val_loss_noisy, _ = trainer.eval(data=val_loader)
        print(f"Validation loss noisy : {val_loss_noisy}")
        init_vocab_channel(noisy=False)
        val_loss_clean, _ = trainer.eval(data=val_loader)
        print(f"Validation loss clean : {val_loss_clean}")

        wandb.log({
            'val_loss_noisy': val_loss_noisy,
            'val_loss_clean': val_loss_clean,
        })

        plot(game_rnn, val_loader, name_split='val', noisy=False, ablation=ablation)
        plot(game_rnn, val_loader, name_split='val', noisy=True, ablation=ablation) 

        res = generate_interactions(game_rnn, val_loader, name_split='val', noisy=False, ablation=ablation)
        res = generate_interactions(game_rnn, val_loader, name_split='val', noisy=True, ablation=ablation, until=200)

        test2by2(game_rnn, val_loader, name_split='val', noisy=False, ablation=ablation)
        test2by2(game_rnn, val_loader, name_split='val', noisy=True, ablation=ablation)



        # C. Evaluating the final performance of the communicating agents on the test set
        init_vocab_channel(noisy=True)
        test_loss_noisy, _ = trainer.eval(data=test_loader)
        print(f"Test loss noisy : {test_loss_noisy}")
        init_vocab_channel(noisy=False)
        test_loss_clean, _ = trainer.eval(data=test_loader)
        print(f"Test loss clean : {test_loss_clean}")

        wandb.log({
            'test_loss_noisy': test_loss_noisy,
            'test_loss_clean': test_loss_clean,
        })

        plot(game_rnn, test_loader, name_split='test', noisy=False, ablation=ablation)
        plot(game_rnn, test_loader, name_split='test', noisy=True, ablation=ablation)

        res = generate_interactions(game_rnn, test_loader, name_split='test', noisy=False, ablation=ablation)
        res = generate_interactions(game_rnn, test_loader, name_split='test', noisy=True, ablation=ablation, until=200)

        test2by2(game_rnn, test_loader, name_split='test', noisy=False, ablation=ablation)
        test2by2(game_rnn, test_loader, name_split='test', noisy=True, ablation=ablation)

    for (i, v) in enumerate(values):
        print('\n\n# ' + '-' * 70 + ' #')
        print(f"  Running ablation experiment [{i + 1}/{len(values)}]  : {parameter}='{v}'")
        print('# ' + '-' * 70 + ' #')

        params[parameter] = v

        t_start = time.time()
        single_run(parameter=parameter, value=v, dataset='mnist')
        t_end = time.time()
        print(f"Time elapsed : {(t_end - t_start):.2f}s !\n\n") 


params['num_epochs'] = 1
s_hidden_sizes = [
                [400],
                [400, 500],
                [400, 600, 500],
                [400, 600, 800, 500],
                [400, 800],
]
ablation['s_hidden_sizes'] = s_hidden_sizes
run_ablation_study(params, parameter='s_hidden_sizes')


In [None]:
!ls "drive/My Drive/Projects/nlp_emergent_languages/interactions/ablation-study/"

In [None]:
# INTERACTIONS_DIR

In [None]:
# !ls "drive/My Drive/Projects/nlp_emergent_languages/interactions/ablation-study/Ablation study : P[s_hidden_sizes], V[[400]]"

#### **Vision models**

💿 `param = model_name`

In [None]:
ablation['model_name'] = list(resnets.keys()) + list(efficientnets.keys())

run_ablation_study(params, parameter='model_name')

#### **Datasets**

💿 `param = dataset`

In [None]:
ablation['dataset'] = image_datasets.keys()

run_ablation_study(params, parameter='dataset', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='dataset', model='efficientnet-b3')

#### **Random seed**

💿 `param = random_seed`

In [None]:
get_seed = lambda s: random.seed(a=s)

ablation['random_seed'] = map(get_seed, [
                                         'emergent',
                                         'languages',
                                         'are',
                                         'very',
                                         'cool',
                                         ',',
                                         'yes',
                                         'indeed',
                                         '!'
])

run_ablation_study(params, parameter='random_seed', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='random_seed', model='efficientnet-b3')

#### *Stochastic grid search* : **Sender** & **Receiver**-level architectural parameters

In [None]:
s_hidden_sizes = [
                [400],
                [400, 500],
                [400, 600, 500],
                [400, 600, 800, 500],
                [400, 800],
]
r_hidden_sizes = s_hidden_sizes

s_hidden_size = [10, 20, 40, 80]
r_hidden_size = s_hidden_size

s_emb_size = [5, 10, 20, 40]
r_emb_size = s_emb_size

s_cell = ['lstm', 'rnn', 'gru']
r_cell = s_cell

s_num_layers = [1, 2, 3]
r_num_layers = s_num_layers

ablation['s_hidden_sizes'] = s_hidden_sizes
ablation['r_hidden_sizes'] = r_hidden_sizes
ablation['s_hidden_size'] = s_hidden_size
ablation['r_hidden_size'] = r_hidden_size
ablation['s_emb_size'] = s_emb_size
ablation['r_emb_size'] = r_emb_size
ablation['s_cell'] = s_cell
ablation['r_cell'] = r_cell
ablation['s_num_layers'] = s_num_layers
ablation['r_num_layers'] = r_num_layers

💿 `param = hidden_sizes`

In [None]:
params['n_epochs'] = 10
run_ablation_study(params, parameter='s_hidden_sizes', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_hidden_sizes', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_hidden_sizes', model='efficientnet-b3')

In [None]:
run_ablation_study(params, parameter='r_hidden_sizes', model='efficientnet-b3')

💿 `param = hidden_size`

In [None]:
run_ablation_study(params, parameter='s_hidden_size', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_hidden_size', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_hidden_size', model='efficientnet-b3')

In [None]:
run_ablation_study(params, parameter='r_hidden_size', model='efficientnet-b3')

💿 `param = emb_size`

In [None]:
run_ablation_study(params, parameter='s_emb_size', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_emb_size', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_emb_size', model='efficientnet-b3')

In [None]:
run_ablation_study(params, parameter='r_emb_size', model='efficientnet-b3')

💿 `param = cell`

In [None]:
run_ablation_study(params, parameter='s_cell', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_cell', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_cell', model='efficientnet-b3')

In [None]:
run_ablation_study(params, parameter='r_cell', model='efficientnet-b3')

💿 `param = num_layers`

In [None]:
run_ablation_study(params, parameter='s_num_layers', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_num_layers', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_num_layers', model='efficientnet-b3')

In [None]:
run_ablation_study(params, parameter='r_num_layers', model='efficientnet-b3')

In [None]:
# Sender-specific parameters
s_entropy_coeff = [0., 0.0015, 0.015, 0.15]
s_lr = [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2]
ablation['s_entropy_coeff'] = s_entropy_coeff
ablation['s_lr'] = s_lr

In [None]:
# Receiver-specific parameters
r_entropy_coeff = [0., 0.0015, 0.015, 0.15]
s_lr = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1]
ablation['r_entropy_coeff'] = r_entropy_coeff
ablation['r_lr'] = r_lr

💿 `param = s_lr`

In [None]:
run_ablation_study(params, parameter='s_lr', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_lr', model='efficientnet-b3')

💿 `param = r_lr`

In [None]:
run_ablation_study(params, parameter='r_lr', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_lr', model='efficientnet-b3')

💿 `param = s_entropy_coeff`

In [None]:
run_ablation_study(params, parameter='s_entropy_coeff', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='s_entropy_coeff', model='efficientnet-b3')

💿 `param = r_entropy_coeff`

In [None]:
run_ablation_study(params, parameter='r_entropy_coeff', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='r_entropy_coeff', model='efficientnet-b3')

#### **Reward functions** (*i.e.* communication channel **stochasticity levels** & token-level **similarity mappings**)

Here, we investigate the extent to which a specific ordering of corruption functions applied to our communication channel's transfer function impact the learning performance of our agents.

💿 `param = corruption_function`

In [None]:
params['p_random_insertion'] = 0.05
params['p_random_deletion'] = 0.05
params['p_random_permutation'] = 0.05
params['p_random_coruption'] = 0.05
ablation['corruption_function'] = [
    ['corruption', 'permutation', 'deletion', 'insertion'],
    ['corruption', 'permutation', 'deletion'],
    ['corruption', 'permutation'],
    ['insertion', 'deletion', 'permutation', 'corruption'],
    ['insertion', 'deletion', 'permutation'],
    ['insertion', 'deletion'],
    ['insertion'],
    ['deletion'],
    ['permutation'],
    []
]

run_ablation_study(params, parameter='corruption_function', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='corruption_function', model='efficientnet-b3')

#### **Vocabulary**-level parameters

💿 `param = vocab_size`

In [None]:
ablation['vocab_size'] = [5, 10, 20, 40, 80, 160, 320, 640, 1_280, 2_560, 5_120, 10_240, 20_480, 40_960]

run_ablation_study(params, parameter='vocab_size', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='vocab_size', model='efficientnet-b3')

💿 `param = vocab_model`

In [None]:
ablation['vocab_model'] = ['uniform',
                           '1d',
                           'chromatic',
                           '2d',
                           'phoneme_accurate',
                           'keyboard_accurate',
                           'tree_hierarchical']

run_ablation_study(params, parameter='vocab_model', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='vocab_model', model='efficientnet-b3')

💿 `param = loss_c`

In [None]:
ablation['loss_c'] = [1]

run_ablation_study(params, parameter='loss_c', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='loss_c', model='efficientnet-b3')

💿 `param = loss_mute`

In [None]:
ablation['loss_mute'] = [1]

run_ablation_study(params, parameter='loss_mute', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='loss_mute', model='efficientnet-b3')

💿 `param = mu_dist`

In [None]:
ablation['mu_dist'] = [1]

run_ablation_study(params, parameter='mu_dist', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='mu_dist', model='efficientnet-b3')

💿 `param = vocab_temperature`

In [None]:
ablation['vocab_temperature'] = [0.1, 0.2, 0.5, 1, 2, 3, 5]

run_ablation_study(params, parameter='vocab_temperature', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='vocab_temperature', model='efficientnet-b3')

#### **Channel**-level parameters

💿 `param = max_len`

In [None]:
ablation['max_len'] = [1, 2, 3, 4, 5, 7, 9, 12, 15, 20, 30, 50, 100]

run_ablation_study(params, parameter='max_len', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='max_len', model='efficientnet-b3')

💿 `param = top_k`

In [None]:
ablation['top_k'] = [3, 5, 10, 15, 20]

run_ablation_study(params, parameter='top_k', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='top_k', model='efficientnet-b3')

💿 `param = channel_temperature`

In [None]:
ablation['channel_temperature'] = [0.1, 0.2, 0.5, 1, 2, 3, 5]

run_ablation_study(params, parameter='channel_temperature', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='channel_temperature', model='efficientnet-b3')

💿 `param = p_random_insertion`

In [None]:
params['num_epohs'] = 1

params['p_random_insertion'] = 0.05
params['p_random_deletion'] = 0.05
params['p_random_permutation'] = 0.05
params['p_random_coruption'] = 0.05

ablation['p_random_insertion'] = [0, 0.001, 0.002, 0.006, 0.01, 0.02, 0.06, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]

run_ablation_study(params, parameter='p_random_insertion', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='p_random_insertion', model='efficientnet-b3')

💿 `param = p_random_deletion`

In [None]:
params['num_epohs'] = 1

params['p_random_insertion'] = 0.05
params['p_random_deletion'] = 0.05
params['p_random_permutation'] = 0.05
params['p_random_coruption'] = 0.05

ablation['p_random_deletion'] = [0, 0.001, 0.002, 0.006, 0.01, 0.02, 0.06, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]

run_ablation_study(params, parameter='p_random_deletion', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='p_random_deletion', model='efficientnet-b3')

💿 `param = p_random_permutation`

In [None]:
params['num_epohs'] = 1

params['p_random_insertion'] = 0.05
params['p_random_deletion'] = 0.05
params['p_random_permutation'] = 0.05
params['p_random_coruption'] = 0.05

ablation['p_random_permutation'] = [0, 0.001, 0.002, 0.006, 0.01, 0.02, 0.06, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4] # 

run_ablation_study(params, parameter='p_random_permutation', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='p_random_permutation', model='efficientnet-b3')

💿 `param = p_random_corruption`

In [None]:
params['num_epohs'] = 1

params['p_random_insertion'] = 0.05
params['p_random_deletion'] = 0.05
params['p_random_permutation'] = 0.05
params['p_random_coruption'] = 0.05

ablation['p_random_corruption'] = [0, 0.001, 0.002, 0.006, 0.01, 0.02, 0.06, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4] # 

run_ablation_study(params, parameter='p_random_corruption', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='p_random_corruption', model='efficientnet-b3')

#### **Optimization**-level parameters

💿 `param = grad_norm`

In [None]:
ablation['grad_norm'] = [0.1, 0.5, 1, 2, 4]

run_ablation_study(params, parameter='grad_norm', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='grad_norm', model='efficientnet-b3')

#### **Task**-level parameters

💿 `param = task_loss`

In [None]:
ablation['task_loss'] = [
                         ('reconstruction', loss_bce_reconstruction, 'bce'),
                         ('discrimination', loss_nll_discrimination, 'nll'),
                         ('discrimination', loss_accuracy_discrimination, 'accuracy'),
]

run_ablation_study(params, parameter='task_loss', model='resnet-50')

In [None]:
run_ablation_study(params, parameter='task_loss', model='efficientnet-b3')

## I. 🚪 Closing the experiment and freeing the models loaded in the memory

In [None]:
# - ⚠️ - # - ⚠️ - # - ⚠️ -
# Important in order to prevent the 'WandB backend process has shutdown error' !
wandb.finish()
# - ⚠️ - # - ⚠️ - # - ⚠️ -

---
# 5. 🔍 **Evaluating** and investigating the induced language 🔎
---

### A. 🤙🏼 Helper function to visualize interactions for any input model

---
# 6. 🚧 To-do list and future work 🚧
---



1.   Resolve the issue with **FakeData** exceeding GPU memory
2.   Resolve the issue with the shape mismatch on **Caltech-101 and -256**
3.   Download **Places365**, **iNaturalist** and **ImageNet** on dedicated a Google GCP engine
4.   Optimize the finetuning pipeline on "Finetuning Vision Models" : LR Scheduler, higher batch size, etc ... to improve classification performance
5.   Resolve the F1_score=NaN problem