In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from PIL import Image

import pandas as pd
import numpy as np

from metrics import Metrics

import matplotlib.pyplot as plt

In [None]:
class MNISTVanilla(Dataset):
    def __init__(self, df):
        self.df = df

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    torch.tensor([33.79 / 255.0 for _ in range(3)]),
                    torch.tensor([79.17 / 255.0 for _ in range(3)]),
                ),
            ]
        )

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row["Image"])
        img = img.convert("RGB")

        return self.transform(img), torch.tensor([row["Label"]])

In [None]:
df = pd.read_csv("../data/mnist.csv")
df["Image"] = df["Image"].apply(lambda x: f"../data/{x}")
df.head()

In [None]:
mnist_vanilla = MNISTVanilla(df)

In [None]:
batch_size = 32

In [None]:
loader = DataLoader(mnist_vanilla, batch_size=batch_size, num_workers=8, shuffle=True)

In [None]:
input_size = mnist_vanilla[0][0].shape.numel()  # 3 * 28 * 28
hidden_sizes = [128, 64, 32]
classes = df["Label"].unique()
num_classes = classes.shape[0]

In [None]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(input_size, hidden_sizes[0]),
    nn.ReLU(),
    nn.Linear(hidden_sizes[0], hidden_sizes[1]),
    nn.ReLU(),
    nn.Linear(hidden_sizes[1], hidden_sizes[2]),
    nn.ReLU(),
    nn.Linear(hidden_sizes[2], num_classes),
    nn.Softmax(dim=1),
)

In [None]:
model

In [None]:
optimizer = optim.Adam(lr=0.001, params=model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [None]:
num_epochs = 15

In [None]:
losses = [None] * num_epochs
for idx, epoch in enumerate(range(num_epochs)):
    metrics = Metrics(classes, len(mnist_vanilla))
    model.train()
    for X, y in loader:
        # Size([32, 1]) -> Size([32]), necessary for CrossEntropyLoss
        # - See https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
        targets = y.squeeze(1)

        outputs = model(X)
        loss = loss_fn(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predictions = outputs.clone().detach().argmax(dim=1)
        metrics.accumulate_batch_metrics(
            loss.clone().detach().item(), targets, predictions
        )

    train_metrics_dict = metrics.compute_epoch_metrics()
    print(f"Epoch: {epoch}")
    print(f"Loss: {train_metrics_dict['loss']}")
    losses[idx] = train_metrics_dict["loss"]

In [None]:
plt.plot(list(range(num_epochs)), losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")

In [None]:
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": train_metrics_dict["loss"],
        "train_accuracy": train_metrics_dict["accuracy"],
        "train_precision": train_metrics_dict["precision"],
        "train_recall": train_metrics_dict["recall"],
        "train_f1": train_metrics_dict["f1"],
        "train_confusion_matrix": train_metrics_dict["confusion_matrix"],
    },
    "basic-linear-epoch-14.tar",
)

In [None]:
train_metrics_dict["accuracy"]  # Not relevant, model is likely very overfit