<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn as nn

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [15]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)

  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)

  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

In [16]:
def modify_resnet18_for_cifar10(model):
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)

    return model


In [17]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [18]:
checkpoint_filenames = ['checkpoint_0000.pth.tar']
# checkpoint_filenames = ['checkpoint_0050.pth.tar', 'checkpoint_0100.pth.tar', 'checkpoint_0150.pth.tar' ]
# checkpoint_filenames = ['checkpoint_0050.pth.tar', 'checkpoint_0100.pth.tar', 'checkpoint_0150.pth.tar', 'checkpoint_0200.pth.tar', 'checkpoint_0250.pth.tar', 'checkpoint_0300.pth.tar' ]

for checkpoint_filename in checkpoint_filenames:

  print(f"\n\n\n\n\n########################checkpoint {checkpoint_filename}##########################\n")

  modified_resnet18 = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
  model = modify_resnet18_for_cifar10(modified_resnet18).to(device)


  checkpoint = torch.load(f'{checkpoint_filename}', map_location=device)
  # checkpoint = torch.load('runs/Apr23_11-52-29_Samme/checkpoint_0100.pth.tar', map_location=device)
  state_dict = checkpoint['state_dict']

  for k in list(state_dict.keys()):

    if k.startswith('backbone.'):
      if k.startswith('backbone') and not k.startswith('backbone.fc'):
        # remove prefix
        state_dict[k[len("backbone."):]] = state_dict[k]
    del state_dict[k]

  log = model.load_state_dict(state_dict, strict=False)
  assert log.missing_keys == ['fc.weight', 'fc.bias']

  train_loader, test_loader = get_cifar10_data_loaders(download=True)

  # freeze all layers but the last fc
  for name, param in model.named_parameters():
      if name not in ['fc.weight', 'fc.bias']:
          param.requires_grad = False

  parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
  assert len(parameters) == 2  # fc.weight, fc.bias

  optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
  criterion = torch.nn.CrossEntropyLoss().to(device)

  epochs = 100
  for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
      x_batch = x_batch.to(device)
      y_batch = y_batch.to(device)

      logits = model(x_batch)
      loss = criterion(logits, y_batch)
      top1 = accuracy(logits, y_batch, topk=(1,))
      top1_train_accuracy += top1[0]

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    top1_train_accuracy /= (counter + 1)
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
      x_batch = x_batch.to(device)
      y_batch = y_batch.to(device)

      logits = model(x_batch)

      top1, top5 = accuracy(logits, y_batch, topk=(1,5))
      top1_accuracy += top1[0]
      top5_accuracy += top5[0]

    top1_accuracy /= (counter + 1)
    top5_accuracy /= (counter + 1)
    print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")






########################checkpoint checkpoint_0000.pth.tar##########################

Files already downloaded and verified
Files already downloaded and verified
Epoch 0	Top1 Train accuracy 32.97552490234375	Top1 Test accuracy: 37.60397720336914	Top5 test acc: 87.5608901977539
Epoch 1	Top1 Train accuracy 38.27248001098633	Top1 Test accuracy: 39.330196380615234	Top5 test acc: 88.32548522949219
Epoch 2	Top1 Train accuracy 39.52367401123047	Top1 Test accuracy: 40.341224670410156	Top5 test acc: 88.73448944091797
Epoch 3	Top1 Train accuracy 40.45280456542969	Top1 Test accuracy: 40.79159164428711	Top5 test acc: 89.03722381591797
Epoch 4	Top1 Train accuracy 41.1224479675293	Top1 Test accuracy: 41.298255920410156	Top5 test acc: 89.42670440673828
Epoch 5	Top1 Train accuracy 41.660552978515625	Top1 Test accuracy: 41.794002532958984	Top5 test acc: 89.72081756591797
Epoch 6	Top1 Train accuracy 42.226959228515625	Top1 Test accuracy: 42.15303421020508	Top5 test acc: 90.02240753173828
Epoch 7	To

In [19]:
# Do Not Use RUN ALL to run following code. They are designed for "random initialized" weights to compare with SimCLR

class StopExecution(Exception):
    def _render_traceback_(self):
        pass

raise StopExecution("Intentionally stopping the notebook here.")

StopExecution: Intentionally stopping the notebook here.

In [21]:
print(f"\n\n\n\n\n######################## random initialized ##########################\n")

modified_resnet18 = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
model = modify_resnet18_for_cifar10(modified_resnet18).to(device)

# Initialize the weights
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        nn.init.normal_(module.weight, std=0.01)
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)

train_loader, test_loader = get_cifar10_data_loaders(download=True)

# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)

    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]

  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")






######################## random initialized ##########################





Files already downloaded and verified
Files already downloaded and verified


  self.pid = os.fork()


Epoch 0	Top1 Train accuracy 17.312658309936523	Top1 Test accuracy: 21.968061447143555	Top5 test acc: 73.88844299316406
Epoch 1	Top1 Train accuracy 23.674665451049805	Top1 Test accuracy: 25.62729835510254	Top5 test acc: 77.49885559082031
Epoch 2	Top1 Train accuracy 26.316564559936523	Top1 Test accuracy: 27.527572631835938	Top5 test acc: 79.31468963623047
Epoch 3	Top1 Train accuracy 28.10746192932129	Top1 Test accuracy: 28.795957565307617	Top5 test acc: 80.49172973632812
Epoch 4	Top1 Train accuracy 29.380977630615234	Top1 Test accuracy: 29.636947631835938	Top5 test acc: 81.0862808227539
Epoch 5	Top1 Train accuracy 30.26227569580078	Top1 Test accuracy: 30.388900756835938	Top5 test acc: 81.591796875
Epoch 6	Top1 Train accuracy 31.033960342407227	Top1 Test accuracy: 30.868566513061523	Top5 test acc: 81.9215316772461
Epoch 7	Top1 Train accuracy 31.548946380615234	Top1 Test accuracy: 31.403379440307617	Top5 test acc: 82.32939147949219
Epoch 8	Top1 Train accuracy 31.979032516479492	Top1 Test a