In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

from torchvision import datasets
import torchvision.transforms as transforms

from PIL import Image
import matplotlib.pyplot as plt

# Load DataSets

In [None]:
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

## Define the training set DataLoader

In [None]:
dataloader_train = DataLoader(
    train_data,
    batch_size=1,
    shuffle=True,
)

In [None]:
image, label = next(iter(dataloader_train))
print(image.shape)


In [None]:
image = image.permute(2, 3, 1, 0)
print(image.shape)

In [None]:
image = image[:, :, 0]
print(image.shape)

In [None]:
plt.imshow(image)
plt.show()

## Define the test set DataLoader

In [None]:
dataloader_test = DataLoader(
    test_data,
    batch_size=10,
    shuffle=False,
)

# Create the Convolutional Neural Network

## Get the number of classes

In [None]:
classes = train_data.classes
num_classes = len(train_data.classes)

print(classes)
print(num_classes)

## Define some relevant variables

In [None]:
num_input_channels = 1
num_output_channels = 16
image_size = train_data[0][0].shape[1]

print(image_size)

## Define CNN

In [None]:
class MultiClassImageClassifier(nn.Module):
  
    def __init__(self, num_classes):
        super(MultiClassImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

        self.fc = nn.Linear(num_output_channels * (image_size//2)**2, num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Training the model

## Define training function

In [None]:
def train_model(optimizer, net, num_epochs):
    num_processed = 0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        running_loss = 0
        num_processed = 0
        for features, labels in dataloader_train:
            optimizer.zero_grad()
            outputs = net(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_processed += len(labels)
        print(f'epoch {epoch}, loss: {running_loss / num_processed}')
        
    train_loss = running_loss / len(dataloader_train)

## Train for 1 epoch

In [65]:
net = MultiClassImageClassifier(num_classes)
optimizer = optim.Adam(net.parameters(), lr=0.001)

train_model(
    optimizer=optimizer,
    net=net,
    num_epochs=1,
)

# Testing the model

## Define the metrics

In [None]:
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

## Run model on test set

In [None]:
net.eval()
predicted = []
for i, (features, labels) in enumerate(dataloader_test):
    output = net.forward(features.reshape(-1, 1, image_size, image_size))
    cat = torch.argmax(output, dim=-1)
    predicted.extend(cat.tolist())
    accuracy_metric(cat, labels)
    precision_metric(cat, labels)
    recall_metric(cat, labels)

## Compute the metrics

In [None]:
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()
print('Accuracy:', accuracy)
print('Precision (per class):', precision)
print('Recall (per class):', recall)