In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

import torch

from torch.utils.data import  DataLoader
from torchvision import datasets, transforms

from dataset import ContrastiveDataset
from optimizer import LARS
from loss import NT_Xent
from model import BackboneModel


import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
seed = 42 
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1117856b0>

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Dataloader for self-supervised case

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

cifar_train = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
cifar_test = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

train_img_array = np.array([np.array(image) for image, _ in cifar_train])
test_img_array = np.array([np.array(image) for image, _ in cifar_test])

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_dataset = ContrastiveDataset("train", train_img_array[:40000])
val_dataset = ContrastiveDataset("val", train_img_array[40000:])
test_dataset = ContrastiveDataset("test", train_img_array)

In [7]:
ssl_batch_size = 128
num_workers = 0 # means no sub-processes, needed for debugging
train_dataloader = DataLoader(
    train_dataset, batch_size=ssl_batch_size, shuffle=True, num_workers=num_workers
)
val_dataloader = DataLoader(
    val_dataset, batch_size=ssl_batch_size, shuffle=False, num_workers=num_workers
)
test_dataloader = DataLoader(
    test_dataset, batch_size=ssl_batch_size, shuffle=False, num_workers=num_workers
)

In [8]:
model = BackboneModel().to(device)
optimizer = LARS(
    [params for params in model.parameters() if params.requires_grad],
    lr=0.2,
    weight_decay=1e-6,
    exclude_from_weight_decay=["batch_normalization", "bias"],
)
criterion = NT_Xent(batch_size=ssl_batch_size, temperature=0.5)

In [10]:
num_epochs = 1
for epoch in range(num_epochs): 

  model.train()
  training_loss = 0
  for (x_i, x_j) in train_dataloader: 
    optimizer.zero_grad()
    x_i, x_j = x_i.to(device), x_j.to(device)

    z_i = model(x_i)
    z_j = model(x_j)

    loss = criterion(z_i, z_j)
    loss.backward()

    optimizer.step()
    training_loss += loss.item()
    break # TODO: Remove this when actually training
  
  training_loss /= len(train_dataloader)
  
  model.eval()
  with torch.no_grad(): 
    validation_loss = 0
    for (x_i, x_j) in val_dataloader: 
      x_i, x_j = x_i.to(device), x_j.to(device)

      z_i = model(x_i)
      z_j = model(x_j)

      loss = criterion(z_i, z_j)
      validation_loss += loss.item()
      break # TODO: Remove this when actually training

    validation_loss /= len(val_dataloader)
    
    print(f"Epoch #{epoch}, training loss: {training_loss}, validation loss: {validation_loss}")

print("Finished")



Epoch #0, training loss: 0.01732151043681672, validation loss: 0.0688595590712149
Finished


In [11]:
# Saving the model to file
model_path = "models/simclr.pth"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    model_path,
)