In [None]:
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import torch

from openvino.tools import mo

In [None]:
class MyModel(torch.nn.Module):
    """ Model wrapper that unpacks forward arguments from single enumerable of tensors."""
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.keywords = [
            'pixel_values',
            'task_inputs',
            'text_inputs',
            'mask_labels',
            'class_labels',
            'pixel_mask',
            'output_auxiliary_logits',
            'output_hidden_states',
            'output_attentions',
            'return_dict'
        ]
    
    def forward(self, *tensors):
        kwargs = {}
        for kw, tensor in zip(self.keywords, tensors):
            kwargs[kw] = None if tensor.isnan().all() else tensor
        return self.model.forward(**kwargs)

In [None]:
# load OneFormer fine-tuned on ADE20k for universal segmentation
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
model = MyModel(model)
model = mo.convert_model(model, example_input=(
    torch.randn(1, 3, 512, 683), # pixel_values
    torch.randn(1, 77), # task_inputs
    torch.tensor([float('nan')]), # text_inputs
    torch.tensor([float('nan')]), # mask_labels
    torch.tensor([float('nan')]), # class_labels
    torch.randn(1, 512, 683), # pixel_mask
    torch.tensor([float('nan')]), # output_auxiliary_logits
    torch.tensor([float('nan')]), # output_hidden_states
    torch.tensor([float('nan')]), # output_attentions
    torch.tensor([float('nan')])), # return_dict
    onnx_opset_version=11
)

In [None]:
url = (
    "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
)
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:

# Semantic Segmentation
inputs = processor(image, ["semantic"], return_tensors="pt")
for k,v in inputs.items():
    print(f'{k}: {v.shape}')
with torch.no_grad():
    outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits

# you can pass them to processor for semantic postprocessing
predicted_semantic_map = processor.post_process_semantic_segmentation(
    outputs, target_sizes=[image.size[::-1]]
)



In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm


def draw_semantic_segmentation(segmentation):
    # get the used color map
    viridis = cm.get_cmap('viridis', torch.max(segmentation))
    # get all the unique numbers
    labels_ids = torch.unique(segmentation).tolist()
    fig, ax = plt.subplots()
    ax.imshow(segmentation)
    handles = []
    for label_id in labels_ids:
        label = model.config.id2label[label_id]
        color = viridis(label_id)
        handles.append(mpatches.Patch(color=color, label=label))
    ax.legend(handles=handles)

draw_semantic_segmentation(predicted_semantic_map)

In [None]:
# Instance Segmentation
inputs = processor(image, ["instance"], return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits

# you can pass them to processor for instance postprocessing
predicted_instance_map = processor.post_process_instance_segmentation(
    outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
f"👉 Instance Predictions Shape: {list(predicted_instance_map.shape)}"
'👉 Instance Predictions Shape: [512, 683]'


In [None]:
# Panoptic Segmentation
inputs = processor(image, ["panoptic"], return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits

# you can pass them to processor for panoptic postprocessing
predicted_panoptic_map = processor.post_process_panoptic_segmentation(
    outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
f"👉 Panoptic Predictions Shape: {list(predicted_panoptic_map.shape)}"
'👉 Panoptic Predictions Shape: [512, 683]'