# Import Libraries and Set Device
# Import necessary libraries and set the random seed and device (GPU if available)


In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import os

torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Define CNN Model for MNIST Classification
# A simple CNN with two convolutional layers, dropout, and two fully connected layers


In [15]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

# Load MNIST Dataset with Data Augmentation
# Apply augmentation only on training set, normalize both training and validation data


In [16]:
def load_mnist_data(batch_size=64):
    train_transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    full_dataset = datasets.MNIST('data', train=True, download=True, transform=train_transform)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    val_dataset.dataset.transform = val_transform

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

# Train Model with Early Stopping
# Save the best model based on validation loss and stop training if no improvement for 'patience' epochs


In [17]:
def train_model(model, train_loader, val_loader, epochs=30, learning_rate=0.001, patience=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.to(device)

    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader.dataset)
        val_loss /= len(val_loader.dataset)
        accuracy = 100 * correct / total

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Val Accuracy: {accuracy:.2f}%")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')  # Save best model weights
            print("Saved best model.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= patience:
            print(f"Early stopping after {epoch+1} epochs.")
            break

    model.load_state_dict(torch.load('best_model.pth'))  # Load best model before returning
    return model

# Preprocess Input Image for Prediction
# Convert input from Gradio Sketchpad to normalized tensor suitable for the model


In [18]:
def preprocess_image(input_data):
    if isinstance(input_data, dict):
        if 'image' in input_data:
            image = input_data['image']
        elif 'composite' in input_data:
            image = input_data['composite']
        else:
            raise ValueError("Dictionary input missing image data")
    elif isinstance(input_data, np.ndarray):
        image = input_data
    else:
        try:
            image = np.array(input_data)
        except:
            raise ValueError(f"Unsupported input type: {type(input_data)}")

    if image.ndim == 3:
        if image.shape[2] == 4:
            image = image[..., :3]
        image = np.mean(image, axis=2)  # Convert to grayscale by averaging channels

    image = 255 - image  # Invert colors: background black, digit white

    image_tensor = transforms.functional.to_tensor(image).unsqueeze(0)
    image_tensor = transforms.functional.resize(image_tensor, (28, 28))
    image_tensor = transforms.functional.normalize(image_tensor, (0.1307,), (0.3081,))
    return image_tensor

# Plot Prediction Probabilities Bar Chart
# Visualize the model's output probabilities for digits 0-9


In [19]:
def plot_probabilities(probabilities):
    fig, ax = plt.subplots()
    bars = ax.bar(range(10), probabilities)
    ax.set_xlabel('Digit')
    ax.set_ylabel('Probability')
    ax.set_title('Prediction Probabilities')
    ax.set_xticks(range(10))
    ax.set_ylim(0, 1)
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}',
                ha='center', va='bottom')
    return fig

# Prediction Function for Gradio Interface
# Processes input, predicts digit, and returns predicted digit and probability plot

In [20]:
def predict_digit(input_data):
    try:
        input_tensor = preprocess_image(input_data).to(device)
        model.eval()
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]
        predicted_digit = int(np.argmax(probabilities))
        prob_plot = plot_probabilities(probabilities)
        return predicted_digit, prob_plot
    except Exception as e:
        print(f"Error during prediction: {str(e)}")
        return "Error", None

# Load Data and Prepare DataLoaders
# Load MNIST dataset with augmentation and split into training and validation loaders


In [21]:
print("Loading data...")
train_loader, val_loader = load_mnist_data(batch_size=64)

Loading data...


# Create or Load Model
# Initialize the CNN model, load saved weights if available, otherwise train the model


In [22]:
print("Creating model...")
model = MNISTClassifier()

if os.path.exists('best_model.pth'):
    model.load_state_dict(torch.load('best_model.pth'))
    model.to(device)
    print("Loaded saved model.")
else:
    print("Training model...")
    model = train_model(model, train_loader, val_loader, epochs=30, patience=5)

Creating model...
Loaded saved model.


# Setup Gradio Interface
# Create a simple UI for drawing digits and displaying predictions and probability plots


In [23]:
with gr.Blocks() as demo:
    gr.Markdown("# MNIST Digit Recognition")
    gr.Markdown("Draw a digit (0-9) in the box below and see the model's prediction.")

    with gr.Row():
        sketchpad = gr.Sketchpad(label="Draw Digit", image_mode="L", type="numpy")
        with gr.Column():
            label = gr.Label(label="Predicted Digit")
            plot = gr.Plot(label="Prediction Probabilities")

    sketchpad.change(fn=predict_digit, inputs=sketchpad, outputs=[label, plot])

if __name__ == "__main__":
    demo.launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://c0c09b0cd564dc7c8e.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
