# OpenVINO Model Demos

In [None]:
from pathlib import Path
from typing import Any, NamedTuple

import cv2
import ipywidgets
import matplotlib.pyplot as plt
from custom_segmentation import (
    MonodepthModel,
    PaddleAnimeModel,
    PaddleSuperResolutionModel,
    U2NetModel,
)
from openvino.inference_engine import IECore
from many_utils import create_superresolution_comparison_video

In [None]:
class ModelConfig(NamedTuple):
    name: str
    xml_file: str
    model_class: Any
    arguments: "ModelConfig" = {}

In [None]:
monodepth = ModelConfig(
    "Monodepth", "../201-vision-monodepth/model/MiDaS_small.xml", MonodepthModel
)
u2net = ModelConfig(
    "Salient Object Detection",
    "../205-vision-background-removal/model/u2net/u2net.xml",
    U2NetModel,
)
anime = ModelConfig(
    "Anime", "../206-vision-paddlegan-anime/model/paddlegan_anime.xml", PaddleAnimeModel
)
superres = ModelConfig(
    "Super Resolution",
    "../207-vision-paddlegan-superresolution/model/paddlegan_sr.xml",
    PaddleSuperResolutionModel)
   # {"resize_shape": [1, 3, 450, 600]})
small_models = [monodepth, u2net, anime]
superres_models = [superres]

## Upload and Select an Image

In [None]:
def on_upload_change(change):
    if not change.new:
        return
    widget = change.owner
    widget._counter = 0
    if len(widget.value.items()) > 0:
        for filename, data in widget.value.items():
            image_fn = Path("data") / filename
            with open(image_fn, "wb") as f:
                f.write(data["content"])
            all_files = list(Path("data").glob("*g"))
            current_options = list(dropdown.options)
            current_options.append(image_fn)
            dropdown.options = current_options
            dropdown.index = len(current_options) - 1

def show_results(image_fn, models):
    ie = IECore()
    image = cv2.cvtColor(cv2.imread(str(image_fn)), cv2.COLOR_BGR2RGB)
    if len(models) == 3:
        fig, ax = plt.subplots(2, 2, figsize=(20, 15))
    else:
        fig, ax = plt.subplots(1, 2, figsize=(20, 8))
    axr = ax.ravel()
    base_image_name = Path(image_fn).stem
    output_dir = Path(f"output/{base_image_name}")
    output_dir.mkdir(exist_ok=True, parents=True)
    axr[0].imshow(image)
    axr[0].axis("off")
    axr[0].set_title("Image")
    for i, model in enumerate(models):
        segmodel = model.model_class(ie, model.xml_file, **model.arguments)
        inputs, meta = segmodel.preprocess({segmodel.input_layer: image})
        exec_net = ie.load_network(segmodel.net, "CPU")
        raw_result = exec_net.infer(inputs)
        result = segmodel.postprocess(raw_result, meta)
        image_path = (
            output_dir
            / f"{base_image_name}_{model.name.lower().replace(' ',' ')}{Path(image_fn).suffix}"
        )
        if model.name == "Super Resolution":
            create_superresolution_comparison_video(image, result, output_dir, base_image_name)
        cv2.imwrite(str(image_path), cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
        axr[i+1].imshow(result)
        axr[i+1].set_title(model.name)
        axr[i+1].axis("off")
    plt.show()

In [None]:
all_files = list(Path("data").glob("*g"))
upload_widget = ipywidgets.FileUpload()
upload_widget.observe(on_upload_change)

In [None]:
dropdown = ipywidgets.Dropdown(options=all_files)
output = ipywidgets.Output()
superres_output=ipywidgets.Output()
small_button = ipywidgets.Button(description="Go")
superres_button = ipywidgets.Button(description="Superres")

In [None]:
def show_output(btn):
    output.clear_output()
    return_value = dropdown.value
    with output:
        show_results(return_value, small_models)
        
def show_superres_output(btn):
    superres_output.clear_output()
    return_value = dropdown.value
    with superres_output:
        show_results(return_value, superres_models)
        
small_button.on_click(show_output)
superres_button.on_click(show_superres_output)

In [None]:
ui = ipywidgets.HBox([upload_widget, dropdown, small_button])
display(ui, output)

In [None]:
superres_ui = ipywidgets.HBox([upload_widget, dropdown, superres_button])
display(superres_ui, superres_output)