In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

os.environ['MKL_THREADING_LAYER'] = 'GNU'
torch.multiprocessing.set_start_method('spawn')

In [2]:
!pip install gdown



In [3]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [4]:
folder_name = 'resnet18_100-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet18_100-epochs_stl10 14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF


In [5]:
# # download and extract model files
# os.system('gdown --continue https://drive.google.com/uc?id={}'.format(file_id))
# os.system('unzip {}'.format(folder_name))
# !ls

In [6]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [8]:
def get_stl10_data_loaders(download, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)
  
  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=False)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=False)
  return train_loader, test_loader

In [9]:
from yaml import BaseLoader, UnsafeLoader
with open(os.path.join('./config.yml')) as file:
  config = yaml.load(file, Loader=UnsafeLoader)

In [10]:
if config.arch == 'resnet18':
  model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)



In [11]:
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]

In [12]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [13]:
%matplotlib widget

sys.path.append('..')
from oads_access.oads_access import *

home = '../../data/oads/mini_oads_2/'
oads = OADS_Access(home)

In [14]:
# get train, val, test split, using crops if specific size
size = (200, 200)
train_data, val_data, test_data = oads.get_train_val_test_split(use_crops=True, min_size=size, max_size=size, exclude_oversized_crops=True, file_formats=['.jpg'])
input_channels = np.array(train_data[0][0]).shape[-1]

output_channels = len(oads.get_class_mapping())
class_index_mapping = {key: index for index, key in enumerate(list(oads.get_class_mapping().keys()))}

100%|██████████| 1/1 [00:09<00:00,  9.33s/it]


In [15]:
len(train_data)

1105

In [16]:
batch_size = 64
num_workers = 6

# Get the custom dataset and dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

traindataset = OADSImageDataset(data=train_data, class_index_mapping=class_index_mapping, transform=transform, device=device)
valdataset = OADSImageDataset(data=val_data, class_index_mapping=class_index_mapping, transform=transform, device=device)
testdataset = OADSImageDataset(data=test_data, class_index_mapping=class_index_mapping, transform=transform, device=device)

train_loader = DataLoader(traindataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(valdataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(testdataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [17]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [19]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [20]:
epochs = 1
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

KeyboardInterrupt: 

: 

In [None]:
top1_accuracy = 0
top5_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
top1_accuracy /= (counter + 1)
top5_accuracy /= (counter + 1)
print(f"Top1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

Top1 Test accuracy: 7.638888835906982	Top5 test acc: 31.11111068725586
