<a href="https://colab.research.google.com/github/matthewchung74/blogs/blob/dev/inference_params.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [79]:
import os
import subprocess
import sys
from enum import Enum
import traceback

import cv2
import PIL
import json
from pathlib import Path

from functools import wraps
import time
import pdb

class FieldType(str, Enum):
  Text = "Text"
  OpenCV = "OpenCV"
  PIL = "PIL"

class EnumEncoder(json.JSONEncoder):
    def default(self, obj):
        if type(obj) in FieldType.values():
            return {"__enum__": str(obj)}
        return json.JSONEncoder.default(self, obj)

def as_enum(d):
    if "__enum__" in d:
        name, member = d["__enum__"].split(".")
        return getattr(FieldType[name], member)
    else:
        return d

def in_colab():
    try:
        import google.colab
        return True
    except:
        return False

def grab_image(image_type:ImageType):
    image_name = "test_download_image.jpg"
    if not Path(image_name).exists():
        url = "http://images.cocodataset.org/val2017/000000439715.jpg"
        r = requests.get(url, allow_redirects=True)
        open(image_name, 'wb').write(r.content)
    if image_type.value == ImageType.OpenCV.value:
        return cv2.imread(image_name)
    else:
        return PIL.Image.open(image_name)

def inference_predict(input:dict, output:dict): 
    if not Path("input.json").exists() or in_colab():
        expect_json = json.dumps(input, cls=EnumEncoder)
        with open("input.json","w") as f: 
            f.write(expect_json)

    output_json = json.dumps(output, cls=EnumEncoder)
    if not Path("output.json").exists() or in_colab():
        with open("output.json","w") as f: 
            f.write(output_json)

    def inference_predict_decorator(func):
        @wraps(func)
        def inference_predict_wrapper(*args, **kwargs): 
            begin = time.time()               
            result = func(*args, **kwargs) 
            end = time.time() 
            return (result, end-begin)
        return inference_predict_wrapper
    return inference_predict_decorator

def inference_test(params: dict):
    try:
        if not Path("input.json").exists():
            raise Exception('missing input.json')
        else:   
            with open("input.json") as f:
                input = json.load(f);

        if type(params) is not dict:
            raise Exception('params needs to be a dictionary')

            print(param.keys)

            result, duration = predict(params)
            with open('result.json', 'w') as f:
                json.dump(result, f);
    except:
        traceback.print_exc()    

In [80]:
# text input will be the label for the app input
input = {"text input": FieldType.Text}
# result will be the label for the app output
output = {"result": FieldType.Text}

@inference_predict(input=input, output=output)
def predict(params):
    return {"result": "positive and 100% accurage"}

In [81]:
inference_test({'text input': 'loved that movie'})