In [None]:
%cd '/content/drive/MyDrive/ecehw/project'

/content/drive/MyDrive/ecehw/project


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random

from resnet import ResNet18, MLP, Block

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cuda


In [2]:
BATCH_SIZE = 512

trainset = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size = BATCH_SIZE, shuffle=True, )

Files already downloaded and verified


In [3]:
def get_color_distortion(s:float=0.5):
    """
    Function from the paper that create color distortion 
    s: float, the strength of color distortion, for CIFAR 10, the paper use 0.5
    """
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

In [4]:
train_transform = transforms.Compose([
            # make sure we're using PIL instead of tensor when doing other transform 
            transforms.ToPILImage(),
            #transforms.GaussianBlur(23, sigma=(0.1, 2.0)), # CIFAR 10 doesn't use gaussian blur
            transforms.RandomResizedCrop(size=32,scale=(0.08,0.1),ratio=(0.75,1.33)),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(),
            transforms.ToTensor(),
            # the normalize numbers are from previous assignment
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

In [5]:
LR = 0.1

net_f = ResNet18(3, Block)
net_g = MLP(512)
net = nn.Sequential(net_f, net_g)
optimizer = optim.Adam(net.parameters(), lr=0.0003*(BATCH_SIZE / 256), weight_decay=1e-6)
# optimizer = optim.SGD(net.parameters(), lr=0.3*(BATCH_SIZE / 256), weight_decay=1e-6)
# optimizer = optim.SGD(net.parameters(), lr=LR, weight_decay=1e-6)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
net.to(device)
net.train()

Sequential(
  (0): ResNet18(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): Block(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (identity): Sequential()
      )
      (1): Block(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

In [6]:
def compute_loss(yhat, t):
    """
    Computing the contrastive loss based on cosine similarity
    input:
        yhat: [tensor] latent embedding features size: BATCH_SIZE * 128
        t: [float] temperature range: (0.0, 1.0)
    output:
        loss: [tensor] 1D
    """
    
    # testing code
    #from collections import defaultdict
    #cache = defaultdict(int)
    #for i in range(yhat.shape[0]):
    #    for j in range(i+1, yhat.shape[0]):
    #        val = torch.nn.functional.cosine_similarity(yhat[i], yhat[j], dim=0, eps=1e-8)
    #        cache[(i, j)] = val
    #        cache[(j, i)] = val
    
    N = yhat.shape[0]

    # Calculate the pair-wise consine similarity
    cache = torch.nn.functional.cosine_similarity(yhat.unsqueeze(0), yhat.unsqueeze(1), dim=-1, eps=1e-8)
    cache = cache / t
    
    # Delete the diagonal entries
    mask = torch.eye(N, dtype=bool)
    cache = cache[~mask].view((N, N-1))

    # Make pesudo-labels
    label = []
    for i in range(N // 2):
        label.append(int(2 *i))
        label.append(int(2 *i))
    label = torch.tensor(label)
    label = label.to(device)

    # Calculate the cross entropy loss
    loss = F.cross_entropy(cache, label)

    return loss
    

    """
    # Calculate the pair-wise consine similarity
    cache = torch.nn.functional.cosine_similarity(yhat.unsqueeze(0), yhat.unsqueeze(1), dim=-1, eps=1e-8)
    # print(cache)
    cache = torch.exp(cache / t)
    cache_sum = torch.sum(cache, dim=1)

    
    # Compute the contrastive loss
    loss = 0
    for n in range(yhat.shape[0]):
        # Get the index of positive pairs, based on 2 tensors of the same pair are adjacent in terms of index
        i = n
        if i % 2 == 0:
            j = i + 1
        else:
            j = i - 1

        # The numerator is between the positive pair
        # The denominator is between one tensor with all OTHER tensors
        numerator = cache[i, j]
        denominator = cache_sum[i] - cache[i, i]
        #denominator = 0
        #for k in range(yhat.shape[0]):
        #    if k != i:
        #        denominator += cache[i, k]
        cur_loss = (-1) * torch.log(numerator / denominator)
        # Add up the loss and take the average
        loss += (1 / (2 * yhat.shape[0])) * cur_loss
    loss = torch.tensor(loss.item())
    loss.requires_grad = True
    # loss = torch.mean(loss)
    
    return loss
    """

In [None]:
# loop through each batch in trainset 
LOSSES = []
EPOCHS = 10
OPTIM_LOSS = float('inf')

for epoch in range(EPOCHS):
    cost = 0
    for data, label in trainset:
        for ind_img in range(len(data)): # loop through each image in batch 
            par_tensor = data[ind_img]
            cur_tensor_0 = train_transform(par_tensor) # first aug
            cur_tensor_1 = train_transform(par_tensor) # second aug 
            # resize 
            cur_tensor_0, cur_tensor_1 = torch.unsqueeze(cur_tensor_0, 0), torch.unsqueeze(cur_tensor_1, 0) 
            # if this is the first image in the batch, we just concat the 2 data aug 
            if ind_img == 0:
                total_tensor = torch.cat((cur_tensor_0, cur_tensor_1), dim=0)
            # else append to the previous augmented pair in the batch 
            else:
                total_tensor = torch.cat((total_tensor, cur_tensor_0, cur_tensor_1), dim=0)
            
        total_tensor = total_tensor.to(device)
        # pass <total_tensor> into the model 
        yhat = net(total_tensor)
        # calculate loss 
        loss = compute_loss(yhat, 0.5)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        cost += loss.item()
        avg_loss = cost / len(trainset)

    LOSSES.append(avg_loss)
    if avg_loss < OPTIM_LOSS:
        OPTIM_LOSS = avg_loss
        torch.save(net.state_dict(), 'simclr.pt')
        # torch.save(net_f.state_dict(), 'simclr_netf.pt')
    print(LOSSES)
        

[6.643378209094612]
[6.643378209094612, 6.498396435562445]


In [None]:
plt.plot([_ for _ in range(len(LOSSES))], LOSSES)

NameError: ignored

In [None]:
# loop through each batch in trainset 
LOSSES = []
#net_f.eval()
temp = []

for data, label in trainset:
    pass

In [None]:
# decide the learning rate, whether to implement lr decay
# linear eval (discard net_g, connect net_f with logistic regressor, freeze net_f, train the whole net)
# supervised counterpart, for simCLR epoch=90
# write the report