[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biodatlab/deep-learning-skooldio/blob/master/02_load_and_pred.ipynb)


# Load trained model and predict on a sample image

In this notebook, we will load the trained model parameter from `02_handwritten_recognition.ipynb` and use it to predict on a sample image.
To do that, we need to
- Create the model
- Upload the model weight to Colab
- Then use it to predict on a sample image.

In [None]:
!git clone https://github.com/biodatlab/deep-learning-skooldio

In [None]:
import torch
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

# Load model using DropoutThaiDigit instead
class DropoutThaiDigit(nn.Module):
    def __init__(self):
        super(DropoutThaiDigit, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 392)
        self.fc2 = nn.Linear(392, 196)
        self.fc3 = nn.Linear(196, 98)
        self.fc4 = nn.Linear(98, 10)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc4(x)
        return x

In [None]:
# TODO: Create model instance, upload trained weight to Colab, and load trained model parameters from ``thai_digit_net.pth``
model = DropoutThaiDigit()
model.load_state_dict(torch.load("deep-learning-skooldio/saved_model/thai_digit_net.pth"))
model.eval()

In [None]:
from glob import glob

sample_path = glob("deep-learning-skooldio/thai-handwritten-dataset/*/*.png")[50]
img = Image.open(sample_path)
y_true = Path(sample_path).parent.name

img = 1 - transform(img)
y_pred = model(img)
y_pred = y_pred.argmax(dim=1)

plt.title("Predicted class = {}, True class = {}".format(y_pred, y_true))
plt.imshow(img.squeeze(0), cmap="gray")
plt.show()

## Using gradio for predicting on a sample image

[Gradio](https://gradio.app/) is a python library that allows you to create a web app for your machine learning model. Here, we will create a prediction app for our model with Gradio. To create a gradio application, we need

- Predict function
- Input component as a sketch pad
- Output component as a list of label

In [None]:
!pip install gradio  # install gradio

In [2]:
import numpy as np

labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label

def predict(img):
    """
    Predict function takes image and return top 5 predictions
    as a dictionary:

        {label: confidence, label: confidence, ...}
    """
    if img is None:
        return None
    img = transform(img)  # do not need to use 1 - transform(img) because gradio already do it
    probs = model(img).softmax(dim=1).ravel()
    probs, indices = torch.topk(probs, 5)  # select top 5
    probs, indices = probs.tolist(), indices.tolist()  # transform to list
    confidences = {LABELS[i]: v for i, v in zip(indices, probs)}
    return confidences

In [None]:
import gradio as gr

gr.Interface(
    fn=predict, 
    inputs=gr.Sketchpad(label="Draw Here", brush_radius=5, type="pil", shape=(120, 120)), 
    outputs=gr.Label(label="Guess"), 
    live=True
).launch()