In [1]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn

### Import Fashion MNIST Dataset

In [2]:
# Download and load the MNIST dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform= ToTensor())
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=ToTensor())

### Dataloaders

In [3]:
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size= batch_size, shuffle=False)

### Define base CNN architecture

In [4]:
# Define a simple CNN model
class BaseCNN(nn.Module):
    def __init__(self):
        super(BaseCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

In [5]:
loss_fn = nn.CrossEntropyLoss()
device = torch.device('cuda')
model = BaseCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Train Function

In [6]:
def train(train_loader, model, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        # Calculate average training loss for the epoch
        avg_train_loss = train_loss / len(train_dataloader)
        print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.4f}')

### Test Function

In [14]:
def test(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            loss = loss_fn(output, labels)
            test_loss += loss.item()

            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy for the epoch
    avg_test_loss = test_loss / len(test_dataloader)
    accuracy = 100 * (correct / total)

    return avg_test_loss, accuracy

### Commence Training

In [8]:
train(train_dataloader, model, 50)

Epoch 1/50, Train Loss: 0.5080
Epoch 2/50, Train Loss: 0.3190
Epoch 3/50, Train Loss: 0.2695
Epoch 4/50, Train Loss: 0.2427
Epoch 5/50, Train Loss: 0.2163
Epoch 6/50, Train Loss: 0.1980
Epoch 7/50, Train Loss: 0.1814
Epoch 8/50, Train Loss: 0.1626
Epoch 9/50, Train Loss: 0.1482
Epoch 10/50, Train Loss: 0.1361
Epoch 11/50, Train Loss: 0.1246
Epoch 12/50, Train Loss: 0.1135
Epoch 13/50, Train Loss: 0.1017
Epoch 14/50, Train Loss: 0.0928
Epoch 15/50, Train Loss: 0.0847
Epoch 16/50, Train Loss: 0.0768
Epoch 17/50, Train Loss: 0.0734
Epoch 18/50, Train Loss: 0.0657
Epoch 19/50, Train Loss: 0.0610
Epoch 20/50, Train Loss: 0.0546
Epoch 21/50, Train Loss: 0.0517
Epoch 22/50, Train Loss: 0.0506
Epoch 23/50, Train Loss: 0.0444
Epoch 24/50, Train Loss: 0.0445
Epoch 25/50, Train Loss: 0.0420
Epoch 26/50, Train Loss: 0.0361
Epoch 27/50, Train Loss: 0.0383
Epoch 28/50, Train Loss: 0.0362
Epoch 29/50, Train Loss: 0.0338
Epoch 30/50, Train Loss: 0.0349
Epoch 31/50, Train Loss: 0.0278
Epoch 32/50, Trai

In [16]:
avg_test_loss, acc = test(test_dataloader, model)

In [18]:
print("Accuracy on base model: " + str(acc) + "%")

Accuracy on base model: 90.93%
