In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
import time
dir = '/content/drive/MyDrive/SIMSIAM'

# Defining SiamSim

In [None]:
class SiamSimModel(nn.Module):
    def __init__(self, encoder, dim, pred_dim, last_dim, stop_grad=True, ditch_pred=False, bn_config='default'):
        super().__init__()
        #conditions:
        self.ldim = last_dim
        self.stop_grad = stop_grad
        self.ditch_pred = ditch_pred


        #encoder
        self.encoder = encoder

        if bn_config=='default':
          #projectionMLP
          self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, dim, bias=False),
                                        nn.BatchNorm1d(dim))        
        
          #predictionMLP
          self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(pred_dim, dim))
        elif bn_config =='none':
          #projectionMLP
          self.projector = nn.Sequential(nn.Linear(last_dim, last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, dim))        
        
          #predictionMLP
          self.predictor = nn.Sequential(nn.Linear(dim, pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(pred_dim, dim))
          
        elif bn_config =='hidden':
          #projectionMLP
          self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, dim))        
        
          #predictionMLP
          self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(pred_dim, dim))
          
        elif bn_config=='all':
          #projectionMLP
          self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, last_dim, bias=False),
                                        nn.BatchNorm1d(last_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(last_dim, dim, bias=False),
                                        nn.BatchNorm1d(dim))        
        
          #predictionMLP
          self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(pred_dim, dim, bias=False),
                                        nn.BatchNorm1d(dim))
        else:
          raise("Choose valid bn_config (default, none, hidden, all)")


    def forward(self, x1, x2):
        z1 = self.projector(self.encoder(x1).view(-1, self.ldim))
        z2 = self.projector(self.encoder(x2).view(-1, self.ldim))
        if self.ditch_pred:
          p1, p2 = z1, z2
        else:
          p1 = self.predictor(z1)
          p2 = self.predictor(z2)
        
        if self.stop_grad:
          return p1, p2, z1.detach(), z2.detach()
        else:
          return p1, p2, z1, z2


In [None]:
#(dis)similarity functions
CosSim = nn.CosineSimilarity(dim=1)
SoftMax = nn.Softmax(dim=1)
LogSoftMax = nn.LogSoftmax(dim=1)
def neg_cosine_sim(a,b):
  return -CosSim(a, b)

def cross_entropy_sim(a,b):
  return -SoftMax(b)*LogSoftMax(a)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = True#True #4.2b: fix predictor as rand init
pred_lr = False#True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 256#64#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'default'#'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
num_epochs = 1
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
elif pred_lr:
  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
#elif batch_size != 256:
#  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
#cosine schedule for LR
import math
def update_lr(optimizer, current_e, total_e, max_lr, min_lr=0, pred_lr=False):
  for g in optimizer.param_groups:
    if g['name'] == 'predictor' and pred_lr:
      pass
    else:
      g['lr'] = min_lr + 0.5*(max_lr-min_lr)*(1  + math.cos((math.pi*current_e)/total_e))

# Defining Data Loaders

In [None]:
from PIL import ImageFilter
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

traindir = "data"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

tranform = TwoCropsTransform(transforms.Compose(augmentation))
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=tranform)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size, shuffle=True,
        num_workers=2, pin_memory=True, sampler=None, drop_last=True)

# Train Loop

In [None]:
model= nn.DataParallel(model)
model = model.to(device)


for epoch in range(num_epochs):
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

#MORE EXPERIMENTS 1 (smallest batchsize)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = False#True #4.2b: fix predictor as rand init
pred_lr = True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 64#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'default'#'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]


#training settings:
#num_epochs = 20
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
#elif pred_lr:
#  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
elif batch_size != 256:
  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
model= nn.DataParallel(model)
model = model.to(device)


for epoch in range(num_epochs):
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

#MORE EXPERIMENTS 2 (largest batchsize)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = False#True #4.2b: fix predictor as rand init
pred_lr = True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 2048#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'default'#'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
#num_epochs = 20
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
#elif pred_lr:
#  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
elif batch_size != 256:
  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
model= nn.DataParallel(model)
model = model.to(device)


for epoch in range(num_epochs):
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

#MORE EXPERIMENTS 3 (no batchnorm)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = False#True #4.2b: fix predictor as rand init
pred_lr = True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 256#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
#num_epochs = 20
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
#elif pred_lr:
#  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
elif batch_size != 256:
  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
model= nn.DataParallel(model)
model = model.to(device)
losses = []
for epoch in range(num_epochs):
    epoch_loss, epoch_count = 0, 0
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_count += 1
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    losses.append(epoch_loss/epoch_count)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

In [None]:
print("here we would like to learn something about " + setup + "!!!!!!!!!")
print(losses)

#MORE EXPERIMENTS 4 (hidden batchnorm)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = False#True #4.2b: fix predictor as rand init
pred_lr = True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 256#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'hidden'#'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
#num_epochs = 20
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
#elif pred_lr:
#  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
elif batch_size != 256:
  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
model= nn.DataParallel(model)
model = model.to(device)
losses = []
for epoch in range(num_epochs):
    epoch_loss, epoch_count = 0, 0
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_count += 1
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    losses.append(epoch_loss/epoch_count)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

In [None]:
print("here we would like to learn something about " + setup + "!!!!!!!!!")
print(losses)

#MORE EXPERIMENTS 5 (batchnorm everywhere)

In [None]:
dim=2048
pred_dim = 512
backbone = models.resnet18()
last_dim=list(backbone.children())[-1].in_features

encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
stop_grad = True#False #4.1: remove stop-grad -DONE
ditch_pred = False#True #4.2a: remove predictor
fix_pred = False#True #4.2b: fix predictor as rand init
pred_lr = True #4.2c: use fixed lr for predictor (authors keep using it after!?) -DONE
batch_size = 256#1024-4096 (whatever we get working) #4.3: batch-size variation
bn_config = 'all'#'none'#'hidden'#'all' #4.4: batch norm configurations #NOTE: seems like bn=none collapses? and bn=all isnt unstable?
sim_fun_flag = False#experiment 4.5: similarity function (use cross_entropy_sim)
asymmetric = False#True #experiment 4.6: symmetrization

model = SiamSimModel(encoder, dim, pred_dim, last_dim, stop_grad=stop_grad, 
                     ditch_pred = ditch_pred, bn_config=bn_config)
if ditch_pred or fix_pred:
  for p in model.predictor.parameters():
    p.requires_grad = False
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'}]
else:
  param_groups = [{'params': model.encoder.parameters(), 'name': 'encoder'},
                  {'params': model.projector.parameters(), 'name': 'projector'},
                  {'params': model.predictor.parameters(), 'name': 'predictor'}]

#training settings:
#num_epochs = 20
lr = 0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=1e-4)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

if sim_fun_flag:
  sim_fun = cross_entropy_sim
else:
  sim_fun = neg_cosine_sim
#name of the experiment to perform
if not stop_grad:
  setup = 'no_stopgrad'#4.1
elif ditch_pred:
  setup = 'no_pred'#4.2a
elif fix_pred:
  setup = 'fixed_pred'#4.2b
#elif pred_lr:
#  setup = 'pred_lr'#4.2c
elif bn_config != 'default':
  setup = 'bn_' + bn_config#4.3
elif batch_size != 256:
  setup = 'batchsize_'+str(batch_size)#4.4
elif sim_fun_flag:
  setup = 'sim_fun'#4.5
elif asymmetric:
  setup = 'asymmetric'#4.6
else:
  setup = 'base'

In [None]:
model= nn.DataParallel(model)
model = model.to(device)
losses = []
for epoch in range(num_epochs):
    epoch_loss, epoch_count = 0, 0
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr, pred_lr=pred_lr)
    for i, (images, _) in enumerate(train_loader):
        start = time.time()
        x1, x2 = images
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, z1, z2 = model(x1, x2)
        if asymmetric:          
          loss = sim_fun(p1, z2).mean()
        else:
          loss = (sim_fun(p1, z2).mean() + sim_fun(p2, z1).mean()) * 0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_count += 1
        if(i%10==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    losses.append(epoch_loss/epoch_count)
    if setup=='base' or setup=='no_stopgrad' or epoch>=num_epochs-1 or epochs%20==0:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.encoder.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"{setup}_checkpoint_{epoch}.pth")

In [None]:
print("here we would like to learn something about " + setup + "!!!!!!!!!")
print(losses)

#Plots and stuff

In [None]:
#import torch
#import matplotlib.pyplot as plt

In [None]:
#checkpoint_19 = torch.load('checkpoint_19.pth')

In [None]:
#checkpoint_19.keys()

In [None]:
#losses = [0.0321]
#print(f"Epoch 0 Loss {losses[0]}")
#for i in range(20):
#    checkpoint = torch.load(f'checkpoint_{i}.pth')
#    loss = checkpoint["loss"]
#    print(f"Epoch {i+1} Loss {loss}")
#    losses.append(loss.cpu().detach().numpy().item())

In [None]:
#plt.plot(losses)