In [18]:
from datasets import load_dataset
from torchvision import transforms
import torch
import sys
import numpy as np
import os
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader

In [24]:
batch_size = 32

In [3]:
# Load the dataset
train_dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
test_dataset = load_dataset("zh-plus/tiny-imagenet", split="valid")

In [4]:
class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        if self.transform:
            sample["image"] = self.transform(sample["image"])
        if sample["image"].shape[0] == 1:
            sample["image"] = sample["image"].repeat(3, 1, 1)
        return sample


train_dataset = CustomDataset(train_dataset, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = CustomDataset(test_dataset, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [21]:
from torchvision import datasets

train_dataset = datasets.STL10(
    "/Users/siharini/github/DL-Project/src/data",
    split="train",
    download=False,
    transform=transforms.ToTensor(),
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=0, drop_last=False, shuffle=True
)

test_dataset = datasets.STL10(
    "/Users/siharini/github/DL-Project/src/data",
    split="test",
    download=False,
    transform=transforms.ToTensor(),
)

test_loader = DataLoader(
    test_dataset,
    batch_size=2 * batch_size,
    num_workers=10,
    drop_last=False,
    shuffle=True,
)



In [None]:
train_dataset[0][0]

In [17]:
next(iter(train_loader)).keys()

dict_keys(['image', 'label'])

In [25]:
device = "cpu"
model = torchvision.models.resnet18(pretrained=False, num_classes=200).to(device)
checkpoint = torch.load("checkpoint_0100.pth.tar", map_location=device)
state_dict = checkpoint["state_dict"]

for k in list(state_dict.keys()):

    if k.startswith("backbone."):
        if k.startswith("backbone") and not k.startswith("backbone.fc"):
            # remove prefix
            state_dict[k[len("backbone.") :]] = state_dict[k]
    del state_dict[k]

log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ["fc.weight", "fc.bias"]
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias



In [26]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [27]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
epochs = 10
for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):

        x_batch = x_batch
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= counter + 1
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

    top1_accuracy /= counter + 1
    top5_accuracy /= counter + 1
    print(
        f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}"
    )