**Notebook permanent link** : 
If you wish to access this notebook in the future, please ue this link : https://colab.research.google.com/drive/12m-SOfFKWQzys5h-z6p8caJ8-6Esgu1e?usp=sharing

*Note* : Run the following code in order to prevent you from being connected from Google Colab's VM :
```
function KeepClicking(){
console.log("Clicking");
document.querySelector("colab-connect-button").click()
}
setInterval(KeepClicking,60000)
```



## 🧱 A. Installing useful dependencies

In [None]:
# EGG : Emergence of lanGuage in Games environment (see https://github.com/facebookresearch/EGG for more information)
!pip install --quiet git+https://github.com/facebookresearch/EGG.git
!pip install --quiet torchvision
!pip install --quiet wandb
!pip install --quiet pytorch_lightning
!pip install --quiet h5py
!pip install --quiet pytorch-ignite
!pip install --quiet tensorboardX
!pip install --quiet opendatasets
!pip install --quiet efficientnet-pytorch
!pip install --quiet timm

## 🛣 B. Defining useful hyperparameters used throughout the notebook

In [None]:
# Please adapt the directories below to your own file arborescence
# Due to storage limits imposed by Torchvision, you should download ImageNet, Tiny-Imagenet, Places365 and 
# iNaturalist independently
PROJECT_DIR = "drive/My Drive/Projects/nlp_emergent_languages/"
CHECKPOINTS_DIR = PROJECT_DIR + 'checkpoints/'
INTERACTIONS_DIR = PROJECT_DIR + 'interactions/'
DATASETS_DIR = PROJECT_DIR + 'datasets/'
PRETRAINED_MODELS_DIR = '/content/' + PROJECT_DIR + 'pretrained_models/'
FINETUNED_MODELS_DIR = '/content/' + PROJECT_DIR + 'finetuned_models/'

DATASET_IMAGENET_DIR = 'imagenet/imagenet/'
DATASET_TINY_IMAGENET_DIR = 'tiny-imagenet/tiny-imagenet-200/'

In [None]:
WANDB_API_KEY = "ENTER_YOUR_OWN_API_KEY_HERE"
WANDB_PROJECT = "ENTER_YOUR_OWN_WANDB_PROJECT_HERE"
WANDB_ENTITY = "ENTER_YOUR_OWN_WANDB_PROJECT_GROUP_HERE"
WANDB_NOTES = "We assess the robustness and generalization / compositionality capabilities of emergent languages \
in a two-agent signaling game under channel noisyness constriants"
WANDB_EXPERIMENT_GROUP="vision-model-pretraining"  # 'reconstruction-game', 'discrimination-game', 'vision-model-pretraining', 'ablation-study'

## 📖 C. Importing useful libraries

In [None]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime as dt

import torch
from torch import optim, nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision.utils import make_grid
from torchvision import transforms as T
import torchvision as tv
from efficientnet_pytorch import EfficientNet

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, Precision, Recall, ClassificationReport
from ignite.handlers import LRScheduler, ModelCheckpoint, global_step_from_engine
from ignite.contrib.handlers import ProgressBar, TensorboardLogger
from torchvision import datasets
from torchvision import transforms as T
import ignite.contrib.engines.common as common
from ignite.contrib.handlers.wandb_logger import *

import opendatasets as od
import os
from random import randint
import urllib
import zipfile
import wandb
import random
from tqdm import tqdm
from PIL import ImageFilter
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Random **seed initialization** (for *reproducibility*) :

In [None]:
hashed_sentence = 'emergent languages are very cool, yes indeed !'
get_seed = lambda s: hash(s) % (2**32 - 1)
SEED = get_seed(hashed_sentence)

# Setting the random seeds of Numpy, PyTorch and Random
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

# Initialize the following parameters in the dataloader :
worker_init_fn = seed_worker
generator = g

# Avoid the following insofar as possible for efficiency purposes :
torch.use_deterministic_algorithms(False)

#### Connecting to WandB :

In [None]:
def wandb_connect():
    wandb_conx = wandb.login(key = WANDB_API_KEY)
    print(f"Connected to Wandb online interface : {wandb_conx}")

wandb_connect()

[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Connected to Wandb online interface : True


## D. Legacy code (*for debugging purposes only, please ignore otherwise*):

In [None]:
"""
preprocess_transform_pretrain = T.Compose([
                T.Resize(256), # Resize images to 256 x 256
                T.CenterCrop(224), # Center crop image
                T.RandomHorizontalFlip(),
                T.ToTensor(),  # Converting cropped images to tensors
                T.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
])

model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=200)

# Move model to designated device (Use GPU when on Colab)
model = model.to(device)

# Define hyperparameters and settings
lr = 0.001  # Learning rate
num_epochs = 3  # Number of epochs
log_interval = 300  # Number of iterations before logging

# Set loss function (categorical Cross Entropy Loss)
loss_func = nn.CrossEntropyLoss()

# Set optimizer (using Adam as default)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Setup pytorch-ignite trainer engine
trainer = create_supervised_trainer(model, optimizer, loss_func, device=device)

# Add progress bar to monitor model training
ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {"Batch Loss": x})

# Define evaluation metrics

precision = Precision(average=False)
recall = Recall(average=False)

metrics = {
    "accuracy": Accuracy(), 
    "loss": Loss(loss_func),
    "precision": precision,
    "recall": recall,
    'F1': (precision * recall * 2 / (precision + recall)).mean(),
}

classification_report = ClassificationReport()
classification_report.attach(trainer, "classification_report")
# res = engine.state.metrics["classification_report"]

# Evaluator for training data
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

# Evaluator for validation data
evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

# Display message to indicate start of training
@trainer.on(Events.STARTED)
def start_message():
    print("Begin training")

# Log results from every batch
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_batch(trainer):
    batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
    print(f"Epoch {trainer.state.epoch} / {num_epochs}, "
          f"Batch {batch} / {trainer.state.epoch_length}: "
          f"Loss: {trainer.state.output:.3f}")

# Evaluate and print training set metrics
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(trainer):
    print(f"Epoch [{trainer.state.epoch}] - Loss: {trainer.state.output:.2f}")
    train_evaluator.run(train_loader_pretrain)
    epoch = trainer.state.epoch
    metrics = train_evaluator.state.metrics
    print(f"Train - Loss: {metrics['loss']:.3f}, "
          f"Accuracy: {metrics['accuracy']:.3f} ")

# Evaluate and print validation set metrics
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_loss(trainer):
    evaluator.run(val_loader_pretrain)
    epoch = trainer.state.epoch
    metrics = evaluator.state.metrics
    print(f"Validation - Loss: {metrics['loss']:.3f}, "
          f"Accuracy: {metrics['accuracy']:.3f}")

# Sets up checkpoint handler to save best n model(s) based on validation accuracy metric
common.save_best_model_by_val_score(
          output_path="best_models",
          evaluator=evaluator, model=model,
          metric_name="accuracy", n_saved=1,
          trainer=trainer, tag="val")

train_loader, val_loader, test_loader = split.values()

def run(train_batch_size, val_batch_size, epochs, lr, log_interval):
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),
        desc=desc.format(0)
    )
    #WandBlogger Object Creation
    wandb_logger = WandBLogger(
        project=WANDB_PROJECT,
        name="cnn-mnist",
        config={"max_epochs": epochs, "batch_size":train_batch_size},
        tags=["pytorch-ignite", "minst"]
    )

    wandb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED,
        tag="training",
        output_transform=lambda loss: {"loss": loss}
    )

    wandb_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="training",
        metric_names=["nll", "accuracy"],
        global_step_transform=lambda *_: trainer.state.iteration,
    )

    wandb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_STARTED,
        optimizer=optimizer,
        param_name='lr'  # optional
    )

    wandb_logger.watch(model)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll)
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    trainer.run(train_loader, max_epochs=num_epochs)
    
    pbar.close()

next(model.parameters()).is_cuda

# Start training
run(8, 8, num_epochs, lr, log_interval)

print(evaluator.state.metrics)

# 1. How to get the number of classes :
len(image_datasets['train'].classes)

# 2. How to get the typical sizes of the images :
for b in train_loader:
    print(b[0].size())
"""

## 👟 E. Parallel runs (training various models on various datasets)

#### Utility functions for **datasets handling** :

In [None]:
def generate_dataloaders_from_local(path, num_workers=2, pin_memory=True, val_test_ratio=(0.75, 0.25)):
    kwargs = {'num_workers': num_workers, # In order to parallelize dataset loading : interesting :) 
            'pin_memory': pin_memory} if torch.cuda.is_available() else {}

    batch_size = params['batch_size']

    train_dataset = datasets.ImageFolder(root=path + 'train/', transform=transform_train)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, 
                                            shuffle=True, 
                                            worker_init_fn=worker_init_fn,
                                            generator=generator,
                                            **kwargs)

    __test_dataset__ = datasets.ImageFolder(path + 'val/', transform=transform_val_test)

    len_val = int(val_test_ratio[0] * len(__test_dataset__))
    len_test = len(__test_dataset__) - len_val
    print("Number of validation samples : {}".format(len_val))
    print("Number of test samples : {}".format(len_test))

    val_dataset, test_dataset = torch.utils.data.random_split(__test_dataset__, [len_val, len_test])

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size, 
                                             shuffle=True, 
                                             worker_init_fn=worker_init_fn,
                                             generator=generator,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size, 
                                              shuffle=True, 
                                              worker_init_fn=worker_init_fn,
                                              generator=generator,
                                              **kwargs)

    print("\nLength of training dataloader : {}".format(len(train_loader)))
    print("Length of validation dataloader : {}".format(len(val_loader)))
    print("Length of test dataloader : {}".format(len(test_loader)))

    split_dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
    }

    return split_dataloaders

In [None]:
def generate_dataloaders_from_remote(dataset, name='mnist', num_workers=1, pin_memory=True, val_test_ratio=(0.75, 0.25)):
    kwargs = {'num_workers': num_workers, # In order to parallelize dataset loading : interesting :) 
            'pin_memory': pin_memory} if torch.cuda.is_available() else {}

    batch_size = params['batch_size']

    if name == 'places-365':
        train_dataset = dataset(DATASETS_DIR + name
                                + '/', split='train-standard', download=True, transform=transform_train)
    elif name == 'svhn':
        train_dataset = dataset(DATASETS_DIR + name + '/', split='train', download=True, transform=transform_train)
    elif name == 'inaturalist':
        train_dataset = dataset(DATASETS_DIR + name + '/', version='2021_train', download=True, transform=transform_train)
    elif name == 'fake-data':
        train_dataset = dataset(size=1_000, image_size=(3, 26, 26), num_classes=10, transform=transform_train)
    elif name == 'caltech-101':
        full_dataset = dataset(DATASETS_DIR + name + '/', target_type='category', download=True, transform=transform_train)
    elif name == 'caltech-256':
        full_dataset = dataset(DATASETS_DIR + name + '/', download=True, transform=transform_train)
    else:
        train_dataset = dataset(DATASETS_DIR + name + '/', train=True, download=True, transform=transform_train)

    if name in ['caltech-101', 'caltech-256']:
        len_train = int(0.85 * len(full_dataset))
        len_val_test = len(full_dataset) - len_train
        print("Caltech dataset : len_train={}, len_val_test={}".format(len_train, len_val_test))
        train_dataset, __test_dataset__ = torch.utils.data.random_split(full_dataset, lengths=[len_train, len_val_test])

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, 
                                            shuffle=True, 
                                            worker_init_fn=worker_init_fn,
                                            generator=generator,
                                            **kwargs)
    
    len_train = len(train_dataset)
    print("Loading dataset : [{}]".format(name))
    print("\nNumber of training samples : {}".format(len_train))

    if name =='places-365':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', split='val', download=True, transform=transform_val_test)
    elif name =='svhn':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', split='test', download=True, transform=transform_val_test)
    elif name == 'inaturalist':
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', version='2021_valid', download=True, transform=transform_val_test)
    elif name == 'fake-data':
        __test_dataset__ = dataset(size=100, image_size=(3, 26, 26), num_classes=10, transform=transform_val_test)
    elif name in ['caltech-101', 'caltech-256']:
        pass
    else:
        __test_dataset__ = dataset(DATASETS_DIR + name + '/', train=False, download=True, transform=transform_val_test)

    len_val = int(val_test_ratio[0] * len(__test_dataset__))
    len_test = len(__test_dataset__) - len_val
    print("Number of validation samples : {}".format(len_val))
    print("Number of test samples : {}".format(len_test))

    val_dataset, test_dataset = torch.utils.data.random_split(__test_dataset__, [len_val, len_test])

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size, 
                                             shuffle=True, 
                                             worker_init_fn=worker_init_fn,
                                             generator=generator,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size, 
                                              shuffle=True, 
                                              worker_init_fn=worker_init_fn,
                                              generator=generator,
                                              **kwargs)

    print("\nBatch size : {}".format(batch_size))
    print("\nLength of training dataloader (in batches) : {}".format(len(train_loader)))
    print("Length of validation dataloader (in batches) : {}".format(len(val_loader)))
    print("Length of test dataloader (in batches) : {}".format(len(test_loader)))

    split_dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
    }

    return split_dataloaders

#### Data **augmentation** (image **transforms**)

In [None]:
class GaussianBlur():
    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

class TransformsAugment():
    def __init__(self, size, multi_channel=True):
        print("Transforms Augment")
        s = 1
        color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        transformations = [
            T.RandomResizedCrop(size=size),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
            T.RandomHorizontalFlip(),  # with 0.5 probability
            T.ToTensor(),
        ]
        # We "pseudo-colorize" the image by broadcasting to the three dimensions
        if not multi_channel:
            # Solution : number 1 : we simply broadcast the pixel information over all three channels,
            # but the main problem is that this is suboptimal
            # transformations.append(T.Lambda(lambda x: x[0:1, :, :]))
            transformations.append(T.Lambda(lambda x: x.repeat(3,1,1)))

            print("Not multi-channel")

        transformations.append(
            T.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )
        )
        self.transform = T.Compose(transformations)

    def __call__(self, x):
        x_trans = self.transform(x)
        return x_trans

#### Utility functions for **models handling** :

In [None]:
def get_vision_module(version: str):
    """
    Loads ResNet & EfficientNet encoders from torchvision along with the final features number
    """

    os.environ['TORCH_HOME'] = PRETRAINED_MODELS_DIR # Used in order to specify where to save the pretrained models, so as not to load them again in the future

    resnets = {
        "resnet-18": (lambda: tv.models.resnet18(pretrained=True, progress=True)),
        "resnet-34": (lambda: tv.models.resnet34(pretrained=True, progress=True)),
        "resnet-50": (lambda: tv.models.resnet50(pretrained=True, progress=True)),
        "resnet-101": (lambda: tv.models.resnet101(pretrained=True, progress=True)),
        "resnet-152": (lambda: tv.models.resnet152(pretrained=True, progress=True)),
    }

    efficientnets = {
        "efficientnet-b0": (lambda: tv.models.efficientnet_b0(pretrained=True, progress=True)),
        "efficientnet-b1": (lambda: tv.models.efficientnet_b1(pretrained=True, progress=True)),
        "efficientnet-b2": (lambda: tv.models.efficientnet_b2(pretrained=True, progress=True)),
        "efficientnet-b3": (lambda: tv.models.efficientnet_b3(pretrained=True, progress=True)),
        "efficientnet-b4": (lambda: tv.models.efficientnet_b4(pretrained=True, progress=True)),
        "efficientnet-b5": (lambda: tv.models.efficientnet_b5(pretrained=True, progress=True)),
        "efficientnet-b6": (lambda: tv.models.efficientnet_b6(pretrained=True, progress=True)),
        "efficientnet-b7": (lambda: tv.models.efficientnet_b7(pretrained=True, progress=True)),
    }

    if (version not in resnets) and (version not in efficientnets):
        raise KeyError(f"{version} is not a valid ResNet / EfficientNet version")

    models_library = {**resnets, **efficientnets}

    model = models_library[version]()
    
    # features_dim = model.fc.in_features
    # model.fc = nn.Identity()

    return model #, features_dim

try:
    get_vision_module('downloading')
except:
    print("Pretrained vision models successfully downloaded")

Pretrained vision models successfully downloaded


#### Utility functions for **memory management** :

In [None]:
def free_memory(model, *args):
    # Now, we free the available memory in order to launch other training experiments.
    print("Freeing memory ...")
    torch.cuda.empty_cache()

    for x in args:
        x = None
        del x
    print("Successfully freed some memory !")

### Utility function in order to properly adapt a three-channel dataset in order to take as inputs images from a one-channel dataset, such as MNIST :

In [None]:
"""
def adapt_to_3_channels(model_constructor):
    class ModelTo3Channel(ResNet):
        def __init__(self):
            super(ModelTo3Channel, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
            self.conv1 = torch.nn.Conv2d(1, 64, 
                kernel_size=(7, 7), 
                stride=(2, 2), 
                padding=(3, 3), bias=False)
"""

'\ndef adapt_to_3_channels(model_constructor):\n    class ModelTo3Channel(ResNet):\n        def __init__(self):\n            super(ModelTo3Channel, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)\n            self.conv1 = torch.nn.Conv2d(1, 64, \n                kernel_size=(7, 7), \n                stride=(2, 2), \n                padding=(3, 3), bias=False)\n'

#### Instantiating dataloaders into **'train' / 'val' / 'test' splits** :

In [None]:
split_dataloaders = {
    'train': None,
    'val': None,
    'test': None,
}

image_datasets = {
    'cifar-10': split_dataloaders,
    'cifar-100': split_dataloaders,
    'mnist': split_dataloaders,
    'tiny-imagenet': split_dataloaders,
    'fashion-mnist': split_dataloaders,
    'q-mnist': split_dataloaders,
    'k-mnist': split_dataloaders,
    'svhn': split_dataloaders,
    'caltech-101': split_dataloaders,
    'caltech-256': split_dataloaders,
}

In [None]:
image_datasets.update({
    'imagenet': split_dataloaders,
    'places-365': split_dataloaders,
    'inaturalist': split_dataloaders,
    'fake-data': split_dataloaders,
})

#### **Loading** them either **from memory** (if already downloaded or manually downloaded) or **from the datasets hub**

In [None]:
image_datasets['cifar-10'] = lambda: generate_dataloaders_from_remote(datasets.CIFAR10, name='cifar-10')
image_datasets['cifar-100'] = lambda: generate_dataloaders_from_remote(datasets.CIFAR100, name='cifar-100')
image_datasets['mnist'] = lambda: generate_dataloaders_from_remote(datasets.MNIST, name='mnist')
image_datasets['tiny-imagenet'] = lambda: generate_dataloaders_from_local(DATASETS_DIR + DATASET_TINY_IMAGENET_DIR)
image_datasets['fashion-mnist'] = lambda: generate_dataloaders_from_remote(datasets.FashionMNIST, name='fashion-mnist')
image_datasets['q-mnist'] = lambda: generate_dataloaders_from_remote(datasets.QMNIST, name='q-mnist')
image_datasets['k-mnist'] = lambda: generate_dataloaders_from_remote(datasets.KMNIST, name='k-mnist')
image_datasets['svhn'] = lambda: generate_dataloaders_from_remote(datasets.SVHN, name='svhn')
image_datasets['caltech-101'] = lambda: generate_dataloaders_from_remote(datasets.Caltech101, name='caltech-101')
image_datasets['caltech-256'] = lambda: generate_dataloaders_from_remote(datasets.Caltech256, name='caltech-256')

In [None]:
image_datasets['inaturalist'] = lambda: generate_dataloaders_from_remote(datasets.INaturalist, name='inaturalist')
image_datasets['places-365'] = lambda: generate_dataloaders_from_remote(datasets.Places365, name='places-365')
image_datasets['imagenet'] = lambda: generate_dataloaders_from_local(DATASETS_DIR + DATASET_IMAGENET_DIR)
image_datasets['fake-data'] = lambda: generate_dataloaders_from_remote(datasets.FakeData, name='fake-data')

#### Standard image sizes for each dataset

In [None]:
sizes = {
    'cifar-10': 32,
    'cifar-100': 32,
    'mnist': 28,
    'imagenet': 256,
    'tiny-imagenet': 64,
    'fashion-mnist': 28,
    'places-365': 256,
    'inaturalist': 256, # the dimension to resize the images to, but they may be actually higher-def (up to 2,048 px)
    'fake-data': 224,
    'q-mnist': 28,
    'k-mnist': 28,
    'svhn': 32,
    'caltech-101': 64, # there are actually sizes for each image - therefore we round up to the average value
    'caltech-256': 64, # same
}

### 1. Training all models on **CIFAR-100** and **MNIST**

In [None]:
params = {
    'batch_size': 64,
    'num_epochs': 5,
    'log_interval': 100,
    'lr': 0.001,
}

TEST_ONE_BATCH = False
LOG_TO_WANDB = True

class StopExecution(Exception):
    def _render_traceback_(self):
        pass

# raise StopExecution

In [None]:
def finetune(model='resnet-50', dataset='cifar-10', num_epochs=None, add_name=None):
    global transform_train, transform_val_test, params

    if add_name is None:
        NAME_RUN = f'Vision model pretraining : M[{model}], D[{dataset}]'
        NAME_CHECKPOINT = f'model={model}_dataset={dataset}'
    else:
        NAME_RUN = f'Vision model pretraining : M[{model}], D[{dataset}], D[{add_name}]'
        NAME_CHECKPOINT = f'model={model}_dataset={dataset}_cond={add_name}'
        
    DIR_CHECKPOINT = FINETUNED_MODELS_DIR + NAME_CHECKPOINT + '/'

    model_name = model

    params = copy.deepcopy(params)
    if num_epochs is not None:
        params['num_epochs'] = num_epochs

    print("Training parameters :", params)

    # -------------------------------------------------------
    # 1. Instantiating the model and the optimization routine
    # -------------------------------------------------------
    # Move model to designated device (Use GPU when on Colab)
    model = get_vision_module(version=model)
    # model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=200)
    model = model.to(device)

    # Define hyperparameters and settings

    # Set loss function (categorical Cross Entropy Loss)
    loss_func = nn.CrossEntropyLoss()

    # Set optimizer (using Adam as default)
    optimizer = optim.Adam(model.parameters(), lr=params['lr'])

    # ----------------------------------------
    # 2. Defining the data augmentation scheme
    # ----------------------------------------
    multi_channel = not dataset in ['mnist', 'k-mnist', 'q-mnist', 'fashion-mnist']
    transform_train = TransformsAugment(size=sizes[dataset], multi_channel=multi_channel)
    transform_val_test = TransformsAugment(size=sizes[dataset], multi_channel=multi_channel)

    # -----------------------
    # 3. Creating the trainer
    # -----------------------
    # Setup pytorch-ignite trainer engine
    trainer = create_supervised_trainer(model, optimizer, loss_func, device=device)

    # Add progress bar to monitor model training
    ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {"Batch Loss": x})

    # ------------------------------
    # 4. Defining evaluation metrics
    # ------------------------------

    __precision__ = Precision(average=False)
    __recall__ = Recall(average=False)

    metrics = {
        "accuracy": Accuracy(), 
        "loss": Loss(loss_func),
        "precision": Precision(average=True),
        "recall": Recall(average=True),
        'F1': (__precision__ * __recall__ * 2 / (__precision__ + __recall__)).mean(),
    }

    # classification_report = ClassificationReport()
    # classification_report.attach(trainer, "classification_report")

    # ----------------------------------------------
    # 5. Defining the supervised classification task
    # ----------------------------------------------
    # Evaluator for training data
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    # Evaluator for validation data
    evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    # -------------------------------
    # 6. Instantiating the dataloader
    # -------------------------------
    split = image_datasets[dataset]()
    train_loader, val_loader, test_loader = split.values()

    # -------------------------------------------
    # 7. Utility functions for experiment logging
    # -------------------------------------------
    @trainer.on(Events.STARTED)
    def start_message():
        print("Begin training")

    # Log results from every batch
    @trainer.on(Events.ITERATION_COMPLETED(every=params['log_interval']))
    def log_batch(trainer):
        batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
        print(f"Epoch {trainer.state.epoch} / {params['num_epochs']}, "
            f"Batch {batch} / {trainer.state.epoch_length}: "
            f"Loss: {trainer.state.output:.3f}")
        """
        evaluator.run(val_loader)
        epoch = trainer.state.epoch
        metrics = evaluator.state.metrics
        print(metrics)
        print(f"Val - Loss: {metrics['loss']:.3f}, ",
              f"Val - Accuracy: {metrics['accuracy']:.3f}, ",
              f"Val - Precision: {metrics['precision']:.3f}, ",
              f"Val - Recall: {metrics['recall']:.3f}, ",
              f"Val - F1: {metrics['F1']:.3f}, ",)
        """

    # Evaluate and print training set metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_loss(trainer):
        print(f"Epoch [{trainer.state.epoch}] - Loss: {trainer.state.output:.2f}")
        train_evaluator.run(train_loader)
        epoch = trainer.state.epoch
        metrics = train_evaluator.state.metrics
        print(metrics)
        print(f"[E] Train - Loss: {metrics['loss']:.3f}, ",
              f"[E] Train - Accuracy: {metrics['accuracy']:.3f}, ",
              f"[E] Train - Precision: {metrics['precision']:.3f}, ",
              f"[E] Train - Recall: {metrics['recall']:.3f}, ",
              f"[E] Train - F1: {metrics['F1']:.3f}, ",)

    # Evaluate and print validation set metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_loss(trainer):
        evaluator.run(val_loader)
        epoch = trainer.state.epoch
        metrics = evaluator.state.metrics
        print(f"[E] Val - Loss: {metrics['loss']:.3f}, ",
              f"[E] Val - Accuracy: {metrics['accuracy']:.3f}, ",
              f"[E] Val - Precision: {metrics['precision']:.3f}, ",
              f"[E] Val - Recall: {metrics['recall']:.3f}, ",
              f"[E] Val - F1: {metrics['F1']:.3f}, ",)

    # Sets up checkpoint handler to save best n model(s) based on validation accuracy metric
    common.save_best_model_by_val_score(
            output_path=DIR_CHECKPOINT,
            evaluator=evaluator, model=model,
            metric_name="accuracy", n_saved=5,
            trainer=trainer, tag="val")

    # Sets up Early Stopping
    common.add_early_stopping_by_val_score(
            patience=100, 
            evaluator=evaluator, 
            trainer=trainer, 
            metric_name='accuracy',
            )

    # -----------------------------------------
    # 8. Other helper functions - WandB Logging
    # -----------------------------------------
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),
        desc=desc.format(0)
    )

    if LOG_TO_WANDB:
        #WandBlogger Object Creation
        wandb_logger = WandBLogger(
            project=WANDB_PROJECT,
            name=NAME_RUN,
            config={"max_epochs": params['num_epochs'], "batch_size": params['batch_size']},
            tags=[model_name, dataset],
            group=WANDB_EXPERIMENT_GROUP,
        )

        wandb_logger.attach_output_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED,
            tag="training",
            output_transform=lambda loss: {"loss": loss}
        )

        wandb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="training",
            metric_names=["nll", "accuracy", "precision", "recall", "F1"],
            global_step_transform=lambda *_: trainer.state.iteration,
        )

        wandb_logger.attach_opt_params_handler(
            trainer,
            event_name=Events.ITERATION_STARTED,
            optimizer=optimizer,
            param_name='lr'  # optional
        )

        wandb_logger.watch(model)

    @trainer.on(Events.ITERATION_COMPLETED(every=params['log_interval']))
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(params['log_interval'])
        
        if TEST_ONE_BATCH:
            raise StopExecution

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        avg_precision = metrics['precision']
        avg_recall = metrics['recall']
        avg_f1 = metrics['F1']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f}  Avg loss: {:.2f}  Avg precision: {:.2f}  Avg recall: {:.2f}  Avg F1: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_loss, avg_precision, avg_recall, avg_f1)
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        avg_precision = metrics['precision']
        avg_recall = metrics['recall']
        avg_f1 = metrics['F1']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f}  Avg loss: {:.2f}  Avg precision: {:.2f}  Avg recall: {:.2f}  Avg F1: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_loss, avg_precision, avg_recall, avg_f1))

        pbar.n = pbar.last_print_n = 0

    handler = ModelCheckpoint(DIR_CHECKPOINT, 'last_models', n_saved=1, create_dir=True)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), handler, {'model': model})

    trainer.run(train_loader, max_epochs=params['num_epochs'])
    
    pbar.close()

    print("Displaying the final metrics : ")
    print(evaluator.state.metrics)

    evaluator.run(test_loader)
    metrics = train_evaluator.state.metrics
    print(metrics)
    print(f"Train - Loss: {metrics['loss']:.3f}, ",
            f"Accuracy: {metrics['accuracy']:.3f}, ",
            f"Precision: {metrics['precision']:.3f}, ",
            f"Recall: {metrics['recall']:.3f}, ",
            f"F1: {metrics['F1']:.3f}, ",)

    free_memory([model, train_loader, val_loader])

In [None]:
finetune(model='resnet-18', dataset='mnist')

#### In order to just forward one batch of samples through each model in order to check that training goes smoothly, please set :


In [None]:
#@title Decide whether to test ablation on a single batch or to run the full ablation study { run: "auto", vertical-output: true, form-width: "65%", display-mode: "code" }

option = "Full ablation" #@param ["Single batch", "Full ablation"]

if option == "Single batch":
    TEST_ONE_BATCH = True
    LOG_TO_WANDB = False
    params['log_interval'] = 10
else:
    TEST_ONE_BATCH = False
    LOG_TO_WANDB = True
    params['log_interval'] = 100


#### 🔸 **Resnet-18**

In [None]:
finetune(model='resnet-18', dataset='cifar-100')

In [None]:
finetune(model='resnet-18', dataset='mnist')

#### 🔸🔸 **Resnet-34**

In [None]:
finetune(model='resnet-34', dataset='cifar-100')

In [None]:
finetune(model='resnet-34', dataset='mnist')

#### 🔸🔸🔸 **Resnet-50** (default *Resnet* model, which we will train longer)

In [None]:
finetune(model='resnet-50', dataset='cifar-100')

In [None]:
finetune(model='resnet-50', dataset='mnist')

#### 🔸🔸🔸🔸 **Resnet-101**

In [None]:
finetune(model='resnet-101', dataset='cifar-100') 

In [None]:
finetune(model='resnet-101', dataset='mnist')

#### 🔸🔸🔸🔸🔸 **Resnet-152**

In [None]:
finetune(model='resnet-152', dataset='cifar-100')

In [None]:
finetune(model='resnet-152', dataset='mnist')

#### 🔹 **EfficientNet-B0** 

In [None]:
finetune(model='efficientnet-b0', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b0', dataset='mnist')

#### 🔹🔹 **EfficientNet-B1** 

In [None]:
finetune(model='efficientnet-b1', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b1', dataset='mnist')

#### 🔹🔹🔹 **EfficientNet-B2** 

In [None]:
finetune(model='efficientnet-b2', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b2', dataset='mnist')

🔹🔹🔹🔹 #### **EfficientNet-B3** (default *EfficientNet* model, which we will train longer)

In [None]:
finetune(model='efficientnet-b3', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b3', dataset='mnist')

#### 🔹🔹🔹🔹🔹 **EfficientNet-B4** 

In [None]:
finetune(model='efficientnet-b4', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b4', dataset='mnist')

#### 🔹🔹🔹🔹🔹🔹 **EfficientNet-B5** 

In [None]:
finetune(model='efficientnet-b5', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b5', dataset='mnist')

#### 🔹🔹🔹🔹🔹🔹🔹 **EfficientNet-B6** 

In [None]:
finetune(model='efficientnet-b6', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b6', dataset='mnist')

#### 🔹🔹🔹🔹🔹🔹🔹🔹 **EfficientNet-B7** 

In [None]:
finetune(model='efficientnet-b7', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b7', dataset='mnist')

### 2. Training **Resnet-50** and *EfficientNet-B3* on all datasets

#### 🏞 CIFAR-10 🏞 (Privileged dataset n°1)

In [None]:
finetune(model='resnet-50', dataset='cifar-10')

In [None]:
finetune(model='efficientnet-b3', dataset='cifar-10')

#### Privileged runs (models trained longer) :

In [None]:
finetune(model='resnet-50', dataset='cifar-10', num_epochs=50, add_name='longer')

In [None]:
finetune(model='efficientnet-b3', dataset='cifar-10', num_epochs=50, add_name='longer')

#### 🏞 CIFAR-100 🏞

In [None]:
finetune(model='resnet-50', dataset='cifar-100')

In [None]:
finetune(model='efficientnet-b3', dataset='cifar-100')

#### 🔢 MNIST 🔢 (Privileged dataset n°1)

In [None]:
finetune(model='resnet-50', dataset='mnist')

In [None]:
finetune(model='efficientnet-b3', dataset='mnist')

#### Privileged runs (models trained longer) :

In [None]:
finetune(model='resnet-50', dataset='mnist', num_epochs=50, add_name='longer')

In [None]:
finetune(model='efficientnet-b3', dataset='mnist', num_epochs=50, add_name='longer')

#### 🪐 ImageNet 🪐

In [None]:
finetune(model='resnet-50', dataset='imagenet', num_epochs=2)

In [None]:
finetune(model='efficientnet-b3', dataset='imagenet', num_epochs=2)

#### 🪐 TinyImageNet 🪐

In [None]:
finetune(model='resnet-50', dataset='tiny-imagenet', num_epochs=4)

In [None]:
finetune(model='efficientnet-b3', dataset='tiny-imagenet', num_epochs=4)

#### 🥋 FashionMNIST 🥋

In [None]:
finetune(model='resnet-50', dataset='fashion-mnist')

In [None]:
finetune(model='efficientnet-b3', dataset='fashion-mnist')

#### 🌉 Places365 🌉 

In [None]:
finetune(model='resnet-50', dataset='places-365', num_epochs=5)

In [None]:
finetune(model='efficientnet-b3', dataset='places-365', num_epochs=5)

#### 🌿 iNaturalist 🌿

In [None]:
finetune(model='resnet-50', dataset='inaturalist', num_epochs=5)

In [None]:
finetune(model='efficientnet-b3', dataset='inaturalist', num_epochs=5)

#### 🤡 FakeData 🤡

In [None]:
# finetune(model='resnet-50', dataset='fake-data')

In [None]:
# finetune(model='efficientnet-b3', dataset='fake-data')

#### 🧾 QMNIST 🧾

In [None]:
finetune(model='resnet-50', dataset='q-mnist')

In [None]:
finetune(model='efficientnet-b3', dataset='q-mnist')

#### 🔖 KMNIST 🔖

In [None]:
finetune(model='resnet-50', dataset='k-mnist')

In [None]:
finetune(model='efficientnet-b3', dataset='k-mnist')

#### 🗻 SVHN 🗻

In [None]:
finetune(model='resnet-50', dataset='svhn')

In [None]:
finetune(model='efficientnet-b3', dataset='svhn')

#### 🪑 Caltech-101 🪑

In [None]:
finetune(model='resnet-50', dataset='caltech-101')

In [None]:
finetune(model='efficientnet-b3', dataset='caltech-101')

#### 💺 Caltech-256 💺

In [None]:
finetune(model='resnet-50', dataset='caltech-256')

In [None]:
finetune(model='efficientnet-b3', dataset='caltech-256')