# Solution: Build a Simple NN on Fashion-MNIST

This notebook implements the required Multilayer Perceptron (MLP) to classify Fashion-MNIST images. It follows the assignment instructions: two hidden layers (256, 128), ReLU activations, CrossEntropyLoss, Adam optimizer (lr=0.001), batch size 64, and 5-10 epochs.

**How to use:** Run each cell in order. The notebook will download the Fashion-MNIST dataset (via torchvision).

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tqdm.notebook import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
# Data loading and preprocessing
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),  # scales to [0,1]
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)

classes = train_dataset.classes
print('Classes:', classes)
print('Train size:', len(train_dataset), 'Test size:', len(test_dataset))

In [None]:
# Model definition: simple MLP as required
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),            # 28*28 -> 784
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleMLP().to(device)
print(model)

In [None]:
# Training and evaluation helpers
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return running_loss / total, correct / total, all_preds, all_labels

In [None]:
# Training loop (set epochs between 5 and 10 as required)
epochs = 8  # change to between 5 and 10 if desired
train_losses, train_accs = [], []
test_losses, test_accs = [], []

for epoch in range(1, epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc, _, _ = evaluate(model, test_loader, criterion, device)
    train_losses.append(train_loss); train_accs.append(train_acc)
    test_losses.append(test_loss); test_accs.append(test_acc)
    print(f"Epoch {epoch}/{epochs} - Train loss: {train_loss:.4f}, Train acc: {train_acc*100:.2f}% | Test loss: {test_loss:.4f}, Test acc: {test_acc*100:.2f}%")

In [None]:
# Plots: loss and accuracy over epochs
import matplotlib.pyplot as plt
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses)+1), train_losses, marker='o', label='Train Loss')
plt.plot(range(1, len(test_losses)+1), test_losses, marker='o', label='Test Loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss per epoch'); plt.legend()

plt.subplot(1,2,2)
plt.plot(range(1, len(train_accs)+1), [a*100 for a in train_accs], marker='o', label='Train Acc')
plt.plot(range(1, len(test_accs)+1), [a*100 for a in test_accs], marker='o', label='Test Acc')
plt.xlabel('Epoch'); plt.ylabel('Accuracy (%)'); plt.title('Accuracy per epoch'); plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Confusion matrix on test set
test_loss, test_acc, all_preds, all_labels = evaluate(model, test_loader, criterion, device)
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
plt.figure(figsize=(8,8))
disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical', values_format='d')
plt.title(f'Test Accuracy: {test_acc*100:.2f}%')
plt.show()

In [None]:
# Example images: correct and incorrect predictions
import random
model.eval()
examples = []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(dim=1)
        for i in range(len(labels)):
            examples.append((imgs[i].cpu(), labels[i].cpu().item(), preds[i].cpu().item()))
        if len(examples) >= 40:
            break

correct_examples = [e for e in examples if e[1]==e[2]][:8]
incorrect_examples = [e for e in examples if e[1]!=e[2]][:8]

def show_grid(exs, title):
    plt.figure(figsize=(10,4))
    for i, (img, label, pred) in enumerate(exs):
        plt.subplot(1, len(exs), i+1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.title(f'T:{classes[label]}\nP:{classes[pred]}')
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

if correct_examples:
    show_grid(correct_examples, 'Correct predictions')
if incorrect_examples:
    show_grid(incorrect_examples, 'Incorrect predictions')

## Conclusion

This notebook implements the required MLP and training setup. With `epochs=8` you should typically obtain test accuracy around or above 85% depending on random initialization and whether GPU is used. If accuracy is below 85%, try increasing epochs to 10, or add small changes like weight decay, a learning-rate scheduler, or simple data augmentation.