# MNIST Digit Classification with SGD

This notebook demonstrates:
1. Loading the MNIST dataset from sklearn
2. Training an SGD classifier
3. Generating confusion matrices for evaluation

## 1. Imports

In [None]:
# Data loading and manipulation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Sklearn imports
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

# Set random seed for reproducibility
np.random.seed(42)

print("All imports successful!")

## 2. Load MNIST Dataset

In [None]:
# Load MNIST dataset (this may take a moment the first time)
print("Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')

# Extract features and labels
X, y = mnist["data"], mnist["target"]

# Convert labels to integers
y = y.astype(np.uint8)

print(f"Dataset shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Feature range: {X.min()} to {X.max()}")
print(f"Unique labels: {np.unique(y)}")

In [None]:
# Let's look at a few sample digits
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    # Reshape the 784 features back to 28x28 image
    digit_image = X[i].reshape(28, 28)
    ax.imshow(digit_image, cmap='gray')
    ax.set_title(f'Label: {y[i]}')
    ax.axis('off')

plt.suptitle('Sample MNIST Digits', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Proper train/dev/test split following ML best practices
from sklearn.model_selection import train_test_split

# First split: separate test set (MNIST convention: last 10k samples)
X_temp, X_test = X[:60000], X[60000:]
y_temp, y_test = y[:60000], y[60000:]

# Second split: divide remaining data into train/dev (80%/20%)
X_train, X_dev, y_train, y_dev = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=42, stratify=y_temp)

print(f"Training set: {X_train.shape} features, {y_train.shape} labels")
print(f"Development set: {X_dev.shape} features, {y_dev.shape} labels")
print(f"Test set: {X_test.shape} features, {y_test.shape} labels")
print(f"\nSplit percentages:")
total = len(X)
print(f"Train: {len(X_train)/total*100:.1f}%")
print(f"Dev: {len(X_dev)/total*100:.1f}%")
print(f"Test: {len(X_test)/total*100:.1f}%")

# Check class distribution in training set
unique, counts = np.unique(y_train, return_counts=True)
print("\nTraining set class distribution:")
for digit, count in zip(unique, counts):
    print(f"Digit {digit}: {count} samples ({count/len(y_train)*100:.1f}%)")

# Verify stratification worked (dev set should have similar distribution)
unique_dev, counts_dev = np.unique(y_dev, return_counts=True)
print("\nDevelopment set class distribution:")
for digit, count in zip(unique_dev, counts_dev):
    print(f"Digit {digit}: {count} samples ({count/len(y_dev)*100:.1f}%)")

## 3. Train SGD Classifier

In [None]:
# Create and train SGD classifier
print("Training SGD Classifier...")
sgd_clf = SGDClassifier(random_state=42, max_iter=1000, tol=1e-3)

# Train the classifier
sgd_clf.fit(X_train, y_train)

print("Training completed!")
print(f"Model classes: {sgd_clf.classes_}")

In [None]:
# Evaluate using cross-validation on training set
print("Performing 3-fold cross-validation...")
cv_scores = cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")

print(f"Cross-validation scores: {cv_scores}")
print(f"Mean CV accuracy: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")

## 4. Generate Confusion Matrices

In [None]:
# Get predictions using cross-validation to avoid overfitting
print("Getting cross-validated predictions...")
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train, cv=3)

# Generate confusion matrix
conf_mx = confusion_matrix(y_train, y_train_pred)
print("Confusion matrix shape:", conf_mx.shape)
print("\nConfusion Matrix:")
print(conf_mx)

In [None]:
# Visualize confusion matrix with matplotlib
plt.figure(figsize=(10, 8))
plt.imshow(conf_mx, interpolation='nearest', cmap='Blues')
plt.title('Confusion Matrix - MNIST SGD Classifier', fontsize=16)
plt.colorbar()

# Add labels
tick_marks = np.arange(10)
plt.xticks(tick_marks, range(10))
plt.yticks(tick_marks, range(10))
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)

# Add text annotations
thresh = conf_mx.max() / 2.
for i in range(10):
    for j in range(10):
        plt.text(j, i, format(conf_mx[i, j], 'd'),
                ha="center", va="center",
                color="white" if conf_mx[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
# Alternative: Use sklearn's ConfusionMatrixDisplay (cleaner)
plt.figure(figsize=(10, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=conf_mx, display_labels=range(10))
disp.plot(cmap='Blues', values_format='d')
plt.title('Confusion Matrix - MNIST SGD Classifier (sklearn version)', fontsize=14)
plt.show()

In [None]:
# Analyze error rates per digit
# Normalize confusion matrix to show error rates
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums

# Zero out the diagonal to focus on errors
np.fill_diagonal(norm_conf_mx, 0)

plt.figure(figsize=(10, 8))
plt.imshow(norm_conf_mx, interpolation='nearest', cmap='Reds')
plt.title('Confusion Matrix - Error Rates Only (Normalized)', fontsize=16)
plt.colorbar()
plt.xticks(range(10))
plt.yticks(range(10))
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.show()

print("\nMost common misclassifications:")
# Find the highest error rates (excluding diagonal)
error_indices = np.unravel_index(np.argsort(norm_conf_mx.ravel())[-5:], norm_conf_mx.shape)
for i in range(4, -1, -1):  # Show top 5 errors in descending order
    true_digit = error_indices[0][i]
    pred_digit = error_indices[1][i]
    error_rate = norm_conf_mx[true_digit, pred_digit]
    if error_rate > 0:  # Only show actual errors
        print(f"Digit {true_digit} classified as {pred_digit}: {error_rate:.3f} ({error_rate*100:.1f}%)")

In [None]:
# Generate detailed classification report
print("Classification Report:")
print(classification_report(y_train, y_train_pred))

## 5. Test Set Evaluation

In [None]:
# Finally, evaluate on the test set
print("Evaluating on test set...")
test_score = sgd_clf.score(X_test, y_test)
print(f"Test set accuracy: {test_score:.4f}")

# Generate test set predictions and confusion matrix
y_test_pred = sgd_clf.predict(X_test)
test_conf_mx = confusion_matrix(y_test, y_test_pred)

plt.figure(figsize=(10, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=test_conf_mx, display_labels=range(10))
disp.plot(cmap='Blues', values_format='d')
plt.title('Test Set Confusion Matrix - MNIST SGD Classifier', fontsize=14)
plt.show()