<a href="https://colab.research.google.com/github/ispromadhka/MNIST-with-tensorflow-gradio/blob/main/MNIST_with_tensorflow_%2B_gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install tensorflow gradio

In [4]:
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
import pandas as pd
import gradio as gr
import numpy as np
import traceback

# Data preparetion

In [None]:
data = tf.keras.datasets.mnist.load_data()
# data[:1]

In [6]:
(X_train_full, y_train_full) , (X_test, y_test) = tf.keras.datasets.mnist.load_data()

X_train_full = X_train_full.astype("float32") / 255.0
X_test = X_test.astype("float32") / 255.0

X_valid, X_train = X_train_full[:5000], X_train_full[5000:]
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]

In [7]:
X_train = X_train[..., np.newaxis]
X_valid = X_valid[..., np.newaxis]
X_test  = X_test[..., np.newaxis]

In [8]:
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1
)
datagen.fit(X_train)

# Model

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(10, activation='softmax',kernel_regularizer=tf.keras.regularizers.l2(0.001))
])

In [10]:
model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-3, weight_decay=1e-4),
        metrics=['accuracy']
    )

## Callbacks

In [11]:
early_stop = tf.keras.callbacks.EarlyStopping(
    patience=3,
    restore_best_weights=True
)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    factor=0.5,
    patience=3,
    verbose=1
)
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath = 'best_model.h5',
    save_best_only=True,
    monitor='val_accuracy'
)

## Learning model

In [None]:
history = model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    validation_data=(X_valid, y_valid),
    epochs=10,
    batch_size=32,
    callbacks=[
        early_stop,
        lr_scheduler,
        checkpoint,
    ]
)

## Plot

In [None]:
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

# Interface

In [16]:
def predict_digit(image_input):
    try:
        image_array = None
        if isinstance(image_input, dict):
            if 'composite' in image_input and image_input['composite'] is not None:
                image_array = image_input['composite']
            elif 'image' in image_input and image_input['image'] is not None:
                image_array = image_input['image']
            elif 'background' in image_input and image_input['background'] is not None:
                image_array = image_input['background']

        if image_array is None:
            image_array = image_input

        if hasattr(image_array, 'convert'):
            pil_image = image_array.convert('L')
        elif isinstance(image_array, np.ndarray):
            pil_image = Image.fromarray(image_array.astype('uint8')).convert('L')
        else:
            return "Invalid input format"

        pil_image = pil_image.resize((28, 28), Image.Resampling.LANCZOS)
        input_image = ((255 - np.array(pil_image)).astype('float32') / 255.0).reshape(1, 28, 28, 1)
        predictions = model.predict(input_image, verbose=0)
        predicted_digit = np.argmax(predictions[0])
        confidence = np.max(predictions[0])
        result = f"🎯 The recognized digit: {predicted_digit}\n"
        result += f"📊 Confidence: {confidence:.1%}\n\n"
        result += "📈 Probability:\n"
        for i in range(10):
            prob = predictions[0][i]
            stars = "★" * int(prob * 10)
            result += f"  {i}: {prob*100:.2f}% {stars}\n"
        return result

    except Exception as e:
        error_msg = f" Error:\n{str(e)}\n\n"
        error_msg += f" Full info about error:\n{traceback.format_exc()}"
        print(error_msg)
        return error_msg

In [None]:
interface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(type="numpy",label="Draw the digit от 0 до 9"),
    outputs=gr.Textbox(label="Result", lines=12),
    title="🔢 Recognizatoin handmade digit",
    description="""
    Instructions:
    1. Draw a digit in the drawing field\n 2. Click Submit\n 3. Get the recognition result \n The model is trained on the MNIST dataset.
    """,
    allow_flagging="never"
)

print("🚀 launicng...")
interface.launch(debug=True, share=True)