## Simple Self-supervised Learning

Self-supervised learning is a 2 phases learning technique:  
- first phase is destined to train the model on a task where we can implement an automatic way to labeled the dataset  
  
- second phase is intended to accomplish the real task. Thanks to the first phase, we start with pre-trained weights that might be closer to the values we'll obtained while during training. Moreover theoritically, we might need less labeled data to train our model  

To simplify, we can see the technique as a form of fine-tuning.

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchsummary import summary

import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset

import random

import matplotlib.pyplot as plt 

## Model

In [2]:
class SelfSupervisedMNIST(nn.Module):

    def __init__(self):
        super(SelfSupervisedMNIST, self).__init__()
        self.encoder = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=1),
                                      nn.MaxPool2d(2),
                                      nn.ReLU(),
                                      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=0),
                                      nn.MaxPool2d(2, padding=1),
                                      nn.ReLU(),
                                      nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, padding=0),
                                      nn.ReLU()
                                    )
        self.classification = nn.Sequential( nn.Linear(in_features=256, out_features=64),
                                             nn.ReLU(),
                                             nn.Linear(64, 4),
                                             nn.Softmax(dim=1)
                                           )

    def forward(self, x):

        x = self.encoder(x)
        x = x.flatten(1)
        x = self.classification(x)

        return x


class Trainer:

    def __init__(self, model, train_dataloader, test_dataloader, optimizer, loss_fn):#, metric):

        self.model            = model
        self.train_dataloader = train_dataloader
        self.test_dataloader  = test_dataloader
        self.optimizer        = optimizer
        self.loss_fn          = loss_fn
        #self.metric           = metric

    def train_step(self, device):

        epoch_loss = 0.0
        self.model.train()

        for i, (images, targets) in enumerate(self.train_dataloader, 0):
        
            #Data send to device + requires_grad=True
            images, targets = images.requires_grad_().to(device), targets.to(device)
            #Zero the gradient 
            self.optimizer.zero_grad()
            #Predictions 
            outputs = self.model(images)
            #Loss
            epoch_loss = self.loss_fn(outputs, targets)
            #Upgrade the gradients (backpropagate) and the optimizer
            epoch_loss.backward()
            self.optimizer.step()

            #self.check_layers_values()

        return epoch_loss

    def test_step(self, device):

        list_loss       = []
        nof_predictions = 0.0
        epoch_accuracy  = 0.0

        self.model.eval()

        with torch.no_grad():
            for i, (images, targets) in enumerate(self.test_dataloader, 0):

                images, targets = images.to(device), targets.to(device)
                outputs = self.model(images)#.squeeze(1)
                epoch_loss = self.loss_fn(outputs, targets)

                _, predictions = torch.max(outputs.data, 1)
                nof_predictions += targets.size(0)

                epoch_accuracy += (predictions == targets).sum().item()
                
            #Compute the accuracy over the test set
            epoch_accuracy = (100*epoch_accuracy/nof_predictions)

        return epoch_accuracy

    def train_model(self, 
                    nof_epochs, batch_size, learning_rate, 
                    file_path_save_model, save_epoch_path,
                    train_loss_name, accuracy_name, test_loss_name,
                    best_accuracy_is_maximal = False,
                    device='cuda:0'):

        print("Starting training...\n")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("The model will be running on", device, "device.\n")
        
        self.model.to(device)
        best_accuracy = 0.0

        for epoch in range(1, nof_epochs+1):

            epoch_accuracy   = 0.0
            train_epoch_loss = 0.0
            test_epoch_loss  = 0.0

            #Training
            train_epoch_loss = self.train_step(device)
            #Validation
            epoch_accuracy   = self.test_step(device)
        
            print(f'Epoch: {epoch}, {train_loss_name}: {train_epoch_loss}, {accuracy_name}: {epoch_accuracy}%')
            
            #Save model when best accuracy is beaten
            if best_accuracy_is_maximal:
                if epoch_accuracy > best_accuracy:
                    save_epoch_path = str(epoch) + "best_accuracy.pth"
                    self.save_model(self.model, save_epoch_path)
                    best_accuracy = epoch_accuracy
            else:
                if epoch_accuracy < best_accuracy:
                    save_epoch_path = str(epoch) + "best_accuracy.pth"
                    self.save_model(save_epoch_path)
                    best_accuracy = epoch_accuracy                

        # Saving the model
        print('Saving the model...\n')
        self.model = self.model.to('cpu')
        self.save_model(file_path_save_model)

        print("Training finish.\n") 

        return self.model

    def save_model(self, file_path_save_model):
        torch.save(self.model.state_dict(), file_path_save_model)

    def load_model(self, file_path_to_model, device):
        state_params = torch.load(file_path_to_model)
        self.model.load_state_dict(state_params)

    def check_layers_values(self):
        for name, param in self.model.named_parameters():
            print(name, param.grad)


## Data

In [3]:
class DataLoaderBuilderFromList:

    def __init__(self, X_train, y_train, X_test, y_test):

        if not all(list_ for list_ in [X_train, y_train, X_test, y_test]):
            raise ValueError("X or y mustn't be empty")

        self.X_train = X_train
        self.X_test  = X_test
        self.y_train = y_train
        self.y_test  = y_test


    def create_dataloaders(self, transform=None, batch_size=32, shuffle=True, type=torch.float32):

        # Convert to tensor
        X_train = torch.tensor(self.X_train, dtype=torch.float32)
        y_train = torch.tensor(self.y_train, dtype=torch.float32)

        X_test = torch.tensor(self.X_test, dtype=torch.float32)
        y_test = torch.tensor(self.y_test, dtype=torch.float32)

        # Apply transforms if present
        if transform is not None:
            X_train = self.transform(X_train)
            X_test  = self.transform(X_test)

        # Create dataloader
        train_dataset = TensorDataset(transform(self.X_train), self.y_train)
        test_dataset  = TensorDataset(transform(self.X_test), self.y_test)
        
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
        test_dataloader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

        return train_dataloader, test_dataloader


class DataLoaderBuilderFromMNIST:

    def __init__(self, dataloader):
        self.X  = dataloader.data
        self._y = dataloader.targets

    @property
    def X(self):
        return self._X

    @X.setter
    def X(self, new_X_tensor):
        self._X = new_X_tensor

    @property
    def y(self):
        return self._y

    @y.setter
    def y(self, angles):
        self._y = angles

    def random_angle(self, angles_list):
        """
            Generate a random degrees from angles list and return a tensor of size _size that contains all of them
        """
        random_idx_angles = torch.randint(low=0, high=4, size=(self.X.data.shape[0],), dtype=torch.long)
        
        return random_idx_angles

    def rotator_images(self, angles):

        batch_size = self.X.shape[0]
        rotated_images = torch.zeros_like(self.X.unsqueeze(dim=1))

        for i in range(batch_size):

            image = self.X[i]
            angle = angles[i].item()

            if image.ndimension() == 2:
                image = image.unsqueeze(0)

            rotated_images[i] = transforms.functional.rotate(image, angle)

        return rotated_images

    def create_dataloaders(self, transform=None, batch_size=32, shuffle=True, _type=torch.float32):

        self._X = torch.tensor(self._X, dtype=torch.float32)
        self._y = torch.tensor(self._y, dtype=torch.long)

        # Apply transforms if present
        if transform is not None:
            self._X = self.transform(self._X)

        # Create dataloader
        dataset = TensorDataset(self._X, self._y)
        
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

        return dataloader

## Utils

In [4]:
def displayConvFilers(model, 
                      layer_name,
                      optimizer, epoch,
                      figsize=(2,2),
                      suptitle=None, savefig=True
                      ):

    layer = model.state_dict()[layer_name]
    n_filters, n_channels, height, width = layer.shape
    total_filters = n_filters * n_channels

    fig = plt.figure(figsize=figsize)

    # Loop through each filter in the layer
    for batch_idx in range(n_filters):

        for channel_idx in range(n_channels):

            filter = layer[batch_idx][channel_idx].cpu()
            subplot_index = batch_idx * n_channels + channel_idx + 1
            
            ax = plt.subplot(n_filters, n_channels, subplot_index)
            plt.imshow(filter, cmap='gray')
            
            ax.set_yticks([])
            ax.set_xticks([])
            
            # Label the y-axis with the batch number
            if (channel_idx==0):
                ax.set_ylabel("Batch #{}".format(batch_idx+1), fontsize=20)
                
            # Label the x-axis with the channel number
            if batch_idx == (n_filters - 1):
                ax.set_xlabel("Channel #{}".format(channel_idx+1), fontsize=20)
    
    #Layout of the figure
    plt.tight_layout(rect=[0, 0.03, 1, 0.92])

    if suptitle == None:
        filename = f"{model.__class__.__name__}_{layer_name}_filters_{n_filters}_channels_{n_channels}_height_{height}_width_{width}.pdf"
        fig.suptitle(f"Visualization of layer's filters on {model.__class__.__name__} model (unpruned)\n"
                     f"(Model characteristics - optimizer: {optimizer.__class__.__name__}, learning rate: {optimizer.state_dict()['param_groups'][0]['lr']}, number of epochs: {epoch})\n\n"
                     f"Layer name: {layer_name}, Filters' size: ({height}x{width})",
                     fontsize=25)
        
    else:
        filename = f"{suptitle}.pdf"
        fig.suptitle(suptitle, fontsize=25)

    #Save figure
    if savefig and suptitle == None:
        plt.savefig(str(model.__class__.__name__) + "_" + layer_name + "_" 
                    + "batch_" + str(n_filters) + "_" 
                    + "channels_" + str(n_channels)  + "_" 
                    + "height" + str(height) + "_" 
                    + "width_" + str(width)  
                    + ".pdf"
                    )
    else:
        plt.savefig(suptitle + ".pdf")
    
    #Show figure
    plt.show()


## Main

In [8]:
def main():

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    #If model need to been trained turn train to True and load to False
    train = True
    #If model need to be loaded from a pth file, turn train to False and load to True
    load = False

    # Model
    sslModel = SelfSupervisedMNIST()
    #summary(sslModel, (1,28,28))

    #Data
    transform = transforms.Compose([transforms.CenterCrop(28),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=0, std=0.2)
                                    ])
                                    
    # Don't forget to take a pourcentage of data
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
    mnist_testset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  

    # Phase 1 : we train our model on a pre-task

    #### Data
                                    
    # Train dataset
    # Applying the rotation and creating new y tensors
    train_data_loader_builder = DataLoaderBuilderFromMNIST(mnist_trainset)
    # y
    angles                = train_data_loader_builder.random_angle([0, 90, 180, 270])
    train_data_loader_builder.y = angles
    # X
    rotated_images        = train_data_loader_builder.rotator_images(angles)
    train_data_loader_builder.X = rotated_images

    train_dataloader = train_data_loader_builder.create_dataloaders(transform=None, batch_size=32, shuffle=True, _type=torch.float32)
    
    #Test dataset
    # Applying the rotation and creating new y tensors
    test_data_loader_builder = DataLoaderBuilderFromMNIST(mnist_testset)
    # y
    angles                     = test_data_loader_builder.random_angle([0, 90, 180, 270])
    test_data_loader_builder.y = angles
    # X
    rotated_images             = test_data_loader_builder.rotator_images(angles)
    test_data_loader_builder.X = rotated_images

    test_dataloader = test_data_loader_builder.create_dataloaders(transform=None, batch_size=32, shuffle=True, _type=torch.float32)
    

    ### Trainer
    model     = sslModel
    lr        = 1e-6 #1e-5 
    optimizer = torch.optim.Adam(sslModel.parameters(), lr)
    loss_fn   = nn.CrossEntropyLoss()
    trainer_phase1 = Trainer(model, train_dataloader, test_dataloader, optimizer, loss_fn)

    nof_epochs = 100
    batch_size = test_dataloader.batch_size
    file_path_save_model = "./saving/train1_trained_model.pth"
    save_epoch_path = "./saving/train1_best_accuracy.pth"    
    train_loss_name = "Cross Entropy Loss"
    accuracy_name = "accuracy"
    test_loss_name = train_loss_name

    sslModel = trainer_phase1.train_model(nof_epochs, batch_size, lr, 
                                            file_path_save_model, save_epoch_path,
                                            train_loss_name, accuracy_name, test_loss_name,
                                            best_accuracy_is_maximal = True,
                                            device='cuda:0')

    ### Plot intermediate result


    # Phase 2 : we train our model on a real task with few data 

    ### Data
    #pourcentage_of_data

    ### Trainer
    #trainer_phase2 = Trainer()

    ### Plot result


In [9]:
if __name__=="__main__":

    main()

  self._X = torch.tensor(self._X, dtype=torch.float32)
  self._y = torch.tensor(self._y, dtype=torch.long)


Starting training...

The model will be running on cuda:0 device.

Epoch: 1, Cross Entropy Loss: 1.4061832427978516, accuracy: 25.06%
Epoch: 2, Cross Entropy Loss: 1.3651574850082397, accuracy: 26.06%
Epoch: 3, Cross Entropy Loss: 1.4202924966812134, accuracy: 27.12%
Epoch: 4, Cross Entropy Loss: 1.3680226802825928, accuracy: 28.2%
Epoch: 5, Cross Entropy Loss: 1.3648046255111694, accuracy: 29.16%
Epoch: 6, Cross Entropy Loss: 1.3575501441955566, accuracy: 29.7%
Epoch: 7, Cross Entropy Loss: 1.3402092456817627, accuracy: 29.88%
Epoch: 8, Cross Entropy Loss: 1.3597891330718994, accuracy: 30.26%
Epoch: 9, Cross Entropy Loss: 1.3571747541427612, accuracy: 31.18%
Epoch: 10, Cross Entropy Loss: 1.3575960397720337, accuracy: 31.27%
Epoch: 11, Cross Entropy Loss: 1.4070438146591187, accuracy: 31.21%
Epoch: 12, Cross Entropy Loss: 1.4271094799041748, accuracy: 31.96%
Epoch: 13, Cross Entropy Loss: 1.3798645734786987, accuracy: 32.15%
Epoch: 14, Cross Entropy Loss: 1.3105170726776123, accuracy:

## Conclusion

During training, we noted that to much angles (aka classes) to predict provokes a vanishing gradient. The extrem limit is when we try to predict 360 classes.
When we test with four classes, with learning rate set to 1e-6, we achieve with 50 epochs an accuracy of 39%.