### 0. Setting

In [56]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn as nn

In [57]:
!pip install gdown



In [58]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

Using device: cuda


In [59]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):

    train_dataset = datasets.STL10('./data', split='train', download=download,
                                   transform=transforms.ToTensor())
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              num_workers=0, drop_last=False, shuffle=shuffle)
    test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())
    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
    
    return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
    
    train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              num_workers=0, drop_last=False, shuffle=shuffle)
  
    test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                    transform=transforms.ToTensor())

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                             num_workers=0, drop_last=False, shuffle=shuffle)
    
    return train_loader, test_loader

In [60]:
with open(os.path.join('../runs/Jul27_LARS_lr1.5_Momentum0.9/config.yml')) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)

In [61]:
if config['arch'] == 'resnet18':
    model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config['arch'] == 'resnet50':
    model = torchvision.models.resnet50(pretrained=False, num_classes=10)
    if config['dataset_name'] == 'cifar10':
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1)
        model.maxpool = nn.Identity()
        model.to(device)

### Remove FC weights and bias

In [62]:
checkpoint = torch.load('../runs/Jul27_LARS_lr1.5_Momentum0.9/checkpoint_0500.pth.tar')
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

    if k.startswith('backbone.'):
        if k.startswith('backbone') and not k.startswith('backbone.fc'):
            # remove prefix
            state_dict[k[len("backbone."):]] = state_dict[k]
        del state_dict[k]

In [63]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [64]:
if config['dataset_name'] == 'cifar10':
    train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config['dataset_name'] == 'stl10':
    train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config['dataset_name'])

Files already downloaded and verified
Files already downloaded and verified
Dataset: cifar10


### Freeze all layers except for the last FC layer

In [65]:
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
# fc.weight, fc.bias
assert len(parameters) == 2

In [66]:
from torchlars import LARS

base_optimizer = torch.optim.SGD(model.parameters(), lr=1.5, weight_decay=1e-6, momentum=0.9)
optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
criterion = torch.nn.CrossEntropyLoss().to(device)

### Define evaluation metric: Top-k accuracy

In [67]:
def accuracy(output, target, topk=(1, )):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [68]:
epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

Epoch 0	Top1 Train accuracy 79.49537658691406	Top1 Test accuracy: 86.05066680908203	Top5 test acc: 99.49333953857422
Epoch 1	Top1 Train accuracy 88.22504425048828	Top1 Test accuracy: 87.50115203857422	Top5 test acc: 99.56169891357422
Epoch 2	Top1 Train accuracy 89.39971160888672	Top1 Test accuracy: 88.03596496582031	Top5 test acc: 99.64958953857422
Epoch 3	Top1 Train accuracy 89.90074920654297	Top1 Test accuracy: 88.38005828857422	Top5 test acc: 99.64958953857422
Epoch 4	Top1 Train accuracy 90.32366180419922	Top1 Test accuracy: 88.69026184082031	Top5 test acc: 99.658203125
Epoch 5	Top1 Train accuracy 90.70830535888672	Top1 Test accuracy: 88.779296875	Top5 test acc: 99.6875
Epoch 6	Top1 Train accuracy 90.92753601074219	Top1 Test accuracy: 88.974609375	Top5 test acc: 99.677734375
Epoch 7	Top1 Train accuracy 91.16071319580078	Top1 Test accuracy: 89.1285629272461	Top5 test acc: 99.6875
Epoch 8	Top1 Train accuracy 91.38392639160156	Top1 Test accuracy: 89.21760559082031	Top5 test acc: 99.677