# Universal Segmentation with OneFormer and OpenVINO

This tutorial demonstrates how to use the [OneFormer](https://arxiv.org/abs/2211.06220) model from HuggingFace with OpenVINO. It describes how to download weights and create PyTorch model using Hugging Face transformers library, then convert model to OpenVINO Intermediate Representation format (IR) using OpenVINO Model Optimizer API and run model inference

![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/oneformer_architecture.png)

OneFormer is a follow-up work of [Mask2Former](https://arxiv.org/abs/2112.01527). The latter still requires training on instance/semantic/panoptic datasets separately to get state-of-the-art results.

OneFormer incorporates a text module in the Mask2Former framework, to condition the model on the respective subtask (instance, semantic or panoptic). This gives even more accurate results, but comes with a cost of increased latency, however.

## Install required libraries

In [1]:
!pip install -q openvino-dev==2023.1.0.dev20230728 gradio

## Prepare the environment
Import all required packages and set paths for models and constant variables.

In [2]:
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Tuple

from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
from PIL import ImageOps

from openvino.tools import mo
from openvino.runtime import serialize, Core

In [3]:
ONNX_PATH = Path("oneformer.onnx")
IR_PATH = Path("oneformer.xml")
OUTPUT_NAMES = ['class_queries_logits', 'masks_queries_logits']

## Load OneFormer fine-tuned on COCO for universal segmentation
Here we use the `from_pretrained` method of `OneFormerForUniversalSegmentation` to load the Swin-L model trained on COCO dataset.

Also, we use HuggingFace processor to prepare the model inputs from images and post-process model outputs for visualization.

In [4]:
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
model = OneFormerForUniversalSegmentation.from_pretrained(
    "shi-labs/oneformer_coco_swin_large",
    torchscript=True
)
id2label = model.config.id2label

## Convert PyTorch model to ONNX

In order to convert PyTorch model to OpenVINO IR, we should convert it into ONNX representation first.

In [5]:
task_seq_length = processor.task_seq_length
shape = (800, 800)
dummy_input = {
    "pixel_values": torch.randn(1, 3, *shape),  # TODO: make shapes dynamic
    "task_inputs": torch.randn(1, task_seq_length),
    "pixel_mask": torch.randn(1, *shape),  # TODO: make shapes dynamic
}

## Convert the model to OpenVINO IR format
While ONNX models are directly supported by OpenVINO runtime, it can be useful to convert them to IR format to take the advantage of OpenVINO optimization tools and features. The `mo.convert_model` python function in OpenVINO Model Optimizer can be used for converting the model. The function returns instance of OpenVINO Model class, which is ready to use in Python interface. However, it can also be serialized to OpenVINO IR format for future execution.

In [6]:
if not IR_PATH.exists():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = mo.convert_model(model, example_input=dummy_input)
    serialize(model, IR_PATH)





We can prepare the image using the HuggingFace processor. OneFormer leverages a processor which internally consists of an image processor (for the image modality) and a tokenizer (for the text modality). OneFormer is actually a multimodal model, since it incorporates both images and text to solve image segmentation.

In [7]:
def prepare_inputs(image: Image.Image, task: str):
    """Convert image to model input"""
    image = ImageOps.pad(image, shape)
    inputs = processor(image, [task], return_tensors="pt")
    converted = {
        'pixel_values': inputs['pixel_values'],
        'task_inputs': inputs['task_inputs']
    }
    return converted

In [8]:
def process_output(d):
    """Convert OpenVINO model output to HuggingFace representation for visualization"""
    hf_kwargs = {
        output_name: torch.tensor(d[output_name]) for output_name in OUTPUT_NAMES
    }

    return OneFormerForUniversalSegmentationOutput(**hf_kwargs)

In [9]:
core = Core()
# Read the model from files.
model = core.read_model(model=IR_PATH)
# Compile the model for a specific device.
model = core.compile_model(model=model, device_name="CPU")

Model predicts `class_queries_logits` of shape `(batch_size, num_queries)`
and `masks_queries_logits` of shape `(batch_size, num_queries, height, width)`.

## Interactive Demo
Here we define functions for visualization of network outputs to show the inference results in Gradio interface.

In [10]:
class Visualizer:
    @staticmethod
    def extract_legend(handles):
        fig = plt.figure()
        fig.legend(handles=handles, ncol=len(handles) // 20 + 1, loc='center')
        fig.tight_layout()
        return fig
    
    @staticmethod
    def predicted_semantic_map_to_figure(predicted_map):
        segmentation = predicted_map[0]
        # get the used color map
        viridis = plt.get_cmap('viridis', torch.max(segmentation))
        # get all the unique numbers
        labels_ids = torch.unique(segmentation).tolist()
        fig, ax = plt.subplots()
        ax.imshow(segmentation)
        ax.set_axis_off()
        handles = []
        for label_id in labels_ids:
            label = id2label[label_id]
            color = viridis(label_id)
            handles.append(mpatches.Patch(color=color, label=label))
        fig_legend = Visualizer.extract_legend(handles=handles)
        fig.tight_layout()
        return fig, fig_legend
        
    @staticmethod
    def predicted_instance_map_to_figure(predicted_map):
        segmentation = predicted_map[0]['segmentation']
        segments_info = predicted_map[0]['segments_info']
        # get the used color map
        viridis = plt.get_cmap('viridis', torch.max(segmentation))
        fig, ax = plt.subplots()
        ax.imshow(segmentation)
        ax.set_axis_off()
        instances_counter = defaultdict(int)
        handles = []
        # for each segment, draw its legend
        for segment in segments_info:
            segment_id = segment['id']
            segment_label_id = segment['label_id']
            segment_label = id2label[segment_label_id]
            label = f"{segment_label}-{instances_counter[segment_label_id]}"
            instances_counter[segment_label_id] += 1
            color = viridis(segment_id)
            handles.append(mpatches.Patch(color=color, label=label))
            
        fig_legend = Visualizer.extract_legend(handles)
        fig.tight_layout()
        return fig, fig_legend

    @staticmethod
    def predicted_panoptic_map_to_figure(predicted_map):
        segmentation = predicted_map[0]['segmentation']
        segments_info = predicted_map[0]['segments_info']
        # get the used color map
        viridis = plt.get_cmap('viridis', torch.max(segmentation))
        fig, ax = plt.subplots()
        ax.imshow(segmentation)
        ax.set_axis_off()
        instances_counter = defaultdict(int)
        handles = []
        # for each segment, draw its legend
        for segment in segments_info:
            segment_id = segment['id']
            segment_label_id = segment['label_id']
            segment_label = id2label[segment_label_id]
            label = f"{segment_label}-{instances_counter[segment_label_id]}"
            instances_counter[segment_label_id] += 1
            color = viridis(segment_id)
            handles.append(mpatches.Patch(color=color, label=label))
            
        fig_legend = Visualizer.extract_legend(handles)
        fig.tight_layout()
        return fig, fig_legend

In [None]:
import gradio as gr

def segment(img: Image.Image, task: str):
    """
    Apply segmentation on an image.

    Args:
        img: Input image. It will be resized to 800x800.
        task: String describing the segmentation task. Supported values are: "semantic", "instance" and "panoptic".
    Returns:
        Tuple[Figure, Figure]: Segmentation map and legend charts.
    """
    if img is None:
        raise gr.Error('Please load the image or use one from the examples list')
    inputs = prepare_inputs(img, task)
    outputs = model(inputs)
    hf_output = process_output(outputs)
    predicted_map = getattr(processor, f'post_process_{task}_segmentation')(hf_output, target_sizes=[img.size[::-1]])
    return getattr(Visualizer, f'predicted_{task}_map_to_figure')(predicted_map)
    
demo = gr.Interface(
    segment,
    [
        gr.Image(label="Image", type="pil"),
        gr.Radio(["semantic", "instance", "panoptic"], label="Task", value="semantic"),
    ],
    [gr.Plot(label="Result"), gr.Plot(label="Legend")],
    examples=[["sample.jpg", "semantic"]],
    allow_flagging="never"
)


try:
    demo.launch(debug=True)
except Exception:
    demo.launch(share=True, debug=True)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/