In [24]:
#from google.colab import drive
#drive.mount('/content/drive', force_remount=True)

In [25]:
import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
import time
dir = './LinEval'

# Defining and Loading Classifier w/Backbone

In [26]:
class LinearClassifier(nn.Module):
    def __init__(self, encoder, dim, last_dim):
        super().__init__()
        #conditions:
        self.ldim = last_dim

        #backbone encoder
        self.encoder = encoder
        #projector (classifier)
        self.projector = nn.Sequential(nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, last_dim, bias=False),
                                      nn.BatchNorm1d(last_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(last_dim, dim, bias=False),
                                      nn.BatchNorm1d(dim))        

    def forward(self, x):
        return self.projector(self.encoder(x).view(-1, self.ldim))


In [27]:
!nvidia-smi

Wed Apr 20 20:40:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 511.23       Driver Version: 511.23       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:02:00.0 Off |                  N/A |
| N/A   36C    P8     2W /  N/A |   3493MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [28]:
import torch
torch.cuda.is_available()

True

In [29]:
n_classes = 100
pred_dim = 512
method = 'byol'#'simsiam'

if method=='simsiam':
  path = './asymmetric_checkpoint_19.pth'
  #load simsiam
  backbone = models.resnet18()
  encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))
  last_dim=list(backbone.children())[-1].in_features
  checkpoint = torch.load(path)
  encoder.load_state_dict(checkpoint['model_state_dict'])
elif method=='byol':
  path = './improved-net-epoch-99.pt'
  #load BYOL
  backbone = models.resnet18()
  last_dim = list(backbone.children())[-1].in_features
  #backbone_alt_w = torch.load()['state_dict']
  checkpoint = torch.load(path)
  backbone.load_state_dict(checkpoint)
  encoder = torch.nn.Sequential(*(list(backbone.children())[:-1]))

#settings:
batch_size = 256#128#512#1024#2048#4096
loss_fun = nn.CrossEntropyLoss()

In [30]:
if method=='simsiam':
  model = LinearClassifier(encoder, n_classes, last_dim)#simsiam
elif method=='byol':
  model = LinearClassifier(encoder, n_classes, last_dim)#byol

for p in model.encoder.parameters():
  p.requires_grad = False
param_groups = [{'params': model.projector.parameters(), 'name': 'projector'}]

#training settings:
num_epochs = 90
lr = 30.0#0.05*(batch_size/256)
optimizer = torch.optim.SGD(param_groups, lr=lr, momentum=0.9, weight_decay=0)
device = torch.device('cuda')# if torch.cuda.is_available() else 'cpu')

In [31]:
#cosine schedule for LR
import math
def update_lr(optimizer, current_e, total_e, max_lr, min_lr=0, pred_lr=False):
  for g in optimizer.param_groups:
    if g['name'] == 'predictor' and pred_lr:
      pass
    else:
      g['lr'] = min_lr + 0.5*(max_lr-min_lr)*(1  + math.cos((math.pi*current_e)/total_e))

# Defining Data Loaders

In [32]:
from PIL import ImageFilter
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

traindir = "data"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
train_dataset = datasets.CIFAR100(root='./data', train=True,
                                    download=True, transform=train_transform)

val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
val_dataset = datasets.CIFAR100(root='./data', train=False,
                                    download=True, transform=val_transform)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size, shuffle=True,
        num_workers=1, pin_memory=True, sampler=None, drop_last=True)

val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size, shuffle=False,
        num_workers=1, pin_memory=True, sampler=None, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [33]:
#helper function from MoCo code
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# Train Loop

In [None]:
model = model.to(device)
best_acc = 0
for epoch in range(num_epochs):
    model.train()
    update_lr(optimizer=optimizer, current_e=epoch, 
              total_e=num_epochs, max_lr=lr)
    epoch_loss, epoch_acc1, epoch_acc5, epoch_count = 0, 0, 0, 0
    for i, (x, y) in enumerate(train_loader):
        start = time.time()
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fun(pred, y)
        acc1, acc5 = accuracy(pred, y, topk=(1, 5))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        epoch_acc1 += acc1
        epoch_acc5 += acc5
        epoch_count += 1
        if(i%int(len(train_loader)/10)==0):
          message = f"epoch={epoch}/{num_epochs} step={i}/{len(train_loader)} loss={loss} time={time.time()-start} secs"
          print(message)
    epoch_loss /= epoch_count
    epoch_acc1 /= epoch_count
    epoch_acc5 /= epoch_count
    print("train loss: {} top-1 acc: {} top-5 acc: {}".format(epoch_loss, epoch_acc1, epoch_acc5))
    if True:
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.projector.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          }, dir + f"linear_classifier_{method}_checkpoint_last.pth")
    model.eval()
    epoch_loss, epoch_acc1, epoch_acc5, epoch_count = 0, 0, 0, 0
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = loss_fun(pred, y)
            acc1, acc5 = accuracy(pred, y, topk=(1, 5))
            epoch_loss += loss
            epoch_acc1 += acc1
            epoch_acc5 += acc5
            epoch_count += 1
        epoch_loss /= epoch_count
        epoch_acc1 /= epoch_count
        epoch_acc5 /= epoch_count
        if epoch_acc1 > best_acc:
          best_acc = epoch_acc1
          #torch.save({
          #  'epoch': epoch,
          #  'model_state_dict': model.projector.state_dict(),
          #  'optimizer_state_dict': optimizer.state_dict(),
          #  'loss': loss,
          #  }, dir + f"linear_classifier_{method}_bigbatch_BEST.pth")
        print("val loss: {} top-1 acc: {} ({}) top-5 acc: {}".format(epoch_loss, epoch_acc1, best_acc, epoch_acc5))

epoch=0/90 step=0/195 loss=4.993008136749268 time=0.7230675220489502 secs
epoch=0/90 step=19/195 loss=6.210032939910889 time=0.8497641086578369 secs
epoch=0/90 step=38/195 loss=6.261281967163086 time=0.8508288860321045 secs
epoch=0/90 step=57/195 loss=5.891622066497803 time=0.85172438621521 secs
epoch=0/90 step=76/195 loss=5.5557780265808105 time=0.8537867069244385 secs
epoch=0/90 step=95/195 loss=5.594282627105713 time=0.8539254665374756 secs
epoch=0/90 step=114/195 loss=6.081092834472656 time=0.8570160865783691 secs
epoch=0/90 step=133/195 loss=5.373440742492676 time=0.8549749851226807 secs
epoch=0/90 step=152/195 loss=5.33833122253418 time=0.8567118644714355 secs
epoch=0/90 step=171/195 loss=5.3870368003845215 time=0.8585684299468994 secs
epoch=0/90 step=190/195 loss=5.61759090423584 time=0.8598084449768066 secs
train loss: 5.912209510803223 top-1 acc: tensor([1.1398], device='cuda:0') top-5 acc: tensor([5.3606], device='cuda:0')
val loss: 5.275754928588867 top-1 acc: tensor([1.0016

epoch=7/90 step=133/195 loss=5.23885440826416 time=0.8707032203674316 secs
epoch=7/90 step=152/195 loss=7.916139602661133 time=0.8701667785644531 secs
epoch=7/90 step=171/195 loss=5.738955020904541 time=0.8717031478881836 secs
epoch=7/90 step=190/195 loss=6.08655309677124 time=0.8710854053497314 secs
train loss: 5.805164813995361 top-1 acc: tensor([1.0036], device='cuda:0') top-5 acc: tensor([5.2925], device='cuda:0')
val loss: 5.438546180725098 top-1 acc: tensor([1.0917], device='cuda:0') (tensor([1.3622], device='cuda:0')) top-5 acc: tensor([5.1783], device='cuda:0')
epoch=8/90 step=0/195 loss=5.314542293548584 time=0.538527250289917 secs
epoch=8/90 step=19/195 loss=5.34110689163208 time=0.8717730045318604 secs
epoch=8/90 step=38/195 loss=5.520364284515381 time=0.8706746101379395 secs
epoch=8/90 step=57/195 loss=5.722893238067627 time=0.8716723918914795 secs
epoch=8/90 step=76/195 loss=6.718998908996582 time=0.8696436882019043 secs
epoch=8/90 step=95/195 loss=6.213150501251221 time=0

epoch=15/90 step=0/195 loss=6.501781463623047 time=0.5783774852752686 secs
epoch=15/90 step=19/195 loss=6.442302703857422 time=0.8714077472686768 secs
epoch=15/90 step=38/195 loss=6.058253765106201 time=0.8729758262634277 secs
epoch=15/90 step=57/195 loss=5.133721828460693 time=0.8716025352478027 secs
epoch=15/90 step=76/195 loss=5.31650972366333 time=0.8722262382507324 secs
epoch=15/90 step=95/195 loss=5.432584762573242 time=0.8717193603515625 secs
epoch=15/90 step=114/195 loss=5.261455059051514 time=0.8718867301940918 secs
epoch=15/90 step=133/195 loss=5.795462131500244 time=0.8722867965698242 secs
epoch=15/90 step=152/195 loss=5.489991664886475 time=0.8722326755523682 secs
epoch=15/90 step=171/195 loss=5.3505401611328125 time=0.8723506927490234 secs
epoch=15/90 step=190/195 loss=6.055126190185547 time=0.872105598449707 secs
train loss: 5.703893184661865 top-1 acc: tensor([1.0737], device='cuda:0') top-5 acc: tensor([5.4207], device='cuda:0')
val loss: 5.405588150024414 top-1 acc: te

epoch=22/90 step=114/195 loss=6.369888782501221 time=0.8741693496704102 secs
epoch=22/90 step=133/195 loss=5.622199058532715 time=0.8726699352264404 secs
epoch=22/90 step=152/195 loss=5.303056716918945 time=0.8721344470977783 secs
epoch=22/90 step=171/195 loss=6.60441255569458 time=0.8735108375549316 secs
epoch=22/90 step=190/195 loss=5.133540630340576 time=0.8745827674865723 secs
train loss: 6.605125427246094 top-1 acc: tensor([1.1759], device='cuda:0') top-5 acc: tensor([5.2344], device='cuda:0')
val loss: 6.321922779083252 top-1 acc: tensor([1.0517], device='cuda:0') (tensor([1.4022], device='cuda:0')) top-5 acc: tensor([6.0998], device='cuda:0')
epoch=23/90 step=0/195 loss=5.822803497314453 time=0.5846962928771973 secs
epoch=23/90 step=19/195 loss=6.497515678405762 time=0.8698959350585938 secs
epoch=23/90 step=38/195 loss=5.521313190460205 time=0.8718104362487793 secs
epoch=23/90 step=57/195 loss=5.272286891937256 time=0.8723552227020264 secs
epoch=23/90 step=76/195 loss=5.17648887

val loss: 5.24670934677124 top-1 acc: tensor([1.0116], device='cuda:0') (tensor([1.4323], device='cuda:0')) top-5 acc: tensor([4.9479], device='cuda:0')
epoch=30/90 step=0/195 loss=5.068747043609619 time=0.634746789932251 secs
epoch=30/90 step=19/195 loss=5.104639530181885 time=0.8729250431060791 secs
epoch=30/90 step=38/195 loss=5.36788272857666 time=0.8712005615234375 secs
epoch=30/90 step=57/195 loss=5.871337890625 time=0.8728659152984619 secs
epoch=30/90 step=76/195 loss=5.194910049438477 time=0.8718984127044678 secs
epoch=30/90 step=95/195 loss=5.120091915130615 time=0.8870820999145508 secs
epoch=30/90 step=114/195 loss=4.7782158851623535 time=0.8741679191589355 secs
epoch=30/90 step=133/195 loss=5.309590816497803 time=0.8706991672515869 secs
epoch=30/90 step=152/195 loss=5.568417549133301 time=0.8731338977813721 secs
epoch=30/90 step=171/195 loss=5.734321117401123 time=0.8725054264068604 secs
epoch=30/90 step=190/195 loss=4.988663196563721 time=0.873666524887085 secs
train loss: 

epoch=37/90 step=76/195 loss=5.724727153778076 time=0.8725199699401855 secs
epoch=37/90 step=95/195 loss=5.43369722366333 time=0.8749537467956543 secs
epoch=37/90 step=114/195 loss=5.57492733001709 time=0.8841750621795654 secs
epoch=37/90 step=133/195 loss=5.0946269035339355 time=0.8731412887573242 secs
epoch=37/90 step=152/195 loss=5.593571662902832 time=0.8741493225097656 secs
epoch=37/90 step=171/195 loss=6.061875820159912 time=0.8735949993133545 secs
epoch=37/90 step=190/195 loss=5.696097373962402 time=0.8727121353149414 secs
train loss: 5.515650272369385 top-1 acc: tensor([1.1719], device='cuda:0') top-5 acc: tensor([5.4447], device='cuda:0')
val loss: 5.395240783691406 top-1 acc: tensor([1.1418], device='cuda:0') (tensor([1.4523], device='cuda:0')) top-5 acc: tensor([5.5489], device='cuda:0')
epoch=38/90 step=0/195 loss=5.630026340484619 time=0.6119594573974609 secs
epoch=38/90 step=19/195 loss=5.478271961212158 time=0.8735888004302979 secs
epoch=38/90 step=38/195 loss=5.07688713

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
checkpoint_19 = torch.load('checkpoint_19.pth')

In [None]:
checkpoint_19.keys()

In [None]:
losses = [0.0321]
print(f"Epoch 0 Loss {losses[0]}")
for i in range(20):
    checkpoint = torch.load(f'checkpoint_{i}.pth')
    loss = checkpoint["loss"]
    print(f"Epoch {i+1} Loss {loss}")
    losses.append(loss.cpu().detach().numpy().item())

In [None]:
plt.plot(losses)