# Train a CNN for Multi-Class Classification

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Define transformations for the dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# DataLoader for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define a simple Fully Connected Neural Network (Normal NN)
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)      # Fully connected layer 1
        self.relu = nn.ReLU()                 # Activation function
        self.fc2 = nn.Linear(128, 64)         # Fully connected layer 2
        self.fc3 = nn.Linear(64, 10)          # Output layer (10 classes)
    
    def forward(self, x):
        x = x.view(-1, 28*28)                  # Flatten the 2D image into a 1D vector
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define a Convolutional Neural Network (CNN)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize models, loss function, and optimizer
simple_nn = SimpleNN()
cnn = CNN()
criterion = nn.CrossEntropyLoss()
optimizer_nn = optim.Adam(simple_nn.parameters(), lr=0.001)
optimizer_cnn = optim.Adam(cnn.parameters(), lr=0.001)

# Function to train a model
def train_model(model, optimizer, train_loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

# Train both models
print("Training Simple Neural Network...")
train_model(simple_nn, optimizer_nn, train_loader)

print("Training Convolutional Neural Network...")
train_model(cnn, optimizer_cnn, train_loader)

# Function to evaluate a model
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

# Evaluate both models
print("Evaluating Simple Neural Network...")
evaluate_model(simple_nn, test_loader)

print("Evaluating Convolutional Neural Network...")
evaluate_model(cnn, test_loader)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 26344387.49it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1652353.59it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 6110410.92it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1916745.02it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training Simple Neural Network...
Epoch 1, Loss: 0.40448390214300867
Epoch 2, Loss: 0.19974939586288892
Epoch 3, Loss: 0.14145979561856878
Epoch 4, Loss: 0.11732239898806537
Epoch 5, Loss: 0.09789502675652996
Training Convolutional Neural Network...
Epoch 1, Loss: 0.19106504687024697
Epoch 2, Loss: 0.05199949782570578
Epoch 3, Loss: 0.035507859311812225
Epoch 4, Loss: 0.02790432198934117
Epoch 5, Loss: 0.02117122623074151
Evaluating Simple Neural Network...
Accuracy: 96.48%
Evaluating Convolutional Neural Network...
Accuracy: 98.79%


# Model Training and Evaluation Process

## Convolutional Neural Network (CNN)
### Architecture & Layers
A **CNN** uses specialized layers to process images:  
* **Convolutional Layer:** Detects patterns like edges, corners, and textures.
* **Activation Function (ReLU):** Adds non-linearity to help learn complex patterns.
* **Pooling Layer (Max Pooling/Average Pooling):** Reduces dimensions while preserving important features.
* **Fully Connected Layer (Dense Layer):** Final layer to make classification decisions.

### 1️⃣ Convolutional Layer (nn.Conv2d)
📌 **What it Does?**
* Purpose: Detects features (edges, textures) in an image using small filters (kernels).
* A filter (kernel) slides over the input image to detect patterns like edges, shapes, and textures.
* Each filter learns different features (e.g., one detects horizontal edges, another detects curves, etc.).
* How It Works: A 3x3 or 5x5 filter slides over the image, performing element-wise multiplication and summing up the values.
* The output is a feature map, highlighting key patterns.
🔹 **Code in CNN Model**
```python
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
```
* 1: Input channels (grayscale image has 1 channel).
* 16: Number of filters (kernels), meaning 16 different feature detectors.
* kernel_size=3: A 3×3 filter slides over the image.
* padding=1: Maintains the image size after convolution.
🖼 **Visual Example**
* 🔹 Original Image (28×28)
* 🔹 After Conv Layer → 16 Feature Maps (28×28 each)

✅ Effect: Extracts low-level features (edges, lines).

### 2️⃣ ReLU Activation Function (nn.ReLU)
📌 **What it Does?**
* Introduces non-linearity by replacing negative values with zero.
* Helps the CNN learn complex patterns rather than just linear features.
🔹 **Code in CNN Model**
```python
self.relu = nn.ReLU()
```

✅ Effect: Keeps positive values and zeros out negative values, making it easier for the model to learn.

### 3️⃣ Pooling Layer (nn.MaxPool2d)
📌 **What it Does?**
* Purpose: Reduces the spatial dimensions of the feature maps, by selecting only the most important features.
* Makes the model more robust to minor shifts in the image.
* Types of Pooling:
  * Max Pooling: Retains the maximum value in a region (best for feature detection).
  * Average Pooling: Takes the average value (used less frequently).
* In our case, a 2×2 pooling window reduces each 28×28 feature map → 14×14.  
🔹 **Code in CNN Model**
```python
self.pool = nn.MaxPool2d(2, 2)
```
* kernel_size=2 → Takes 2x2 patches.
* stride=2 → Moves the window by 2 pixels at a time.
* `2,2`: Kernel size & stride, meaning a 2×2 window moves in steps of 2. 

🖼 **Visual Example**
* 🔹 Feature Map (28×28) → After Pooling (14×14)  
This reduces computation & overfitting.

✅ Effect: Downsamples the image, making computations faster while retaining important features.


### 4️⃣ Fully Connected Layers (nn.Linear)
📌 **What it Does?**
* Purpose: After convolution and pooling, the feature maps are flattened and passed through dense layers for final classification.
* Flattens feature maps into a 1D vector and passes it through dense layers for classification.
* The CNN now learns to associate patterns with class labels.  
🔹 **Code in CNN Model**
```python
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
```
* Takes flattened features from CNN layers.
* Learns complex representations before making predictions.
* `32 * 7 * 7`: The flattened feature maps from Conv layers.
* `128`: Hidden neurons for feature learning.
* `10`: Output neurons (digits 0-9 for classification).

✅ Effect: Converts extracted image features into class predictions.