reference: https://github.com/Spijkervet/SimCLR

### Import dependency

In [None]:
import torchvision.transforms as T
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
!pip install wandb
import wandb

### Define device and label names for CIFAR10

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

## Phase 1: Pre-training

### Define transforms for training and testing

In [None]:
class TransformsSimCLR:
    def __init__(self, size):
        s = 0.5
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size, scale=(0.5, 1)),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
            ]
        )
        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(size=size),
                torchvision.transforms.ToTensor(),
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

### Define dataset and loader

In [None]:
dataset_dir = './dataset'
img_size = 32
simclr_dataset = torchvision.datasets.CIFAR10(
    dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=img_size)
)

In [None]:
batch_size = 128
simclr_dataloader = DataLoader(simclr_dataset, batch_size, shuffle=True, drop_last=True)

### Plot the example

In [None]:
sample = next(iter(simclr_dataloader))
img_1 = sample[0][0]
img_2 = sample[0][1]
label = sample[1]

In [None]:
idx = torch.randint(low=0, high=batch_size, size=(1,)).item()
fig, ax = plt.subplots(1, 2, figsize=(3, 3))
ax[0].imshow(img_1[idx].permute(1, 2, 0))
ax[1].imshow(img_2[idx].permute(1, 2, 0))
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_xticks([])
ax[1].set_yticks([])
fig.suptitle('label: {}'.format(label_names[label[idx]]))

In [None]:
torchvision.models.resnet18()

### Define SimCLR model

In [None]:
class SimCLR(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torchvision.models.resnet18()
        self.n_features = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity()

        self.projection_layer = nn.Sequential(nn.Linear(self.n_features, self.n_features),
                                              nn.GELU(),
                                              nn.Linear(self.n_features, self.n_features))

    def encode(self, x1, x2):
        return self.encoder(x1), self.encoder(x2)

    def project(self, h1, h2):
        return self.projection_layer(h1), self.projection_layer(h2)

    def forward(self, x1, x2):
        h1, h2 = self.encode(x1, x2)
        z1, z2 = self.project(h1, h2)
        return h1, h2, z1, z2

In [None]:
simclr_model = SimCLR().to(device)
simclr_optimizer = torch.optim.Adam(simclr_model.parameters(), lr=0.0003, weight_decay=1e-6)

In [None]:
B = 4
D = 10
z1 = torch.randn(B, D)
z2 = torch.randn(B, D)

nz_1 = F.normalize(z1, dim=1)
nz_2 = F.normalize(z2, dim=1)
similarity_matrix = torch.matmul(nz_1, nz_2.T)

pos_mask = torch.eye(B, dtype=torch.bool)
positives = similarity_matrix[pos_mask].view(B, -1)
negatives = similarity_matrix[~pos_mask].view(B, -1)
print(positives.shape, negatives.shape)

### Define loss function

In [None]:
def info_nce_loss(z1, z2, temperature=0.5):
    criterion = nn.CrossEntropyLoss()

    nz_1 = F.normalize(z1, dim=1)
    nz_2 = F.normalize(z2, dim=1)
    similarity_matrix = torch.matmul(nz_1, nz_2.T)

    pos_mask = torch.eye(batch_size, dtype=torch.bool).to(device)
    positives = similarity_matrix[pos_mask].view(batch_size, -1)
    negatives = similarity_matrix[~pos_mask].view(batch_size, -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(batch_size, dtype=torch.long).to(device)

    logits = logits / temperature

    loss = criterion(logits, labels)

    return loss

### Define simclr training function

In [None]:
def simclr_train(simclr_model, simclr_optimizer, sample):
    img_1 = sample[0][0].to(device)
    img_2 = sample[0][1].to(device)

    h1, h2, z1, z2 = simclr_model(img_1, img_2)

    loss = info_nce_loss(z1, z2)

    simclr_optimizer.zero_grad()
    loss.backward()
    simclr_optimizer.step()

    return loss.item()

### train the simclr model

In [None]:
epoch = 5
wandb.init(project='simclr', entity='cotton-ahn')
for e in range(epoch):
    total_loss = 0.0
    for sample in tqdm(simclr_dataloader):
        loss = simclr_train(simclr_model, simclr_optimizer, sample)
        total_loss += loss / len(simclr_dataloader)
        wandb.log({'loss': loss})
    print('[EPOCH {}] loss : {:.03f}'.format(e+1, total_loss))
    torch.save(simclr_model.state_dict(), './checkpoint.pth')

## Phase 2: Supervised learning

### define datasets and loaders

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=img_size).test_transform
)

test_dataset = torchvision.datasets.CIFAR10(
    dataset_dir,
    download=True,
    train = False,
    transform=TransformsSimCLR(size=img_size).test_transform
)

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True, drop_last=True)

### define MLP based classifier which uses feature extractor of SimCLR

In [None]:
class MLP_Classifier(nn.Module):
    def __init__(self, simclr_model, feat_dim=512, n_classes=10):
        super().__init__()

        self.simclr_model = simclr_model
        self.feat_dim = feat_dim

        for p in self.simclr_model.parameters():
            p.requires_grad = False

        self.mlp = nn.Sequential(nn.Linear(feat_dim, feat_dim),
                                 nn.GELU(),
                                 nn.Linear(feat_dim, n_classes))

    def forward(self, img):
        B = img.shape[0]

        self.simclr_model.eval()
        with torch.no_grad():
            feature = self.simclr_model.encoder(img)
        return self.mlp(feature.reshape(B, -1))

In [None]:
simclr_model = SimCLR().to(device)
simclr_model.load_state_dict(torch.load('./checkpoint_ver1.pth'))
supervise_model = MLP_Classifier(simclr_model).to(device)
supervise_optimizer = torch.optim.Adam(supervise_model.parameters(), lr=0.0003, weight_decay=1e-6)

In [None]:
def supervise_train(model, optimizer, sample):
    model.train()
    criterion = nn.CrossEntropyLoss()
    img = sample[0].to(device)
    label = sample[1].to(device)

    logit = model(img)

    optimizer.zero_grad()
    loss = criterion(logit, label)
    loss.backward()
    optimizer.step()

    n_correct = sum(torch.argmax(logit, dim=1) == label).item()

    return loss.item(), n_correct

In [None]:
def test(model, sample):
    model.eval()
    img = sample[0].to(device)
    label = sample[1].to(device)

    with torch.no_grad():
        logit = model(img)
    n_correct = sum(torch.argmax(logit, dim=1) == label).item()

    return n_correct

In [None]:
epoch = 100
for e in range(epoch):
    total_loss = 0.0
    for sample in tqdm(train_dataloader):
        loss, n_correct = supervise_train(supervise_model, supervise_optimizer, sample)
        total_loss += loss / len(train_dataloader)

    total_n_correct = 0.0
    for sample in tqdm(test_dataloader):
        n_correct = test(supervise_model, sample)
        total_n_correct += n_correct / len(test_dataset)

    print('[EPOCH {}] loss: {}, n_correct: {}%'.format(e+1, total_loss, total_n_correct*100))

In [None]:
n_correct