# NOTE
This notebook, and all others involving Docker, cannot be run on Colab, and should be run either on your local machine or on your Vertex AI Workbench instance by cloning the [til-24-curriculum repository from GitHub](https://github.com/TIL-24/til-24-curriculum/).

# Deploying a Pre-Trained Model with FastAPI and Docker on GCP

This tutorial demonstrates how to expose a pre-trained PyTorch model as a REST API using FastAPI, containerize the application with Docker on Google Cloud, and test the endpoint using Python's `requests` library.

### Step 1: Prepare Your PyTorch Model

First, make sure you have your PyTorch model ready. For this example, we'll use the ResNet model. The ResNet model, short for Residual Network, is a type of convolutional neural network (CNN). It was trained on the ImageNet dataset, a large visual database containing over 14 million images categorized into 1,000 classes. ResNet models are capable of predicting the class of an image among these categories, demonstrating impressive accuracy and efficiency in image classification tasks.


### Step 2: Create the FastAPI Application

Create a new Python script named `app.py` that will serve as our FastAPI application. This application will handle API requests and return predictions from the PyTorch model.

```python
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
from torchvision import models, transforms
from PIL import Image
import io

app = FastAPI()

# Load the pre-trained model
model = models.resnet18(pretrained=True)
model.eval()  # Set model to evaluation mode

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    image_data = await file.read()
    image = Image.open(io.BytesIO(image_data))
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        output = model(image)
        predicted_index = output.argmax(1).item()
    
    return JSONResponse(content={"predicted_class": predicted_index})


### Step 3: Dockerize the Application
Create a Dockerfile in the same directory as your `app.py`:

```Dockerfile
# Use an official PyTorch runtime as a parent image
FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime

# Install pip packages
RUN pip install fastapi uvicorn Pillow torchvision python-multipart

# Set the working directory
WORKDIR /app

# Copy the local directory contents to the container
COPY . /app

# Command to run the app using uvicorn
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
```

### Step 4: Build and Run the Docker Container
Open a terminal window in JupyterLab and run the following command to build the Docker image and wait for the `Successfully tagged demo-model:latest` message.

<img src="https://lh3.googleusercontent.com/d/18onL8nKzNOQ7yJwsXhtN-Azc2vudc7dJ" alt="drawing" width="600"/>

```bash
docker build -t demo-model .
```

<img src="https://lh3.googleusercontent.com/d/1NU9DAnhWysP9wh6-mOuj0F5ZzSdUl-Jj" alt="drawing" width="800"/>

Run the Docker container

```bash
docker run -p 8000:8000 demo-model
```


<img src="https://lh3.googleusercontent.com/d/1_efTzPvazat3b1WU1Pmueplqj5NwGLRe" alt="drawing" width="1000"/>

### Step 5: Test the API Using Python Requests
With the Docker container running, create a Python notebook `test_api.ipynb` to send an image to the FastAPI application and receive a prediction:
```Python
import requests

# URL of the FastAPI endpoint
url = 'http://localhost:8000/predict/'

# Path to the image file
file_path = '/home/jupyter/imgs/cat.jpeg'

# Open the image file in binary mode
with open(file_path, 'rb') as f:
    # Prepare the request payload as a dictionary
    files = {'file': (file_path, f, 'image/jpeg')}
    
    # Send the POST request
    response = requests.post(url, files=files)

# Print the response
print(response.json())
```

<html>
<body>

<p>
  <img src="https://lh3.googleusercontent.com/d/1X_uWg8ADU3kEgxe-BKsEvAu3Oy54VJHU" alt="drawing" width="400"/>
  <img src="https://lh3.googleusercontent.com/d/19mLZnbfLTbGiCxZt_U5Vo8x3iIOIm9D-" alt="drawing" width="500"/>
</p>

</body>
</html>


To interpret the predicted class from a ResNet model trained on the ImageNet dataset, we can refer to the ImageNet class index [here](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/) , which maps each class index to a human-readable label. When the model predicts an integer class index, this index corresponds to a specific label in the ImageNet class list.

For example, if the model outputs a prediction of `281`, you can look up this index in the ImageNet label file to find that it corresponds to the class "tabby,tabby cat".

<img src="https://lh3.googleusercontent.com/d/1bhhPWBdxEvP-qKGJpA9m-iWFGOyhufcS" alt="drawing" width="550"/>

<html>
<body>

<p>
  <img src="https://lh3.googleusercontent.com/d/1VeGaNxgHgf65nSHHZ4FMLAvi16zEGCck" alt="drawing" width="400"/>
  <img src="https://lh3.googleusercontent.com/d/1YF8B1V6k4xaRm94N9u7r8L4YWshaa0Q_" alt="drawing" width="600"/>
  <img src="https://lh3.googleusercontent.com/d/1G8M9dmiXcX1U-rd7iA24T5wEQDmnsEkc" alt="drawing" width="400"/>
</p>

</body>
</html>
