<a href="https://colab.research.google.com/github/yhatpub/yhatpub/blob/main/notebooks/fastai/lesson6_pose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fastai Lesson Regression on YHat.pub

This notebook picks up from [Fastai Fastbook 6 Regression](https://github.com/fastai/fastbook/blob/master/06_multicat.ipynb) to [YHat.pub](https://yhat.pub)

To save your model, you'll need to save just the weights and balances of the model, the `pth` file for your learner. A really nice and easy to follow tutorial on `pth` files is here [inference-with-fastai](https://benjaminwarner.dev/2021/10/01/inference-with-fastai)

This is because `load_learner` from lesson 6 relies on the serialized `get_ctr` method, which when unserialzied, need to be on the `__main__` module. If that doesn't make sense, don't worry about it. Just follow the steps below and you'll be fine.

On your lesson 6 notebook, after fine tuning your learner, do the following to save and download your `pth` file
```
learn.save('lesson_6_pose', with_opt=False)
from google.colab import files
files.download('models/lesson_6_pose.pth') 
```

### Installs
The following cell installs pytorch, fastai and yhat_params, which is used to decorate your `predict` function.

In [None]:
!pip install -q --upgrade --no-cache-dir fastai
!pip install -q --no-cache-dir git+https://github.com/yhatpub/yhat_params.git@main

Add the following since matplotlib needs to know where to write it's temp files

In [None]:
import os
import tempfile
os.environ["MPLCONFIGDIR"] = tempfile.gettempdir()

### Imports
**Warning** don't place `pip installs` and `imports` in the same cell. The imports might not work correctly if done that way.

In [None]:
from fastai.vision.all import *
from yhat_params.yhat_tools import FieldType, inference_predict

### Download Model
Google drive does not allow direct downloads for files over 100MB, so you'll need to follow the snippet below to get the download url.

In [None]:
#cleanup from previous download
!rm uc*

#file copied from google drive
google_drive_url = "https://drive.google.com/file/d/10tkEH4-e9mEsxlZlfA1Ta-ILxwtzzHFO/view?usp=sharing"
import os
os.environ['GOOGLE_FILE_ID'] = google_drive_url.split('/')[5]
os.environ['GDRIVE_URL'] = f'https://docs.google.com/uc?export=download&id={os.environ["GOOGLE_FILE_ID"]}'
!echo "This is the Google drive download url $GDRIVE_URL"

`wget` it from google drive. This script places the model in a `model` folder

In [None]:
!wget -q --no-check-certificate $GDRIVE_URL -r -A 'uc*' -e robots=off -nd
!mkdir -p models
!mv $(ls -S uc* | head -1) ./models/export.pth

verify the model exists. **Warning** YHat is pretty finicky about where you place your models. Make sure you create a `model` directory and download your model(s) there  

In [None]:
!ls -l models

### Recreate dataloader and learner

Let's start by creating a dummy image for regression. This is going to be used for our dataloader. 

In [None]:
from PIL import Image, ImageDraw
import os

if not os.path.exists('data'):
    os.mkdir('data')
    img = Image.new('RGB', (1, 1))
    img.save('data/dummyimage.jpg')

And now, we can make a lightweight `DataBlock`, passing in the single image and dummy regression value. 

In [None]:
dblock = DataBlock(
    blocks=(ImageBlock, PointBlock),
    get_x=ColReader('image'), 
    get_y=ColReader('point'),    
    batch_tfms=[*aug_transforms(size=(240,320)), 
                Normalize.from_stats(*imagenet_stats)])

df = pd.DataFrame(
     {
        'image': [
                  "data/dummyimage.jpg", 
                  ], 
        'point': [
                  np.array([1,1]), 
                  ], 
     },
    )
dls = dblock.dataloaders(df, bs=64)

### Load your learner
The following is the equivalent of torch `torch.load` or ts `model.load_weights`

In [None]:
learn_inf = cnn_learner(dls, resnet18, y_range=(-1,1), pretrained=False)
learn_inf.load('export')
learn_inf.model.eval();

And write your predict function. Note, you will need to decorate your function with <a href="https://github.com/yhatpub/yhat_params">inference_predict</a> which takes 2 parameters, a `dic` for input and output.

**Info** These parameters are how YHat.pub maps your predict functions input/output of the web interface. The `dic` key is how you access the variable and the value is it's type. You can use autocomplete to see all the input/output types and more documentation on `inference_predict` is available at the link. 

In [None]:
input = {"image": FieldType.PIL}
output = {"text": FieldType.Text, "image": FieldType.PIL}

@inference_predict(input=input, output=output)
def predict(params):
    img = PILImage.create(np.array(params["image"].convert("RGB")))
    result = learn_inf.predict(img)
    x = float(result[0][0][0])
    y = float(result[0][0][1])

    input_image = params["image"]
    input_image = input_image.resize((320, 240))
    draw = ImageDraw.Draw(input_image)
    radius = 2
    draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=(0, 255, 0), outline=(0, 0, 0))

    return {"text": f"Positions ({x},{y})", "image":input_image}

### Test
First, import `in_colab` since you only want to run this test in colab. YHat will use this colab in a callable API, so you don't want your test to run every time `predict` is called. Next, import `inference_test` which is a function to make sure your `predict` will run with YHat.

Now, inside a `in_colab` boolean, first get whatever test data you'll need, in this case, an image. Then you'll call your predict function, wrapped inside  `inference_test`, passing in the same params you defined above. If something is missing, you should see an informative error. Otherwise, you'll see something like
`Please take a look and verify the results`

In [None]:
from yhat_params.yhat_tools import in_colab, inference_test

if in_colab():
    import urllib.request
    from PIL import Image
    urllib.request.urlretrieve("https://c4.wallpaperflare.com/wallpaper/96/207/799/look-face-pose-background-wallpaper-preview.jpg", "input_image.jpg")
    img = Image.open("input_image.jpg")
    inference_test(predict_func=predict, params={'image': img})

### That's it

If you run into errors, feel free to hop into Discord.

Otherwise, you'll now want to clear your outputs and save a public repo on Github