Supervised Contrastive Loss
  - use colab
  - use cifar-10(a reduced version of the paper)
  - closer to a tutorial

Instructions
  - make a directory(folder) in google drive named state_dict
    - you can use different paths by changing the code
    
  

Feel free to use!!



In [36]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as tfs

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from tqdm import trange

In [21]:
config = {'batch_size':128, 
          'lr_embed':1e-3,
          'lr_proj':1e-4,
          'epochs_embed':700,
          'epochs_proj':10,
          'T':0.1, 
          'mean':0.5, 
          'std':0.5, 
          'dataset':'CIFAR-10'}

if torch.cuda.is_available():
  device='cuda'
else:
  device='cpu'
print(device)


if config['dataset'] == 'CIFAR-10':
  output_size = 10
elif config['dataset'] == 'CIFAR-100':
  output_size=100
else:
  raise NameError('wrong dataset name : dataset name should be CIFAR-10 or CIFAR-100')

cuda


In [0]:
class Aug2:
  def __init__(self, tfs):
    self._tfs = tfs
  
  def __call__(self, x):
    return [self._tfs(x), self._tfs(x)]

#generate a pair of augmented data
def AugmentDataset(dataset='CIFAR-10', split='train', download=True, size=32, mean=0.5, std=0.5):
  _transforms = tfs.Compose([tfs.RandomResizedCrop(size, scale=(0.75, 1.25)),
                             tfs.RandomGrayscale(0.2),
                             tfs.RandomHorizontalFlip(0.3),
                             tfs.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2)),
                             tfs.RandomAffine(degrees=(-10, 10), translate=(0.01, 0.05), scale=(0.9, 1.1), shear=(0, 5), fillcolor=0),
                             tfs.ToTensor(),
                             tfs.Normalize(mean=(mean,), std=(std,))])

  if split=='train':
   train=True
  elif split=='val' or split=='test':
   train=False
  else:
    raise NameError('split should be train, test, or val')

  if dataset=='CIFAR-10':
    _data = torchvision.datasets.CIFAR10(root='./data', 
                                         train=train, 
                                         transform=Aug2(_transforms), 
                                         download=download)
  elif dataset=='CIFAR-100':
    _data = torchvision.datasets.CIFAR100(root='./data', 
                                          train=train, 
                                          transform=Aug2(_transforms), 
                                          download=download)

  print(len(_data))
  return _data

In [0]:
train_dataset = AugmentDataset(dataset=config['dataset'], split='train', download=True, size=32, mean=config['mean'], std=config['std'])
val_dataset = AugmentDataset(dataset=config['dataset'], split='val', download=True, size=32, mean=config['mean'], std=config['std'])

In [0]:

#train, validation data loader
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True, drop_last=False)

def imshow(x):
  x = x.numpy()
  x = x * config['std'] + config['mean']
  x = np.transpose(x, [1, 2, 0])
  plt.imshow(x)
  plt.show()

#we should check data the data is loader properly
def checkdata(dataloader):
  for step, data in enumerate(dataloader):
    x, y = data
    print(x[0].shape)
    print(x[1].shape)
    print(y.shape)

    imshow(x[0][0])
    imshow(x[1][0])
    print(y[0])
    break
checkdata(train_loader)
checkdata(val_loader)

In [0]:
#in the paper it uses Resnet 50

#but here I will use a custum net 
#for the sake of simplicity & a limited computational resources
def conv3(input_channels, output_channels):
  return nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)

class Encoder(nn.Module):

  def __init__(self, output_channels):
    super(Encoder, self).__init__()
    
    self.cnn1 = conv3(3, 8)
    self.cnn2 = conv3(8, 12)
    self.cnn3 = conv3(12, 16)
    self.cnn4 = conv3(16, 20)
    self.cnn5 = conv3(20, output_channels)
  
  def forward(self, x1, x2):
    return self._forward_single(x1), self._forward_single(x2)


  def _forward_single(self, x):
    x = F.max_pool2d(F.relu(self.cnn1(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.cnn2(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.cnn3(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.cnn4(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.cnn5(x)), (2, 2))
    x = F.normalize(x, p=2, dim=1)
    return x


class Projection(nn.Module):

  def __init__(self, input_size, output_size, type_='perceptron'):
    super(Projection, self).__init__()

    if type_=='perceptron':
      self._layers = nn.Sequential(
          nn.Linear(input_size, input_size),
          nn.ReLU(),
          nn.Linear(input_size, output_size)
      )
    elif type_=='linear':
      self._layers = nn.Linear(input_size, output_size)
    else:
      raise NameError('type_ should be a string type of perceptron or linear')

  def forward(self, x1, x2):
    return torch.cat([self._forward_single(x1), self._forward_single(x2)], dim=0)


  def _forward_single(self, x):
    return F.normalize(self._layers(x), p=2, dim=1)

class SupConNet(nn.Module):

  def __init__(self, enet, pnet):
    super(SupConNet, self).__init__()
    self._enet = enet
    self._pnet = pnet

  def forward(self, x1, x2):
    r1, r2 = self._enet(x1, x2)
    r1, r2 = r1.view(r1.size(0), -1), r2.view(r2.size(0), -1)
    z = self._pnet(r1, r2)
    return z

  #when training the embedding, we need to apply gradiennts to 
  #both the embedding layer, and the perceptron layer
  def train_embedding(self):
    self._enet.train()
    self._pnet.train()
  
  #when training the projection net
  #apply gradients to only the linear layer
  def train_projection(self):
    self._enet.eval()
    self._pnet.train()
  

class SupContrastLoss(nn.Module):
  
  def __init__(self, device, T=1.0, EPS=1e-9):
    super(SupContrastLoss, self).__init__()
    self._T = T
    self._EPS = EPS
    self._softmax = nn.Softmax(dim=1)
    self._device = device

  def forward(self, z, y):
    #z is a shape of (batch_size * 2, num of classes)
    #y is a shape of (batch_size)
    batch_size = z.size(0) // 2
    shape_ = (2 * batch_size, 2 * batch_size)

    #Iyy is an (2 * batch x 2* batch) maxtrix 
    #Iyy (i, j) = 0 if y[i] == y[j] or i == j
    #else Iyy(i, j) = 1
    Iyy = torch.eq(torch.cat([y.unsqueeze(dim=1), y.unsqueeze(dim=1)], dim=0),
                   torch.cat([y.unsqueeze(dim=1), y.unsqueeze(dim=1)], dim=0).T).float() - torch.eye(2*batch_size).to(self._device)
    
    #softmax(z * z.T) is softmax of all i,j pairs
    #apply mask of zero diagonals
    #apply L1 normalization to obtain Pij
    #self._EPS is used to avoid nan
    Pij = F.normalize(self._softmax(torch.mm(z, z.T)) * 
                      (torch.ones(shape_).to(self._device) - (1 - self._EPS) * torch.eye(2*batch_size).to(self._device)), p=1, dim=1)
    
    #since Iyy has 0 diagonals, we need to use 2n+1 instead of 2n-1 as in the paper
    loss = - torch.sum(torch.div(torch.sum(Iyy * torch.log(Pij + self._EPS), dim=1), 2 * torch.sum(Iyy, dim=1) + 1))
    return loss

class SupCrossEntropyLoss(nn.Module):

  def __init__(self, device):
    super(SupCrossEntropyLoss, self).__init__()

    self._cross_entropy = nn.CrossEntropyLoss()
    self._device = device
  
  def forward(self, y_, y):
    #y_ is the size of [2 * batch_size, number of classes]
    #y is the size of [batch_size]
    return self._cross_entropy(y_, torch.cat([y.unsqueeze(dim=1), y.unsqueeze(dim=1)], dim=0).squeeze(dim=1))

In [0]:
def backprop(loss, optimizer):
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

def Train(enet, pnet, trainloader, valloader, epochs, lr, T, start_epoch, device, state):
  scnet = SupConNet(enet, pnet)
  scnet.to(device)
  optimizer = optim.Adam(scnet.parameters(), lr=lr)
  min_val_loss = 1e20

  if state == 'embedding':
    loss_fn = SupContrastLoss(device, T=T).to(device)
  elif state == 'projection':
    loss_fn = SupCrossEntropyLoss(device).to(device)
  else:
    raise NameError('state arg should be embedding or projection')

  for epoch in trange(epochs):

    losses = []
    
    if state=='embedding':
      scnet.train_embedding()
    else:
      scnet.train_projection()
    for step, (x, y) in enumerate(trainloader):
      x1, x2 = x[0], x[1]
      x1, x2, y = x1.to(device), x2.to(device), y.to(device)

      z = scnet(x1, x2)
      loss = loss_fn(z, y)
      backprop(loss, optimizer)

      losses.append(loss.item())
    avg_train_loss = sum(losses) / len(losses)

    scnet.eval()
    for step, (x, y) in enumerate(valloader):
      x1, x2 = x[0], x[1]
      x1, x2, y = x1.to(device), x2.to(device), y.to(device)

      z = scnet(x1, x2)
      loss = loss_fn(z, y)
      losses.append(loss.item())
    avg_val_loss = sum(losses) / len(losses)

    if avg_val_loss < min_val_loss:
      min_val_loss = avg_val_loss
      if state=='embedding':
        torch.save(enet.state_dict(), f'/content/gdrive/My Drive/state_dict/enet_checkpoint{epoch + start_epoch }.pt')
      else:
        torch.save(scnet.state_dict(), f'/content/gdrive/My Drive/state_dict/scnet_checkpoint{epoch + start_epoch}.pt')
    print(f'\n epoch {epoch + start_epoch} finished ====> train loss : {avg_train_loss}, val loss : {avg_val_loss}')

In [0]:
enet = Encoder(40)
enet.load_state_dict(torch.load('/content/gdrive/My Drive/state_dict/enet_checkpoint25.pt'))
start_epoch = 25
pnet1 = Projection(40, output_size, type_='perceptron')
from torch import autograd
#with autograd.detect_anomaly():
Train(enet, pnet1, train_loader, val_loader, config['epochs_embed'], config['lr_embed'], config['T'], start_epoch, device, 'embedding')

In [0]:
enet.load_state_dict(torch.load('/content/gdrive/My Drive/state_dict/enet_checkpoint108.pt'))
pnet2 = Projection(40, output_size, type_='linear')
Train(enet, pnet2, train_loader, val_loader, config['epochs_proj'], config['lr_proj'], config['T'], 0, device, 'projection')