In [1]:
import json
import numpy as np
import cv2
import torch
from catalyst import utils
from albumentations import Compose, Normalize, LongestMaxSize, PadIfNeeded
from albumentations.torch import ToTensor

# serving load
model = torch.jit.load("./logs/serving-190906-174504-3N/model.pth")
image: np.ndarray = utils.imread("./logs/dataset/images/ants/132478121_2a430adea2.jpg")
with open("./logs/serving-190906-174504-3N/tag2class.json") as fin:
    tag2class = json.load(fin)
    class2tag = {v: k for k, v in tag2class.items()}

print("image: ", image.shape)
print("class2tag: ", class2tag)

# serving preprocessing
transofrm = Compose([
    LongestMaxSize(max_size=224),
    PadIfNeeded(
        224, 224, border_mode=cv2.BORDER_CONSTANT
    ),
    Normalize(), 
    ToTensor()
])

# serving prediction
input_t = transofrm(image=image)["image"].unsqueeze_(0)
print("input_t: ", input_t.shape)
output_t = model(input_t).squeeze_(0)
print("output_t: ", output_t.shape)

# serving result
tag = class2tag[output_t.argmax().item()]
print("result: ", tag)

image:  (96, 128, 3)
class2tag:  {0: 'ants', 1: 'bees'}
input_t:  torch.Size([1, 3, 224, 224])
output_t:  torch.Size([2])
result:  ants
