In [2]:
import sys
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

mnist_train = datasets.FashionMNIST(root='./data', download=True, train=True, transform=ToTensor())
mnist_test = datasets.FashionMNIST(root='./data', download=True, train=False, transform=ToTensor())

train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(mnist_test, batch_size=32, shuffle=True)

model = nn.Sequential(
    nn.Conv2d(1, 3, kernel_size=(3, 3), padding=1, padding_mode="reflect"),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(2352, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

# Create the loss function for multi-class classification (use CrossEntropyLoss)
loss_fn = torch.nn.CrossEntropyLoss()

# Create the optimizer (use Adam with lr=0.001) over model.parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop for 10 epochs:
#   - put model in train() mode
#   - iterate over train_dataloader:
#       * convert labels to one-hot with F.one_hot(y, num_classes=10).float()
#       * zero gradients
#       * forward pass
#       * compute loss
#       * backward()
#       * optimizer.step()
#   - accumulate and print the epoch loss
for i in range(0, 10):
    model.train()
    loss_sum = 0.0
    for X, y in train_dataloader:
        y_onehot = F.one_hot(y, num_classes=10).type(torch.float32)

        optimizer.zero_grad()
        outputs = model(X)
        loss = loss_fn(outputs, y_onehot)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
    print(loss_sum)

# Evaluation:
#   - model.eval(), wrap in torch.no_grad()
#   - iterate over test_dataloader:
#       * forward pass
#       * apply softmax over dim=1
#       * compare argmax with ground-truth y to count correct predictions
#   - print validation accuracy as a float in [0, 1]
model.eval()
with torch.no_grad():
    accurate = 0
    total = 0
    for X, y in test_dataloader:
        outputs = nn.functional.softmax(model(X), dim=1)
        correct_pred = (y == outputs.max(dim=1).indices)
        total += correct_pred.size(0)
        accurate += correct_pred.type(torch.int).sum().item()
    print("Accuracy on validation data:", accurate / total)

923.8479181304574
617.6442737728357
535.0794975254685
486.1533539183438
441.2763391789049
402.9044731967151
372.1597451120615
341.3199079912156
316.3318266160786
287.95779645536095
Accuracy on validation data: 0.8967
