# Develop Model Driver

In this notebook, we will develop the API that will call our model. This module initializes the model, transforms the input so that it is in the appropriate format and defines the scoring method that will produce the predictions. The API will expect the input to be in JSON format. Once  a request is received, the API will convert the json encoded request body into the image format. There are two main functions in the API. The first function loads the model and returns a scoring function. The second function process the images and uses the first function to score them.

In [49]:
import logging
from testing_utilities import img_url_to_json
from pprint import pprint
from azureml.core.webservice import Webservice, AksWebservice

In [50]:
logging.basicConfig(level=logging.DEBUG)

We use the writefile magic to write the contents of the below cell to score.py which includes the driver methods. It is important that the file have two methods ```init``` and ```run```. These two functions define the contract with the Flask web application. Have a look here at another example
https://docs.microsoft.com/en-us/azure/machine-learning/service/tutorial-deploy-models-with-aml

In [62]:
%%writefile score.py 
import base64
import json
import logging
import os
import timeit as t
from io import BytesIO

import PIL
import numpy as np
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torchvision import models, transforms
import sys
from azureml.core.model import Model
from glob import glob
import warnings

logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) # TODO: remove


_LABEL_FILE = os.getenv("LABEL_FILE", "synset.txt")
_MODEL_NAME = os.getenv("MODEL_NAME", "pytorch_resnet152")
_NUMBER_RESULTS = 3


class ModelFileNotFoundError(Exception):
    pass

def _create_label_lookup(label_path):
    with open(label_path, "r") as f:
        label_list = [l.rstrip() for l in f]

    def _label_lookup(*label_locks):
        return [label_list[l] for l in label_locks]

    return _label_lookup


def _load_model():
    logger = logging.getLogger("model_driver")
    # Load the model
    model_path = Model.get_model_path(_MODEL_NAME)
    
    file_list = glob(os.path.join(_MODEL_NAME,'*.pth'))
    if len(file_list)==0:
        raise ModelFileNotFoundError(f'Appropriate model not found in {_MODEL_NAME}')
    elif len(file_list)>1:
        warnings.warn("More than one model found. Selecting first model")

    filename = file_list[0]
    logger.debug(f'Loading {filename}')
    
    # ResNet 152
    model = models.ResNet(models.resnet.Bottleneck, [3, 8, 36, 3])
    model.load_state_dict(torch.load(filename))
    
    model = model.cuda()
    softmax = nn.Softmax(dim=1).cuda()
    model = model.eval()

    preprocess_input = transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224), interpolation=PIL.Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    def predict_for(image):
        image = preprocess_input(image)
        with torch.no_grad():
            image = image.unsqueeze(0)
            image_gpu = image.type(torch.float).cuda()
            outputs = model(image_gpu)
            pred_proba = softmax(outputs)
        return pred_proba.cpu().numpy().squeeze()

    return predict_for


def _base64img_to_pil_image(base64_img_string):
    if base64_img_string.startswith("b'"):
        base64_img_string = base64_img_string[2:-1]
    base64Img = base64_img_string.encode("utf-8")

    # Preprocess the input data
    decoded_img = base64.b64decode(base64Img)
    img_buffer = BytesIO(decoded_img)

    # Load image with PIL (RGB)
    pil_img = Image.open(img_buffer).convert("RGB")
    return pil_img


def create_scoring_func(label_path=_LABEL_FILE):
    logger = logging.getLogger("model_driver")

    start = t.default_timer()
    labels_for = _create_label_lookup(label_path)
    predict_for = _load_model()
    end = t.default_timer()

    loadTimeMsg = "Model loading time: {0} ms".format(round((end - start) * 1000, 2))
    logger.info(loadTimeMsg)

    def _call_model(image, number_results=_NUMBER_RESULTS):
        pred_proba = predict_for(image).squeeze()
        selected_results = np.flip(np.argsort(pred_proba), 0)[:number_results]
        labels = labels_for(*selected_results)
        return list(zip(labels, pred_proba[selected_results].astype(np.float64)))

    return _call_model


def get_model_api():
    logger = logging.getLogger("model_driver")
    scoring_func = create_scoring_func()

    def _process_and_score(images_dict, number_results=_NUMBER_RESULTS):
        start = t.default_timer()

        results = {}
        for key, base64_img_string in images_dict.items():
            rgb_image = _base64img_to_pil_image(base64_img_string)
            results[key] = scoring_func(rgb_image, number_results=number_results)

        end = t.default_timer()

        logger.debug("Predictions: {0}".format(results))
        logger.info("Predictions took {0} ms".format(round((end - start) * 1000, 2)))
        return (results, "Computed in {0} ms".format(round((end - start) * 1000, 2)))

    return _process_and_score

def version():
    return torch.__version__

def init():
    """ Initialise the model and scoring function
    """
    global process_and_score
    process_and_score = get_model_api()

def run(raw_data):
    """ Make a prediction based on the data passed in using the preloaded model
    """
    return process_and_score(json.loads(raw_data)['input'])

Overwriting score.py


Let's test the module.

We run the file driver.py which will bring everything into the context of the notebook.

In [64]:
%run score.py

We will use the same Lynx image we used ealier to check that our driver works as expected.

In [65]:
IMAGEURL = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Lynx_lynx_poing.jpg/220px-Lynx_lynx_poing.jpg"

In [66]:
init()

DEBUG:azureml.core.model:RunEnvironmentException: Failed to load a submitted run, if outside of an execution context, use project.start_run to initialize an azureml.core.Run.
DEBUG:azureml.core.model:Checking root for pytorch_resnet152 because candidate dir azureml-models had 0 nodes: 
DEBUG:model_driver:Loading pytorch_resnet152/resnet152-b121ed2d.pth
INFO:model_driver:Model loading time: 1950.18 ms


In [67]:
predict_for = get_model_api()

DEBUG:azureml.core.model:RunEnvironmentException: Failed to load a submitted run, if outside of an execution context, use project.start_run to initialize an azureml.core.Run.
DEBUG:azureml.core.model:Checking root for pytorch_resnet152 because candidate dir azureml-models had 0 nodes: 
DEBUG:model_driver:Loading pytorch_resnet152/resnet152-b121ed2d.pth
INFO:model_driver:Model loading time: 1833.64 ms


In [68]:
jsonimg = img_url_to_json(IMAGEURL)

In [69]:
resp = run(jsonimg)

DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'iCCP' 41 292
DEBUG:PIL.PngImagePlugin:iCCP profile name b'ICC Profile'
DEBUG:PIL.PngImagePlugin:Compression method 0
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 345 65536
DEBUG:model_driver:Predictions: {'image': [('n02127052 lynx, catamount', 0.9965722560882568), ('n02128757 snow leopard, ounce, Panthera uncia', 0.0013256857637315989), ('n02128385 leopard, Panthera pardus', 0.0009192763245664537)]}
INFO:model_driver:Predictions took 36.73 ms


In [70]:
pprint(resp[0])

{'image': [('n02127052 lynx, catamount', 0.9965722560882568),
           ('n02128757 snow leopard, ounce, Panthera uncia',
            0.0013256857637315989),
           ('n02128385 leopard, Panthera pardus', 0.0009192763245664537)]}


Next, we can move on to [building our docker image](02_BuildImage.ipynb).