# Feedforward Neural Network for MNIST Classification

This notebook demonstrates how to build, train, and evaluate a feedforward neural network for classifying handwritten digits using the MNIST dataset.

## Overview:
1. Load and preprocess the MNIST dataset
2. Split data into train, validation, and test sets
3. Build a feedforward neural network
4. Train the model with hyperparameter tuning
5. Evaluate model performance
6. Visualize results

## 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import time

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
torch.manual_seed(42)

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Load and Explore MNIST Dataset

In [None]:
# Load MNIST dataset from scikit-learn
mnist = fetch_openml('mnist_784', version=1, cache=True)
X = mnist.data.astype('float32')
y = mnist.target.astype('int64')

# Display dataset information
print(f"MNIST dataset shape: {X.shape}")
print(f"Target shape: {y.shape}")
print(f"Data type: {X.dtype}")
print(f"Target type: {y.dtype}")
print(f"Min value: {X.min()}")
print(f"Max value: {X.max()}")

### Visualize Some Examples

In [None]:
# Visualize some examples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.flatten()

for i in range(10):
    # Find an example of digit i
    idx = np.where(y == i)[0][0]
    
    # Plot the example
    axes[i].imshow(X[idx].reshape(28, 28), cmap='gray')
    axes[i].set_title(f"Digit: {i}")
    axes[i].axis('off')
    
plt.tight_layout()
plt.show()

## 3. Preprocess Data and Split into Train, Validation, and Test Sets

In [None]:
# Normalize pixel values to [0, 1]
X = X / 255.0

# Split the data: 70% train, 15% validation, 15% test
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.15, random_state=42, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.15/0.85, random_state=42, stratify=y_train_val)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

### Convert to PyTorch Tensors and Create DataLoaders

In [None]:
# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long)

X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val.values, dtype=torch.long)

X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

## 4. Build the Feedforward Neural Network Model

Implement a feedforward neural network `FeedForwardNN(nn.Module)` with using the following architecture:
- Input layer: 784 neurons
- Hidden layers: sizes to be specified
- Output layer: 10 neurons (corresponding to 10 classes)

In [None]:
class FeedForwardNN(nn.Module):
    NotImplementedError

## 5. Define Training, Validation, and Testing Functions

Add code to perform the backpropagation and optimization steps in the training loop. \

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # TO DO: Zero the parameter gradients
        assert False, "Fill in the missing code here!"
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    val_loss = running_loss / total
    val_acc = correct / total
    
    return val_loss, val_acc

def test_model(model, test_loader, device):
    model.eval()
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    # Calculate accuracy
    accuracy = accuracy_score(all_targets, all_predictions)
    print(f"Test Accuracy: {accuracy:.4f}")
    
    # Generate confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)
    
    # Generate classification report
    report = classification_report(all_targets, all_predictions, digits=4)
    print("\nClassification Report:")
    print(report)
    
    return all_predictions, all_targets, cm, accuracy

## 6. Train the Model

Modify the function `train_model` to train a `FeedforwardNeuralNetwork` model.\
You should use the validation to make an early stopping (pick the epoch at which the model weights are to be kept).\
The function should return:
- the trained model
- the training loss history
- the training accuracy history
- the validation loss history
- the validation accuracy history
- best validation accuracy

We will use the following hyperparameters:
- Learning rate: 0.001
- Number of epochs: 30
- Batch size: 64
- Optimizer: Adam
- Loss function: CrossEntropyLoss
- weight decay: 0.00001
- dropout rate: 0.3
- hidden sizes = [512, 256, 128]



![Alt text](early_stopping.png "Early Stopping Curves")


In [None]:
def train_model(input_size, hidden_sizes, output_size, dropout_rate, learning_rate, weight_decay, num_epochs):
    # Create model
    model = FeedForwardNN(input_size, hidden_sizes, output_size, dropout_rate).to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    return None


### Define and Train the Model with Optimal Hyperparameters

In [None]:
# Model configuration
input_size = 784  # 28x28 pixels
hidden_sizes = [512, 256, 128]  # Three hidden layers
output_size = 10  # 10 digits (0-9)
dropout_rate = 0.3
learning_rate = 0.001
weight_decay = 1e-5
num_epochs = 30

# Train the model
model, train_losses, train_accs, val_losses, val_accs, best_val_acc = train_model(
    input_size, hidden_sizes, output_size, dropout_rate, learning_rate, weight_decay, num_epochs
)

## 7. Evaluate the Model on the Test Set

In [None]:
# Test the model
predictions, targets, confusion_mat, test_accuracy = test_model(model, test_loader, device)

## 8. Visualize Training Progress and Results

Plot the training and validation loss and accuracy over epochs.\
Visualize the confusion matrix of the test set predictions.

## 9. Visualize Misclassified Examples

Plot some examples of misclassified images and their predicted labels.