In [1]:
# -----------------------------
# MNIST DIGIT CLASSIFIER (PyTorch)
# -----------------------------

import torch
import torch.nn as nn #neural network superclass
from torchvision import datasets, transforms #gives ready to use datasets and preprocessing tools
import torch.optim as optim #imports optimization algorithms from torch
from torch.utils.data import DataLoader #streamlines process of loading data

In [2]:
# -----------------------------
# 1. LOAD DATA
# Transforms are preprocessing steps that get applied automatically to every image you load from a dataset. 
# Think of transforms as a recipe that says:

# “Every time you give me an image, do X, then Y, then Z to it.”
# “For every MNIST image: convert it to a PyTorch tensor.
# MNIST images come in as PIL images (Python Imaging Library).

# But your neural network expects tensors.
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor()
])

In [3]:
# Load training dataset (MNIST)
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform,
    download=True
)


In [4]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform,
    download=True
)


In [5]:
# Make DataLoaders (create loaders that give batches of data)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) #increasing batch size speeds up training process
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# TODO: Access and print the unique labels in the training data set using the train_loader object
labels = set() #stores unique items and removes duplicates

for batch in train_loader:
    _, targets = batch #ignores data with _ because we just need labels
    for t in targets: # loop through labels
        labels.add(int(t))

print(labels)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}


In [None]:
# -----------------------------
# 2. DEFINE NEURAL NETWORK
# TODO: Neural Network with 2 hidden layers of 128 neurons
# -----------------------------
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: Define layers (784 inputs from 28 * 28)
        self.fc1 = nn.Linear(784, 128) #1st hidden layer
        self.fc2 = nn.Linear(128, 128) #2nd hidden layer
        self.fc3 = nn.Linear(128, 10) #10 outputs from 0-9
        
    def forward(self, x):
        # Flatten image: (batch, 1, 28, 28) → (batch, 784)
        x = x.view(-1, 28*28)

        # TODO: Add activation between layers
        x = torch.relu(self.fc1(x)) #outputs input if positive, 0 if negative
        x = torch.relu(self.fc2(x))

        # TODO: Output layer
        x = self.fc3(x)

        return x


In [7]:
# TODO: Create the model

model = SimpleNN()

In [8]:
# -----------------------------
# 3. LOSS FUNCTION + OPTIMIZER
# -----------------------------
# TODO: Define your loss function
criterion = nn.CrossEntropyLoss() # loss for classification tasks

# TODO: Setup your gradient descent . Try different values for the learning rate
optimizer = optim.SGD(model.parameters(), lr=0.03, momentum = 0.9) #Stochastic Gradient Descent
# learning rate = size of step, momentum adds inertia (smooths out oscillations)
# optimizer = optim.Adam(model.parameters(), lr=0.001) 

In [9]:
# -----------------------------
# 4. TRAINING LOOP
# -----------------------------

# TODO: Define the number of epochs
epochs = 50

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        # TODO: Reset the gradients
        optimizer.zero_grad()

        # TODO: Forward pass
        outputs = model(images)

        # TODO: Compute loss
        loss = criterion(outputs, labels)

        # TODO: Backpropagate
        loss.backward()
        
        # TODO: Update gradients
        optimizer.step()

        total_loss += loss.item() #converts loss from a tensor into a python number (add every batch's loss to total loss for epoch)

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


Epoch 1, Loss: 321.4670
Epoch 2, Loss: 109.2142
Epoch 3, Loss: 74.2942
Epoch 4, Loss: 53.7998
Epoch 5, Loss: 44.4435
Epoch 6, Loss: 35.0250
Epoch 7, Loss: 27.9654
Epoch 8, Loss: 23.1732
Epoch 9, Loss: 19.1767
Epoch 10, Loss: 16.5698
Epoch 11, Loss: 11.3132
Epoch 12, Loss: 7.5828
Epoch 13, Loss: 7.0057
Epoch 14, Loss: 5.8395
Epoch 15, Loss: 6.8501
Epoch 16, Loss: 2.3408
Epoch 17, Loss: 1.0345
Epoch 18, Loss: 0.5207
Epoch 19, Loss: 0.3677
Epoch 20, Loss: 0.2868
Epoch 21, Loss: 0.2483
Epoch 22, Loss: 0.2177
Epoch 23, Loss: 0.1953
Epoch 24, Loss: 0.1814
Epoch 25, Loss: 0.1656
Epoch 26, Loss: 0.1535
Epoch 27, Loss: 0.1451
Epoch 28, Loss: 0.1361
Epoch 29, Loss: 0.1296
Epoch 30, Loss: 0.1230
Epoch 31, Loss: 0.1160
Epoch 32, Loss: 0.1111
Epoch 33, Loss: 0.1065
Epoch 34, Loss: 0.1022
Epoch 35, Loss: 0.0980
Epoch 36, Loss: 0.0942
Epoch 37, Loss: 0.0906
Epoch 38, Loss: 0.0875
Epoch 39, Loss: 0.0842
Epoch 40, Loss: 0.0818
Epoch 41, Loss: 0.0790
Epoch 42, Loss: 0.0765
Epoch 43, Loss: 0.0743
Epoch 4

In [10]:
# -----------------------------
# 5. EVALUATION
# -----------------------------
correct = 0
total = 0
model.eval()

with torch.no_grad(): #faster with no gradient
    for images, labels in test_loader:
        # TODO: Forward pass
        outputs = model(images)

        # Predicted class = index of max logit
        _, predicted = torch.max(outputs.data, 1) #finds index of the highest score for each image (predicted stores class indices, __ stores max values)

        total += labels.size(0) #counts how many images were tested
        correct += (predicted == labels).sum().item() #number of correct predictions in batch

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 98.48%


In [11]:
import gradio as gr #interactive web-demo library

In [12]:
# 6. TEST SINGLE PREDICTION
# ------------------------------
# Gradio Sketchpad gives you:

# * a full-color NumPy array

# * black digit on white background

# * large resolution

# * no consistent scale
#
# Hence the preprocessing
# ------------------------------

def preprocess_image(image):
    sketch_transform = transforms.Compose([
    transforms.ToPILImage(),                      # NumPy → PIL
    transforms.Grayscale(),                       # ensure 1 channel
    transforms.Resize((28, 28)),                  # 28x28 like MNIST
    transforms.Lambda(lambda img: ImageOps.invert(img)),  # invert colors
    transforms.ToTensor(),                        # → tensor, shape (1,28,28), values in [0,1]
    ])
    # Gradio Sketchpad sometimes passes a dict with 'composite'
    if isinstance(image, dict):
        image = image['composite']   # this is a NumPy array
    
    # Apply the preprocessing transform
    img_tensor = sketch_transform(image)  # (1, 28, 28)
    
    # Add batch dimension → (1, 1, 28, 28)
    img_tensor = img_tensor.unsqueeze(0)

    return img_tensor

def predict_digit(image):
    # --- STEP 1: CHECK IF SOMETHING HAS BEEN DRAWN ---
    if image is None: return "Draw something!"

    # --- STEP 2: PREPROCESS THE IMAGE ---
    img_tensor = preprocess_image(image)
    
    # --- STEP 3: RUN THE MODEL ---
    with torch.no_grad():
        prediction = model(img_tensor)
        
        # Get the index of the highest score (the predicted digit)
        predicted_digit = torch.argmax(prediction).item()
        
    return str(predicted_digit)

In [13]:
from PIL import Image, ImageOps #Pillow library that allows you to open, manipulate, and process images

In [14]:
interface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(type="numpy", label="Draw a digit"),
    outputs=gr.Label(num_top_classes=1),
    live=False
)
interface.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


