In [None]:
# Importing the full fastai vision module for image-related functionality.
from fastai.vision.all import *

# Downloads a subset of the MNIST dataset containing only digits 3 and 7. Useful for binary classification experiments.
path = untar_data(URLs.MNIST_SAMPLE)

# Creates a DataLoaders object using folder names for labels. Applies image resizing and batch normalization — crucial for consistent CNN performance.
dls = ImageDataLoaders.from_folder(
    path,
    train='train',
    valid='valid',
    item_tfms=Resize(224),
    batch_tfms=Normalize()
)

# Visualizes a sample batch of training images. Helps ensure the dataloader and transforms are working correctly.
dls.show_batch(max_n=9, figsize=(6,6))

In [None]:
# Initializes a ResNet-18 learner with accuracy as the metric. This model is appropriate for lightweight tasks like 3-vs-7 digit classification.
learn = vision_learner(dls, resnet18, metrics=accuracy)

# Fine-tunes the model for 3 epochs. Transfer learning is leveraged to converge quickly.
learn.fine_tune(3)

In [None]:
# Generates an interpretation object to analyze the model’s predictions.
interp = ClassificationInterpretation.from_learner(learn)

# Confirms whether the model is performing well by visualizing the confusion matrix.
interp.plot_confusion_matrix()

In [None]:
# Shows examples of top-5 misclassified digits. Useful for identifying confusing inputs or mislabeled data.
interp.plot_top_losses(5, nrows=1)

In [None]:
from pathlib import Path

# # Ensures that the export path is set to the current working directory.
learn.path = Path('.')  # Redirect save location

# Saves the trained model as a .pkl file. This will be used for inference.
learn.export('mnist_sample_classifier.pkl')
print("Model saved to:", Path('mnist_sample_classifier.pkl').resolve())

In [None]:
# Reloads the exported model for prediction. Critical for decoupling training from deployment.
learn_inf = load_learner('./mnist_sample_classifier.pkl')  # explicitly local

In [None]:
# Randomly selects a sample image of a '3' to test the model. Simulates real-world prediction pipeline.

import random

# Load any digit 3 from the validation set
valid_3s = list((path/'valid'/'3').glob('*.png'))

# Choose one at random
img_path = random.choice(valid_3s)
img = PILImage.create(img_path)

# Runs prediction and prints both the predicted label and confidence. Helps validate the .pkl model independently.
pred, pred_idx, probs = learn_inf.predict(img)

print(f"File: {img_path.name}")
print(f"Prediction: {pred} ({probs[pred_idx]:.2%} confidence)")

In [None]:
# Constructs a Gradio app to interactively classify hand-drawn digits.
import gradio as gr
from fastai.vision.all import *
from PIL import Image, ImageOps
import numpy as np

# Load model
learn = load_learner('mnist_sample_classifier.pkl')

def classify_digit(img_input, threshold):
    if img_input is None:
        return "Please draw a digit before submitting."

    # # Handles Sketchpad input (either as a dict with composite image or raw ndarray). Ensures input preprocessing is robust.
    # Extract from Sketchpad format
    if isinstance(img_input, dict) and "composite" in img_input:
        img_array = img_input["composite"]
    elif isinstance(img_input, np.ndarray):
        img_array = img_input
    else:
        return "Unsupported input format."

    try:
        # Preprocesses the input image to match MNIST style — grayscale, inverted, centered, padded, resized. This is **critical** for good model performance with Sketchpad inputs.
        img = Image.fromarray(img_array).convert('L')     # Grayscale
        img = ImageOps.invert(img)                         # White background, black digit
        img = ImageOps.pad(img, (28, 28), color=255)       # Center and match MNIST format
        img = img.resize((224, 224))                       # Match model input size

        # Performs prediction and formats a user-friendly output with optional low-confidence warning.
        pred, pred_idx, probs = learn.predict(img)
        confidence = float(probs[pred_idx])
        label = learn.dls.vocab[pred_idx]

        if confidence < threshold:
            return f"Prediction: {label} ({confidence:.2%})\nLow confidence"
        else:
            return f"Prediction: {label} ({confidence:.2%})"

    except Exception as e:
        return f"Prediction failed: {str(e)}"

# Description with visual hints
description = """
Draw a handwritten **digit** — either a **3** or a **7**.

The model predicts the digit and shows how confident it is.  
Use the dropdown to choose how cautious the warning system should be.

**Label Guide**:
- **3** → curved top/bottom
- **7** → flat angled top

If the confidence is below your selected threshold, a warning will appear.
"""

# Builds a well-structured Gradio UI:
# - Markdown description for clarity
- Sketchpad for digit drawing
# - Dropdown for confidence threshold
# - Textbox for output
with gr.Blocks() as demo:
    gr.Markdown("# Digit Classifier: 3 vs 7")
    gr.Markdown(description)

    with gr.Row():
        sketch_input = gr.Sketchpad(label="Draw here", image_mode="L")

    threshold_dropdown = gr.Dropdown(
        choices=[0.50, 0.70, 0.85, 0.95],
        value=0.85,
        label="Confidence Threshold"
    )

    result_output = gr.Textbox(label="Model Prediction")

    submit = gr.Button("Classify")
    submit.click(fn=classify_digit, inputs=[sketch_input, threshold_dropdown], outputs=result_output)

# `demo.launch()` starts the app locally.
demo.launch()