Use flask is the easiest way to serve PyTorch model, but it will not work for a use case with high performance requirements. For that:

    - Use TorchScript Model in C++ tutorial.

### Define API endpoint
- endpoint: /predict
- HTTP POST with a file parameter containing image
- Response: JSON containing prediction

### Simple Web Server

In [None]:
from flask import Flask

app = Flask(__name__)

@app.route('/')
def hello():
    return 'Hello World!'

In [None]:
@app.route('/dummy_predict', methods=['POST'])
def dummy_predict():
    return 'Hello World!'

In [None]:
@app.route('/dummy_predict_2', methods=['POST'])
def dummy_predict_2():
    return jsonify({
        'class_id': 'IMAGE_NET_XXX',
        'class_name': 'Cat'
    })

### Build an image transformation pipeline

In [None]:
import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225]
        )
    ])
    
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

In [None]:
# test transform_image()
with open('cat.jpg', 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)
    print(tensor.shape)

### Prediction

In [None]:
from torchvision import models
import json

imagenet_class_index = json.load(open('imagenet_class_index.json'))

# Keep the model loaded in memory just once before serving the requests.
model = models.densenet121(pretrained=True)

model.eval()

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

In [None]:
# test get_prediction()
with open('cat.jpg', 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

### Update predict method to read file from the requests

In [None]:
from flask import request

# Completed predict endpoint
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(
            image_bytes=image_bytes
        )
    return jsonify({
        'class_id': class_id,
        'class_name': class_name
    })

### Putting everything together

To test: put the following code into app.py and run:

`FLASK_ENV=development FLASK_APP=app.py flask run -p 5002`

In [None]:
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)

imagenet_class_index = json.load(open('imagenet_class_index.json'))

model = models.densenet121(pretrained=True)
model.eval()

def transform_image(image_bytes):
    my_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225]
        )
    ])
    
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(img_bytes)
        return jsonify({
            'class_id': class_id,
            'class_name': class_name
        })
    
if __name__ == '__main__':
    app.run(
        host='localhost',
        port=5002,
        debug=True
    )

### Send a POST request to the deployed API endpoint

Note: run the flask app in the previous cell in terminal first

In [1]:
import requests

resp = requests.post(
    'http://localhost:5002/predict',
    files={
        'file': open('cat.jpg', 'rb')
    }
)

In [3]:
resp.json()

{'class_id': 'n02123159', 'class_name': 'tiger_cat'}