# Kosmos-2: Multimodal Large Language Model and OpenVINO

[KOSMOS-2](https://github.com/microsoft/unilm/tree/master/kosmos-2) is a multimodal large language model (MLLM) that has new capabilities of multimodal grounding and referring. KOSMOS-2 can understand multimodal input, follow instructions, 
perceive object descriptions (e.g., bounding boxes), and ground language to the visual world.

Multimodal Large Language Models (MLLMs) have successfully played a role as a general-purpose interface across a wide range of tasks, such as language, vision, and vision-language tasks. MLLMs can perceive general modalities, including texts, images, and audio, and generate responses using free-form texts under zero-shot and few-shot settings. 

[In this work](https://arxiv.org/abs/2306.14824), authors unlock the grounding capability for multimodal large language models. Grounding capability can provide a more convenient and efficient human-AI interaction for vision-language tasks. It enables the user to point to the object or region in the image directly rather than input detailed text descriptions to refer to it, the model can understand that image region with its spatial locations. Grounding capability also enables the model to respond with visual answers (i.e., bounding boxes), which can support more vision-language tasks such as referring expression comprehension. Visual answers are more accurate and resolve the coreference ambiguity compared with text-only responses. In addition, grounding capability can link noun phrases and referring expressions in the generated free-form text response to the image regions, providing more accurate, informational, and comprehensive answers.


![image](https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg)

#### Table of contents:
- [Install requirements](#Install-requirements)
- [Original model inference](#Original-model-inference)
- [Convert models to OpenVINO Intermediate representation (IR) format](#Convert-models-to-OpenVINO-Intermediate-representation-(IR)-format)


## Install requirements
[back to top ⬆️](#Table-of-contents:)

In [None]:
%pip install -q "transformers>=4.33" Pillow "torch==1.13.0" "torchvision==0.14.0"
%pip install -q "openvino>=2023.2.0"

## Original model inference

In [None]:
import requests

from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq


model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")

prompt = "<grounding>An image of"

url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png"
image = Image.open(requests.get(url, stream=True).raw)

# The original Kosmos-2 demo saves the image first then reload it. For some images, this will give slightly different image input and change the generation outputs.
image.save("new_image.jpg")
image = Image.open("new_image.jpg")

inputs = processor(text=prompt, images=image, return_tensors="pt")

generated_ids = model.generate(
    pixel_values=inputs["pixel_values"],
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    image_embeds=None,
    image_embeds_position_mask=inputs["image_embeds_position_mask"],
    use_cache=True,
    max_new_tokens=128,
)
print(generated_ids)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
# Specify `cleanup_and_extract=False` in order to see the raw model generation.
processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)

print(processed_text)
# `<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.`

# By default, the generated  text is cleanup and the entities are extracted.
processed_text, entities = processor.post_process_generation(generated_text)

print(processed_text)
# `An image of a snowman warming himself by a fire.`

print(entities)
# `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`

## Convert models to OpenVINO Intermediate representation (IR) format
[back to top ⬆️](#Table-of-contents:)

In [None]:
import gc
from pathlib import Path

import torch
import openvino as ov


model.config.torchscript = True

models_base_folder = Path("models")


def cleanup_torchscript_cache():
    """
    Helper for removing cached model representation
    """
    torch._C._jit_clear_class_registry()
    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
    torch.jit._state._clear_class_state()

### Convert the vision model
[back to top ⬆️](#Table-of-contents:)

In [None]:
vision_model_ir_path = models_base_folder / "vision_model.xml"


if not vision_model_ir_path.exists():
    with torch.no_grad():
        ov_model = ov.convert_model(model.vision_model, example_input=inputs["pixel_values"])

    ov.save_model(ov_model, vision_model_ir_path)
    del ov_model
    cleanup_torchscript_cache()
    gc.collect()
    print("Vision model successfully converted to IR")
else:
    print(f"Vision model will be loaded from {vision_model_ir_path}")

### Convert Image To Text Projection model
[back to top ⬆️](#Table-of-contents:)

In [None]:
from torch import nn


image_to_text_projection_model_ir_path = models_base_folder / "image_to_text_projection_model.xml"


if not image_to_text_projection_model_ir_path.exists():
    vision_model_output = model.vision_model(inputs["pixel_values"])
    image_embeds = model.vision_model.model.post_layernorm(vision_model_output[0])
    image_embeds = nn.functional.normalize(image_embeds, dim=-1)
    
    with torch.no_grad():
        ov_model = ov.convert_model(model.image_to_text_projection, example_input=image_embeds)

    ov.save_model(ov_model, image_to_text_projection_model_ir_path)
    del ov_model
    cleanup_torchscript_cache()
    gc.collect()
    print("Image To Text Projection model successfully converted to IR")
else:
    print(f"Image To Text Projection model will be loaded from {image_to_text_projection_model_ir_path}")

### Convert Text model 
[back to top ⬆️](#Table-of-contents:)

In [None]:
first_stage_model_path = models_base_folder / "cosmos_input_embed.xml"
second_stage_model_path = models_base_folder / "cosmos_with_past.xml"


def convert_text_model():
    conv_inputs = {
        'input_ids': inputs["input_ids"],
        'attention_mask': inputs["attention_mask"],
    }
    model.text_model.model.config.torchscript = True
    if not first_stage_model_path.exists():
        
        with torch.no_grad():
            ov_model = ov.convert_model(model.text_model, example_input=conv_inputs)
        ov.save_model(ov_model, first_stage_model_path)
        del ov_model
        cleanup_torchscript_cache()
        gc.collect()
    
    if not second_stage_model_path.exists():
        outs = model.text_model(**conv_inputs)
        example_input_second_stage = {
            "input_ids": torch.ones((1, 1), dtype=torch.long),
            "past_key_values": outs[1],
            "attention_mask": torch.ones((1, outs[1][-1][-1].shape[-2] + 1), dtype=torch.long),
            "image_embeds": image_embeds,
            "image_embeds_position_mask": inputs["image_embeds_position_mask"],
        }
        ov_model = ov.convert_model(model.text_model, example_input=example_input_second_stage)
        ov.save_model(ov_model, second_stage_model_path)
        del ov_model
        cleanup_torchscript_cache()
        gc.collect()


convert_text_model()    

#### Select inference device
[back to top ⬆️](#Table-of-contents:)

Select device that will be used to do models inference using OpenVINO from the dropdown list:

In [None]:
import ipywidgets as widgets


core = ov.Core()
DEVICE = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value='AUTO',
    description='Device:',
    disabled=False,
)

DEVICE

In [None]:
class WraperInternalVisionModel:
    post_layernorm = model.vision_model.model.post_layernorm
    

class VisionModelWrapper(torch.nn.Module):
    def __init__(self, model_ir_path):
        super().__init__()
        self.model = WraperInternalVisionModel()
        self.vision_model = core.compile_model(model_ir_path, DEVICE.value)

    def forward(self, pixel_values, **kwargs):
        vision_model_output = self.vision_model(pixel_values)[0]
        # image_embeds = post_layernorm(torch.from_numpy(vision_model_output[0]))
        # image_embeds = nn.functional.normalize(image_embeds, dim=-1)
        
        return torch.from_numpy(vision_model_output)
        
    

class ImageToTextProjectionModelWrapper(torch.nn.Module):
    def __init__(self, model_ir_path):
        super().__init__()
        self.image_to_text_projection = core.compile_model(model_ir_path, DEVICE.value)

    def forward(self, image_embeds, **kwargs):
        image_embeds = self.image_to_text_projection(image_embeds.detach().numpy())[0]
        
        return image_embeds


class TextModelWrapper(torch.nn.Module):
    def __init__(self, model_stage_1_ir_path, model_stage_2_ir_path):
        super().__init__()
        self.model_stage_1 = core.compile_model(model_stage_1_ir_path, DEVICE.value)
        self.model_stage_2 = core.compile_model(model_stage_2_ir_path, DEVICE.value)

    def generate(self, input_ids, attention_mask, image_embeds, image_embeds_position_mask, **kwargs):
        past_key_values = kwargs.get("past_key_values")
        if not past_key_values:
            outs = self.model_stage_1(input_ids, attention_mask)[0]
        else:
            outs = self.model_stage_2(input_ids, past_key_values, attention_mask, image_embeds, image_embeds_position_mask)[0]
        return outs

In [None]:
vision_model_ov = VisionModelWrapper(vision_model_ir_path)
image_to_text_projection_ov = ImageToTextProjectionModelWrapper(image_to_text_projection_model_ir_path)
text_model_ov = TextModelWrapper(first_stage_model_path, second_stage_model_path)


model.vision_model = vision_model_ov
model.image_to_text_projection = image_to_text_projection_ov
model.text_model = text_model_ov

In [None]:
generated_ids = model.generate(
    pixel_values=inputs["pixel_values"],
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    image_embeds=None,
    image_embeds_position_mask=inputs["image_embeds_position_mask"],
    use_cache=True,
    max_new_tokens=128,
)
print(generated_ids)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
# Specify `cleanup_and_extract=False` in order to see the raw model generation.
processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)

print(processed_text)
# `<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.`

# By default, the generated  text is cleanup and the entities are extracted.
processed_text, entities = processor.post_process_generation(generated_text)

print(processed_text)
# `An image of a snowman warming himself by a fire.`

print(entities)
# `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`