<a href="https://colab.research.google.com/github/basetenlabs/demos/blob/main/Deploying_a_FastAI_Model_with_Baseten.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![Baseten](https://assets.website-files.com/624535121db2930bcd043f5d/62453d9bddc3de287134cb76_baseten-logo.svg)

# Install dependencies

In [None]:
! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab
!pip install -Uqq fastai fastbook
!pip install torchvision
!pip install baseten

# Write the Baseten-compatible class with load and predict methods

In [None]:
from fastai.data.external import URLs, untar_data
from fastai.vision.data import ImageDataLoaders
from fastai.vision.learner import cnn_learner, error_rate
from fastai.vision.augment import Resize
from fastai.data.transforms import get_image_files
from torchvision.models import resnet34

FASTAI_MODEL_CODE = """import joblib
import requests
from numpy import asarray
from PIL import Image

def label_func(f):
   return f[0].isupper()


class FastAiModel:
   def __init__(self):
       self._model = None

   def load(self):
       self._model = joblib.load("model/model.joblib")

   def predict(self, inputs):
       image_urls = [inp["url"] for inp in inputs]
       images = [self._fetch_image_url(img) for img in image_urls]
       predictions = [self._model.predict(image) for image in images]
       clean_predictions = [self._clean_prediction(pred) for pred in predictions]
       return clean_predictions

   def _fetch_image_url(self, url):
       img = Image.open(requests.get(url, stream=True).raw)
       return asarray(img)

   def _clean_prediction(self, prediction):
       return [prediction[0], prediction[1].tolist(), prediction[2].tolist()]
"""

REQUIREMENTS_TXT = """
fastai==2.5.2
joblib==1.0.1
Pillow==8.3.2
"""

with open('./fai_model.py', 'w') as py_file:
  py_file.write(FASTAI_MODEL_CODE)

with open('./requirements.txt', 'w') as py_file:
  py_file.write(REQUIREMENTS_TXT)

# Train and store the FastAI model

In [None]:
import joblib

from fastai.data.external import URLs, untar_data
from fastai.vision.data import ImageDataLoaders
from fastai.vision.learner import cnn_learner, error_rate
from fastai.vision.augment import Resize
from fastai.data.transforms import get_image_files
from torchvision.models import resnet34

from fai_model import label_func

def create_model():
   path = untar_data(URLs.PETS)
   files = get_image_files(f"{path}/images")
   dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
   learn = cnn_learner(dls, resnet34, metrics=error_rate)
   return learn

pet_model = create_model()
with open("model.joblib", "wb") as f:
   joblib.dump(pet_model, f)

# Call Baseten API

In [None]:
import baseten
baseten.login("*** INSERT API KEY ***") # https://docs.baseten.co/applications/overview/api-keys
baseten.deploy_custom(
    model_name="FastAI demo",
    model_class="FastAiModel",
    model_files=["fai_model.py", "model.joblib"],
    requirements_file="requirements.txt"
)