In [1]:
import json
import logging
import os
import torch
import requests
from PIL import Image
from torchvision import transforms
from torchvision import models
import torch.nn as nn

In [6]:
logger = logging.getLogger(__name__)

In [8]:
class ResNet50(torch.nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        model = models.resnet50(pretrained=True)
        modules = list(model.children())[:-1]
        self.feature_extract = nn.Sequential(*modules)
        self.fc1 = nn.Linear(2048, 1000)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(1000,7)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.feature_extract(x)
        # x = x.mean(dim=(-2, -1))
        # (batch, 2048, 4, 4)
        x = torch.squeeze(x)
        x = self.relu(self.fc1(x))
        out = self.fc2(x)
        out = self.softmax(out)
        return out

In [5]:
def model_fn(model_dir):
    device = torch.device('cpu')
    logger.info('Loading the model.')
    model = ResNet50()


    with open(os.path.join(model_dir, 'model_0.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f))

    model.to(device).eval()
    logger.info('Done loading model')
    return model

In [7]:
def input_fn(request_body, content_type='application/json'):
    logger.info('Deserializing the input data.')
    if content_type == 'application/json':
        input_data = json.loads(request_body)
        url = input_data['url']
        logger.info(f'Image url: {url}')
        image_data = Image.open(requests.get(url, stream=True).raw)

        image_transform = transforms.Compose([
            transforms.Resize(size=256),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        return image_transform(image_data)
    raise Exception(f'Requested unsupported ContentType in content_type: {content_type}')

In [None]:
def output_fn(prediction_output, accept='application/json'):
    logger.info('Serializing the generated output.')
    # 각 분류기마다 다른 레이블 값을 가짐. 해당 레이블 값을 토대로 쿼리 진행
    label = ['35102', '55701', '65753', '66304', '35192', '45661', '35954']
    classes = {0: label[0], 1: label[1], 2: label[2], 3: label[3], 4: label[4], 5: label[5], 6: label[6]}

    topk, topclass = prediction_output.topk(1, dim=0)
    result = []

    for i in range(3):
        pred = {'prediction': classes[topclass.cpu().numpy()[0]], 'score': f'{topk.cpu().numpy()[0] * 100}%'}
        logger.info(f'Adding pediction: {pred}')
        result.append(pred)

    if accept == 'application/json':
        return json.dumps(result), accept
    raise Exception(f'Requested unsupported ContentType in Accept: {accept}')

In [None]:
def predict_fn(input_data, model):
    logger.info('Generating prediction based on input parameters.')
    if torch.cuda.is_available():
        input_data = input_data.view(1, 3, 224, 224).cuda()
    else:
        input_data = input_data.view(1, 3, 224, 224)

    with torch.no_grad():
        model.eval()
        out = model(input_data)
        ps = torch.exp(out)

    return ps