In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
import time
dir = '/content/drive/MyDrive/SIMSIAM/LinEval'

# Defining and Loading Classifier w/Backbone

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, encoder, dim, last_dim):
        super().__init__()
        #conditions:
        self.ldim = last_dim

        #backbone encoder
        self.encoder = encoder
        #projector (classifier)
        self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, dim, bias=False),
                                      nn.BatchNorm1d(dim))        

    def forward(self, x):
        return self.projector(self.encoder(x).view(-1, self.ldim))


In [None]:
n_classes = 100
pred_dim = 512
method = 'byol'#'simsiam'

if method=='simsiam':
  path = '/content/drive/MyDrive/SIMSIAM/LinEval/pred_lr_checkpoint_99.pth'
  #load simsiam
  backbone = models.resnet18()
  encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))
  last_dim=list(backbone.children())[-1].in_features
  checkpoint = torch.load(path)
  encoder.load_state_dict(checkpoint['model_state_dict'])
elif method=='byol':
  path = '/content/drive/MyDrive/SIMSIAM/LinEval/improved-net-epoch-99.pt'
  #load BYOL
  backbone = models.resnet18()
  last_dim = list(backbone.children())[-1].in_features
  #backbone_alt_w = torch.load()['state_dict']
  checkpoint = torch.load(path)
  backbone.load_state_dict(checkpoint)
  encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
batch_size = 256#128#512#1024#2048#4096
loss_fun = nn.CrossEntropyLoss()

In [None]:
if method=='simsiam':
  model = LinearClassifier(encoder, n_classes, last_dim)#simsiam
elif method=='byol':
  model = LinearClassifier(encoder, n_classes, last_dim)#byol

for p in model.encoder.parameters():
  p.requires_grad = False
param_groups = [{'params': model.projector.parameters(), 'name': 'projector'}]

#training settings:
num_epochs = 90
lr = 30.0#0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=0)
"""
if method=='byol': #do it BYOL style
  lr=0.4
  optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=0)
"""
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

In [None]:
#cosine schedule for LR
import math
def update_lr(optimizer, current_e, total_e, max_lr, min_lr=0, pred_lr=False):
  for g in optimizer.param_groups:
    if g['name'] == 'predictor' and pred_lr:
      pass
    else:
      g['lr'] = min_lr + 0.5*(max_lr-min_lr)*(1  + math.cos((math.pi*current_e)/total_e))

# Defining Data Loaders

In [None]:
from PIL import ImageFilter
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

traindir = "data"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
train_dataset = datasets.CIFAR100(root='./data', train=True,
                                    download=True, transform=train_transform)

val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
val_dataset = datasets.CIFAR100(root='./data', train=False,
                                    download=True, transform=val_transform)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size, shuffle=True,
        num_workers=2, pin_memory=True, sampler=None, drop_last=True)

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size, shuffle=False,
        num_workers=2, pin_memory=True, sampler=None, drop_last=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
#helper function from MoCo code
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# Train Loop

In [None]:
model = model.to(device)
best_acc = 0
for epoch in range(num_epochs):
    model.train()
    #update_lr(optimizer=optimizer, current_e=epoch, 
    #          total_e=num_epochs, max_lr=lr)
    epoch_loss, epoch_acc1, epoch_acc5, epoch_count = 0, 0, 0, 0
    for i, (x, y) in enumerate(train_loader):
        start = time.time()
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fun(pred, y)
        acc1, acc5 = accuracy(pred, y, topk=(1, 5))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_acc1 += acc1
        epoch_acc5 += acc5
        epoch_count += 1
        if(i%int(len(train_loader)/10)==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    epoch_loss /= epoch_count
    epoch_acc1 /= epoch_count
    epoch_acc5 /= epoch_count
    print("train loss: {} top-1 acc: {} top-5 acc: {}".format(epoch_loss, epoch_acc1, epoch_acc5))
    if epoch>=num_epochs-1:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.projector.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"linear_classifier_{method}_100eps_checkpoint_{epoch}.pth")
    model.eval()
    epoch_loss, epoch_acc1, epoch_acc5, epoch_count = 0, 0, 0, 0
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = loss_fun(pred, y)
            acc1, acc5 = accuracy(pred, y, topk=(1, 5))
            epoch_loss += loss
            epoch_acc1 += acc1
            epoch_acc5 += acc5
            epoch_count += 1
        epoch_loss /= epoch_count
        epoch_acc1 /= epoch_count
        epoch_acc5 /= epoch_count
        if epoch_acc1 > best_acc or epoch%20==0:
          best_acc = epoch_acc1
          torch.save({
            'epoch': epoch,
            'model_state_dict': model.projector.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, dir + f"linear_classifier_{method}_100eps_BEST.pth")
        print("val loss: {} top-1 acc: {} ({}) top-5 acc: {}".format(epoch_loss, epoch_acc1, best_acc, epoch_acc5))

epoch=0/90 step=0/195 loss=4.391273498535156 time=0.5170595645904541 secs
epoch=0/90 step=19/195 loss=3.4498300552368164 time=0.9318602085113525 secs
epoch=0/90 step=38/195 loss=4.781686305999756 time=0.4518144130706787 secs
epoch=0/90 step=57/195 loss=3.723480224609375 time=0.6138923168182373 secs
epoch=0/90 step=76/195 loss=3.7594730854034424 time=0.45255541801452637 secs
epoch=0/90 step=95/195 loss=4.98921012878418 time=0.8530383110046387 secs
epoch=0/90 step=114/195 loss=5.545228004455566 time=0.4568047523498535 secs
epoch=0/90 step=133/195 loss=4.321620941162109 time=0.8435957431793213 secs
epoch=0/90 step=152/195 loss=4.302101135253906 time=0.43692922592163086 secs
epoch=0/90 step=171/195 loss=4.620236873626709 time=0.8650197982788086 secs
epoch=0/90 step=190/195 loss=3.9206271171569824 time=0.4454622268676758 secs
train loss: 4.167858600616455 top-1 acc: tensor([18.0549], device='cuda:0') top-5 acc: tensor([43.6719], device='cuda:0')
val loss: 3.6538639068603516 top-1 acc: tenso

#Evaluate Pre-trained Classifier

In [None]:
#LOAD:
path = #path to pre-trained lin_classifier#'/content/drive/MyDrive/SIMSIAM/LinEvallinear_classifier_simsiam_checkpoint_89.pth'
checkpoint = torch.load(path)
model.projector.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
#settings:
batch_size = 64#128#512#1024#2048#4096
loss_fun = nn.CrossEntropyLoss()
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

In [None]:
model = model.to(device)
model.eval()
epoch_loss, epoch_acc1, epoch_acc5, epoch_count = 0, 0, 0, 0
with torch.no_grad():
    for i, (x, y) in enumerate(val_loader):
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fun(pred, y)
        acc1, acc5 = accuracy(pred, y, topk=(1, 5))
        epoch_loss += loss
        epoch_acc1 += acc1
        epoch_acc5 += acc5
        epoch_count += 1
    epoch_loss /= epoch_count
    epoch_acc1 /= epoch_count
    epoch_acc5 /= epoch_count
    print("val loss: {} top-1 acc: {} top-5 acc: {}".format(epoch_loss, epoch_acc1, epoch_acc5))

val loss: 2.8380870819091797 top-1 acc: tensor([32.4920], device='cuda:0') top-5 acc: tensor([62.9006], device='cuda:0')
