# Handwritten Digit Classification with scikit-learn

This short notebook shows how to train a simple neural network on the built-in scikit-learn digits dataset.

In [None]:
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

# Load the digits dataset (8x8 grayscale images of handwritten digits)
digits = load_digits()
X, y = digits.data, digits.target

print(f'Dataset shape: {X.shape}, Labels: {set(y)}')

The dataset contains 8x8 grayscale images of handwritten digits (0-9). Each image is flattened into 64 features.

In [None]:
# Show a grid of sample digit images
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for ax, image, label in zip(axes.ravel(), digits.images[:10], y[:10]):
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Label: {label}')
    ax.axis('off')
plt.tight_layout();

Split the data so we can train the model on one portion and test it on unseen images.

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print(f'Training samples: {X_train.shape[0]}, Test samples: {X_test.shape[0]}')

Train a simple neural network using scikit-learn's `MLPClassifier`.

In [None]:
# Define and train the model
mlp = MLPClassifier(hidden_layer_sizes=(32,), max_iter=300, random_state=42)
mlp.fit(X_train, y_train)

# Evaluate accuracy on the test set
y_pred = mlp.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test accuracy: {accuracy:.3f}')

Look at a few test images alongside the model's predictions.

In [None]:
# Display several test images with their true and predicted labels
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for ax, image, true_label, pred_label in zip(axes.ravel(), X_test[:10], y_test[:10], y_pred[:10]):
    ax.imshow(image.reshape(8, 8), cmap='gray')
    ax.set_title(f'True: {true_label}\nPred: {pred_label}')
    ax.axis('off')
plt.tight_layout();