# MNIST Handwritten Digit Recognition with FastAI

This notebook demonstrates how to fine-tune a machine learning model using the MNIST dataset with FastAI for handwritten digit recognition. The model will be trained to classify handwritten digits (0-9) and then saved for future use.

## Overview
- Load and explore the MNIST dataset
- Prepare data loaders with augmentations
- Create a CNN model using FastAI's vision learner
- Train and fine-tune the model
- Evaluate performance
- Save the trained model

## 1. Import Required Libraries

Let's start by importing all the necessary libraries for our handwritten digit recognition project.

In [None]:
# Import essential libraries
import torch
import torch.nn as nn
from fastai.vision.all import *
from fastai.data.external import *
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import time
from sklearn.metrics import classification_report

# Set up matplotlib for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (10, 6)

# Device setup - simplified approach for maximum compatibility
print("Configuring device...")

# Check for CUDA (NVIDIA GPU)
if torch.cuda.is_available():
    device = torch.device("cuda")
    device_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"‚úÖ CUDA GPU detected: {device_name}")
    print(f"üìä GPU Memory: {gpu_memory:.1f} GB")
    print(f"üîß Using device: {device}")

    # Set CUDA optimizations
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.set_default_device(device)

# Check for MPS (Apple Silicon GPU) - use CPU for training, MPS for inference
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    print("‚úÖ Apple Metal Performance Shaders (MPS) detected")
    print("‚ö†Ô∏è  Note: MPS has compatibility issues with FastAI training")
    print("üîß Using CPU for training, MPS for inference")

    # Use CPU for training due to FastAI compatibility issues
    device = torch.device("cpu")
    print(f"üîß Training device: {device}")

    # Set default device to CPU for training
    torch.set_default_device(device)
    torch.set_default_dtype(torch.float32)
    print("üí° After training, models can be moved to MPS for faster inference")

# Fallback to CPU
else:
    device = torch.device("cpu")
    print("‚ö†Ô∏è  No GPU detected, using CPU")
    print(f"üîß Using device: {device}")
    torch.set_default_device(device)

print(f"üß† PyTorch version: {torch.__version__}")
print()

# Set random seed for reproducibility
torch.manual_seed(42)
set_seed(42)

## 2. Load and Explore MNIST Dataset

The MNIST dataset contains 70,000 grayscale images of handwritten digits (0-9), each 28x28 pixels. We'll use FastAI's built-in functionality to download and load this dataset.

In [None]:
# Download and load MNIST dataset
print("Downloading MNIST dataset...")
path = untar_data(URLs.MNIST)
print(f"Dataset downloaded to: {path}")

# Explore the dataset structure
print("\nDataset structure:")
print(f"Path contents: {list(path.ls())}")

# Check training and testing folders
train_path = path / "training"
test_path = path / "testing"

print(f"\nTraining classes: {[f.name for f in train_path.ls().sorted()]}")
print(f"Testing classes: {[f.name for f in test_path.ls().sorted()]}")

# Count images in each class
print("\nNumber of training images per class:")
for class_folder in train_path.ls().sorted():
    count = len(list(class_folder.ls()))
    print(f"Class {class_folder.name}: {count} images")

print("\nNumber of testing images per class:")
for class_folder in test_path.ls().sorted():
    count = len(list(class_folder.ls()))
    print(f"Class {class_folder.name}: {count} images")

## 3. Prepare Data Loaders

We'll create data loaders with appropriate transformations and augmentations to improve model performance and generalization.

In [None]:
# Create data loaders with transformations
dls = ImageDataLoaders.from_folder(
    path,
    train="training",
    valid="testing",
    item_tfms=Resize(28),  # Ensure all images are 28x28
    batch_tfms=[
        *aug_transforms(size=28, min_scale=0.8, max_rotate=10.0, max_lighting=0.2),
        Normalize.from_stats(*imagenet_stats)
    ],
    bs=64,  # batch size
    device=device,
    num_workers=2
)

print(f"Training samples: {len(dls.train_ds)}")
print(f"Validation samples: {len(dls.valid_ds)}")
print(f"Classes: {dls.vocab}")
print(f"Number of classes: {dls.c}")

if device:
    print(f"üîß Data loaders configured for device: {device}")

# Display sample images from the dataset
print("\nSample training images:")
dls.show_batch(max_n=12, figsize=(10, 8))

## 4. Create the Model Architecture

We'll use a pre-trained ResNet18 model and fine-tune it for our digit classification task. FastAI makes this process very straightforward.

In [None]:
# Create a CNN learner with ResNet18 architecture
learn = vision_learner(
    dls,
    resnet18,
    metrics=[accuracy, error_rate],
    loss_func=CrossEntropyLossFlat()
)

print("Model created successfully!")
print(f"Model architecture: {learn.model.__class__.__name__}")
print(f"Number of parameters: {sum(p.numel() for p in learn.model.parameters()):,}")

# Handle device-specific configurations
if device:
    print(f"üîß Model will use device: {device}")

    # Show GPU memory usage if CUDA
    if device.type == "cuda":
        torch.cuda.empty_cache()
        print(f"üíæ GPU Memory allocated: {torch.cuda.memory_allocated(device) / 1024**2:.1f} MB")
        print(f"üíæ GPU Memory cached: {torch.cuda.memory_reserved(device) / 1024**2:.1f} MB")

# Display model summary
learn.model

## 5. Train the Model

We'll use FastAI's learning rate finder and one-cycle training policy to efficiently train our model.

In [None]:
# Find optimal learning rate
print("Finding optimal learning rate...")
lr_find_result = learn.lr_find()
print(f"Suggested learning rate: {lr_find_result.valley}")

# Plot learning rate finder results
learn.recorder.plot_lr_find()

In [None]:
# Train the model with simplified approach
epochs = 5
learning_rate = 1e-3
print(f"Training model for {epochs} epochs...")

start_time = time.time()

try:
    # Use fine_tune for all devices (CUDA/CPU)
    print("? Starting training...")
    learn.fine_tune(epochs, base_lr=learning_rate)

except Exception as e:
    # Print stack trace for debugging errors
    import traceback
    print("‚ö†Ô∏è  Exception occurred during training:")
    traceback.print_exc()
    raise e

training_time = time.time() - start_time
print(f"Training completed in {training_time:.2f} seconds")

# For Apple Silicon users: Move model to MPS for faster inference
def get_inference_device():
    """Get the best device for inference"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

# Move model to best inference device
inference_device = get_inference_device()
if inference_device != device:
    print(f"üîß Moving model to {inference_device} for faster inference...")
    learn.model = learn.model.to(inference_device)
    print(f"‚úÖ Model moved to {inference_device} for inference")

## 6. Evaluate Model Performance

Let's assess our model's performance using various metrics and visualizations.

In [None]:
# Get validation metrics
valid_loss, accuracy = learn.validate()
print(f"Final Validation Accuracy: {accuracy:.4f}")
print(f"Final Validation Loss: {valid_loss:.4f}")

# Create classification interpretation
interp = ClassificationInterpretation.from_learner(learn)

# Show confusion matrix
print("\nConfusion Matrix:")
interp.plot_confusion_matrix(figsize=(10, 10))
plt.title("MNIST Digit Recognition - Confusion Matrix")
plt.show()

# Show most confused classes
print("\nMost confused pairs:")
confused_pairs = interp.most_confused(min_val=2)
for pair in confused_pairs:
    print(f"Confused {pair[0]} with {pair[1]}: {pair[2]} times")

In [None]:
# Show worst predictions (top losses)
print("Analyzing worst predictions...")
interp.plot_top_losses(9, nrows=3, figsize=(12, 8))
plt.suptitle("MNIST Digit Recognition - Worst Predictions", fontsize=16)
plt.show()

# Get predictions for classification report
preds, targets = learn.get_preds()
y_pred = torch.argmax(preds, dim=1)

# Print classification report
print("\nClassification Report:")
print(classification_report(targets, y_pred, target_names=[str(i) for i in range(10)]))

## 7. Save the Trained Model

Now let's save our trained model for future use and inference.

In [None]:
# Create models directory if it doesn't exist
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)

# Save the complete learner using FastAI's export method
model_name = "mnist_digit_recognizer"
model_path = models_dir / f"{model_name}.pkl"

print(f"Saving model to: {model_path}")
learn.export(model_path)

# Also save just the model state dict for PyTorch compatibility
torch_model_path = models_dir / f"{model_name}_state_dict.pth"
torch.save(learn.model.state_dict(), torch_model_path)

print(f"‚úÖ FastAI model saved to: {model_path}")
print(f"‚úÖ PyTorch state dict saved to: {torch_model_path}")

# Verify the saved model by loading it
print("\nVerifying saved model...")
loaded_learn = load_learner(model_path)
print("‚úÖ Model loaded successfully!")

In [None]:
# Test the saved model with some sample predictions
print("Testing model predictions on sample data...")

# Get a batch of validation data
x, y = dls.valid.one_batch()

# Make predictions with the loaded model
with torch.no_grad():
    preds = loaded_learn.model(x)
    pred_classes = torch.argmax(preds, dim=1)

# Visualize some predictions
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for i in range(8):
    img = x[i].cpu()
    # Denormalize the image for display
    img = img * torch.tensor(imagenet_stats[1]).view(3, 1, 1) + torch.tensor(imagenet_stats[0]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)

    # Convert to grayscale for display
    if img.shape[0] == 3:
        img = img.mean(dim=0)

    axes[i].imshow(img, cmap='gray')
    correct = "‚úÖ" if y[i] == pred_classes[i] else "‚ùå"
    axes[i].set_title(f"True: {y[i]}, Pred: {pred_classes[i]} {correct}")
    axes[i].axis('off')

plt.suptitle("Sample Predictions from Saved Model", fontsize=16)
plt.tight_layout()
plt.show()

print("\nüéâ MNIST digit recognition model training completed successfully!")
print(f"üìä Final accuracy: {accuracy:.4f}")
print(f"üíæ Model saved and ready for use!")