<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [2]:
normalize = [(0.12, 0.12, 0.12), (0.19, 0.19, 0.19)]

def get_oct_test_simclr_pipeline_transform():
    """Return a set of data augmentation transformations as described in the SimCLR paper."""
    data_transforms = transforms.Compose(
        [
            # transforms.Resize(size=(224, 224)),
            transforms.Resize(size=(256, 256)),
            transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(*normalize),
        ]
    )
    return data_transforms


def get_oct_simclr_pipeline_transform():
    """Return a set of data augmentation transformations as described in the SimCLR paper."""
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    data_transforms = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(*normalize),
        ]
    )
    return data_transforms


def get_oct_data_loaders(root_path, batch_size=32):
    train_dataset = datasets.ImageFolder(f"{root_path}/train", transform=get_oct_simclr_pipeline_transform())

    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, drop_last=False, shuffle=True)

    test_dataset = datasets.ImageFolder(f"{root_path}/test", transform=get_oct_test_simclr_pipeline_transform())

    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8, drop_last=False, shuffle=True)

    val_dataset = datasets.ImageFolder(f"{root_path}/val", transform=get_oct_test_simclr_pipeline_transform())

    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, drop_last=False, shuffle=True)
    return train_loader, test_loader, val_loader

In [3]:
with open("./config.yml") as file:
    config = yaml.load(file, Loader=yaml.UnsafeLoader)

In [4]:
if config.arch == "resnet18":
    model = torchvision.models.resnet18(pretrained=False, num_classes=4).to(device)
elif config.arch == "resnet50":
    model = torchvision.models.resnet50(pretrained=False, num_classes=4).to(device)



In [5]:
checkpoint = torch.load(config.checkpoint_path, map_location=device)
state_dict = checkpoint["state_dict"]

for k in list(state_dict.keys()):
    if k.startswith("backbone."):
        if k.startswith("backbone") and not k.startswith("backbone.fc"):
            # remove prefix
            state_dict[k[len("backbone.") :]] = state_dict[k]
    del state_dict[k]

In [6]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [7]:
if config.dataset_name == "oct":
    train_loader, test_loader, val_loader = get_oct_data_loaders(config.dataset_path, config.batch_size)
print("Dataset:", config.dataset_name)

Dataset: oct


In [8]:
# # freeze all layers but the last fc
# for name, param in model.named_parameters():
#     if name not in ['fc.weight', 'fc.bias']:
#         param.requires_grad = False

# parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
# assert len(parameters) == 2  # fc.weight, fc.bias

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [10]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [11]:
train_loss_arr, train_acc_arr, test_acc_arr, val_acc_arr = [], [], [], []
from tqdm import tqdm
import csv

epochs = 150
for epoch in range(epochs):
    top1_train_accuracy = 0
    train_loss = 0
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0].item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss_arr.append(train_loss / len(train_loader))
    train_acc_arr.append(top1_train_accuracy / len(train_loader))


    top1_accuracy = 0
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits = model(x_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_accuracy += top1[0].item()
    test_acc_arr.append(top1_accuracy / len(test_loader))

    top1_accuracy = 0
    for x_batch, y_batch in val_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits = model(x_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_accuracy += top1[0].item()
    val_acc_arr.append(top1_accuracy / len(val_loader))

    print(f"Epoch: {epoch}, train_loss: {train_loss_arr[-1]}, train_acc: {train_acc_arr[-1]}, test_acc: {test_acc_arr[-1]}, val_acc: {val_acc_arr[-1]}")

root_path = "./csv"
name = config.name
if not os.path.exists(root_path):
    os.makedirs(root_path)
    
with open(f"{root_path}/{name}.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["train_loss", "train_acc", "test_acc", "val_acc"])
    for i in range(len(train_loss_arr)):
        writer.writerow([train_loss_arr[i], train_acc_arr[i], test_acc_arr[i], val_acc_arr[i]])

Epoch: 0, train_loss: 1.2629941701889038, train_acc: 42.0386905670166, test_acc: 48.979591369628906, val_acc: 38.095237731933594
