<a href="https://colab.research.google.com/github/leot13/BarlowTwins/blob/main/Barlow_Twins_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installing required libraries


In [None]:
!pip install -U albumentations -q
!pip install wandb --upgrade -q

## Importing dependencies and dataset

In [None]:
#Login to wandb. Get the API key in wandb settings
!wandb login --relogin

In [None]:
import wandb
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import config 
from training import train_BT, train_FT
from model import (BarlowTwins, loss_fun, BarlowTwins_FT)
from utils import (CIFARDataset, makeTransforms, makeTransforms_Fine_Tuning,
                   compute_accuracy, save_checkpoint, load_checkpoint)

In [None]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)
valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=None)

## Self-supervised learning set-up and training

In [None]:
#Get the transforms
train_transform1, train_transform2 = makeTransforms(config.IMG_HEIGHT, config.IMG_WIDTH)

#Create the datasets and dataloaders using the transforms
train_data = CIFARDataset(trainset, transform1= train_transform1, transform2=  train_transform2)
val_data = CIFARDataset(valset, transform1= train_transform1, transform2=  train_transform2)

train_loader = DataLoader(train_data, batch_size = config.BATCH_SIZE, shuffle= True )
val_loader = DataLoader(val_data, batch_size = config.BATCH_SIZE, shuffle= True )

In [None]:
#Setting up all the model's parameters
model = BarlowTwins(config.IN_FEATURES, config.Z_DIM).to(config.DEVICE)
optimizer = optim.Adam(model.parameters(),lr = config.LR) 
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
scaler = torch.cuda.amp.GradScaler()

if config.LOAD_CHECKPOINT:
  load_checkpoint(config.CHECKPOINT_FILENAME, model, optimizer, lr= config.LR)
  model = model.to(config.DEVICE)

#Start the run on wandb. Here the entity should be your wandb name
wandb.init(project=config.PROJECT_NAME, entity="tronchonleo")
wandb.watch(model, loss_fun, log="all", log_freq=10)

In [None]:
best_val_loss = float('inf')

for epoch in range(config.NUM_EPOCHS):

  #Train and return losses
  loss, avg_val_loss = train_BT(train_loader, val_loader, model, optimizer, config.DEVICE, scaler, config.LAMBDA)

  #Display results
  print(f"Epoch: {epoch}, Loss: {loss.item()} Val Loss: {avg_val_loss.item()}")
  wandb.log({"loss": loss.item(), "val_loss": avg_val_loss.item(), "lr": optimizer.param_groups[0]['lr']})

  #Update learning rate scheduler
  scheduler.step(avg_val_loss)

  #When the model improves, save checkpoint and update the best validation loss 
  if (config.SAVE_CHECKPOINT and best_val_loss > avg_val_loss.item()):
    best_val_loss = avg_val_loss.item()
    save_checkpoint(model, optimizer, filename=config.CHECKPOINT_FILENAME)

## Evaluation set-up an training


In [None]:
#Get the fine tuning transforms
train_transform1, train_transform2 = makeTransforms_Fine_Tuning(config.IMG_HEIGHT, config.IMG_WIDTH)

#Create the datasets and dataloaders using the transforms
train_data = CIFARDataset(trainset, transform1= train_transform1, transform2=  train_transform2)
val_data = CIFARDataset(valset, transform1= train_transform1, transform2=  train_transform2)

train_loader = DataLoader(train_data, batch_size = config.BATCH_SIZE, shuffle= True )
val_loader = DataLoader(val_data, batch_size = config.BATCH_SIZE, shuffle= True )

In [None]:
barlow_twins = BarlowTwins(config.IN_FEATURES, config.Z_DIM)
bt_optimizer = optim.Adam(barlow_twins.parameters(), lr = config.LR) 
scaler = torch.cuda.amp.GradScaler()

#Load the self-supervised BarlowTwins and create the fine tuning model
load_checkpoint(config.CHECKPOINT_FILENAME, barlow_twins, bt_optimizer, lr= config.LR)
ft_model = BarlowTwins_FT(barlow_twins, config.Z_DIM, num_cat= config.NUM_CAT).to(config.DEVICE)
optimizer = optim.Adam(ft_model.parameters(), lr = config.LR) 
criterion = nn.CrossEntropyLoss()

#Freeze the BarlowTwins' parameters
ft_model.bt.requires_grad_(False)
ft_model.linear.requires_grad_(True)

#Start the run on wandb. Here the entity should be your wandb name
wandb.init(project=config.PROJECT_NAME, entity="tronchonleo")
wandb.watch(ft_model, criterion, log="all", log_freq=10)

In [None]:
best_accuracy = 0

for epoch in range(config.FT_NUM_EPOCHS):

  #Train model. Return losses and accuracy 
  loss, val_loss, val_accuracy = train_FT(train_loader, val_loader, ft_model, optimizer, criterion, config.DEVICE, scaler, config.LAMBDA )

  #Display results
  print(f"Loss epoch {epoch}: ", loss.item())
  print(f"Validation Loss epoch {epoch}: ", val_loss.item())
  print(f"Validation Accuracy epoch {epoch}: ", val_accuracy)
  wandb.log({"loss": loss.item(), 
             "epoch": epoch,
             "val_loss": val_loss.item(),
             "val_accuracy": val_accuracy}
            )
  
  #Save best fine-tuned model
  if (config.SAVE_CHECKPOINT and val_accuracy > best_accuracy):
    best_accuracy = val_accuracy
    save_checkpoint(ft_model, optimizer, filename= config.CHECKPOINT_FT_FILENAME)