### **Image Classification with a UI Interface**

In this notebook, I create a customizable web-based GUI with Gradio where we can draw digits (0-9) to test a simple image classification model based on the MNIST dataset.

In [27]:
!pip install -q gradio

In [28]:
import tensorflow as tf
import gradio as gr

In [29]:
# I load the MNIST dataset.
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [30]:
print('x_train shape: ', x_train.shape)
print('y_train shape: ', y_train.shape)
print('x_test shape: ', x_test.shape)
print('y_test shape: ', y_test.shape)

x_train shape:  (60000, 28, 28)
y_train shape:  (60000,)
x_test shape:  (10000, 28, 28)
y_test shape:  (10000,)


In [31]:
# I rescale the training and test data set.
x_train = x_train / 255.0, 
x_test = x_test / 255.0

In [32]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense

# I define the architecture of the NN model.
model = Sequential([
  Flatten(input_shape=(28, 28)),
  Dense(128,activation='elu'),
  Dense(128,activation='elu'),
  Dense(10, activation='softmax')
])

# I configure the model for training.
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# I train the model.
model.fit(x_train, y_train,
          validation_data=(x_test, y_test),
          epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f4d52571ef0>

In [33]:
# I evaluate the performance of the test set.
accuracy = model.evaluate(x_test, y_test)

print('Test set accuracy: ', accuracy * 100)

Test set accuracy:  [0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287598, 0.09157952666282654, 0.9779000282287

In [None]:
# In order to create the web-based GUI, we need to define the prediction function, the input UI, and the output UI.
def recognize_digit(img):
    prediction = model.predict(img).tolist()[0]
    return {str(i): prediction[i] for i in range(10)}
    
sketchpad = gr.inputs.Sketchpad()
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(recognize_digit,
             inputs=sketchpad,
             outputs=label,
             live=True,
             capture_session=True).launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
This share link will expire in 24 hours. If you need a permanent link, email support@gradio.app
Running on External URL: https://21412.gradio.app
Interface loading below...
