## AI photo-realistic Synthesis with OpenVINO™

This notebook demonstrates the synthesis of AI photo-realistic image based on an exemplar image sketch semantic using CocosNet and OpenVINO. We utilize CoCosNet model from Open Model Zoo. At the end of the notebook you should see the demo where users can draw sketch using interactive canvas and get a realistic photo based on provided semantic drawings.


### Imports

In [1]:
import sys

import cv2
from pathlib import Path
import gradio as gr
import numpy as np
import logging as log
from openvino.runtime import Core

from utils.models import CocosnetModel, SegmentationModel
from utils.preprocessing import preprocess_for_seg_model, preprocess_image, preprocess_semantics
from utils.postprocessing import postprocess

log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.DEBUG, stream=sys.stdout)

## The Model

### Download the Model

The cococsnet model, will be downloaded to the `base_model_dir`. if you have not already downloaded it. This notebook will also show how you can use the Model downloader to get OpenVINO Intermediate Representation (IR) with FP16 precision.

In [2]:
# directory where model will be downloaded
base_model_dir = "model"

# model name as named in Open Model Zoo
translation_model_name = "cocosnet"
segmentation_model_name = "hrnet-v2-c1-segmentation"
# selected precision (FP32, FP16)
precision = "FP16"

TRANSLATION_MODEL = f"{base_model_dir}/public/{translation_model_name}/{precision}/{translation_model_name}"
SEGMENTATION_MODEL = f"{base_model_dir}/public/{segmentation_model_name}/{precision}/{segmentation_model_name}"

# Path to the model path
translation_model_path = Path(TRANSLATION_MODEL).with_suffix(".xml")
segmentation_model_path = Path(SEGMENTATION_MODEL).with_suffix(".xml")

if not translation_model_path.exists() :
    download_command = (
        f"omz_downloader " f"--name {translation_model_name} " f"--output_dir {base_model_dir}"
    )
    ! $download_command
if not segmentation_model_path.exists() :
    download_command = (
        f"omz_downloader " f"--name {segmentation_model_name} " f"--output_dir {base_model_dir}"
    )
    ! $download_command

### Convert Model to OpenVINO IR format
The selected model comes from the public directory, which means it must be converted into OpenVINO Intermediate Representation (OpenVINO IR). We use `omz_converter` to convert the ONNX format model to the OpenVINO IR format.

In [3]:
TRANSLATION_ONNX = f"{base_model_dir}/public/{translation_model_name}/{translation_model_name}"
SEGMENTATION_ONNX = f"{base_model_dir}/public/{segmentation_model_name}/{segmentation_model_name}"
translation_onnx_path = Path(TRANSLATION_ONNX).with_suffix(".onnx")
segmentation_onnx_path = Path(SEGMENTATION_ONNX).with_suffix(".onnx")

if not translation_onnx_path.exists():
    convert_command = (
        f"omz_converter "
        f"--name {translation_model_name} "
        f"--precisions {precision} "
        f"--download_dir {base_model_dir} "
        f"--output_dir {base_model_dir}"
    )
    ! $convert_command

if not segmentation_onnx_path.exists():
    convert_command = (
        f"omz_converter "
        f"--name {segmentation_model_name} "
        f"--precisions {precision} "
        f"--download_dir {base_model_dir} "
        f"--output_dir {base_model_dir}"
    )
    ! $convert_command

In [4]:

class SegmentationVisualizer:
    def __init__(self, colors_path=None):
        if colors_path:
            self.color_palette = self.get_palette_from_file(colors_path)
            log.debug('The palette is loaded from {}'.format(colors_path))
        else:
            pascal_palette_path = Path(__file__).resolve().parents[3] /\
                'data/palettes/pascal_voc_21cl_colors.txt'
            self.color_palette = self.get_palette_from_file(pascal_palette_path)
            log.debug('The PASCAL VOC palette is used')
        log.debug('Get {} colors'.format(len(self.color_palette)))
        self.color_map = self.create_color_map()

    def get_palette_from_file(self, colors_path):
        with open(colors_path, 'r') as file:
            colors = []
            for line in file.readlines():
                values = line[line.index('(') + 1:line.index(')')].split(',')
                colors.append([int(v.strip()) for v in values])
            return colors

    def create_color_map(self):
        classes = np.array(self.color_palette, dtype=np.uint8)[:, ::-1]  # RGB to BGR
        color_map = np.zeros((256, 1, 3), dtype=np.uint8)
        classes_num = len(classes)
        color_map[:classes_num, 0, :] = classes
        color_map[classes_num:, 0, :] = np.random.uniform(0, 255, size=(256 - classes_num, 3))
        return color_map

    def apply_color_map(self, input):
        input_3d = cv2.merge([input, input, input])
        return cv2.LUT(input_3d, self.color_map)


In [5]:
def render_segmentation(frame, masks, visualiser, only_masks=False):
    output = visualiser.apply_color_map(masks)
    if not only_masks:
        output = cv2.addWeighted(frame, 0.5, output, 0.5, 0)
    return output

In [6]:
visualizer = SegmentationVisualizer("./data/colors.txt")

[ DEBUG ] The palette is loaded from ../colors.txt
[ DEBUG ] Get 150 colors


## Model Initialization

We are loading Cocosnet and Segmentation model for image translation.

Converted models are located in a fixed structure, which indicates vendor, model name and precision.
First, initialize the inference engine, OpenVINO Runtime. Then, read the network architecture and model weights from the .bin and .xml files to compile for the desired device. An inference request is then created to infer the compiled model.

In [10]:
# Initialize OpenVINO Runtime.
ie_core = Core()
device = "CPU"

# Initialize CocosnetModel
gan_model = CocosnetModel(ie_core, translation_model_path, device)

# Initialize SegmentationModel
seg_model = SegmentationModel(ie_core, segmentation_model_path,
                              device) if segmentation_model_path else None

## Input preprocessing and Model inferencing

In this section, we are mainly preprocessing the input by using masks from segmentation model to generate input and reference semantics.

Model Inference is done with the GAN model, by providing input and reference semantics repectively.

In [11]:
# Method to get mask from image
def get_mask_from_image(image, model):
    image = preprocess_for_seg_model(image, input_size=model.input_size)
    res = model.infer(image)
    print(res.shape)
    mask = np.argmax(res, axis=1)
    mask = np.squeeze(mask, 0)
    return mask


# Process the input and reference image
def gradioProcessing(input_image=None, reference_image=None):
    print("input image",input_image.shape)
    print("reference image",reference_image.shape)
    
    if input_image is None:
        raise IOError('Image {} cannot be read'.format(input_image))
    input_semantic = get_mask_from_image(input_image, seg_model)
    if reference_image is None:
        raise IOError('Image {} cannot be read'.format(reference_image))
    reference_semantic = get_mask_from_image(reference_image, seg_model)
    input_semantic = preprocess_semantics(input_semantic, input_size=gan_model.input_semantic_size)
    
    if reference_image is None:
        raise IOError('Image {} cannot be read'.format(reference_image))
    reference_image = preprocess_image(reference_image, input_size=gan_model.input_image_size)
    reference_semantic = preprocess_semantics(reference_semantic, input_size=gan_model.input_semantic_size)
    
    # Model Inference
    result = postprocess(gan_model.infer(input_semantic, reference_image,reference_semantic))
        
    return result

## Gradio interface 

A web-based GUI to synthesize an image based on the drawing input on the Gradio canvas and the uploaded reference image. 
User can upload the reference image and draw on the gradio canvas using different colors representing different objects.
Click Generate button to synthesize the AI painting.

Path to reference images: [data](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/data/image)


In [14]:
# gradio method to fetch user input and reference image
def generate(image_input, ref_image):
    result = gradioProcessing(input_image=image_input, reference_image=ref_image)
    return result

def createCanvas(reference_image=None):
    canvas = get_mask_from_image(reference_image, seg_model)
    canvas = canvas.astype(np.uint8)      
    canvas = render_segmentation(reference_image, canvas, visualizer, only_masks=True)
    return canvas

# Initialize gradio canvas
with gr.Blocks(css="#small-b {width: 24px}") as demo:
    with gr.Row().style(equal_height=True):
        with gr.Column():
            canvas_input = gr.ImagePaint(source="upload")
            
            load_mask = gr.Button(" Load Painting ")
            submit = gr.Button("Translate")
        output_image = gr.Image(label='Synthesis')
        ref_image = gr.Image(label='Reference')
        load_mask.click(createCanvas,inputs=[ref_image],outputs=[canvas_input])
        submit.click(generate, inputs=[canvas_input, ref_image], outputs=output_image)

# Start the gradio interactive canvas
demo.launch(share=True)

[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] Using selector: EpollSelector
[ DEBUG ] Starting new HTTP connection (1): 127.0.0.1:7911
[ DEBUG ] http://127.0.0.1:7911 "GET /startup-events HTTP/1.1" 200 5
[ DEBUG ] Starting new HTTP connection (1): 127.0.0.1:7911
[ DEBUG ] http://127.0.0.1:7911 "HEAD / HTTP/1.1" 200 0
Running on local URL:  http://127.0.0.1:7911
[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] https://api.gradio.app:443 "POST /gradio-initiated-analytics/ HTTP/1.1" 200 None
[ DEBUG ] https://api.gradio.app:443 "GET /v2/tunnel-request HTTP/1.1" 200 None
Running on public URL: https://7cc4ab4000f93df6ad.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces
[ DEBUG ] Starting new HTTPS connection (1): 7cc4ab4000f93df6ad.gradio.live:443
[ DEBUG ] https://7cc4ab4000f93df6ad.gradio.live:443 "HEAD / HTTP/1.1" 200 0


[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443
[ DEBUG ] Starting new HTTPS connection (1): api.gradio.app:443




[ DEBUG ] https://api.gradio.app:443 "POST /gradio-launched-analytics/ HTTP/1.1" 200 None
[ DEBUG ] https://api.gradio.app:443 "POST /gradio-launched-telemetry/ HTTP/1.1" 200 None
(1, 150, 320, 320)
[ DEBUG ] STREAM b'IHDR' 16 13
[ DEBUG ] STREAM b'sRGB' 41 1
[ DEBUG ] STREAM b'IDAT' 54 5095
input image (320, 320, 3)
reference image (429, 640, 3)
(1, 150, 320, 320)
(1, 150, 320, 320)
[ DEBUG ] STREAM b'IHDR' 16 13
[ DEBUG ] STREAM b'sRGB' 41 1
[ DEBUG ] STREAM b'IDAT' 54 6409
input image (320, 320, 3)
reference image (429, 640, 3)
(1, 150, 320, 320)
(1, 150, 320, 320)
(1, 150, 320, 320)
[ DEBUG ] STREAM b'IHDR' 16 13
[ DEBUG ] STREAM b'sRGB' 41 1
[ DEBUG ] STREAM b'IDAT' 54 5101
input image (320, 320, 3)
reference image (429, 640, 3)
(1, 150, 320, 320)
(1, 150, 320, 320)
(1, 150, 320, 320)
[ DEBUG ] STREAM b'IHDR' 16 13
[ DEBUG ] STREAM b'sRGB' 41 1
[ DEBUG ] STREAM b'IDAT' 54 5095
input image (320, 320, 3)
reference image (429, 640, 3)
(1, 150, 320, 320)
(1, 150, 320, 320)
