In [66]:
import os
import torch
from torch import nn
from torchvision import transforms
from mlrun.artifacts import get_model
import json
import base64
from io import BytesIO
from PIL import Image

class ModelServer:
    def __init__(self):
        self.device = os.getenv("device")
        self.model = self.load_model(device=self.device)
        self.img_transforms = transforms.Compose([transforms.Resize(((int(os.getenv("img_dimensions"))),
                                                                      int(os.getenv("img_dimensions")))),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        
    def load_model(self, device):
        model_file, model_spec, _ = get_model(os.getenv("model_url"), suffix='.pth')
        layer_size = model_spec.parameters['layer_size']
        model = self.prep_model(device=device, layer_size=layer_size)
        model.load_state_dict(torch.load(open(model_file, "rb")))
        model.eval()
        return model

    def encode_image(self, image_path):
        img = Image.open(image_path)
        im_file = BytesIO()
        img.save(im_file, format="JPEG")
        im_bytes = im_file.getvalue()
        im_b64 = base64.b64encode(im_bytes)
        return im_b64.decode("utf-8")

    def decode_image(self, byte_stream):
        im_bytes = base64.b64decode(byte_stream)
        im_file = BytesIO(im_bytes)
        img = Image.open(im_file)
        return img

    def prep_model(self, device, layer_size, num_classes=2):
        model_resnet50 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)

        for name, param in model_resnet50.named_parameters():
            if "bn" not in name:
                param.requires_grad = False

        model_resnet50.fc = nn.Sequential(nn.Linear(model_resnet50.fc.in_features, layer_size),
                                          nn.ReLU(),
                                          nn.Dropout(),
                                          nn.Linear(layer_size, num_classes))

        return model_resnet50.to(device)
    
    def predict(self, context, data):
        model = self.model
        
        response = {}
        for image in data['data']:
            path, encoded = image.values()
            labels = ["cat", "dog"] # 1 = dog, 0 = cat
            img = self.decode_image(encoded)
            img = self.img_transforms(img).to(self.device)
            img = img.unsqueeze(0)
            prediction = model(img.to(self.device))
            prediction = prediction.argmax()
            response[path] = {'pred' : labels[prediction],
                              'encoded' : encoded}
        return context.Response(body=json.dumps(response))
            
def init_context(context):
    context.model_server = ModelServer()
    
def handler(context, event):
    return context.model_server.predict(context, json.loads(event.body))