In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision
import numpy as np
from datasets import load_dataset
import matplotlib.pyplot as plt



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define hyperparameters
batch_size = 32
epochs = 10
learning_rate = 3e-4
weight_decay = 0.0008

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:


# Data preprocessing and loading
def prepare_datasets(name):
    if name == "tiny-imagenet":
        
        train_dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
        test_dataset = load_dataset("zh-plus/tiny-imagenet", split="valid")

        train_transform = transforms.Compose([
            transforms.ToTensor()
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])

        train_dataset = CustomDataset(train_dataset, transform=train_transform)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        test_dataset = CustomDataset(test_dataset, transform=test_transform)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    elif name == "stl10":
        train_dataset = datasets.STL10('/Users/siharini/github/DL-Project/src/data', split='train', download=False,
                                       transform=transforms.ToTensor())
        train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                  num_workers=0, drop_last=False, shuffle=True)

        test_dataset = datasets.STL10('/Users/siharini/github/DL-Project/src/data', split='test', download=False,
                                      transform=transforms.ToTensor())

        test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                                 num_workers=10, drop_last=False, shuffle=True)

    return train_loader, test_loader
    

class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        if self.transform:
            sample["image"] = self.transform(sample["image"])
        if sample["image"].shape[0] == 1:
            sample["image"] = sample["image"].repeat(3, 1, 1)
        return sample






In [4]:
# Model definition and setup
def initialize_model(num_classes):
    model = torchvision.models.resnet18(pretrained=False, num_classes=num_classes).to(device)
    return model

def load_pretrained_weights(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    state_dict = checkpoint["state_dict"]

    for k in list(state_dict.keys()):
        if k.startswith("backbone."):
            if k.startswith("backbone") and not k.startswith("backbone.fc"):
                # remove prefix
                state_dict[k[len("backbone.") :]] = state_dict[k]
        del state_dict[k]

    log = model.load_state_dict(state_dict, strict=False)
    assert log.missing_keys == ["fc.weight", "fc.bias"]

    # Freeze all layers but the last fully connected layer
    for name, param in model.named_parameters():
        if name not in ["fc.weight", "fc.bias"]:
            param.requires_grad = False

    return model

In [5]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import seaborn as sns

def train(model, train_loader, test_loader, criterion, optimizer, epochs):
    train_losses = []
    train_top1_accuracies = []
    test_top1_accuracies = []
    test_top5_accuracies = []

    for epoch in range(epochs):
        model.train()
        epoch_train_loss = 0.0
        top1_train_accuracy = 0

        for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

        epoch_train_loss /= (counter + 1)
        train_losses.append(epoch_train_loss)
        top1_train_accuracy /= (counter + 1)
        train_top1_accuracies.append(top1_train_accuracy)

        top1_accuracy, top5_accuracy = evaluate(model, test_loader)
        test_top1_accuracies.append(top1_accuracy)
        test_top5_accuracies.append(top5_accuracy)

        print(f"Epoch {epoch}\tTrain Loss: {epoch_train_loss:.4f}\tTop1 Train Accuracy: {top1_train_accuracy:.2f}%\tTop1 Test Accuracy: {top1_accuracy:.2f}%\tTop5 Test Accuracy: {top5_accuracy:.2f}%")

    plot_loss_curves(train_losses)
    plot_accuracy_curves(train_top1_accuracies, test_top1_accuracies, test_top5_accuracies)

def plot_loss_curves(train_losses):
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

def plot_accuracy_curves(train_top1_accuracies, test_top1_accuracies, test_top5_accuracies):
    plt.plot(train_top1_accuracies, label='Training Top-1 Accuracy')
    plt.plot(test_top1_accuracies, label='Testing Top-1 Accuracy')
    plt.plot(test_top5_accuracies, label='Testing Top-5 Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy Curve')
    plt.legend()
    plt.show()

def evaluate(model, test_loader):
    model.eval()
    top1_accuracy = 0
    top5_accuracy = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)
            top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]

            _, predicted = torch.max(logits, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(y_batch.cpu().numpy())

    top1_accuracy /= (counter + 1)
    top5_accuracy /= (counter + 1)

    # Compute confusion matrix
    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

    # Compute precision, recall, and F1-score
    precision, recall, f1_score, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1_score:.2f}")

    return top1_accuracy, top5_accuracy

In [6]:

# Prepare datasets
dataset_name = "stl10"
num_classes = 200 if dataset_name == "tiny-imagenet" else 10
train_loader, test_loader = prepare_datasets(dataset_name)


# Initialize model
model = initialize_model(num_classes)

# Load pretrained weights
checkpoint_path = "checkpoint_0100.pth.tar"
model = load_pretrained_weights(model, checkpoint_path)

# Set up criterion and optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Train the model
train(model, train_loader, test_loader, criterion, optimizer, epochs)



Epoch 0	Top1 Train accuracy 41.30175018310547	Top1 Test accuracy: 55.037498474121094	Top5 test acc: 96.38749694824219
Epoch 1	Top1 Train accuracy 58.47929763793945	Top1 Test accuracy: 61.087501525878906	Top5 test acc: 97.38749694824219
Epoch 2	Top1 Train accuracy 61.92277145385742	Top1 Test accuracy: 62.712501525878906	Top5 test acc: 97.38749694824219
Epoch 3	Top1 Train accuracy 62.57961654663086	Top1 Test accuracy: 63.962501525878906	Top5 test acc: 97.73750305175781
Epoch 4	Top1 Train accuracy 64.13216400146484	Top1 Test accuracy: 64.4749984741211	Top5 test acc: 97.7750015258789
Epoch 5	Top1 Train accuracy 64.9283447265625	Top1 Test accuracy: 64.76249694824219	Top5 test acc: 97.80000305175781
Epoch 6	Top1 Train accuracy 65.30652618408203	Top1 Test accuracy: 65.19999694824219	Top5 test acc: 97.9000015258789
Epoch 7	Top1 Train accuracy 65.84394836425781	Top1 Test accuracy: 65.61250305175781	Top5 test acc: 97.9000015258789
Epoch 8	Top1 Train accuracy 66.12261199951172	Top1 Test accuracy:

KeyboardInterrupt: 

In [None]:
evaluate(model, test_loader)