# 🧠 Interactive CNN Image Classification on MNIST (PyTorch)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ishwar-git/Image-Classification/blob/main/mnist_cnn_classification.ipynb)


## Objective
Aim of this notebook is to use a **CNN** on MNIST to understand image classification — with visuals, analysis, training, evaluation, error analysis, and intuition.



## 📦 Step 0: Import Libraries & Set Device

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Optional: tqdm progress bar
try:
    from tqdm.auto import tqdm
    TQDM = True
except:
    TQDM = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

> ❓ **Mini Quiz**: Why we do Normalization ? 
> 📝 Try to answer first: It centers and scales the input so training becomes stable and converges faster.

## 📊 Step 1: Download, Load & Analyze MNIST

In [None]:
# Transform: Tensor + Normalize (MNIST mean=0.1307, std=0.3081)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

print('Total Training Samples:', len(train_dataset))
print('Total Test Samples:', len(test_dataset))
print('Classes:', train_dataset.classes)

## 🖼 Step 2: Visualize Sample Images

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(10,3))
for i in range(6):
    img, label = train_dataset[i]
    axes[i].imshow(img.squeeze(), cmap='gray')
    axes[i].set_title(f'Label: {label}')
    axes[i].axis('off')
plt.suptitle('Sample MNIST Images')
plt.show()

## 📊 Step 3: Class-wise Distribution (Train Split)

In [None]:
labels, counts = np.unique(train_dataset.targets.numpy(), return_counts=True)
plt.figure(figsize=(8,3.5))
plt.bar(labels, counts)
plt.title('Number of Images per Digit (Train)')
plt.xlabel('Digit Class')
plt.ylabel('Count')
plt.show()

## ⚙️ Step 4: Data Augmentation & DataLoaders

In [None]:
# Simple augmentation for training (slight rotation); test set no augmentation
train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Re-create datasets with augmentation for train
train_dataset_aug = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)

batch_size = 64
train_loader = DataLoader(train_dataset_aug, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print('Training batches:', len(train_loader))
print('Testing batches:', len(test_loader))

## 🏗 Step 5: Build a Simple CNN Model (PyTorch)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2,2)
        self.fc1   = nn.Linear(64*7*7, 128)
        self.fc2   = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))          # 1x28x28 -> 32x28x28
        x = self.pool(F.relu(self.conv2(x)))# 32x28x28 -> 64x28x28 -> pool -> 64x14x14
        x = torch.flatten(x, 1)            # -> 64*14*14 = 12544 (but padding+pool -> 64*14*14); 
                                           # our layer uses 64*7*7 b/c two pools expected; we used one pool.
                                           # Adjust: add another pool to match 64*7*7
        return x

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2,2)
        self.fc1   = nn.Linear(64*7*7, 128) # after two pools on 28 -> 14 -> 7
        self.fc2   = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))          # 1x28x28 -> 32x28x28
        x = self.pool(F.relu(self.conv2(x)))# -> 64x14x14
        x = self.pool(x)                    # -> 64x7x7
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
print(model)

## ⚙️ Step 6: Loss Function & Optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## 🏋️ Step 7: Training & Testing Functions

In [None]:
def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    running = 0.0
    it = loader
    if TQDM:
        it = tqdm(loader, leave=False)
    for data, target in it:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        running += loss.item()
    return running/len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    running = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            out = model(data)
            loss = loss_fn(out, target)
            running += loss.item()
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return running/len(loader), 100.0*correct/total

## 📈 Step 8: Train the Model & Track Progress

In [None]:
epochs = 5
train_losses = []
test_losses = []
test_accuracies = []

for epoch in range(1, epochs+1):
    tr_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    te_loss, te_acc = evaluate(model, test_loader, loss_fn, device)
    train_losses.append(tr_loss)
    test_losses.append(te_loss)
    test_accuracies.append(te_acc)
    print(f'Epoch {epoch}/{epochs} - Train Loss: {tr_loss:.4f} | Test Loss: {te_loss:.4f} | Test Acc: {te_acc:.2f}%')

## 📊 Step 9: Plot Training Curves

In [None]:
plt.figure(figsize=(10,4))
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()
plt.show()

plt.figure(figsize=(10,4))
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy Curve')
plt.legend()
plt.show()

## 🔍 Step 10: Sample Predictions (Correct & Wrong)

In [None]:
model.eval()
data_iter = iter(test_loader)
images, labels = next(data_iter)
with torch.no_grad():
    outputs = model(images.to(device))
preds = outputs.argmax(dim=1).cpu()

fig, axes = plt.subplots(2, 5, figsize=(10,4))
axes = axes.flatten()
for i, ax in enumerate(axes):
    ax.imshow(images[i][0], cmap='gray')
    ax.set_title(f'Pred: {preds[i].item()} | True: {labels[i].item()}')
    ax.axis('off')
plt.suptitle('Sample Predictions')
plt.show()

## 📉 Step 11: Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix

all_preds = []
all_targets = []
model.eval()
with torch.no_grad():
    for data, target in test_loader:
        out = model(data.to(device))
        p = out.argmax(dim=1).cpu().numpy()
        all_preds.extend(p)
        all_targets.extend(target.numpy())

cm = confusion_matrix(all_targets, all_preds)

# Plot using matplotlib to avoid dependency on seaborn
fig, ax = plt.subplots(figsize=(6,5))
im = ax.imshow(cm, interpolation='nearest')
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(10), yticks=np.arange(10), xlabel='Predicted', ylabel='True', title='Confusion Matrix')
for i in range(10):
    for j in range(10):
        ax.text(j, i, cm[i, j], ha='center', va='center', fontsize=8)
plt.tight_layout()
plt.show()

## 🧠 Step 12: Visualize Feature Maps (What CNN Sees)

In [None]:
model.eval()
image, label = test_dataset[0]
image = image.unsqueeze(0).to(device)
with torch.no_grad():
    fmap1 = model.conv1(image).cpu()

fig, axes = plt.subplots(4, 8, figsize=(10,5))
axes = axes.flatten()
for i, ax in enumerate(axes):
    if i < fmap1.shape[1]:
        ax.imshow(fmap1[0, i], cmap='gray')
    ax.axis('off')
plt.suptitle('Feature Maps after conv1')
plt.show()

## 🎓 Step 13: Student Experiment Zone
Try these:
- Add **Dropout** between layers
- Change **learning rate**
- Increase **filters** (e.g., 32→64, 64→128)
- Add **BatchNorm** after conv layers
- Train for **more epochs** and compare curves


## 💾 Step 14: Save & Load the Trained Model

In [None]:
# Save
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print('Model saved as mnist_cnn_model.pth')

# Load
loaded = SimpleCNN().to(device)
loaded.load_state_dict(torch.load('mnist_cnn_model.pth', map_location=device))
loaded.eval()
print('Model loaded back!')

## ✅ Conclusion & Next Steps
**You learned:** data prep, CNN basics, training, evaluation, confusion matrix, and feature visualization.  
**Next:** Try deeper CNNs (VGG/ResNet), switch to CIFAR-10, and play with augmentation/regularization.
