In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision
import torch.nn.functional as F
import torch.nn as nn

In [None]:
# Data and target transfomations
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)
        
        
class ReshapeTransformTarget:
    def __init__(self, number_classes):
        self.number_classes = number_classes
    
    def __call__(self, target):
        target=torch.tensor(target).unsqueeze(0).unsqueeze(1)
        target_onehot = torch.zeros((1,self.number_classes))      
        return target_onehot.scatter_(1, target, 1).squeeze(0)

    
class ContrastiveTransformations(object):
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms #random transformations
        self.n_views = n_views # number of differents copies with different 

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

# Train a MLP with Backprop first and try to fit the parameters of the VicREG objective

In [None]:
data_aug = True

In [None]:
# dataloader for the supervised training
abatch_size = 256
batch_size_test = 512

transforms_supervised_train_data_aug =  transforms.Compose([torchvision.transforms.RandomRotation(degrees = 5, fill=0), #random rotation
                                               torchvision.transforms.RandomCrop((28,28), padding = 2), #random crop
                                               torchvision.transforms.RandomAffine(degrees=(0, 0), translate=(0.0, 0.0), scale=(0.9, 1.1)),
                                               torchvision.transforms.ToTensor(),
                                               ReshapeTransform((-1,))])
    
transforms_supervised_test = transforms.Compose([torchvision.transforms.ToTensor(),
                                       ReshapeTransform((-1,))])


if data_aug:
    train_loader_supervised = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = transforms_supervised_train_data_aug,
                                target_transform=ReshapeTransformTarget(10)), batch_size = batch_size, shuffle=True)
    
else:
    train_loader_supervised = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                transform = transforms_supervised_test,
                                target_transform=ReshapeTransformTarget(10)), batch_size = batch_size, shuffle=True)

test_loader_supervised = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(root='./data', train=False, download=True,
                            transform = transforms_supervised_test,
                            target_transform=ReshapeTransformTarget(10)), batch_size = batch_size_test, shuffle=False)

In [None]:
class Network(nn.Module):
    ''' 
    Define the network used
    '''
    def __init__(self, run_gpu):

        super(Network, self).__init__()
        
        self.n_neurons = 1000
        self.n_layers = 4
        
        self.layers = [nn.Linear(784, self.n_neurons, bias = False)]
        self.layers += [nn.Linear(self.n_neurons, self.n_neurons, bias = False) for k in range(self.n_layers-1)]
        self.layers = nn.Sequential(*self.layers)
        
        self.classifier = nn.Linear(self.n_neurons, 10, bias = False)
        
        self.f = nn.ReLU()
        
        if run_gpu >= 0 and torch.cuda.is_available():
            device = torch.device(run_gpu)
            self.cuda = True
        else:
            device = torch.device("cpu")
            self.cuda = False

        self.device = device
        self = self.to(device)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.999))
    
    
    def forward(self, x):
        #simple forward pass of a MLP
        
        for idx, fc in enumerate(self.layers):
            x = self.f(fc(x)) 
        x = self.classifier(x)
        return x
    
    
    def forward_simple(self, x):
        '''
        Forward pass that stores the state of each layer during inference, discard the linear classifier
        '''
        with torch.no_grad():
            states = []
            for idx, fc in enumerate(self.layers):
                #1. compute forward pass for every layer
                x = self.f(fc(x)) 
                states.append(x)

        return states

In [None]:
def train(net, train_loader):
    '''
    Train the network for 1 epoch
    '''
    net.train()
    criterion = nn.CrossEntropyLoss()
    error, loss_tot = 0, 0
    
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        net.optimizer.zero_grad()
        data, target  = data.to(net.device), target.to(net.device)
       
        y = net(data) 
        loss = criterion(y, torch.argmax(target, dim = 1))
        loss.backward()
        net.optimizer.step()
        
        loss_tot += loss.item()
        del loss
        error += (torch.argmax(y, dim =1) != torch.argmax(target, dim =1)).sum()

    return net,(error/len(train_loader.dataset))*100, loss_tot/len(train_loader.dataset)


def test(net, test_loader):
    '''
    Train the network for 1 epoch
    '''
    net.eval()
    criterion = nn.CrossEntropyLoss()
    error, loss_tot = 0, 0
    
    for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
        net.optimizer.zero_grad()
        data, target  = data.to(net.device), target.to(net.device)
        
        y = net(data) 
        loss = criterion(y, torch.argmax(target, dim = 1))
        loss_tot += loss.item()
        del loss
        
        error += (torch.argmax(y, dim = 1) != torch.argmax(target, dim = 1)).sum()

    return net,(error/len(test_loader.dataset))*100, loss_tot/len(test_loader.dataset)

In [None]:
net = Network(0)

In [None]:
if data_aug:
    n_epochs = 100
else:
    n_epochs = 50

In [None]:
train_error, test_error, train_loss, test_loss = [], [], [], []

for k in range(n_epochs):
    net, err, loss = train(net, train_loader_supervised)
    train_error.append(err.item())
    train_loss.append(loss)
    
    net, err, loss = test(net, test_loader_supervised)
    test_error.append(err.item())
    test_loss.append(loss)

In [None]:
if data_aug:
    torch.save(net.state_dict(), "Models/checkpoint_data_augmentations.pt")
else:
    torch.save(net.state_dict(), "Models/checkpoint.pt")