# Develop Model Driver

In this notebook, we will develop the API that will call our model. This module initializes the model, transforms the input so that it is in the appropriate format and defines the scoring method that will produce the predictions. The API will expect the input to be in JSON format. Once  a request is received, the API will convert the json encoded request body into the image format. There are two main functions in the API. The first function loads the model and returns a scoring function. The second function process the images and uses the first function to score them.

In [1]:
import logging
from testing_utilities import img_url_to_json
from pprint import pprint

In [2]:
logging.basicConfig(level=logging.DEBUG)

We use the writefile magic to write the contents of the below cell to driver.py which includes the driver methods.

In [3]:
%%writefile driver.py 
import base64
import json
import logging
import os
import timeit as t
from io import BytesIO

import PIL
import numpy as np
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torchvision import models, transforms



_LABEL_FILE = os.getenv("LABEL_FILE", "synset.txt")
_NUMBER_RESULTS = 3


def _create_label_lookup(label_path):
    with open(label_path, "r") as f:
        label_list = [l.rstrip() for l in f]

    def _label_lookup(*label_locks):
        return [label_list[l] for l in label_locks]

    return _label_lookup


def _load_model():
    # Load the model
    model = models.resnet152(pretrained=True)
    model = model.cuda()
    softmax = nn.Softmax(dim=1).cuda()
    model = model.eval()

    preprocess_input = transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224), interpolation=PIL.Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    def predict_for(image):
        image = preprocess_input(image)
        with torch.no_grad():
            image = image.unsqueeze(0)
            image_gpu = image.type(torch.float).cuda()
            outputs = model(image_gpu)
            pred_proba = softmax(outputs)
        return pred_proba.cpu().numpy().squeeze()

    return predict_for


def _base64img_to_pil_image(base64_img_string):
    if base64_img_string.startswith("b'"):
        base64_img_string = base64_img_string[2:-1]
    base64Img = base64_img_string.encode("utf-8")

    # Preprocess the input data
    decoded_img = base64.b64decode(base64Img)
    img_buffer = BytesIO(decoded_img)

    # Load image with PIL (RGB)
    pil_img = Image.open(img_buffer).convert("RGB")
    return pil_img


def create_scoring_func(label_path=_LABEL_FILE):
    logger = logging.getLogger("model_driver")

    start = t.default_timer()
    labels_for = _create_label_lookup(label_path)
    predict_for = _load_model()
    end = t.default_timer()

    loadTimeMsg = "Model loading time: {0} ms".format(round((end - start) * 1000, 2))
    logger.info(loadTimeMsg)

    def call_model(image, number_results=_NUMBER_RESULTS):
        pred_proba = predict_for(image).squeeze()
        selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]
        labels = labels_for(*selected_results)
        return list(zip(labels, pred_proba[selected_results].astype(np.float64)))

    return call_model


def get_model_api():
    logger = logging.getLogger("model_driver")
    scoring_func = create_scoring_func()

    def process_and_score(images_dict, number_results=_NUMBER_RESULTS):
        start = t.default_timer()

        results = {}
        for key, base64_img_string in images_dict.items():
            rgb_image = _base64img_to_pil_image(base64_img_string)
            results[key] = scoring_func(rgb_image, number_results=number_results)

        end = t.default_timer()

        logger.info("Predictions: {0}".format(results))
        logger.info("Predictions took {0} ms".format(round((end - start) * 1000, 2)))
        return (results, "Computed in {0} ms".format(round((end - start) * 1000, 2)))

    return process_and_score


def version():
    return torch.__version__

Overwriting driver.py


Let's test the module.

We run the file driver.py which will bring everything into the context of the notebook.

In [4]:
%run driver.py

We will use the same Lynx image we used ealier to check that our driver works as expected.

In [5]:
IMAGEURL = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg"

In [6]:
predict_for = get_model_api()

INFO:model_driver:Model loading time: 3961.17 ms


In [7]:
jsonimg = img_url_to_json(IMAGEURL)
json_load_img = json.loads(jsonimg)
body = json_load_img["input"]
resp = predict_for(body)

DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292
DEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'
DEBUG:PIL.PngImagePlugin:Compression method 0
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536
INFO:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9965722560882568), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0013256857637315989), ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}
INFO:model_driver:Predictions took 93.72 ms


In [8]:
pprint(resp[0])

{'image': [('n02127052 lynx, catamount', 0.9965722560882568),
           ('n02128757 snow leopard, ounce, Panthera uncia',
            0.0013256857637315989),
           ('n02128385 leopard, Panthera pardus', 0.0009192737634293735)]}


Next, we can move on to [building our docker image](02_BuildImage.ipynb).