In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
from torch import linalg as LA
import time
dir = '/content/drive/MyDrive/SIMSIAM/Experiments'

# Defining SiamSim

In [None]:
class SiamSimModel(nn.Module):
    def __init__(self, encoder, dim, pred_dim, last_dim, spec_norm=False):
        super().__init__()
        #conditions:
        self.ldim = last_dim

        #encoder
        self.encoder = encoder

        
        #projectionMLP
        self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, dim, bias=False),
                                      nn.BatchNorm1d(dim))        
        
          
        
        if spec_norm:
          #predictionMLP
          self.predictor = nn.Sequential(nn.utils.parametrizations.spectral_norm(nn.Linear(dim, pred_dim, bias=False)),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.utils.parametrizations.spectral_norm(nn.Linear(pred_dim, dim)))
          
        else:
          #predictionMLP
          self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(pred_dim, dim))


    def forward(self, x1, x2):
        z1 = self.projector(self.encoder(x1).view(-1, self.ldim))
        z2 = self.projector(self.encoder(x2).view(-1, self.ldim))
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)    
        return p1, p2, z1.detach(), z2.detach()


In [None]:
#(dis)similarity functions
CosSim = nn.CosineSimilarity(dim=1)
SoftMax = nn.Softmax(dim=1)
LogSoftMax = nn.LogSoftmax(dim=1)
MSE = nn.MSELoss()
pd_2 = nn.PairwiseDistance(p=2)
pd_1 = nn.PairwiseDistance(p=1)
pd_inf = nn.PairwiseDistance(p=torch.inf)

def neg_cosine_sim(a,b):
  return -CosSim(a, b)

def cross_entropy_sim(a,b):
  return -SoftMax(b)*LogSoftMax(a)

def mse_sim(a,b):
  return nn.MSELoss()

def pdist_1(a,b):
  return pd_1(a, b)

def pdist_2(a,b):
  return pd_2(a, b)

def pdist_inf(a,b):
  return pd_inf(a,b)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#EXPERIMENTS
opt='default'#'avg'#'adam' #1. OPTIMIZERS:
sim_fun_flag='default'#'p2'#'pinf' #2. SIMILARITY FUNCTIONS
loss_agg='default'#'geo' #3. LOSS AGGREGATIONS:
spec_norm = False #4. REGULARIZERS (SPEC NORM)

model = SiamSimModel(encoder, dim, pred_dim, last_dim, spec_norm=spec_norm)

#EXPERIMENT 2 - MORE SIM FUNS
if sim_fun_flag=='default':
  sim_fun = cross_entropy_sim
elif sim_fun_flag=='p2':
  sim_fun = pdist_2
elif sim_fun_flag=='pinf':
  sim_fun = pdist_inf

if spec_norm:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'weight_decay': 0, 'name': 'predictor'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
batch_size = 256
num_epochs = 20#100
lr = 0.05*(batch_size/256)
if opt=='default':
  optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
elif opt=='avg'
  optimizer = torch.optim.ASGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
elif opt=='adam':
  optimizer = torch.optim.Adam(param_groups, lr=lr, betas=(0.9, 0.999), weight_decay=1e-4)

device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

#name of the experiment to perform
setup = opt + '_' + sim_fun_flag + '_' + loss_agg
if spec_norm:
  setup = setup + '_specnorm'



In [None]:
#cosine schedule for LR
import math
def update_lr(optimizer, current_e, total_e, max_lr, min_lr=0, spec_norm=False):
  for g in optimizer.param_groups:
    if g['name'] == 'predictor':
      pass
    else:
      g['lr'] = min_lr + 0.5*(max_lr-min_lr)*(1  + math.cos((math.pi*current_e)/total_e))

# Defining Data Loaders

In [None]:
from PIL import ImageFilter
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

traindir = "data"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

tranform = TwoCropsTransform(transforms.Compose(augmentation))
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=tranform)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size, shuffle=True,
        num_workers=2, pin_memory=True, sampler=None, drop_last=True)

Files already downloaded and verified


# Train Loop

In [None]:
model = model.to(device)
for epoch in range(num_epochs):
    model.train()
    epoch_loss, epoch_count = 0, 0
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if loss_agg=='geo':
          pass#use geometric mean here!
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_count += 1
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    epoch_loss /= epoch_count
    print("train loss: {}".format(epoch_loss))
    if True:#epoch>=num_epochs-1 or epoch%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'proj_state_dict': model.projector.state_dict(),
          'pred_state_dict': model.predictor.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

epoch=0/100 step=0/781 loss=0.003763296641409397 time=1.0193977355957031 secs
epoch=0/100 step=10/781 loss=0.003764156252145767 time=1.14306640625 secs
epoch=0/100 step=20/781 loss=0.0037637464702129364 time=1.233649730682373 secs
epoch=0/100 step=30/781 loss=0.00376409525051713 time=1.0110406875610352 secs
epoch=0/100 step=40/781 loss=0.0037634908221662045 time=1.0809943675994873 secs
epoch=0/100 step=50/781 loss=0.003762795589864254 time=1.0188500881195068 secs
epoch=0/100 step=60/781 loss=0.0037636999040842056 time=1.0053744316101074 secs
epoch=0/100 step=70/781 loss=0.003763706423342228 time=1.0113224983215332 secs
epoch=0/100 step=80/781 loss=0.0037635217886418104 time=1.0095503330230713 secs
epoch=0/100 step=90/781 loss=0.003763469634577632 time=1.0095727443695068 secs
epoch=0/100 step=100/781 loss=0.003763496642932296 time=1.0090186595916748 secs
epoch=0/100 step=110/781 loss=0.0037637169007211924 time=1.02461576461792 secs
epoch=0/100 step=120/781 loss=0.0037635681219398975 tim

KeyboardInterrupt: ignored