# Layer-wise Self-supervised training of a 4-layers MLP on MNIST 

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision
import optuna
import os
import argparse
from tqdm import tqdm

from utils_mlp_spsa import*

# Hyperparameters of the simulation

In [None]:
parser = argparse.ArgumentParser(description='Train a MLP with layer-wise SSL on MNIST - no BP but SPSA')
# general params
parser.add_argument(
    '--device',
    type=int,
    default=2,
    help='GPU name to use cuda')
parser.add_argument(
    '--dataset',
    type=str,
    default='mnist',
    help='Dataset we use for training (default=mnist, others: None for this specific notebook)')
# Architecture params
parser.add_argument(
    '--nlayers',
    type=int,
    default=2,
help='Number of layers for the main MLP (default: 4)')
parser.add_argument(
    '--nneurons',
    type=int,
    default=1000,
help='Number of neuron per layer of the main MLP (default: 1000)')
parser.add_argument(
    '--nlayers_proj',
    type=int,
    default=3,
help='Number of layers for each non-linear projector (default: 3)')
parser.add_argument(
    '--nneurons_proj',
    type=int,
    default=256,
help='Number of neuron per layer of each non-linear projector (default: 256)')
# Optimization params
parser.add_argument(
    '--epochs',
    type=int,
    default=200,
help='Number of epochs to train each layer (default: 500)')
parser.add_argument(
    '--epochs_classifier',
    type=int,
    default=20,
help='Number of epochs to train a linear classifier on top of each layer (default: 20)')
parser.add_argument(
    '--batchSize_pretrain',
    type=int,
    default=128,
    help='Batch size for pre-training (default=256)')
parser.add_argument(
    '--n_average',
    type=int,
    default=10,
    help='Number of sub-mini-batches in a mini-batches (default=10)')
parser.add_argument(
    '--batchSize_classifier',
    type=int,
    default=64,
    help='Batch size for training the linear classifier (default=64)')
parser.add_argument(
    '--test_batchSize',
    type=int,
    default=512,
    help='Testing Batch size (default=512)')
parser.add_argument(
    '--lr',
    type=float,
    default=1e-4,
    help='Learning rate for Adam optimizer for pre-training (default=1e-4)')
parser.add_argument(
    '--lr_classifier',
    type=float,
    default=1e-3,
    help='Learning rate for Adam optimizer for training the linear classifier (default=1e-4)')
parser.add_argument(
    '--initial_step',
    type=float,
    default=1e-2,
    help='Initial step of the random perturbation (default=1e-2)')
parser.add_argument(
    '--gamma_perturbation',
    type=float,
    default=0.95,
    help='Decay rate for the size of the random perturbation applied at each inference (default=0.95)')
parser.add_argument(
    '--gamma_optimizer',
    type=float,
    default=0.9,
    help='Decay rate for the learning rate for pre-training (default=0.9)')
# SSL settings
parser.add_argument(
    '--nviews',
    type=int,
    default=2,
    help='Number of views for computing the SSL objective (default=2)')
parser.add_argument(
    '--deg',
    type=int,
    default=15,
    help='Maximum angle for the RandomRotation transform applied to the input image (default=15)')
parser.add_argument(
    '--pad',
    type=int,
    default=2,
    help='Padding applied before RandomCrop (default=2)')
parser.add_argument(
    '--contrast',
    type=float,
    default=0.5,
    help='Contrast value for ColorJittering (default=0.5)')
parser.add_argument(
    '--hue',
    type=float,
    default=0.5,
    help='Hue value for ColorJittering (default=0.5)')
parser.add_argument(
    '--scaleaffine',
    nargs='+',
    type=float,
    default=[0.7, 1.3],
    help='Scale parameters for the RandomAffine transform applied (default=(0.7, 1.3))')
# SSL objective params (for the sigmoidal parametrization)
parser.add_argument(
    '--lambda_bt',
    type=float,
    default=5e-3,
    help='Hyperparam for Barlow Twins loss (default=(5e-3))')

args, _ = parser.parse_known_args()

# Dataset creation

In [None]:
# Data and target transfomations
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)
        
        
class ReshapeTransformTarget:
    def __init__(self, number_classes):
        self.number_classes = number_classes
    
    def __call__(self, target):
        target=torch.tensor(target).unsqueeze(0).unsqueeze(1)
        target_onehot = torch.zeros((1,self.number_classes))      
        return target_onehot.scatter_(1, target, 1).squeeze(0)

    
class ContrastiveTransformations(object):
    def __init__(self, base_transforms, n_views=2, n_average = 1):
        self.base_transforms = base_transforms #random transformations
        self.n_views = n_views # number of differents copies with different 
        self.n_average = n_average
        
    def __call__(self, x):
        return [[self.base_transforms(x) for i in range(self.n_views)] for n in range(self.n_average)]

In [None]:
#Dataset creation
def get_dataloader(args):
    '''
    Function that returns the dataloaders given the current hyperparameters 
    '''
    # random data augmentations for the pre-training stage
    contrast_transforms =  transforms.Compose([torchvision.transforms.RandomRotation(degrees = args.deg, fill=0), #random rotation
                                               torchvision.transforms.RandomCrop((28,28), padding = args.pad), #random crop
                                               torchvision.transforms.RandomAffine(degrees=(0, 0), translate=(0.0, 0.0), scale=(args.scaleaffine[0], args.scaleaffine[1])),
                                               torchvision.transforms.ColorJitter(brightness=0, contrast = args.contrast, saturation=0, hue = args.hue),
                                               torchvision.transforms.ToTensor(),
                                               ReshapeTransform((-1,))])
    
    # fixed transformation for training testing the linear classifier
    transforms_test = transforms.Compose([torchvision.transforms.ToTensor(),
                                           ReshapeTransform((-1,))])

    # Train loader for pre-training
    train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = ContrastiveTransformations(contrast_transforms, n_views=args.nviews, n_average=args.n_average),
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_pretrain, shuffle=True)

    # Small train loader for pre-training
    train_loader_small = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=False, download=True,
                                transform = ContrastiveTransformations(contrast_transforms, n_views=args.nviews, n_average=args.n_average),
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_pretrain, shuffle=True)

    
    # Train loader for the linear classifier
    train_loader_classifier = torch.utils.data.DataLoader(torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = transforms_test,
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.batchSize_classifier, shuffle=True)

    # Test loader for the linear classifier
    test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(root='./data', train=False, download=True,
                                transform = transforms_test,
                                target_transform=ReshapeTransformTarget(10)), batch_size = args.test_batchSize, shuffle=False)
    
    
    return train_loader, train_loader_small, train_loader_classifier, test_loader
 

In [None]:
train_loader_pretrain, train_loader_pretrain_small, train_loader_classifier, test_loader = get_dataloader(args)

In [None]:
# Visualize the random transformations applied on mnist data
# Data from the pretraining dataloader is a list with 2 elements! we can stack them to have a global mini-batch in the training loop ?

data, target = next(iter(train_loader_pretrain))
print(len(data))
plt.figure()
plt.imshow(data[0][0][0].view(28,28), cmap = "gray")

plt.figure()
plt.imshow(data[0][1][0].view(28,28), cmap = "gray")

plt.figure()
plt.imshow(data[0][0][0].view(28,28) - data[0][1][0].view(28,28), cmap = "gray")


plt.show()

# SSL (VicREG) Objective

In [None]:
def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

# new version of the VicRegLoss: add an intra-sample variance objective! we want the different views of the
# same inputs to be initially quite different then they should be tend to be more similar as we go deeper
# initial version is from the Meta github repo

class BarlowTwinsLoss(torch.nn.Module):

    def __init__(self, device, lambda_param=5e-3):
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param
        self.device = device
        self.epsilon = 1e-6

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor):
        # normalize repr. along the batch dimension - "batch norm with alpha = 1 and beta = 0"
        z_a_norm = (z_a - z_a.mean(0)) / (z_a.std(0)+self.epsilon) # NxD - add espilon to avoid nan is std is 0 with small mini batch
        z_b_norm = (z_b - z_b.mean(0)) / (z_b.std(0)+self.epsilon) # NxD - add espilon to avoid nan is std is 0 with small mini batch

        N = z_a.size(0)
        D = z_a.size(1)

        # cross-correlation matrix
        c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD

        # loss
        c_diff = (c - torch.eye(D,device=self.device)).pow(2) # DxD
        
        c_diff[~torch.eye(D, dtype=bool)] *= c_diff[~torch.eye(D, dtype=bool)]            # test avec le carré
        c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param 
        loss = c_diff.sum()

        return loss    

# Model definition
- MLP with Relu
- each layer is trained with a non-linear projector (MLP wiht Relu)

In [None]:
#Network class
class Network(nn.Module):
    ''' 
    Define the network used
    '''
    #def __init__(self, thetas_coefs, thetas_exponents, run_gpu):
    def __init__(self, args):

        super(Network, self).__init__()
        
        self.n_views = args.nviews
        
        self.n_layers = args.nlayers
        self.n_neurons = args.nneurons

        self.n_layers_proj = args.nlayers_proj
        self.n_neurons_proj = args.nneurons_proj
        
        self.initial_step = args.initial_step
        self.step = args.initial_step #this will be updated after each epoch: step(epoch + 1) = step(epoch)*gamma_perturbation
        self.gamma_perturbation = args.gamma_perturbation
        
        self.layers = [nn.Linear(784, self.n_neurons, bias = False)]
        self.layers += [nn.Linear(self.n_neurons, self.n_neurons, bias = False) for k in range(self.n_layers-1)]
        self.layers = nn.Sequential(*self.layers)
        
        # need to optimize how we create the non-linear projectors to parametrize the numbers we create
        self.projs1 = [nn.Linear(self.n_neurons, self.n_neurons_proj, bias = False)]
        self.projs1 += [nn.Linear(self.n_neurons_proj, self.n_neurons_proj, bias = False) for k in range(self.n_layers_proj - 1)]
        self.projs1 = nn.Sequential(*self.projs1)
        
        self.projs2 = [nn.Linear(self.n_neurons, self.n_neurons_proj, bias = False)]
        self.projs2 += [nn.Linear(self.n_neurons_proj, self.n_neurons_proj, bias = False) for k in range(self.n_layers_proj - 1)]
        self.projs2 = nn.Sequential(*self.projs2)
        
        self.projs3 = [nn.Linear(self.n_neurons, self.n_neurons_proj, bias = False)]
        self.projs3 += [nn.Linear(self.n_neurons_proj, self.n_neurons_proj, bias = False) for k in range(self.n_layers_proj - 1)]
        self.projs3 = nn.Sequential(*self.projs3)
        
        self.projs4 = [nn.Linear(self.n_neurons, self.n_neurons_proj, bias = False)]
        self.projs4 += [nn.Linear(self.n_neurons_proj, self.n_neurons_proj, bias = False) for k in range(self.n_layers_proj - 1)]
        self.projs4 = nn.Sequential(*self.projs4)
        
        self.projs = [self.projs1, self.projs2, self.projs3, self.projs4]
        
        self.f = nn.ReLU()
        self.dropout = torch.nn.Dropout(p=0.5)
        
        #put model on GPU is available and asked
        if args.device >= 0 and torch.cuda.is_available():
           # device = torch.device("cuda:"+str(args.device)+")")
            device = torch.device(args.device)
            self.cuda = True
        else:
            device = torch.device("cpu")
            self.cuda = False

        self.device = device
        self = self.to(device)

        self.losses = [BarlowTwinsLoss(self.device, args.lambda_bt) for idx in range(self.n_layers)]
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.lr, betas=(0.9, 0.999))
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=args.gamma_optimizer, last_epoch=-1, verbose=False)
    
    def forward(self, x, train_layer, perturbation):
        '''
         for the forward pass, we store the gradient of each layer, but we use detach() between each layer so the computational graph 
         is only at the layer-level
         and we store the activations before the normalization
         we send the layer-norm activation vector to the next layer (we normalize along the 2 axis)
         we compute the layer-wise loss at each layer during the forward pass that allows us to use the .detach() features
         the two views of the input are treated in the same mini-batch and we split the mini-batch for computing the SSL loss
        '''
        
        loss = 0
        for idx, fc in enumerate(self.layers):
            if idx != train_layer: #if we don't train that layer with SPSA, we use the standard MVM of pytorch
                # 1. compute forward pass for every layer
                x = self.f(fc(x)) 
            
            # 2. compute the SSL loss of that layer
            elif idx == train_layer:
                #1. first do the MVM + non-linearity
                x = self.f(torch.matmul(x, (self.layers[idx].weight + perturbation).t()))
                
                # 3. feed to a 3 layers MLP (non-linear)
                y = x
                
                for _, proj in enumerate(self.projs[idx][:-1]): # 4. iterate on the layers of the projector except the last one (that has no Relu)
                    y = self.f(proj(y))
                y = self.projs[idx][-1](y)
                
                y = self.multi_views_proj(y) # No Relu on the last layer
                y0, y1 = y[0], y[1]
                
                loss += self.losses[idx](y0, y1) # 5. Compute the SSL loss
                break #break if we have computed the loss: avoid to compute the next layers if we don't train them

            x = x.detach() # 6. detach the activations vector to have a layer-wise computational graph here

        return loss
    
    
    def forward_simple(self, x):
        '''
        Forward pass without computing the loss - 
        '''
        loss = 0
        states = []
        for idx, fc in enumerate(self.layers):
            x = self.f(fc(x)) 
            states.append(x)
            x = x.detach()

        return states
    
        
    def single_batch(self, x):
        '''
        return a single big batch given the two views of the input data
        '''
        x = torch.stack(x) #here data as a new first dimension which is the number of views for the same input image
        x = x.view(x.size()[0]*x.size()[1], -1) #we change the first dim to be n_views*batch_size
        
        return x
    
    
    def multi_views(self, x):
        '''
        return a view of the tensor that has first dim n_views, second dim batch_size and third dim the layers dimension
        '''
        
        return x.view(self.n_views, -1, self.n_neurons)


    def multi_views_proj(self, x):
        '''
        return a view of the tensor that has first dim n_views, second dim batch_size and third dim the layers dimension
        '''
        
        return x.view(self.n_views, -1, self.n_neurons_proj)
    
    
    def reset_perturbation_step(self):
        '''
        Reset to initial_step the perturbation step
        '''
        
        self.step = self.initial_step
        
        return 0
    
    
    def update_perturbation_step(self):
        '''
        Do one step of exponential decay of the perturbation step
        '''
        
        self.step *= self.gamma_perturbation
        
        return 0
    
    
    def generate_perturbation(self, layer = 0):
        '''
        Generate a random perturbation "vector" - in fact a matrix - for the weights indicated by the layer index
        1. Generate a random vector of -1/1
        2. Scale down the vector by self.step
        '''
        
        perturb = torch.randint(low=0, high = 2, size = self.layers[layer].weight.size()) #randint between 0 and 1 - 2 is exclusive
        perturb = self.step * (2*(perturb-0.5)) #scale the perturbation vector to -1/+1 and scale down by the current step value
        
    
        return perturb.to(self.device)
    
    
    def compute_spsa(self, perturbation, pos_obj, neg_obj, layer = 0):
        '''
        Compute the gradient for all the parameters with the SPSA rule
        g(obj|theta)=(obj(theta+ pertubation)-obj(theta-perturbation))/perturbation (SPSA)
        theta = theta - eta*g(obj|theta) (SGD)
        '''
        
        grad = (pos_obj-neg_obj)*(torch.ones(perturbation.size()).to(self.device)) #numerator - scaled to the size of the weight matrix
        grad /= perturbation
        
        assert grad.size() == self.layers[layer].weight.size(), "the size of the gradient computed does not match the size of the weight matrix"
            
        return grad
    
    
    def apply_spsa(self, grads, layer = 0):
        '''
        Apply the averaged gradients prescribed by SPSA to the respective weights
        '''
        #add +grad because we compute the gradient of the objective but we want -grad? maybe not - pytorch return the grad, the optimier then do -grad?
        self.layers[layer].weight.grad = grads #fill in the grad attribute of the specific weight to use the optimizer with the integrated scheduler for the learning rate
    
        return 0

# Generate simulations environment + data folders

In [None]:
BASE_PATH = createPath(archi = "MLP-SPSA-BarlowTwins", dataset = "MNIST")

In [None]:
saveHyperparameters(args, BASE_PATH)

In [None]:
dataframe = initDataframe_pretraining(BASE_PATH, dataframe_to_init = 'pre_training.csv')

In [None]:
net = Network(args)

In [None]:
net, pretraining_loss = pretraining_loop(BASE_PATH, args, net, train_loader_pretrain, train_loader_classifier, test_loader, epochs = 800)