# Train your image classifier

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hieultp/holistics-nlp-workshop/blob/main/01-image-classification.ipynb)

Let's recap. To train your ~~dragon~~ AI model, you will need:
- Prepare data
- Define the model
- Define the loss function
- Train it!

Okay, let's get started.

In [None]:
! git clone https://github.com/hieultp/holistics-nlp-workshop
cd holistics-nlp-workshop

## 1. Prepare dataset

Let's us prepare the dataset. We will use the MNIST dataset in this excercise.

The MNIST dataset contains images of handwritten digit ranging from 0 to 9.

We will have two seperated image sets. The first one is to used for training the model while the second one, considered unknown to the model, will be used as the test set.

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [None]:
train_data = MNIST(root="./", train=True, download=True)
test_data = MNIST(root="./", train=False, download=True)

Here `train_data` and `test_data` are a kind of generator that will give us a pair of (data, label) as we iterate through it.

Let see what's inside these generator.

In [None]:
train_data[0]

You can see that it is a `tuple` which contains two elements: the former is the input image, the latter is the ground truth corresponding to the input.

Let's check the image then.

In [None]:
img, label = train_data[0]

In [None]:
img

It's kinda small. Let's use a specialize library for visualization to visualize then.

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(img)

Great, that digit sure looks like the number five!

This looks good but unfortunately this is not the right format that Pytorch required. We need a little modification to convert the image above to the `Tensor` format so that we could use to train the model.

In [None]:
train_data = MNIST(root="./", train=True, download=True, transform=ToTensor())
test_data = MNIST(root="./", train=False, download=True, transform=ToTensor())

In [None]:
img, label = train_data[0]

In [None]:
img

This is the `Tensor` format of the above image compliant with Pytorch so that we could use to train our AI model.

## 2. Define the model

There are two ways we could approach for this problem. We could make our model predict a **single** real number ranging from 0 to 9 (regression) or make the model predict what is the probability that the picture belongs to a certain class (classification).

There's no need to rush. I have already defined the model for you ;) Let try to import it then.

Try to fill in the missing code and initialize the model. You could try with different model, play around with it.

In [None]:
from src.models import ClassificationModel, RegressionModel

model = ...

Let's check our model, run the line below.

In [None]:
model

Cool, let's define the loss function and we're ready to go.

## 3. Define the loss function

In [None]:
import torch.nn as nn

Depends on the type of model, we have to choose an appropriate type of the loss function for it. Of course you could go wild and try something new.

But following the convention usually for the *regression model*, we would pair it with *mean squared error* loss function. And for the *classification model*, we will have *cross-entropy* loss for it.

In [None]:
# loss_fn = nn.CrossEntropyLoss()
# loss_fn = nn.MSELoss()

## 4. Train it

It's timeeeee.

I have abstract the training step for you, but feel free to check the source code for more details. In the mean time, just run the two lines below and then check the performance of the model.

<div>
<img src="https://i.kym-cdn.com/entries/icons/original/000/018/147/Shia_LaBeouf__Just_Do_It__Motivational_Speech_(Original_Video_by_LaBeouf__R%C3%B6nkk%C3%B6___Turner)_0-4_screenshot.png" width="600"/>
</div>

In [None]:
from src.train import train

model = train(model, loss_fn, train_data, test_data, num_epochs=10)

## 5. Test it

Great, now that you have your model trained (and hopefully the performance on the test set does no bad).

Let's test it. Run the three cells below and draw any numbers then see if your model could guess it or not.

Then feel free to go back and play around with the loss function, the model to see whether you could improve the model performance or not.

In [None]:
!pip install -q gradio

import torch
import gradio as gr

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

In [None]:
def inference(image):
    try:
        image = torch.from_numpy(image).reshape(1, 1, 28, 28).to(device, dtype=torch.float32) / 255
    except Exception:
        image = torch.zeros(1, 1, 28, 28, dtype=torch.float32, device=device)
    prediction = model(image).softmax(1).squeeze()
    return {str(i): prediction[i].item() for i in range(10)}

gr.Interface(
    fn=inference,
    inputs="sketchpad",
    outputs=gr.outputs.Label(num_top_classes=3),
    live=True,
    css=".footer {display:none !important}",
    title="MNIST Sketchpad",
    description="Draw a number 0 through 9 on the sketchpad, and see predictions in real time.",
    thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png"
).launch()


That's all for this notebook. Thank you for staying until this end ;)