In [1]:
import torch
import pathlib
import numpy as np
import torchvision.transforms as transforms

from tqdm import tqdm as tqdm
from simclr import SimCLR
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from flash.core.optimizers import LARS


from torch.utils.tensorboard import SummaryWriter
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules import NT_Xent


In [2]:
image_size = 224
batch_size = 128

# To make it work in both Jupyter and standalone:
if "__file__" in globals():
    root = pathlib.Path(__file__).parent.resolve()
else:
    # Probably running interactively; in Jupyter, notebook path is
    # typically 'os.getcwd()', if it's not that's where we are going
    # to store the CIFAR data.
    import os
    root = pathlib.Path(os.getcwd())
    
    
dataset = CIFAR10(root=root, download=True, transform = TransformsSimCLR(size = image_size))
torch.manual_seed(43)
train_loader = DataLoader(dataset, 
                          batch_size, 
                          shuffle=False,
                          drop_last = True,
                          num_workers=2,
                          sampler = None)

train_dataset = CIFAR10(root = root, transform = transforms.ToTensor())

loader = DataLoader(train_dataset, 
                          batch_size, 
                          shuffle=False,
                          drop_last = True,
                          num_workers=2,
                          sampler = None)

Files already downloaded and verified


In [3]:
global_step = 0
epochs = 50

encoder = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False) 
projection_dim = 64
n_features = encoder.fc.in_features  # get dimensions of last fully-connected layer
model = SimCLR(encoder, projection_dim, n_features)

Using cache found in C:\Users\Shulu/.cache\torch\hub\pytorch_vision_v0.10.0


In [4]:
optimizer = LARS(model.parameters(), lr = 0.075 * np.sqrt(batch_size), weight_decay = 1e-6)
criterion = NT_Xent(batch_size, temperature = 0.2, world_size=1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                       epochs, 
                                                       eta_min=0, 
                                                       last_epoch=-1)

In [5]:
writer = SummaryWriter()

In [9]:
def train(global_step, loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for steps, ((i, j), _) in enumerate(loader):
    
        optimizer.zero_grad()
        h_i, h_j, z_i, z_j = model(i, j)
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()

        if steps % 50 == 0:
            print(f"Step [{steps}/{len(loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), global_step)
        loss_epoch += loss.item()
        global_step += 1
    return loss_epoch

for epoch in tqdm(range(epochs)):
    loss_epoch = train(global_step, train_loader, model, criterion, optimizer, writer)
    scheduler.step()
    writer.add_scalar("Loss/train", loss_epoch / len(loader), epoch)
    print(
        f"Epoch [{epoch}/{epochs}]\t Loss: {loss_epoch / len(train_loader)}\t"
    )

  0%|                                                                                           | 0/50 [00:00<?, ?it/s]

Step [0/390]	 Loss: 5.498638153076172


  0%|                                                                                           | 0/50 [11:25<?, ?it/s]


KeyboardInterrupt: 