In [1]:
! nvidia-smi

Mon Nov 28 16:42:27 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:81:00.0 Off |                  N/A |
| 35%   31C    P0    50W / 250W |      0MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# imports

In [1]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f12d03ac138>

In [3]:
BATCH_SIZE = 128
EPOCHS = 200
LR = 3e-4
TEMP = 0.5

# model

In [4]:
class ResNet_Block(nn.Module):
    def __init__(self, in_chs, out_chs, strides):
        super(ResNet_Block, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                      stride=strides, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_chs, out_channels=out_chs,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs)
        )

        if in_chs != out_chs:
            self.id_mapping = nn.Sequential(
                nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                          stride=strides, padding=0, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_chs))
        else:
            self.id_mapping = None
        self.final_activation = nn.ReLU(True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.id_mapping is not None:
            x_ = self.id_mapping(x)
        else:
            x_ = x
        return self.final_activation(x_ + out)

class ResNet20Encoder(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64)):
        super(ResNet20Encoder, self).__init__()
        self.num_layers = num_layers
        self.head_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_stem_conv,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(num_stem_conv),
            nn.ReLU(True)
        )
        num_layers_per_stage = (num_layers - 2) // 6
        self.body_op = []
        num_inputs = num_stem_conv
        for i in range(len(config)):
            for j in range(num_layers_per_stage):
                if j == 0 and i != 0:
                    strides = 2
                else:
                    strides = 1
                self.body_op.append(ResNet_Block(num_inputs, config[i], strides))
                num_inputs = config[i]
        self.body_op = nn.Sequential(*self.body_op)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.final_fc = nn.Linear(config[-1], 10)

    def forward(self, x):
        out = self.head_conv(x)
        out = self.body_op(out)
        features = self.avg_pool(out)
        return features

    
class SimCLR(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64), projection_dim=20):
        super(SimCLR, self).__init__()
        self.encoder = ResNet20Encoder(num_layers=num_layers, num_stem_conv=num_stem_conv, config=config)
        self.linear1 = nn.Linear(config[-1], config[-1], bias=False)
        self.linear2 = nn.Linear(config[-1], projection_dim, bias=False)
        
    def forward(self, aug1, aug2):
        aug1_out = self.encoder(aug1).squeeze()
        aug1_out = self.linear1(aug1_out)
        aug1_out = nn.ReLU()(aug1_out)
        aug1_out = self.linear2(aug1_out)
        
        aug2_out = self.encoder(aug2).squeeze()
        aug2_out = self.linear1(aug2_out)
        aug2_out = nn.ReLU()(aug2_out)
        aug2_out = self.linear2(aug2_out)
        
        return aug1_out, aug2_out

# dataset

In [6]:
class SimCLRTransform:
    
    def __init__(self):
        
        color_jitter = torchvision.transforms.ColorJitter(
            0.4, 0.4, 0.4, 0.1
        )
        self.simclr_transform = torchvision.transforms.Compose(
            [
                transforms.RandomResizedCrop(size=(32, 32)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
            ]
        )

    def __call__(self, x):
        return self.simclr_transform(x), self.simclr_transform(x)
    
train_transform = SimCLRTransform()
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)

In [7]:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


# loss

In [65]:
def nt_xent(aug1_out, aug2_out, temperature=1.0):
    batch_size = aug1_out.shape[0]
    augs = torch.cat((aug1_out, aug2_out), dim=0)
    sim_f = nn.CosineSimilarity(dim=2)
#     cos_sim = sim_f(augs.unsqueeze(1), augs.unsqueeze(0))
#     cos_sim = torch.matmul(augs, augs.T) / (torch.linalg.norm(augs, dim=1, ord=2) * torch.linalg.norm(augs, dim=1, ord=2))
#     print(torch.linalg.norm(augs, dim=1, ord=2).shape)
    cos_sim = torch.matmul(torch.nn.functional.normalize(augs, dim=1) , torch.nn.functional.normalize(augs, dim=1).T)
    cos_sim /= temperature
    
    cos_sim_diag = torch.diag(cos_sim)
    
    loss = None
    for i in range(2 * batch_size):
        if i < batch_size:
            positive_idx = i + batch_size
        else:
            positive_idx = i - batch_size
        
        row_loss = -torch.log(torch.exp(cos_sim[i][positive_idx]) / 
                              (torch.sum(torch.exp(cos_sim[i])) - torch.exp(cos_sim_diag[i])))

        if loss == None:
            loss = row_loss
        else:
            loss += row_loss
    loss /= (2 * batch_size)
    return loss

In [66]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size, temperature, world_size):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.world_size = world_size

        self.mask = self.mask_correlated_samples(batch_size, world_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size, world_size):
        N = 2 * batch_size * world_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size * world_size):
            mask[i, batch_size * world_size + i] = 0
            mask[batch_size * world_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size * self.world_size

        z = torch.cat((z_i, z_j), dim=0)
        if self.world_size > 1:
            z = torch.cat(GatherLayer.apply(z), dim=0)

        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size * self.world_size)
        sim_j_i = torch.diag(sim, -self.batch_size * self.world_size)

        # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss
    
official_nt_xent = NT_Xent(BATCH_SIZE, TEMP, 1)

In [5]:
simclr_model = SimCLR(projection_dim=64).cuda()

pytorch_total_params = sum(p.numel() for p in simclr_model.parameters() if p.requires_grad)
pytorch_total_params

280666

# loop

In [67]:
simclr_model = SimCLR(projection_dim=64).cuda()
optimizer = torch.optim.Adam(simclr_model.parameters(), lr=3e-4)

best_loss = 9999999



for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    simclr_model.train()
    
    limiter = 10
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, _ = data
        image = image
        if batch_idx >= limiter: break
        simclr_model.zero_grad()
        aug1, aug2 = image
        aug1 = aug1.cuda()
        aug2 = aug2.cuda()
        aug1_out, aug2_out = simclr_model(aug1, aug2)
        print(aug1_out.shape, aug2_out.shape)
        loss = nt_xent(aug1_out, aug2_out, temperature=TEMP)
        official_loss = official_nt_xent(aug1_out, aug2_out)
        print(loss, official_loss)
        loss.backward()
        optimizer.step()

        epoch_losses += loss
        
#         break
        
    break
#     epoch_losses /= len(trainloader)
#     print(epoch_losses.item())
    
#     if epoch_losses < best_loss:
#         best_loss = epoch_losses
#         torch.save(simclr_model.state_dict(), f'simclr_model_{EPOCHS}_{BATCH_SIZE}_{TEMP}_{LR}.pth')
    

  0%|          | 0/391 [00:00<?, ?it/s]

torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.5249, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.5249, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.4723, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.4723, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.4356, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.4356, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.4245, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.4245, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.3197, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.3197, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.4217, device='cuda:0', grad_fn=<DivBackward0>) tensor(5.4217, device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([128, 64]) torch.Size([128, 64])
tensor(5.3408, devic