## AI photo-realistic Synthesis with OpenVINO™

This notebook demonstrates the synthesis of AI photo-realistic image based on an exemplar image sketch semantic using CocosNet and OpenVINO. We utilize CoCosNet model from Open Model Zoo. At the end of the notebook you should see the demo where users can draw sketch using interactive canvas and get a realistic photo based on provided semantic drawings.


### Imports

In [38]:
import sys
import os

import cv2
from pathlib import Path
import gradio as gr
import numpy as np
import logging as log
from openvino.runtime import Core, get_version

from utils.models import CocosnetModel, SegmentationModel
from utils.preprocessing import preprocess_for_seg_model, preprocess_image, preprocess_semantics
from utils.postprocessing import postprocess, save_result

log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.DEBUG, stream=sys.stdout)

## The Model

### Download the Model

The cococsnet model, will be downloaded to the `base_model_dir`. if you have not already downloaded it. This notebook will also show how you can use the Model downloader to get OpenVINO Intermediate Representation (IR) with FP16 precision.

In [53]:
# directory where model will be downloaded
base_model_dir = "model"

# model name as named in Open Model Zoo
translation_model_name = "cocosnet"
segmentation_model_name = "hrnet-v2-c1-segmentation"
# selected precision (FP32, FP16)
precision = "FP16"

TRANSLATION_MODEL = f"{base_model_dir}/public/{translation_model_name}/{precision}/{translation_model_name}"
SEGMENTATION_MODEL = f"{base_model_dir}/public/{segmentation_model_name}/{precision}/{segmentation_model_name}"

# Path to the model path
translation_model_path = Path(TRANSLATION_MODEL).with_suffix(".xml")
segmentation_model_path = Path(SEGMENTATION_MODEL).with_suffix(".xml")

if not translation_model_path.exists() :
    download_command = (
        f"omz_downloader " f"--name {translation_model_name} " f"--output_dir {base_model_dir}"
    )
    ! $download_command
if not segmentation_model_path.exists() :
    download_command = (
        f"omz_downloader " f"--name {segmentation_model_name} " f"--output_dir {base_model_dir}"
    )
    ! $download_command

model/public/cocosnet/FP16/cocosnet


### Convert Model to OpenVINO IR format
The selected model comes from the public directory, which means it must be converted into OpenVINO Intermediate Representation (OpenVINO IR). We use `omz_converter` to convert the ONNX format model to the OpenVINO IR format.

In [55]:
TRANSLATION_ONNX = f"{base_model_dir}/public/{translation_model_name}/{translation_model_name}"
SEGMENTATION_ONNX = f"{base_model_dir}/public/{segmentation_model_name}/{segmentation_model_name}"
translation_onnx_path = Path(TRANSLATION_ONNX).with_suffix(".onnx")
segmentation_onnx_path = Path(SEGMENTATION_ONNX).with_suffix(".onnx")

if not translation_onnx_path.exists():
    convert_command = (
        f"omz_converter "
        f"--name {translation_model_name} "
        f"--precisions {precision} "
        f"--download_dir {base_model_dir} "
        f"--output_dir {base_model_dir}"
    )
    ! $convert_command

if not segmentation_onnx_path.exists():
    convert_command = (
        f"omz_converter "
        f"--name {segmentation_model_name} "
        f"--precisions {precision} "
        f"--download_dir {base_model_dir} "
        f"--output_dir {base_model_dir}"
    )
    ! $convert_command

## Model Initialization

We are loading Cocosnet and Segmentation model for image translation.

Converted models are located in a fixed structure, which indicates vendor, model name and precision.
First, initialize the inference engine, OpenVINO Runtime. Then, read the network architecture and model weights from the .bin and .xml files to compile for the desired device. An inference request is then created to infer the compiled model.

In [56]:
# Initialize OpenVINO Runtime.
ie_core = Core()
device = "CPU"

# Initialize CocosnetModel
gan_model = CocosnetModel(ie_core, translation_model_path, device)

# Initialize SegmentationModel
seg_model = SegmentationModel(ie_core, segmentation_model_path,
                              device) if segmentation_model_path else None

## Input preprocessing and Model inferencing

In this section, we are mainly preprocessing the input by using masks from segmentation model to generate input and reference semantics.

Model Inference is done with the GAN model, by providing input and reference semantics repectively.

In [45]:
# Method to get mask from image
def get_mask_from_image(image, model):
    image = preprocess_for_seg_model(image, input_size=model.input_size)
    res = model.infer(image)
    mask = np.argmax(res, axis=1)
    mask = np.squeeze(mask, 0)
    return mask + 1

# Process the input and reference image
def gradioProcessing(input_image=None, reference_image=None, input_semantic=None, reference_semantic=None):
#     Set to True if no input and reference semantics are provided
    use_seg = True
    assert use_seg ^ (bool(input_semantic) and bool(reference_semantic)), "Verify Gradio module to provide semantic inputs"

    if use_seg:
        if input_image is None:
            raise IOError('Image {} cannot be read'.format(input_image))
        input_semantic = get_mask_from_image(input_image, seg_model)
        if reference_image is None:
            raise IOError('Image {} cannot be read'.format(reference_image))
        reference_semantic = get_mask_from_image(reference_image, seg_model)
    else:
            # TODO: Remove this snippet if reference and input semantics are not required
            #   input_sem_file = input_semantic
            #   input_sem = cv2.imread(input_sem_file, cv2.IMREAD_GRAYSCALE)
        if input_semantic is None:
            raise IOError('Image {} cannot be read'.format(input_semantic))
            # TODO: Remove this snippet if reference and input semantics are not required
            #   ref_sem_file = ref_sem
            #   ref_sem = cv2.imread(ref_sem, cv2.IMREAD_GRAYSCALE)
        if ref_sem is None:
            raise IOError('Image {} cannot be read'.format(reference_semantic))
    input_semantic = preprocess_semantics(input_semantic, input_size=gan_model.input_semantic_size)

    if reference_image is None:
        raise IOError('Image {} cannot be read'.format(reference_image))
    reference_image = preprocess_image(reference_image, input_size=gan_model.input_image_size)
    reference_semantic = preprocess_semantics(reference_semantic, input_size=gan_model.input_semantic_size)
    # Model Inference
    result = postprocess(gan_model.infer(input_semantic, reference_image, reference_semantic))
        
    return result

## Gradio interface 

A web-based GUI to synthesize an image based on the drawing input on the Gradio canvas and the uploaded reference image. 
User can upload the reference image and draw on the gradio canvas using different colors representing different objects.
Click Generate button to synthesize the AI painting.

Path to reference images: `data/`


In [48]:
# gradio method to fetch user input and reference image
def generate(image_input, ref_image):
    result = gradioProcessing(input_image=image_input, reference_image=ref_image,
                           input_semantic=None, reference_semantic=None)
    return result

# Initialize gradio canvas
with gr.Blocks(css="#small-b {width: 24px}") as demo:
    with gr.Row().style(equal_height=True):
        with gr.Column():
            canvas_input = gr.Paint()
            submit = gr.Button("Translate")
        output_image = gr.Image(label='Synthesis')
        ref_image = gr.Image(label='Reference')
        submit.click(generate, inputs=[canvas_input, ref_image], outputs=output_image)

# Start the gradio interactive canvas
demo.launch(share=True)

[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] Using selector: EpollSelector
[ DEBUG ] Starting new HTTP connection (1): 127.0.0.1:7874
[ DEBUG ] http://127.0.0.1:7874 "GET /startup-events HTTP/1.1" 200 5
[ DEBUG ] Starting new HTTP connection (1): 127.0.0.1:7874
[ DEBUG ] http://127.0.0.1:7874 "HEAD / HTTP/1.1" 200 0
Running on local URL:  http://127.0.0.1:7874
[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] https://api.gradio.app:443 "POST /gradio-initiated-analytics/ HTTP/1.1" 200 None
[ DEBUG ] https://api.gradio.app:443 "GET /v2/tunnel-request HTTP/1.1" 200 None
Running on public URL: https://75baaec7091cd1a80d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces
[ DEBUG ] Starting new HTTPS connection (1): 75baaec7091cd1a80d.gradio.live:443
[ DEBUG ] https://75baaec7091cd1a80d.gradio.live:443 "HEAD / HTTP/1.1" 200 0


[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443




[ DEBUG ] https://api.gradio.app:443 "POST /gradio-launched-analytics/ HTTP/1.1" 200 None
[ DEBUG ] https://api.gradio.app:443 "POST /gradio-launched-telemetry/ HTTP/1.1" 200 None
