In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall
import matplotlib.pyplot as plt


In [3]:
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

100%|██████████| 26.4M/26.4M [00:00<00:00, 113MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 3.98MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 59.4MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 10.9MB/s]


In [4]:
len(train_data)


60000

In [5]:
len(test_data)

10000

In [6]:
sample_image, sample_label = train_data[0]
print(f" Sample image shape: {sample_image.shape}")
print(f" Sample label type: {type(sample_label)}")

 Sample image shape: torch.Size([1, 28, 28])
 Sample label type: <class 'int'>


In [7]:
print(f"Train data classes {train_data.classes}")
print(f"Train data length {len(train_data.classes)}")
classes = train_data.classes
num_classes = len(classes)

Train data classes ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Train data length 10


In [14]:
num_input_channels = 1
num_output_channels = 16
image_size = train_data[0][0].shape[1]
print(image_size)


28


In [17]:
class ClothesClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ClothesClassifier, self).__init__()
        self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size = 3, stride = 1, padding =1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(num_output_channels * (image_size // 2) ** 2, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

dataloader_train = DataLoader(
    train_data,
    batch_size = 10,
    shuffle = True,
)

def train_model(optimizer, net, num_epochs):
    num_processed = 0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        running_loss = 0
        num_processed = 0

        for features, labels in dataloader_train:
            optimizer.zero_grad()
            outputs = net(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_processed += len(labels)
        print(f"epoch {epoch}, loss: {running_loss / num_processed}")

        train_loss = running_loss / len(dataloader_train)

In [18]:
net = ClothesClassifier(num_classes)
optimizer = optim.Adam(net.parameters(), lr = 0.001)

train_model(
    optimizer = optimizer,
    net = net,
    num_epochs = 1,
    )

dataloader_test = DataLoader(
    test_data,
    batch_size = 10,
    shuffle = False
    )

accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

net.eval()
predictions = []
for i, (features, labels) in enumerate(dataloader_test):
    output = net.forward(features.reshape(-1, 1, image_size, image_size))
    cat = torch.argmax(output, dim=-1)
    predictions.extend(cat.tolist())
    accuracy_metric(cat, labels)
    precision_metric(cat, labels)
    recall_metric(cat, labels)

accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()
print('Accuracy:', accuracy)
print('Precision (per class):', precision)
print('Recall (per class):', recall)

epoch 0, loss: 0.039598041016639524
Accuracy: 0.8823999762535095
Precision (per class): [0.8662131428718567, 0.9937888383865356, 0.8171091675758362, 0.820244312286377, 0.8344017267227173, 0.961614191532135, 0.6610487103462219, 0.9317963719367981, 0.983589768409729, 0.9821615815162659]
Recall (per class): [0.7639999985694885, 0.9599999785423279, 0.8309999704360962, 0.9399999976158142, 0.781000018119812, 0.9769999980926514, 0.7059999704360962, 0.9700000286102295, 0.9589999914169312, 0.9359999895095825]
