In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import numpy as np

import torch

from torch.utils.data import  DataLoader
from torchvision import datasets, transforms

from dataset import ContrastiveDataset
from optimizer import LARS
from loss import NT_Xent
from model import ContrastiveLearningModel


import matplotlib.pyplot as plt
%matplotlib inline

In [12]:
seed = 42 
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x148b553e22b0>

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Dataloader for self-supervised case

In [14]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

cifar_train = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
cifar_test = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

train_img_array = np.array([np.array(image) for image, _ in cifar_train])
test_img_array = np.array([np.array(image) for image, _ in cifar_test])

Files already downloaded and verified
Files already downloaded and verified


In [15]:
train_dataset = ContrastiveDataset("train", train_img_array[:40000])
val_dataset = ContrastiveDataset("val", train_img_array[40000:])
test_dataset = ContrastiveDataset("test", test_img_array)

In [16]:
ssl_batch_size = 200
num_workers = 0 # means no sub-processes, needed for debugging
train_dataloader = DataLoader(
    train_dataset, batch_size=ssl_batch_size, shuffle=True, num_workers=num_workers
)
val_dataloader = DataLoader(
    val_dataset, batch_size=ssl_batch_size, shuffle=False, num_workers=num_workers
)
test_dataloader = DataLoader(
    test_dataset, batch_size=ssl_batch_size, shuffle=False, num_workers=num_workers
)

In [17]:
model = ContrastiveLearningModel().to(device)
optimizer = LARS(
    [params for params in model.parameters() if params.requires_grad],
    lr=0.2,
    weight_decay=1e-6,
    exclude_from_weight_decay=["batch_normalization", "bias"],
)
criterion = NT_Xent(batch_size=ssl_batch_size, temperature=0.5)

In [18]:
import time

num_epochs = 6
for epoch in range(num_epochs): 

  start = time.time()
  model.train()
  training_loss = 0
  for (x_i, x_j) in train_dataloader: 
    optimizer.zero_grad()
    x_i, x_j = x_i.to(device), x_j.to(device)

    z_i = model(x_i)
    z_j = model(x_j)

    loss = criterion(z_i, z_j)
    loss.backward()

    optimizer.step()
    training_loss += loss.item()
  
  training_loss /= len(train_dataloader)
  
  model.eval()
  with torch.no_grad(): 
    validation_loss = 0
    for (x_i, x_j) in val_dataloader: 
      x_i, x_j = x_i.to(device), x_j.to(device)

      z_i = model(x_i)
      z_j = model(x_j)

      loss = criterion(z_i, z_j)
      validation_loss += loss.item()

    validation_loss /= len(val_dataloader)
    
    end = time.time()
    
  print(f"Epoch #{epoch+1}, training loss: {training_loss}, validation loss: {validation_loss}, time: {end - start:.2f}")

Epoch #1, training loss: 4.979559516906738, validation loss: 4.334469223022461, time: 187.00
Epoch #2, training loss: 4.431811802387237, validation loss: 4.177019290924072, time: 187.85
Epoch #3, training loss: 4.333452017307281, validation loss: 4.150139713287354, time: 190.48
Epoch #4, training loss: 4.296066000461578, validation loss: 4.118095951080322, time: 188.43
Epoch #5, training loss: 4.27168850183487, validation loss: 4.108700761795044, time: 187.46
Epoch #6, training loss: 4.254152755737305, validation loss: 4.105760011672974, time: 187.84


In [19]:
# Saving the model to file
model_path = "models/encoder.pth" 
torch.save(model.encoder, model_path)