In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from PIL import Image

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

In [None]:
class Metrics:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.loss = 0.0
        self.accuracy = 0.0
        self.precision = np.array([0.0 for _ in range(num_classes)])
        self.recall = np.array([0.0 for _ in range(num_classes)])
        self.f1 = np.array([0.0 for _ in range(num_classes)])

    def update_loss(self, loss):
        self.loss += loss

    def update_metrics(self, targets, predictions):
        self.accuracy += accuracy_score(targets, predictions)
        prec_rec_f1 = precision_recall_fscore_support(
            targets, predictions, labels=list(range(self.num_classes))
        )

        self.precision += prec_rec_f1[0]
        self.recall += prec_rec_f1[1]
        self.f1 += prec_rec_f1[2]

    def compute_metrics(self, loader_size):
        return (
            self.loss / loader_size,
            self.accuracy / loader_size,
            self.precision / loader_size,
            self.recall / loader_size,
            self.f1 / loader_size,
        )

In [None]:
class MNISTVanilla(Dataset):
    def __init__(self, df):
        self.df = df

        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    torch.tensor([33.79 / 255.0 for _ in range(3)]),
                    torch.tensor([79.17 / 255.0 for _ in range(3)]),
                ),
            ]
        )

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row["Image"])
        img = img.convert("RGB")

        return self.transform(img), torch.tensor([row["Label"]])

In [None]:
df = pd.read_csv("../data/mnist.csv")
df["Image"] = df["Image"].apply(lambda x: f"../data/{x}")
df = df.head(100)
df.head()

In [None]:
mnist_vanilla = MNISTVanilla(df)

In [None]:
batch_size = 32

In [None]:
loader = DataLoader(mnist_vanilla, batch_size=batch_size, num_workers=8, shuffle=True)

In [None]:
num_classes = df["Label"].unique().shape[0]

In [None]:
# Expects input img of Size([3, 224, 224])!
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)

In [None]:
optimizer = optim.Adam(lr=0.001, params=model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [None]:
num_epochs = 5

In [None]:
losses = [None] * num_epochs
for idx, epoch in enumerate(range(num_epochs)):
    metrics = Metrics(num_classes)
    model.train()
    for X, y in loader:
        # Size([32, 1]) -> Size([32]), necessary for CrossEntropyLoss
        # - See https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
        targets = y.squeeze(1)

        outputs = model(X)
        loss = loss_fn(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics.update_loss(loss.clone().detach().item())
        predictions = outputs.clone().detach().argmax(dim=1)
        metrics.update_metrics(targets, predictions)

    t_loss, t_acc, t_prec, t_rec, t_f1 = metrics.compute_metrics(len(loader))
    print(f"Epoch: {epoch}")
    print(f"Loss: {t_loss}")
    losses[idx] = t_loss

In [None]:
plt.plot(list(range(num_epochs)), losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")

In [None]:
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": t_loss,
        "train_accuracy": t_acc,
        "train_precision": t_prec,
        "train_recall": t_rec,
        "train_f1": t_f1,
    },
    "resnet18-transfer-epoch-4.tar",
)

In [None]:
t_acc