In [None]:
# Digit Classifier using USPS dataset and Skorch Library
# Author : Kopal Rastogi
# Created : Mar-23-2024
# Keywords : CNN, Skorch, USPS dataset, Early stopping, Dropout
# Assumptions : None

In [None]:
# Installing Skorch library. Required only once. Uncomment and run if library is not installed.
!pip install skorch

In [None]:
print('Importing Libraries... ',end='')
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from skorch import NeuralNetClassifier
from skorch.callbacks import EarlyStopping
from skorch.dataset import Dataset
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
print('Done')

In [None]:
# Loading the data
print('Loading data... ',)
X, y = fetch_openml('usps', return_X_y=True)
print('Done')

# Get dataset statitics
print('Dataset statistics... ')
print(X.shape,y.shape)

In [None]:
# Preprocessing
X = X / 16.0 # Scale the input to [0, 1] range
X = X.values.reshape(-1, 1, 16, 16).astype(np.float32) # Reshape for CNN input
y = y.astype('int')-1

In [None]:
# Display 5 random samples along with their labels
num_samples = 5
random_indices = random.sample(range(len(X)), num_samples)

fig, axes = plt.subplots(1, num_samples, figsize=(12, 3))
for i, idx in enumerate(random_indices):
    axes[i].imshow(X[idx].reshape(16, 16), cmap='gray')
    axes[i].set_title(f'Label: {y[idx]}')
    axes[i].axis('off')

plt.show()

In [None]:
# Split train-test data in 70:30
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=11)

In [None]:
# Define CNN model
class DigitClassifier(nn.Module):
    def __init__(self):
        super(DigitClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Hyperparameters
max_epochs = 25
lr = 0.001
batch_size = 32
patience = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Wrap the model in Skorch NeuralNetClassifier
digit_classifier = NeuralNetClassifier(
    module = DigitClassifier,
    max_epochs = max_epochs,
    lr = lr,
    iterator_train__batch_size = batch_size,
    iterator_train__shuffle = True,
    iterator_valid__batch_size = batch_size,
    iterator_valid__shuffle = False,
    criterion = nn.CrossEntropyLoss,
    optimizer = torch.optim.Adam,
    callbacks = [EarlyStopping(patience=5)],
    device = device
)

In [None]:
# Train the model
print('Using...', device)
print("Training started...")
digit_classifier.fit(X_train, y_train)
print("Training completed!")

# Evaluate the model
# Evaluate on test data
y_pred = digit_classifier.predict(X_test)
accuracy = digit_classifier.score(X_test, y_test)
print(f'Test accuracy: {accuracy:.4f}')

In [None]:
y_pred = digit_classifier.predict(X_test)

In [None]:
# Plot training loss and accuracy
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(digit_classifier.history[:, 'train_loss'], label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(digit_classifier.history[:, 'valid_acc'], label='Train Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import confusion_matrix, f1_score
import seaborn as sns
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()

In [None]:
# F1 score
f1 = f1_score(y_test, y_pred, average='weighted')
print(f"F1 Score: {f1:.4f}")