# Lab 7, Module 4: Training a CNN on MNIST

**Estimated time:** 15-20 minutes

---

## **Opening: From Pretrained to Training Your Own**

In Modules 1-3, you:
- Applied hand-designed filters (Sobel, blur, sharpen)
- Explored a pretrained CNN (MobileNetV2)
- Learned about hierarchical feature extraction

**Now it's time to see the magic happen:**

You'll **train your own CNN from scratch**‚Äîwatching it learn to:
1. Detect edges in handwritten digits
2. Combine edges into shapes
3. Recognize digits 0-9

**Best part?** This takes only **2-3 minutes** on a regular CPU!

### **Why This Matters**

Training a CNN demonstrates:
- **CNNs aren't magic**‚Äîthey're just optimization via gradient descent
- **Learned filters** emerge naturally (similar to Sobel, but optimized for digits)
- **Fast training** shows CNNs are practical for real applications
- **High accuracy** (>98%) shows hierarchical features work!

### **Connection to Lab 4**

Remember Lab 4, where you trained neural networks on Iris and Breast Cancer datasets?

| Aspect | Lab 4 (Dense Networks) | Lab 7 (CNNs) |
|--------|------------------------|---------------|
| **Data type** | Tabular (features in columns) | Images (28√ó28 pixels) |
| **Architecture** | Fully-connected layers | Convolutional + dense layers |
| **Training** | Gradient descent + backprop | Same! |
| **Loss function** | Cross-entropy | Same! |
| **Optimizer** | Adam | Same! |

**Key insight:** CNNs use the same training process as dense networks‚Äîjust a different architecture!

---

## üìä **About the MNIST Dataset**

**MNIST** = Modified National Institute of Standards and Technology

### **Dataset Details:**
- **60,000 training images** (handwritten digits 0-9)
- **10,000 test images**
- **28√ó28 pixels**, grayscale (1 channel)
- **10 classes** (digits 0, 1, 2, ..., 9)

### **Why MNIST?**
- Classic benchmark dataset (since 1998)
- Small enough to train quickly (2-3 minutes on CPU)
- Complex enough to need a CNN (>98% accuracy is hard without convolution)
- Real-world application: Check reading, postal mail sorting

### **Historical Context:**
- 1998: LeNet-5 (first CNN) achieved 99.2% on MNIST
- 2012: Dropout improved to 99.5%
- Today: State-of-the-art reaches 99.8% (only 20 errors out of 10,000!)

---

In [None]:
# Setup: Import libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix, classification_report

print(f"‚úÖ TensorFlow version: {tf.__version__}")
print("‚úÖ Libraries imported successfully!")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

---

## üì• **Load and Explore MNIST Data**

Let's load the dataset and see what handwritten digits look like!

---

In [None]:
# Load MNIST dataset (built into Keras)
print("Loading MNIST dataset...\n")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print("‚úÖ Dataset loaded!\n")
print(f"Training set: {x_train.shape[0]} images")
print(f"Test set: {x_test.shape[0]} images")
print(f"Image shape: {x_train.shape[1]} √ó {x_train.shape[2]} pixels")
print(f"\nClass distribution (training):")
for digit in range(10):
    count = np.sum(y_train == digit)
    print(f"  Digit {digit}: {count} images ({count/len(y_train)*100:.1f}%)")

In [None]:
# Visualize sample images from each class
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
axes = axes.flatten()

for digit in range(10):
    # Find first example of this digit
    idx = np.where(y_train == digit)[0][0]
    
    # Display
    axes[digit].imshow(x_train[idx], cmap='gray')
    axes[digit].set_title(f'Digit: {digit}', fontsize=12, fontweight='bold')
    axes[digit].axis('off')

plt.tight_layout()
plt.suptitle('Sample MNIST Digits (One Example per Class)', fontsize=14, fontweight='bold', y=1.02)
plt.show()

print("\nNotice: Handwriting varies widely!")
print("  - Different stroke widths")
print("  - Different slants and orientations")
print("  - Different styles (loopy vs. straight)")
print("\nThis is why we need a CNN‚Äîsimple rules won't work!")

---

## üõ†Ô∏è **Preprocess the Data**

Before training, we need to:
1. **Reshape** images to add channel dimension (28, 28, 1)
2. **Normalize** pixel values to [0, 1] (currently [0, 255])

---

In [None]:
# Reshape to add channel dimension
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32')

# Normalize to [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0

print("‚úÖ Data preprocessed!\n")
print(f"Training data shape: {x_train.shape}")
print(f"  (samples, height, width, channels) = {x_train.shape}")
print(f"\nTest data shape: {x_test.shape}")
print(f"\nPixel value range: [{x_train.min():.2f}, {x_train.max():.2f}]")

---

## üèóÔ∏è **Build the CNN Architecture**

We'll create a simple but effective CNN:

### **Architecture:**
```
Input: (28, 28, 1)
    ‚Üì
Conv2D: 32 filters, 3√ó3, ReLU
    ‚Üì
MaxPooling: 2√ó2
    ‚Üì
Conv2D: 64 filters, 3√ó3, ReLU
    ‚Üì
MaxPooling: 2√ó2
    ‚Üì
Flatten
    ‚Üì
Dense: 128 units, ReLU
    ‚Üì
Dense: 10 units, Softmax
    ‚Üì
Output: Class probabilities (0-9)
```

### **What Each Layer Does:**
- **Conv2D:** Applies convolution filters to detect patterns
- **ReLU:** Activation function (introduces nonlinearity)
- **MaxPooling:** Downsamples by taking maximum value in 2√ó2 window
- **Flatten:** Converts 2D feature maps to 1D vector
- **Dense:** Fully-connected layer (like Lab 4!)
- **Softmax:** Converts outputs to probabilities (sum to 1)

---

In [None]:
# Build the model
model = keras.Sequential([
    # First convolutional block
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1), name='conv1'),
    layers.MaxPooling2D((2, 2), name='pool1'),
    
    # Second convolutional block
    layers.Conv2D(64, (3, 3), activation='relu', name='conv2'),
    layers.MaxPooling2D((2, 2), name='pool2'),
    
    # Flatten and dense layers
    layers.Flatten(name='flatten'),
    layers.Dense(128, activation='relu', name='dense1'),
    layers.Dense(10, activation='softmax', name='output')
])

# Print model summary
print("üìã Model Architecture:\n")
model.summary()

# Count parameters
total_params = model.count_params()
print(f"\nüìä Total parameters: {total_params:,}")
print("\n(Compare to a fully-connected network: millions of parameters!)")

---

## üìù **Question Q20 (Prediction)**

### **Q20. Before training, predict: What accuracy do you expect on MNIST?**

*Consider these options:*
- **10%** (random guessing: 1 out of 10 classes)
- **50%** (better than random, but not great)
- **90%** (very good)
- **99%** (near-perfect)

*What do you think is realistic for a simple CNN after just 3-5 epochs of training?*

**Record your prediction in the Answer Sheet BEFORE continuing!**

---

## üéØ **Compile the Model**

We need to specify:
1. **Optimizer:** Adam (adaptive learning rate)
2. **Loss function:** Sparse categorical cross-entropy (for multi-class classification)
3. **Metrics:** Accuracy (% of correct predictions)

---

In [None]:
# Compile the model
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("‚úÖ Model compiled and ready to train!")
print("\nTraining configuration:")
print("  - Optimizer: Adam")
print("  - Loss: Sparse categorical cross-entropy")
print("  - Metric: Accuracy")

---

## üöÄ **Train the Model!**

Time to train! This will take **2-3 minutes** on CPU.

Watch the accuracy improve with each epoch!

---

In [None]:
# Train the model
print("üöÄ Starting training...\n")
print("This will take ~2-3 minutes on CPU.")
print("Watch the accuracy improve with each epoch!\n")
print("="*70)

history = model.fit(
    x_train, y_train,
    epochs=3,
    batch_size=128,
    validation_split=0.1,
    verbose=1
)

print("\n" + "="*70)
print("‚úÖ Training complete!")

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy plot
axes[0].plot(history.history['accuracy'], label='Training Accuracy', marker='o', linewidth=2)
axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy', marker='s', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Training & Validation Accuracy', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim([0.9, 1.0])

# Loss plot
axes[1].plot(history.history['loss'], label='Training Loss', marker='o', linewidth=2)
axes[1].plot(history.history['val_loss'], label='Validation Loss', marker='s', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey observations:")
print("  ‚úì Accuracy increased each epoch")
print("  ‚úì Loss decreased each epoch")
print("  ‚úì Validation accuracy close to training accuracy (no overfitting!)")

---

## üìä **Evaluate on Test Set**

Now let's see how well the model performs on completely unseen data!

---

In [None]:
# Evaluate on test set
print("Evaluating on test set...\n")
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)

print("\n" + "="*70)
print("üìä TEST SET RESULTS")
print("="*70)
print(f"\nTest Accuracy: {test_acc*100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")
print(f"\nOut of {len(y_test)} test images:")
print(f"  ‚úì Correct: {int(test_acc * len(y_test))}")
print(f"  ‚úó Incorrect: {int((1-test_acc) * len(y_test))}")
print("="*70)

print("\nüí° Context:")
print("  - Random guessing: 10% accuracy")
print("  - Simple dense network: ~85% accuracy")
print(f"  - Our simple CNN: {test_acc*100:.2f}% accuracy")
print("  - State-of-the-art: 99.8% accuracy")

---

## üìù **Question Q21 (Observation)**

### **Q21. After training for 3 epochs, what test accuracy did you achieve? Was this higher or lower than your prediction from Q20?**

*Look at the test accuracy above. Were you surprised by the result? Why or why not?*

**Record your answer in the Answer Sheet.**

---

## üîç **Confusion Matrix**

Remember confusion matrices from Lab 4? Let's see which digits are confused with each other!

---

In [None]:
# Make predictions on test set
y_pred_probs = model.predict(x_test, verbose=0)
y_pred = np.argmax(y_pred_probs, axis=1)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Visualize
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
            xticklabels=range(10), yticklabels=range(10))
plt.xlabel('Predicted Digit', fontsize=12, fontweight='bold')
plt.ylabel('True Digit', fontsize=12, fontweight='bold')
plt.title('Confusion Matrix: MNIST Test Set', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("HOW TO READ THIS MATRIX:")
print("="*70)
print("\n- Diagonal (dark blue): Correct predictions")
print("- Off-diagonal (lighter): Misclassifications\n")
print("Example: If row 5, column 3 has value 8, that means:")
print("  ‚Üí 8 images of digit 5 were incorrectly classified as digit 3\n")
print("="*70)

In [None]:
# Identify most common confusions
print("\nüìä MOST COMMON CONFUSIONS:\n")

# Set diagonal to 0 to ignore correct predictions
cm_off_diag = cm.copy()
np.fill_diagonal(cm_off_diag, 0)

# Find top 5 confusions
top_confusions = []
for true_digit in range(10):
    for pred_digit in range(10):
        if true_digit != pred_digit and cm[true_digit, pred_digit] > 0:
            top_confusions.append((cm[true_digit, pred_digit], true_digit, pred_digit))

top_confusions.sort(reverse=True)

for i, (count, true_digit, pred_digit) in enumerate(top_confusions[:5], 1):
    print(f"{i}. {count} times: Digit '{true_digit}' misclassified as '{pred_digit}'")

print("\nüí° Why might these confusions happen?")
print("  - Similar shapes (e.g., 4 and 9, 3 and 8, 5 and 6)")
print("  - Handwriting variations")
print("  - Ambiguous examples (even humans would struggle!)")

---

## üìù **Question Q22 (Analysis)**

### **Q22. Looking at the confusion matrix, which digits are most commonly confused with each other? Why might this be?**

*Hint: Look at the off-diagonal cells with the highest values. Do the confused digits look similar? Think about their shapes.*

**Record your answer in the Answer Sheet.**

---

## üî¨ **Visualize Learned Filters**

Remember Module 2, where you saw pretrained filters from MobileNetV2?

**Now let's see what YOUR CNN learned!**

---

In [None]:
# Extract learned filters from first convolutional layer
conv1_weights = model.get_layer('conv1').get_weights()[0]  # Shape: (3, 3, 1, 32)

print(f"First convolutional layer learned {conv1_weights.shape[-1]} filters")
print(f"Each filter is {conv1_weights.shape[0]}√ó{conv1_weights.shape[1]} pixels\n")

# Visualize all 32 filters
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
axes = axes.flatten()

for i in range(32):
    filter_img = conv1_weights[:, :, 0, i]  # Extract filter i
    
    # Normalize for visualization
    vmin, vmax = filter_img.min(), filter_img.max()
    
    axes[i].imshow(filter_img, cmap='RdBu', vmin=vmin, vmax=vmax)
    axes[i].set_title(f'Filter {i+1}', fontsize=9)
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Learned Filters from First Convolutional Layer (32 filters, 3√ó3 each)', 
             fontsize=14, fontweight='bold', y=1.01)
plt.show()

print("\n" + "="*70)
print("WHAT TO LOOK FOR:")
print("="*70)
print("\n- Edge detectors (similar to Sobel filters from Module 1!)")
print("- Diagonal patterns")
print("- Corner detectors")
print("- Gradient patterns\n")
print("Notice: The network learned these filters automatically through")
print("gradient descent‚Äîno human design required!")
print("="*70)

---

## üìù **Question Q23 (Observation)**

### **Q23. Examine the learned filters from the first convolutional layer. Do they look like edge detectors (similar to Sobel filters from Module 1)?**

*Hint: Compare these learned filters to the Sobel vertical and horizontal edge detectors you saw in Module 1. What patterns do you recognize?*

**Record your answer in the Answer Sheet.**

---

## üîó **Connection to Lab 4: Dense Networks vs. CNNs**

Let's compare what you did in Lab 4 to what you just did now!

---

## üìä **Lab 4 vs. Lab 7 Comparison**

| Aspect | Lab 4 (Breast Cancer) | Lab 7 (MNIST) |
|--------|-----------------------|---------------|
| **Data type** | Tabular (30 features) | Images (28√ó28 pixels) |
| **Input shape** | (30,) | (28, 28, 1) |
| **Architecture** | Dense ‚Üí Dense ‚Üí Output | Conv ‚Üí Pool ‚Üí Conv ‚Üí Pool ‚Üí Dense ‚Üí Output |
| **Parameters** | ~4,000 | ~100,000 |
| **Training time** | <1 minute | ~2-3 minutes |
| **Training method** | Gradient descent | Same! |
| **Loss function** | Binary cross-entropy | Sparse categorical cross-entropy |
| **Optimizer** | Adam | Adam |
| **Visualization** | ROC curve | Confusion matrix |
| **Key concept** | Hidden layers learn representations | Conv layers learn spatial features |

### **Key Similarities:**
- Both use gradient descent + backpropagation
- Both use Adam optimizer
- Both use cross-entropy loss
- Both achieved high accuracy (>95%)

### **Key Differences:**
- **Data structure:** Tabular vs. spatial (images)
- **Architecture:** Fully-connected vs. convolutional
- **Parameters:** Dense layers have many more parameters for images
- **Inductive bias:** CNNs assume local spatial patterns matter

### **When to Use Each:**
- **Dense networks (Lab 4):** Tabular data, sensor readings, feature vectors
- **CNNs (Lab 7):** Images, videos, spatial data, anything with local patterns

---

## üìù **Question Q24 (Synthesis)**

### **Q24. Compare this CNN training to the Breast Cancer classifier from Lab 4. What's similar? What's different?**

*Think about: training process, architecture, data type, accuracy, time to train, etc.*

**Record your answer in the Answer Sheet.**

---

## üîç **Analyze Misclassified Examples**

Let's look at examples the model got wrong. Sometimes these are genuinely ambiguous!

---

In [None]:
# Find misclassified examples
incorrect_indices = np.where(y_pred != y_test)[0]

print(f"Found {len(incorrect_indices)} misclassified examples out of {len(y_test)}")
print(f"Error rate: {len(incorrect_indices)/len(y_test)*100:.2f}%\n")

# Show first 12 misclassified examples
fig, axes = plt.subplots(3, 4, figsize=(14, 10))
axes = axes.flatten()

for i in range(min(12, len(incorrect_indices))):
    idx = incorrect_indices[i]
    
    # Get image, true label, predicted label
    img = x_test[idx].reshape(28, 28)
    true_label = y_test[idx]
    pred_label = y_pred[idx]
    confidence = y_pred_probs[idx][pred_label] * 100
    
    # Display
    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(f'True: {true_label}, Predicted: {pred_label}\nConfidence: {confidence:.1f}%',
                     fontsize=10, color='red', fontweight='bold')
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Misclassified Examples (What the Model Got Wrong)', 
             fontsize=14, fontweight='bold', y=1.01)
plt.show()

print("\nüí° OBSERVATIONS:")
print("  - Some examples are genuinely ambiguous (even humans might struggle!)")
print("  - Poor handwriting quality")
print("  - Unusual writing styles")
print("  - Digits that look similar (4 vs. 9, 3 vs. 8, etc.)")

---

## üìù **Question Q25 (Critical Thinking)**

### **Q25. Find 2-3 misclassified examples above. Can you understand why the model got them wrong? Are they ambiguous even to you?**

*Look at the images carefully. Record:*
1. True label and predicted label
2. Why you think the model made the mistake
3. Whether you would have classified it correctly

**Record your answer in the Answer Sheet.**

---

## üåâ **Bridge to Lab 8: From Analysis to Synthesis**

### **What You've Learned in Lab 7:**
- CNNs **extract features** from images (edges ‚Üí textures ‚Üí shapes ‚Üí objects)
- Hierarchical architecture builds complex patterns from simple ones
- Learned filters emerge automatically through training
- **Analysis task:** Given an image, identify what's in it

### **What's Coming in Lab 8: Diffusion Models**
- **Synthesis task:** Given a description (or noise), **generate** an image
- Diffusion models **create images from noise**
- Similar architecture (U-Net uses convolution!), opposite direction
- Powers DALL-E, Stable Diffusion, Midjourney

### **The Connection:**
```
Lab 7 (CNNs): Image ‚Üí Features ‚Üí Classification
              (Analysis: "What is this?")

Lab 8 (Diffusion): Noise ‚Üí Features ‚Üí Image
                   (Synthesis: "Create this!")
```

**Both use convolutional architectures!**
- CNNs: Downsampling (image ‚Üí features)
- Diffusion: Upsampling (noise ‚Üí image)
- U-Net: Both together (encoder-decoder architecture)

### **Multimodal AI:**
- **Lab 5:** Text embeddings (sentences ‚Üí vectors)
- **Lab 7:** Image embeddings (images ‚Üí vectors)
- **Lab 8:** Text ‚Üí Image generation ("A dog wearing a hat" ‚Üí actual image!)

**You're building toward understanding modern AI systems like GPT-4 with vision, DALL-E, and beyond!**

---

## ‚úÖ Module 4 Complete!

You now understand:
- **How to build and train a CNN from scratch**
- **What CNNs learn** (automatic feature extraction via gradient descent)
- **How to evaluate CNN performance** (accuracy, confusion matrix)
- **What learned filters look like** (edge detectors, patterns)
- **Connection to Lab 4** (same training process, different architecture)
- **When CNNs fail** (ambiguous examples, similar-looking digits)

**Key insight:**
> CNNs learn hierarchical feature extractors automatically‚Äîno human filter design required. The same training process from Lab 4 (gradient descent) produces sophisticated pattern detectors for images!

**Congratulations!** You've completed all 5 modules of Lab 7!

---

## üìö **Lab 7 Complete: Review**

### **What You've Accomplished:**

**Module 0:** Learned what convolution is (sliding window + multiply-and-add)

**Module 1:** Applied classic filters to real images (Sobel, blur, sharpen)

**Module 2:** Visualized feature maps from pretrained CNN (MobileNetV2)

**Module 3:** Understood hierarchical feature extraction principles

**Module 4:** Trained your own CNN from scratch on MNIST (this module!)

### **Core Concepts Mastered:**
- Convolution operation
- Filters and feature maps
- Hierarchical learning (edges ‚Üí textures ‚Üí shapes ‚Üí objects)
- Parameter sharing and translation invariance
- CNN training pipeline
- Model evaluation and error analysis

### **Connections Made:**
- Lab 3: Activation functions (ReLU after convolution)
- Lab 4: Hidden layers (conv layers are spatially-structured hidden layers)
- Lab 5: Embeddings (CNN final layers create image embeddings)
- Lab 6: Saliency (feature maps show WHAT, saliency shows WHERE)

### **What's Next:**
Lab 8 will explore **diffusion models**‚Äîthe technology behind DALL-E, Stable Diffusion, and Midjourney. You'll learn how to **generate images from text descriptions**, completing your understanding of multimodal AI!

---

**üéâ Great work!**

---