In [7]:
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model
from PIL import Image

In [21]:
model = load_model("mnist_dense_model.h5")



In [30]:
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model
from PIL import Image

# Load model
model = load_model("mnist_dense_model.h5")

def predict_digit(img):
    if img is None:
        return "Please draw a digit."
    
    if isinstance(img, dict) and 'composite' in img:
        img = img['composite']
    
    # Convert array to PIL image
    pil_img = Image.fromarray(np.array(img).astype("uint8")).convert("L")
    
    # Resize to 28x28 (MNIST dimensions)
    pil_img = pil_img.resize((28, 28))
    
    # Invert colors (MNIST has white digits on black background)
    pil_img = Image.eval(pil_img, lambda x: 255 - x)
    
    # Apply thresholding to make it more binary-like (like MNIST)
    threshold = 100  # Adjust this value if needed
    pil_img = pil_img.point(lambda x: 255 if x > threshold else 0)
    
    # Center the digit
    # First, find bounding box of non-zero pixels
    img_array = np.array(pil_img)
    non_zero = np.where(img_array > 0)
    if len(non_zero[0]) > 0:  # Only process if there are non-zero pixels
        top, bottom = np.min(non_zero[0]), np.max(non_zero[0])
        left, right = np.min(non_zero[1]), np.max(non_zero[1])
        
        # Calculate center of mass and shift needed
        center_y, center_x = (top + bottom) // 2, (left + right) // 2
        shift_y = 14 - center_y  # 14 is center of 28x28 image
        shift_x = 14 - center_x
        
        # Create a new blank image
        centered = np.zeros_like(img_array)
        
        # Define bounds for the region to copy, ensuring we stay within bounds
        src_top = max(0, -shift_y)
        src_bottom = min(28, 28 - shift_y)
        src_left = max(0, -shift_x)
        src_right = min(28, 28 - shift_x)
        
        dst_top = max(0, shift_y)
        dst_bottom = min(28, 28 + shift_y)
        dst_left = max(0, shift_x)
        dst_right = min(28, 28 + shift_x)
        
        # Copy the shifted region
        try:
            centered[dst_top:dst_bottom, dst_left:dst_right] = img_array[src_top:src_bottom, src_left:src_right]
            img_array = centered
        except:
            # If centering fails, just use the original image
            pass
    
    # Convert back to PIL Image for saving debug image
    pil_img = Image.fromarray(img_array.astype(np.uint8))
    
    # Save processed image for debugging
    debug_path = "processed_digit.png"
    pil_img.save(debug_path)
    print(f"Saved processed image to {debug_path} for debugging")
    
    # Normalize the array
    img_array = img_array / 255.0
    
    # Reshape for the model (flattening for Dense model)
    img_array = img_array.reshape(1, 784)
    
    # Predict
    prediction = model.predict(img_array, verbose=0)
    predicted_class = np.argmax(prediction)
    
    # Show confidence for top 3 predictions
    top_3_indices = np.argsort(prediction[0])[-3:][::-1]
    result = f"Predicted Digit: {predicted_class}\n\nTop 3 predictions:\n"
    
    for i, idx in enumerate(top_3_indices):
        result += f"{i+1}. Digit {idx}: {prediction[0][idx]*100:.2f}%\n"
    
    return result

# Interface
iface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(),
    outputs="text",
    title="MNIST Digit Recognizer",
    description="Draw a digit (0-9) and the model will predict which digit it is.",
    examples=None,
    theme="default"
)


iface.launch()




* Running on local URL:  http://127.0.0.1:7876
* To create a public link, set `share=True` in `launch()`.




Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
Saved processed image to processed_digit.png for debugging
