# Layer-wise Self-supervised training of a 4-layers MLP on MNIST

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision
import optuna
import os
import argparse
from tqdm import tqdm

from utils_mlp_endtoend import*

# Hyperparameters of the simulation

In [None]:
parser = argparse.ArgumentParser(description='Train a MLP with layer-wise SSL on MNIST')
# general params
parser.add_argument(
    '--device',
    type=int,
    default=0,
    help='GPU name to use cuda')
parser.add_argument(
    '--dataset',
    type=str,
    default='mnist',
    help='Dataset we use for training (default=mnist, others: None for this specific notebook)')
# Architecture params
parser.add_argument(
    '--nlayers',
    type=int,
    default=4,
help='Number of layers for the main MLP (default: 4)')
parser.add_argument(
    '--nneurons',
    type=int,
    default=1000,
help='Number of neuron per layer of the main MLP (default: 1000)')
parser.add_argument(
    '--nlayers_proj',
    type=int,
    default=3,
help='Number of layers for each non-linear projector (default: 3)')
parser.add_argument(
    '--nneurons_proj',
    type=int,
    default=256,
help='Number of neuron per layer of each non-linear projector (default: 256)')
# Optimization params
parser.add_argument(
    '--epochs',
    type=int,
    default=1000,
help='Number of epochs to train each layer (default: 500)')
parser.add_argument(
    '--epochs_classifier',
    type=int,
    default=20,
help='Number of epochs to train a linear classifier on top of each layer (default: 20)')
parser.add_argument(
    '--batchSize_pretrain',
    type=int,
    default=256,
    help='Batch size for pre-training (default=256)')
parser.add_argument(
    '--batchSize_classifier',
    type=int,
    default=64,
    help='Batch size for training the linear classifier (default=64)')
parser.add_argument(
    '--test_batchSize',
    type=int,
    default=512,
    help='Testing Batch size (default=512)')
parser.add_argument(
    '--lr',
    type=float,
    default=1e-4,
    help='Learning rate for Adam optimizer for pre-training (default=1e-4)')
parser.add_argument(
    '--lr_classifier',
    type=float,
    default=1e-3,
    help='Learning rate for Adam optimizer for training the linear classifier (default=1e-4)')
# SSL settings
parser.add_argument(
    '--nviews',
    type=int,
    default=2,
    help='Number of views for computing the SSL objective (default=2)')
parser.add_argument(
    '--deg',
    type=int,
    default=15,
    help='Maximum angle for the RandomRotation transform applied to the input image (default=15)')
parser.add_argument(
    '--pad',
    type=int,
    default=2,
    help='Padding applied before RandomCrop (default=2)')
parser.add_argument(
    '--contrast',
    type=float,
    default=0.5,
    help='Contrast value for ColorJittering (default=0.5)')
parser.add_argument(
    '--hue',
    type=float,
    default=0.5,
    help='Hue value for ColorJittering (default=0.5)')
parser.add_argument(
    '--scaleaffine',
    nargs='+',
    type=float,
    default=[0.7, 1.3],
    help='Scale parameters for the RandomAffine transform applied (default=(0.7, 1.3))')
# SSL objective params (for the sigmoidal parametrization)
parser.add_argument(
    '--scale',
    nargs='+',
    type=float,
    default=[40, 20, 1, 0],
    help='Scale value for the sigmoidal parametrization (default=(10, 20, 1, 0))')
parser.add_argument(
    '--slope',
    nargs='+',
    type=float,
    default=[1.5,0,2,0],
    help='Slope value for the sigmoidal parametrization (default=(10, 20, 1, 0))')
parser.add_argument(
    '--threshold',
    nargs='+',
    type=float,
    default=[4,3,2,2],
    help='Threshold value for the sigmoidal parametrization (default=(10, 20, 1, 0))')
parser.add_argument(
    '--bias',
    nargs='+',
    type=float,
    default=[10,15,0,0],
    help='Bias value for the sigmoidal parametrization (default=(10, 20, 1, 0))')


args, _ = parser.parse_known_args()

# Dataset creation

In [None]:
# Data and target transfomations
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)
        
        
class ReshapeTransformTarget:
    def __init__(self, number_classes):
        self.number_classes = number_classes
    
    def __call__(self, target):
        target=torch.tensor(target).unsqueeze(0).unsqueeze(1)
        target_onehot = torch.zeros((1,self.number_classes))      
        return target_onehot.scatter_(1, target, 1).squeeze(0)

    
class ContrastiveTransformations(object):
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms #random transformations
        self.n_views = n_views # number of differents copies with different 

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

In [None]:
#Dataset creation
def get_dataloader(args):
    '''
    Function that returns the dataloaders given the current hyperparameters 
    '''
    # random data augmentations for the pre-training stage
    contrast_transforms =  transforms.Compose([torchvision.transforms.RandomRotation(degrees = args.deg, fill=0), #random rotation
                                               torchvision.transforms.RandomCrop((28,28), padding = args.pad), #random crop
                                               torchvision.transforms.RandomAffine(degrees=(0, 0), translate=(0.0, 0.0), scale=(args.scaleaffine[0], args.scaleaffine[1])),
                                               torchvision.transforms.ColorJitter(brightness=0, contrast = args.contrast, saturation=0, hue = args.hue),
                                               torchvision.transforms.ToTensor(),
                                               ReshapeTransform((-1,))])
    
    # fixed transformation for training testing the linear classifier
    transforms_test = transforms.Compose([torchvision.transforms.ToTensor(),
                                           ReshapeTransform((-1,))])

    # Train loadr for pre-training
    train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = ContrastiveTransformations(contrast_transforms, n_views=args.nviews),
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_pretrain, shuffle=True)

    # Train loadr for pre-training
    train_loader_small = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=False, download=True,
                                transform = ContrastiveTransformations(contrast_transforms, n_views=args.nviews),
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_pretrain, shuffle=True)

    
    # Train loader for the linear classifier
    train_loader_classifier = torch.utils.data.DataLoader(torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = transforms_test,
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_classifier, shuffle=True)

    # Test loader for the linear classifier
    test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(root='./data', train=False, download=True,
                                transform = transforms_test,
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.test_batchSize, shuffle=False)
    
    
    return train_loader, train_loader_small, train_loader_classifier, test_loader


In [None]:
train_loader_pretrain, train_loader_pretrain_small, train_loader_classifier, test_loader = get_dataloader(args)

In [None]:
# Visualize the random transformations applied on mnist data
# Data from the pretraining dataloader is a list with 2 elements! we can stack them to have a global mini-batch in the training loop ?

data, target = next(iter(train_loader_pretrain))

plt.figure()
plt.imshow(data[0][0].view(28,28), cmap = "gray")

plt.figure()
plt.imshow(data[1][0].view(28,28), cmap = "gray")

plt.figure()
plt.imshow(data[0][0].view(28,28) - data[1][0].view(28,28), cmap = "gray")


plt.show()

# SSL (VicREG) Objective

In [None]:
def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

# new version of the VicRegLoss: add an intra-sample variance objective! we want the different views of the
# same inputs to be initially quite different then they should be tend to be more similar as we go deeper
# initial version is from the Meta github repo

class VicRegLoss(torch.nn.Module):
    '''
    Class that implement the VicReg loss
    '''
    def __init__(self, device, sim_coeff = 25, std_coeff = 25, cov_coeff = 1):
        super(VicRegLoss, self).__init__()
        self.device = device
        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff

        self.flatten = nn.Flatten()
        
    def forward(self, x, y):
        '''
        Args:
            x,y: embeddings vectors - are "flattened" so they have dimension (batch_size, num_features)
        '''
        batch_size = x.size(0)
        num_features = x.size(1)
        repr_loss = F.mse_loss(x, y) 

        x = x - x.mean(dim=0) 
        y = y - y.mean(dim=0) 
                
        std_x = torch.sqrt(x.var(dim=0) + 0.0001) 
        std_y = torch.sqrt(y.var(dim=0) + 0.0001) 
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
        
        cov_x = (x.T @ x) / (batch_size - 1)
        cov_y = (y.T @ y) / (batch_size - 1)
        
        diag = torch.eye(num_features, device=self.device)

        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
            num_features # num elements here?
        ) + off_diagonal(cov_y).pow_(2).sum().div(num_features)

                                                                      # we can also set it to 1 and decrease the relative strenght of that loss
        loss = (
            self.sim_coeff * repr_loss
            + self.std_coeff * std_loss
            + self.cov_coeff * cov_loss
        )
        return loss

# Model definition
- MLP with Relu
- each layer is trained with a non-linear projector (MLP wiht Relu)

In [None]:
class Network(nn.Module):
    ''' 
    Define the network used
    '''
    #def __init__(self, thetas_coefs, thetas_exponents, run_gpu):
    def __init__(self, args):

        super(Network, self).__init__()
        
        self.n_neurons = args.nneurons
        self.n_views = args.nviews
        self.n_layers = args.nlayers
        self.n_neurons_proj = args.nneurons_proj
        self.n_layers_proj = args.nlayers_proj
        
        self.layers = [nn.Linear(784, self.n_neurons, bias = False)]
        self.layers += [nn.Linear(self.n_neurons, self.n_neurons, bias = False) for k in range(self.n_layers-1)]
        self.layers = nn.Sequential(*self.layers)
        
        self.projs = [nn.Linear(self.n_neurons, self.n_neurons_proj, bias = False)]
        self.projs += [nn.Linear(self.n_neurons_proj, self.n_neurons_proj, bias = False) for k in range(self.n_layers_proj-1)]
        self.projs = nn.Sequential(*self.projs)

        self.f = nn.ReLU()
        self.dropout = torch.nn.Dropout(p=0.5)
        
        if args.device >= 0 and torch.cuda.is_available():
            device = torch.device(args.device)
            self.cuda = True
        else:
            device = torch.device("cpu")
            self.cuda = False

        self.device = device
        self = self.to(device)
        
        self.loss = VicRegLoss(self.device, sim_coeff = 25, 
                                               std_coeff = 25, 
                                               cov_coeff = 1)
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay = 1e-6)
    
    
    def forward(self, x):
        # for the forward pass, we store the gradient of each layer, but we use detach() between each layer so the computational graph 
        # is only at the layer-level
        # and we store the activations before the normalization
        # we send the layer-norm activation vector to the next layer (we normalize along the 2 axis)
        # we compute the layer-wise loss at each layer during the forward pass that allows us to use the .detach() features
        
        loss = 0
        states = []
        for idx, fc in enumerate(self.layers):
            #1. compute forward pass for every layer
            x = self.f(fc(x)) 
        
        for idx, fc in enumerate(self.projs[:-1]):
            #1. compute forward pass for every layer
            x = self.f(fc(x)) 
        
        x = self.projs[-1](x) #not applying Relu to the last layer
            
        x = self.multi_views(x)
        
        loss = self.loss(x[0].float(), x[1].float())

        return loss
    
    
    def forward_simple(self, x):
        '''
        Forward pass without computing the loss - and does not treat the case with pair of inputs 
        '''
        loss = 0
        states = []
        for idx, fc in enumerate(self.layers):
            #1. compute forward pass for every layer
            x = self.f(fc(x)) 
            states.append(x)
            x = x.detach() # detach to stop the computational graph here

        return x
    
        
    def single_batch(self, x):
        '''
        return a single big batch given the two views of the input data
        '''
        x = torch.stack(x) #here data as a new first dimension which is the number of views for the same input image
        x = x.reshape(x.size()[0]*x.size()[1], -1) #we change the first dim to be n_views*batch_size
        
        return x
    
    
    def multi_views(self, x):
        '''
        return a view of the tensor that has first dim n_views, second dim batch_size and third dim the layers dimension
        '''
        
        return x.reshape(self.n_views, -1, self.n_neurons_proj)

# Generate simulations environment + data folders

In [None]:
BASE_PATH = createPath(archi = "MLP-End-to-end", dataset = "MNIST")

In [None]:
saveHyperparameters(args, BASE_PATH)

In [None]:
dataframe = initDataframe_pretraining(BASE_PATH, dataframe_to_init = 'pre_training.csv')

In [None]:
net = Network(args)

In [None]:
net, pretraining_loss = pretraining_loop(BASE_PATH, args, net, train_loader_pretrain, train_loader_classifier, test_loader, epochs = args.epochs)