# Bharath Gunasekarn
# SimCLR pytorch

# Code was written in reference to  https://medium.com/analytics-vidhya/understanding-simclr-a-simple-framework-for-contrastive-learning-of-visual-representations-d544a9003f3c

In [None]:
import numpy as np
import torch
from torchvision import transforms
from torchsummary import summary
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18

import os
from PIL import Image
from collections import OrderedDict

import random

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
tsne = TSNE()

# Get Data
Using a subset of cifar10 dataset. 

In [None]:
# Download Data
%%capture
!rm -rf ./data
!mkdir -p data
!cd data
!wget https://raw.githubusercontent.com/bharathGuna/CMPE-297-Special-Topics/main/assignment1/data/cifar.zip -P ./data/
!unzip data/cifar.zip -d data/

In [None]:
train_files = sorted(os.listdir('data/train'))
test_files = sorted(os.listdir('data/test'))

random.seed(1)

train = random.sample(train_files, len(train_files))
eval = random.sample(train, len(train_files) // 10)
test = random.sample(test_files, len(test_files))

label_set = set()
train_labels = []
test_labels = []
eval_labels = [] 
for name in train:
  label = name.split('_')[0]
  label_set.add(label)
  train_labels.append(label)

for name in eval:
  label = name.split('_')[0]
  eval_labels.append(label)

for name in test:
  label = name.split('_')[0]
  test_labels.append(label)

label_map = {}

for label,value in zip(label_set,[0,1,2,3,4]):
  label_map[label] = value

label_map

# Data Augmentation

In [None]:
# Modifys the color of the images
def get_color_distortion(s=1.0):
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    rnd_color_jitter =  transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray =  transforms.RandomGrayscale(p=0.2)
    color_distort =  transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort


# Image DataLoader

In [None]:
from torchvision import transforms

class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, datapath, filenames, labels, mutation):
        self.datapath = datapath
        self.filenames = filenames
        self.labels = labels
        self.mutation = mutation

 
    def __len__(self):
        return len(self.filenames)

    def tensorify(self, img):
        return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(
            transforms.ToTensor()(img)
            )

    def augmented_image(self, img):
        return get_color_distortion(1)(
            transforms.RandomResizedCrop(224)(img)
            )    

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = transforms.Resize((224, 224))(
                                Image.open(os.path.join(self.datapath, self.filenames[idx])).convert('RGB')
                            )
        if self.mutation:
          return {
          'image1':self.tensorify(
              self.augmented_image(img)
              ), 
          'image2': self.tensorify(
              self.augmented_image(img)
              ),
          'label': self.labels[idx]
          }
        else:
          return {
          'image':self.tensorify(
              transforms.RandomResizedCrop(244)(img)
              ),
          'label': self.labels[idx]
          }


In [None]:
training_dataset_mutated = ImageDataset('data/train', train, train_labels, mutation=True)
training_dataset = ImageDataset('data/train', eval, eval_labels, mutation=False)
testing_dataset = ImageDataset('data/test', test, test_labels, mutation=False)

In [None]:
dataloader_training_dataset_mutated = DataLoader(training_dataset_mutated, batch_size=64, shuffle=True, num_workers=2)
dataloader_training_dataset = DataLoader(training_dataset, batch_size=32, shuffle=True, num_workers=2)
dataloader_testing_dataset = DataLoader(testing_dataset, batch_size=64, shuffle=True, num_workers=2)

In [None]:
# defining our deep learning architecture
resnet = resnet18(pretrained=False)

classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(resnet.fc.in_features, 100)),
    ('added_relu1', nn.ReLU(inplace=True)),
    ('fc2', nn.Linear(100, 50)),
    ('added_relu2', nn.ReLU(inplace=True)),
    ('fc3', nn.Linear(50, 25))
]))

resnet.fc = classifier

In [None]:
device = torch.device('cuda')


In [None]:
resnet.to(device)


In [None]:
tau = 0.05

def loss_function(a, b):
    a_norm = torch.norm(a, dim=1).reshape(-1, 1)
    a_cap = torch.div(a, a_norm)
    b_norm = torch.norm(b, dim=1).reshape(-1, 1)
    b_cap = torch.div(b, b_norm)
    a_cap_b_cap = torch.cat([a_cap, b_cap], dim=0)
    a_cap_b_cap_transpose = torch.t(a_cap_b_cap)
    b_cap_a_cap = torch.cat([b_cap, a_cap], dim=0)
    sim = torch.mm(a_cap_b_cap, a_cap_b_cap_transpose)
    sim_by_tau = torch.div(sim, tau)
    exp_sim_by_tau = torch.exp(sim_by_tau)
    sum_of_rows = torch.sum(exp_sim_by_tau, dim=1)
    exp_sim_by_tau_diag = torch.diag(exp_sim_by_tau)
    numerators = torch.exp(torch.div(torch.nn.CosineSimilarity()(a_cap_b_cap, b_cap_a_cap), tau))
    denominators = sum_of_rows - exp_sim_by_tau_diag
    num_by_den = torch.div(numerators, denominators)
    neglog_num_by_den = -torch.log(num_by_den)
    return torch.mean(neglog_num_by_den)

In [None]:

# Defining data structures for storing training info

losses_train = []
num_epochs = 20

# using SGD optimizer
optimizer = optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)

if not os.path.exists('results'):
    os.makedirs('results')


In [None]:
# Boolean variable on whether to perform training or not 
# Note that this training is unsupervised, it uses the NT-Xent Loss function

TRAINING = True

def get_mean_of_list(L):
    return sum(L) / len(L)

if TRAINING:
    # get resnet in train mode
    resnet.train()

    # run a for loop for num_epochs
    for epoch in range(num_epochs):

        # a list to store losses for each epoch
        epoch_losses_train = []

        # run a for loop for each batch
        for (_, sample_batched) in enumerate(dataloader_training_dataset_mutated):
            
            # zero out grads
            optimizer.zero_grad()

            # retrieve x1 and x2 the two image batches
            x1 = sample_batched['image1']
            x2 = sample_batched['image2']

            # move them to the device
            x1 = x1.to(device)
            x2 = x2.to(device)

            # get their outputs
            y1 = resnet(x1)
            y2 = resnet(x2)

            # get loss value
            loss = loss_function(y1, y2)
            
            # put that loss value in the epoch losses list
            epoch_losses_train.append(loss.cpu().data.item())

            # perform backprop on loss value to get gradient values
            loss.backward()

            # run the optimizer
            optimizer.step()

        # append mean of epoch losses to losses_train, essentially this will reflect mean batch loss
        
        loss = get_mean_of_list(epoch_losses_train) 
        losses_train.append(loss)
        print("Epoch: {} Loss: {}".format(epoch,loss))

        # Plot the training losses Graph and save it
        fig = plt.figure(figsize=(10, 10))
        sns.set_style('darkgrid')
        plt.plot(losses_train)
        plt.legend(['Training Losses'])
        plt.savefig('losses.png')
        plt.close()

        # Store model and optimizer files
        torch.save(resnet.state_dict(), 'results/model.pth')
        torch.save(optimizer.state_dict(), 'results/optimizer.pth')
        np.savez("results/lossesfile", np.array(losses_train))

In [None]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd

tsne = TSNE()

def plot_vecs_n_labels(df,fname):
    fig = plt.figure(figsize = (10, 10))
    plt.axis('off')
    sns.set_style("darkgrid")
    sns.scatterplot(data=df,x='x', y='y',  hue="label", legend='full')
    plt.legend(['car', 'dog', 'cat', 'elephant','airplane'])
    plt.savefig(fname)
    plt.close()

for (_, sample_batched) in enumerate(dataloader_training_dataset):
    x = sample_batched['image']
    x = x.to(device)
    y = resnet(x)
    y_tsne = tsne.fit_transform(y.cpu().data)
    labels = sample_batched['label']
    df = pd.DataFrame(y_tsne, columns=['x','y'])
    df['label'] = np.array(labels)
    plot_vecs_n_labels(df,'tsne_train_last_layer.png')