<a href="https://colab.research.google.com/github/kwb425/class-2025-spring/blob/main/class-2025-spring_0509-5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries
!pip install tensorflow gradio

# Import necessary libraries
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import numpy as np
import gradio as gr
from PIL import Image

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0  # Normalize pixel values to [0, 1]
x_test = x_test / 255.0

# Add a channel dimension (for grayscale images)
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]

# Build the CNN model
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    # DNN
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")

# Map digit predictions to word equivalents
digit_to_word = {
    0: "zero",
    1: "one",
    2: "two",
    3: "three",
    4: "four",
    5: "five",
    6: "six",
    7: "seven",
    8: "eight",
    9: "nine"
}

# Define the Gradio interface function
def classify_digit(data):
    try:
        # Extract composite image from Gradio Sketchpad data
        composite_array = np.array(data["composite"])
        image = Image.fromarray(composite_array)
        image = image.resize((28, 28))  # Resize to match MNIST format
        image_array = np.array(image) / 255.0
        r, g, b, a = image_array[:,:,0], image_array[:,:,1], image_array[:,:,2], image_array[:, :, 3]

        # Make a prediction
        prediction = model.predict(a.reshape(1, 28, 28, 1)).flatten()
        class_idx = np.argmax(prediction)
        confidence = prediction[class_idx]

        # Return the word label and confidence
        return f"{digit_to_word[class_idx]} ({confidence:.2f})"
    except Exception as e:
        return f"Error: {str(e)}"



# Create the Gradio interface
demo = gr.Interface(
    fn=classify_digit,
    inputs=gr.Sketchpad(label="Draw a digit"),  # Use Sketchpad for drawing
    outputs="text",  # Output a text label
    title="Real-time Digit Classifier",
    description="Draw a digit (0-9) on the canvas and see the prediction in words."
)

# Launch the app
demo.launch(debug=True)
