# Visual-language assistant with LLaVA Next and OpenVINO

[LLaVA-NeXT](https://llava-vl.github.io/blog/2024-01-30-llava-next/) is new generation of LLaVA model family that marks breakthrough in advanced language reasoning over images, introducing improved OCR and expanded world knowledge. [LLaVA](https://llava-vl.github.io) (Large Language and Vision Assistant) is large multimodal model that aims to develop a general-purpose visual assistant that can follow both language and image instructions to complete various real-world tasks. The idea is to combine the power of large language models (LLMs) with vision encoders like CLIP to create an end-to-end trained neural assistant that understands and acts upon multimodal instructions.

In this tutorial we consider how to convert and optimize LLaVa-NeXT model from Transformers library for creating multimodal chatbot.

## Prerequisites

In [None]:
%pip install -q "openvino>=2024.0.0" "nncf>=2.9.0" "torch>=2.1" "transformers>=4.39.1" "pillow" "gradio" --extra-index-url https://download.pytorch.org/whl/cpu

In [2]:
from pathlib import Path

MODEL_DIR = Path("model")
IMAGE_ENCODER_PATH = MODEL_DIR / "image_encoder.xml"
INPUT_EMBEDDING_PATH = MODEL_DIR / "input_embeddings.xml"
LANGUAGE_MODEL_PATH = MODEL_DIR / "language_model.xml"

requires_pt_model_loading = not all([p.exists() for p in [IMAGE_ENCODER_PATH, INPUT_EMBEDDING_PATH, LANGUAGE_MODEL_PATH]])

## Download PyTorch model

In [3]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
import gc

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
image_encoder_model, input_embedding_model, language_model = None, None, None

class ImageEncoder(torch.nn.Module):
    def __init__(self, config, vision_tower, multi_modal_projector):
        super().__init__()
        self.config = config
        self.vision_tower = vision_tower
        self.multi_modal_projector = multi_modal_projector

    def forward(self, pixel_values):
        batch_size, num_patches, num_channels, height, width = pixel_values.shape
        reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
        image_features = self.vision_tower(reshaped_pixel_values, output_hidden_states=True)
        selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
        if self.config.vision_feature_select_strategy == "default":
            selected_image_feature = selected_image_feature[:, 1:]
        elif self.config.vision_feature_select_strategy == "full":
            selected_image_feature = selected_image_feature
        image_features = self.multi_modal_projector(selected_image_feature)
        return image_features

if requires_pt_model_loading:
    model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", low_cpu_mem_usage=True)
    model.config.save_pretrained(MODEL_DIR)
    image_encoder_model = ImageEncoder(model.config, model.vision_tower, model.multi_modal_projector)
    input_embedding_model = input_embedding_model = model.get_input_embeddings()
    language_model = model.language_model
    del model
    gc.collect()

  warn("The installed version of bitsandbytes was compiled without GPU support. "


/home/ea/work/my_optimum_intel/optimum_env/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32


2024-04-02 18:44:43.544801: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-02 18:44:43.546798: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-02 18:44:43.583801: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-02 18:44:43.584756: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tu

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## Convert model to OpenVINO Intermediate Representation

### Image Encoder

In [4]:
import torch
import openvino as ov
import gc

def cleanup_torchscript_cache():
    """
    Helper for removing cached model representation
    """
    torch._C._jit_clear_class_registry()
    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
    torch.jit._state._clear_class_state()


if not IMAGE_ENCODER_PATH.exists():
    ov_image_encoder = ov.convert_model(image_encoder_model, example_input=torch.zeros((1, 5, 3, 336, 336)))
    ov.save_model(ov_image_encoder, IMAGE_ENCODER_PATH)
    del ov_image_encoder
    cleanup_torchscript_cache()

del image_encoder_model
gc.collect();

### Text Embedding

In [5]:
llm_input = None

if not LANGUAGE_MODEL_PATH.exists():
    llm_input = input_embedding_model(torch.ones((2, 2), dtype=torch.int64))

if not INPUT_EMBEDDING_PATH.exists():
    ov_input_embeddings_model = ov.convert_model(input_embedding_model, example_input=torch.ones((2, 2), dtype=torch.int64))
    ov.save_model(ov_input_embeddings_model, INPUT_EMBEDDING_PATH)
    del ov_input_embeddings_model
    cleanup_torchscript_cache()

del input_embedding_model
gc.collect();

### Language Model

In [6]:
from typing import Optional, Tuple, List
import nncf
from openvino.runtime import opset13
import numpy as np

def model_has_state(ov_model: ov.Model):
    # TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
    return len(ov_model.get_sinks()) > 0


def model_has_input_output_name(ov_model: ov.Model, name: str):
    """
    Helper function for checking that model has specified input or output name

    Parameters:
      ov_model (ov.Model):   # TODO: Can we derive the dimensions from the model topology?
      name (str):
          name of input or output

    Returns:
      True if input or output with requested name exists else False
    """
    return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])

def fuse_cache_reorder(
    ov_model: ov.Model, not_kv_inputs: List[str], key_value_input_names: List[str], gather_dim: int
):
    """
    Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.

    Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
    Should be run before make_stateful. Implements optimumum's _reorder_cache
    inside the model in the beginning of each iteration.
    Gather works along given gather_dim dimension that may vary from model to model.
    KV-cache inputs are identified based on names in key_value_input_names.
    Append the new beam_idx parameter to not_kv_inputs.

    Parameters:
      ov_model (`ov.Model`):
          openvino model for processing
      not_kv_inputs (`List[str]`):
          list of input nodes in model that not related to past key values
      key_value_input_names (`List[str]`):
          list of names for key value input layers
      gather_dim (int):
          dimension for gathering cache during reorder pass
    """

    if model_has_input_output_name(ov_model, "beam_idx"):
        raise ValueError("Model already has fused cache")
    input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0]
    beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
    beam_idx.output(0).get_tensor().add_names({"beam_idx"})  # why list is not accepted?
    ov_model.add_parameters([beam_idx])
    not_kv_inputs.append(ov_model.inputs[-1])
    # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
    for input_name in key_value_input_names:
        parameter_output_port = ov_model.input(input_name)
        consumers = parameter_output_port.get_target_inputs()
        gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
        for consumer in consumers:
            consumer.replace_source_output(gather.output(0))
    ov_model.validate_nodes_and_infer_types()


def build_state_initializer(ov_model: ov.Model, batch_dim: int):
    """
    Build initialization ShapeOf Expression for all ReadValue ops

    Parameters:
      ov_model (ov.Model):
          openvino model
      batch_dim (int):
          index of dimension corresponding to batch size
    """
    input_ids = ov_model.input("inputs_embeds")
    batch = opset13.gather(opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0))
    for op in ov_model.get_ops():
        if op.get_type_name() == "ReadValue":
            dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
            dims[batch_dim] = batch
            dims = [opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim for dim in dims]
            shape = opset13.concat(dims, axis=0)
            broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
            op.set_arguments([broadcast])
    ov_model.validate_nodes_and_infer_types()


def make_stateful(
    ov_model: ov.Model,
    not_kv_inputs: List[str],
    key_value_input_names: List[str],
    key_value_output_names: List[str],
    batch_dim: int,
    num_attention_heads: int,
    num_beams_and_batch: int = None,
):
    """
    Hides kv-cache inputs and outputs inside the model as variables.

    Parameters:
        ov_model (ov.Model):
            openvino model
        not_kv_inputs (`List[str]`):
            list of input nodes in model that not related to past key values
        key_value_input_names (`List[str]`):
            list of names for key value input layers
        key_value_output_names (`List[str]`):
            list of names for key value input layers
        batch_dim (int):
            index of batch dimension in key value layers
        num_attention_heads (int):
            number of attention heads for batch dimension initialization
        num_beams_an_batch (int):
            precalculated number of beams and batch for shapes initialization
    """
    from openvino._offline_transformations import apply_make_stateful_transformation

    input_output_map = {}

    if num_beams_and_batch is not None:
        # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
        for input in not_kv_inputs:
            shape = input.get_partial_shape()
            if shape.rank.get_length() <= 2:  # == 1 for beam_index
                shape[0] = num_beams_and_batch
                input.get_node().set_partial_shape(shape)
            else:
                log.warn(f"Rank of {input.get_any_name()} input of the model is not 2, batch size is not set")

    for kv_name_pair in zip(key_value_input_names, key_value_output_names):
        input_output_map[kv_name_pair[0]] = kv_name_pair[1]
        if num_beams_and_batch is not None:
            input = ov_model.input(kv_name_pair[0])
            shape = input.get_partial_shape()
            shape[batch_dim] = num_beams_and_batch * num_attention_heads
            input.get_node().set_partial_shape(shape)

    if num_beams_and_batch is not None:
        # Re-validation model if shapes are altered above
        ov_model.validate_nodes_and_infer_types()

    apply_make_stateful_transformation(ov_model, input_output_map)
    if num_beams_and_batch is None:
        build_state_initializer(ov_model, batch_dim)


def patch_stateful(ov_model):
    key_value_input_names = [
        key.get_any_name() for key in ov_model.inputs[2:-1]
    ]
    key_value_output_names = [
        key.get_any_name() for key in ov_model.outputs[1:]
    ]
    not_kv_inputs = [
        input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
    ]
    if not key_value_input_names or not key_value_output_names:
        return
    batch_dim =  0
    num_attention_heads = 1

    fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
    make_stateful(
        ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
    )

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino


In [7]:
import nncf

make_stateful_model = True
compress_weights = True

compression_configuration = {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64, "ratio": 0.6}

core = ov.Core()

if not LANGUAGE_MODEL_PATH.exists():
    pkv = language_model(inputs_embeds=llm_input, attention_mask=torch.ones((2, 2), dtype=torch.int64))[1]
    model_inputs = ["attention_mask", "position_ids"]
    model_outputs = ["logits"]
    for idx in range(len(pkv)):
        model_inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"])
        model_outputs.extend([f"present.{idx}.key", f"present.{idx}.value"])
    model_inputs.append("inputs_embeds")
    language_model.config.torchscript = True
    position_ids = torch.tensor([[2, 3], [2, 3]])
    ov_model = ov.convert_model(language_model, example_input={"inputs_embeds": llm_input, "attention_mask": torch.ones((2, 4)), "past_key_values": pkv, "position_ids": position_ids})

    for input, input_name in zip(ov_model.inputs, model_inputs):
        input.get_tensor().set_names({input_name})

    for output, output_name in zip(ov_model.outputs, model_outputs):
        output.get_tensor().set_names({output_name})
    if make_stateful_model:
        patch_stateful(ov_model)
    ov.save_model(ov_model, LANGUAGE_MODEL_PATH)
    del ov_model
    cleanup_torchscript_cache()
    del language_model
    gc.collect()









No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
  if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  if past_key_values_length > 0:
  if seq_len > self.max_seq_len_cached:
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  if a.grad is not None:


Output()

INFO:nncf:Statistics of the bitwidth distribution:
+--------------+---------------------------+-----------------------------------+
| Num bits (N) | % all parameters (layers) |    % ratio-defining parameters    |
|              |                           |             (layers)              |
| 8            | 41% (135 / 225)           | 40% (134 / 224)                   |
+--------------+---------------------------+-----------------------------------+
| 4            | 59% (90 / 225)            | 60% (90 / 224)                    |
+--------------+---------------------------+-----------------------------------+


Output()

### Compress Language Model werights to INT4

In [None]:
import nncf

LANGUAGE_MODEL_PATH_INT4 = LANGUAGE_MODEL_PATH.parent / LANGUAGE_MODEL_PATH.name.replace(".xml", "-int4.xml")
if compress_weights and not LANGUAGE_MODEL_PATH_INT4.exists():
    ov_model = core.read_model(LANGUAGE_MODEL_PATH)
    ov_model = nncf.compress_weights(ov_model, **compression_configuration)
    ov.save_model(ov_model, LANGUAGE_MODEL_PATH_INT4)
    del ov_model
    gc.collect()

### Quantize Image Encoder

In [None]:
# TBD

### Prepare model inference pipeline

In [8]:
from transformers.generation import GenerationConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoConfig
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
import numpy as np
import torch
from typing import Optional, Tuple
import openvino as ov


class OVLlavaForCausalLM(GenerationMixin):
    def __init__(self, core, model_dir, device):
        self.image_encoder = core.compile_model(model_dir / "image_encoder.xml", device)
        self.input_embeddings = core.compile_model(model_dir / "input_embeddings.xml", device)
        self.model = core.read_model(model_dir / "language_model-int4.xml")
        self.input_names = {
            key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)
        }
        self.output_names = {
            idx: key for idx, key in enumerate(self.model.outputs)
        }
        self.key_value_input_names = [
            key for key in list(self.input_names) if key not in  ["beam_idx", "inputs_embeds", "attention_mask", "position_ids"]
        ]
        self.key_value_output_names = [
            key for key in list(self.output_names)[1:]
        ]
        self.stateful = len(self.key_value_input_names) == 0
        compiled_model = core.compile_model(self.model, device)
        self.request = compiled_model.create_infer_request()
        self.config = AutoConfig.from_pretrained(model_dir)
        self.generation_config = GenerationConfig.from_model_config(self.config)
        self.main_input_name = "input_ids"
        self.device = torch.device("cpu")
        self.num_pkv = 2
        self.next_beam_idx = None
        self.image_newline = torch.nn.Parameter(torch.empty(self.config.text_config.hidden_size, dtype=torch.float32))
        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
        self.past_len = 0

    def can_generate(self):
        """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
        return True

    def __call__(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        image_sizes=None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        return self.forward(
            input_ids, pixel_values, attention_mask, past_key_values, position_ids, image_sizes, **kwargs
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.Tensor,skip_prompt=True
        attention_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        image_sizes=None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """General inference method"""
        inputs = {}
        if past_key_values is not None:
            inputs = {}
            if not self.stateful:
                past_key_values = tuple(
                    past_key_value
                    for pkv_per_layer in past_key_values
                    for past_key_value in pkv_per_layer
                )
                # Add the past_key_values to the decoder inputs
                inputs = dict(zip(self.key_value_input_names, past_key_values))
            input_ids = np.array(input_ids)[:, -1:]
            inputs_embeds = self.input_embeddings(input_ids)[0]
            inputs["inputs_embeds"] = inputs_embeds
            # inputs["attention_mask"] = attention_mask
            if "beam_idx" in self.input_names:
                inputs["beam_idx"] = (
                    self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
                )

            if not self.stateful:
                first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
            else:
                first_layer_past_key_value = torch.from_numpy(ov_llava_model.request.query_state()[0].state.data[:, :, :, 0])

            # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
            batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

            # Get the target length
            target_length = input_ids.shape[1]
            past_length = first_layer_past_key_value.shape[-1]

            extended_attention_mask = torch.ones(
                    (attention_mask.shape[0], past_length),
                    dtype=attention_mask.dtype,
                    device=attention_mask.device,
            )

            # Filter out only the tokens that can be un-attended, this can happen
            # if one uses Llava + Fused modules where the cache on the
            # first iteration is already big enough, or if one passes custom cache
            valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
            new_batch_index = batch_index[valid_indices]
            new_non_attended_tokens = non_attended_tokens[valid_indices]

            # Zero-out the places where we don't need to attend
            extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

            attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
            position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
            inputs["attention_mask"] = attention_mask
            inputs["position_ids"] = position_ids
                
        else:
            inputs = self.prepare_multimodal_input(
            input_ids, pixel_values, attention_mask, position_ids, image_sizes
            )

        # Run inference
        self.request.start_async(inputs, share_inputs=True)
        self.request.wait()

        logits = torch.from_numpy(self.request.get_tensor(self.output_names[0]).data)

        if not self.stateful:

            # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
            past_key_values = tuple(
                self.request.get_tensor(key).data for key in self.key_value_output_names
            )
            # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
            past_key_values = tuple(
                past_key_values[i : i + self.num_pkv]
                for i in range(0, len(past_key_values), self.num_pkv)
            )
        else:
            past_key_values = ((),)
        self.past_len += inputs["inputs_embeds"].shape[1]
        return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


    def prepare_multimodal_input(self, input_ids, pixel_values, attention_mask, position_ids, image_sizes=None):
        """Preprocessing function for embedding multimodal data"""
        inputs = {}
        inputs_embeds = torch.from_numpy(self.input_embeddings(input_ids)[0])
        batch_size = input_ids.shape[0]
        if not self.stateful:
            for input_name in self.key_value_input_names:
                model_inputs = self.modeget_anyres_image_grid_shapel.input(input_name)
                shape = model_inputs.get_partial_shape()
                shape[0] = batch_size
                if shape[2].is_dynamic:
                    shape[2] = 0
                else:
                    shape[1] = 0
                inputs[input_name] = ov.Tensor(model_inputs.get_element_type(), shape.get_shape())
        else:
            self.request.reset_state()
            # Set initial value for the next beam_idx input that will be used at the current iteration
            # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
            self.next_beam_idx = np.arange(batch_size, dtype=int)

        if "beam_idx" in self.input_names:
            inputs["beam_idx"] = (
                    self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
            )
        if pixel_values is None:
            inputs["inputs_embeds"] = inputs_embeds
            inputs["attention_mask"] = attention_mask
            if position_ids is None:
                position_ids = torch.cumsum(attention_mask, axis=1) - 1
                position_ids[attention_mask == 0] = 1
            inputs["position_ids"] = position_ids
        res = self.image_encoder(pixel_values)
        image_features = torch.from_numpy(res[0])
        split_sizes = [image.shape[0] for image in pixel_values]
        image_features = torch.split(image_features, split_sizes, dim=0)

        # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
        height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

        new_image_features = []
        for image_idx, image_feature in enumerate(image_features):
            if image_feature.shape[0] > 1:
                base_image_feature = image_feature[0]
                image_feature = image_feature[1:]

                if height * width != base_image_feature.shape[0]:
                    raise ValueError("The number of patches is not consistent with the image size.")
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size)
                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                image_feature = unpad_image(image_feature, image_sizes[image_idx])
                image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)), dim=-1)
                image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                image_feature = torch.cat((base_image_feature, image_feature), dim=0)
            else:
                image_feature = image_feature[0]
                image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
            new_image_features.append(image_feature)
        image_features = torch.stack(new_image_features, dim=0)

        inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, None)
        inputs["inputs_embeds"] = inputs_embeds
        inputs["attention_mask"] = attention_mask
        inputs["position_ids"] = position_ids
        
        return inputs


    def _reorder_cache(
        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called.
        This is required to match `past_key_values` with the correct beam_idx at every generation step.
        """

        # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
        return tuple(
            tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
            for layer_past in past_key_values
        )

    
    def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
        num_images, num_image_patches, embed_dim = image_features.shape
        batch_size, sequence_length = input_ids.shape
        left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
        # 1. Create a mask to know where special image tokens are
        special_image_token_mask = input_ids == self.config.image_token_index
        num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
        # Compute the maximum embed dimension
        max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
        batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

        # 2. Compute the positions where text should be written
        # Calculate new positions for text tokens in merged image-text sequence.
        # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
        # `torch.cumsum` computes how each image token shifts subsequent text token positions.
        # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
        new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
        nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
        if left_padding:
            new_token_positions += nb_image_pad[:, None]  # offset for left padding
        text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

        # 3. Create the full embedding, already padded to the maximum position
        final_embedding = torch.zeros(
            batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
        )
        final_attention_mask = torch.zeros(
            batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
        )
        # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
        # set the corresponding tensors into their correct target device.
        target_device = inputs_embeds.device
        batch_indices, non_image_indices, text_to_overwrite = (
            batch_indices.to(target_device),
            non_image_indices.to(target_device),
            text_to_overwrite.to(target_device),
        )
        attention_mask = attention_mask.to(target_device)

        # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
        # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
        final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
        final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
        if labels is not None:
            final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

        # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
        image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
        image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

        if image_to_overwrite.sum() != image_features.shape[:-1].numel():
            raise ValueError(
                f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
                f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
            )

        final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
        final_attention_mask |= image_to_overwrite
        position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

        # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
        batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
        indices_to_mask = new_token_positions[batch_indices, pad_indices]

        final_embedding[batch_indices, indices_to_mask] = 0

        return final_embedding, final_attention_mask, position_ids

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        pixel_values=None,
        image_sizes=None,
        attention_mask=None,
        **kwargs,
    ):
        if past_key_values is not None:
            if not self.stateful:
                cache_length = past_length = past_key_values[0][0].shape[2]
            else:
                cache_length = past_length = self.past_len

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.llava
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
            elif self.config.image_token_index in input_ids:
                input_ids = input_ids[:, input_ids.shape[1] - 1 :]
            # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
            # older attention values, as their corresponding values are not part of the input.
            if cache_length < past_length and attention_mask is not None:
                attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch gllavaeneration
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "pixel_values": pixel_values,
                "image_sizes": image_sizes,
            }
        )
        return model_inputs

## Run OpenVINO model inference

### Select device

In [9]:
import ipywidgets as widgets

core = ov.Core()

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="CPU",
    description="Device:",
    disabled=False,
)

device

Dropdown(description='Device:', options=('CPU', 'GPU.0', 'GPU.1', 'AUTO'), value='CPU')

In [17]:
ov_llava_model = OVLlavaForCausalLM(core, MODEL_DIR, device.value)

In [18]:
from PIL import Image
import requests


from transformers import TextStreamer

url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
streamer = TextStreamer(processor, skip_special_tokens=True, skip_prompt=True)

inputs = processor(prompt, image, return_tensors="pt")

In [19]:
output = ov_llava_model.generate(**inputs, max_new_tokens=20, streamer=streamer)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


The image appears to be a radar chart, which


tensor([[    1,   733, 16289, 28793, 28705, 32000, 28705,    13,  3195,   349,
          4894,   297,   456,  3469, 28804,   733, 28748, 16289, 28793,   415,
          3469,  8045,   298,   347,   264, 24951, 10968, 28725,   690]])

## Interactive demo

In [None]:
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
import re
import time 
from PIL import Image
import torch

def bot_streaming(message, history):
    print(message)
    if message["files"]:
        image = message["files"][-1]["path"]
    else:
    # if there's no image uploaded for this turn, look for images in the past turns
    # kept inside tuples, take the last one
        for hist in history:
            if isinstance(hist[0], tuple):
                image = hist[0][0]

    if image is None:
        gr.Error("You need to upload an image for LLaVA to work.")
    prompt=f"[INST] <image>\n{message['text']} [/INST]"
    image = Image.open(image).convert("RGB")
    inputs = processor(prompt, image, return_tensors="pt")

    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
    generated_text = ""

    thread = Thread(target=ov_llava_model.generate, kwargs=generation_kwargs)
    thread.start()

    text_prompt =f"[INST]  \n{message['text']} [/INST]"
  

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[len(text_prompt):]
        yield generated_text_without_prompt


demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", #examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
                                                              #        {"text": "How to make this pastry?", "files":["./baklava.png"]}], 
                        description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo. Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
                        stop_btn="Stop Generation", multimodal=True)

try:
    demo.launch(debug=True)
except Exception:
    demo.launch(debug=True, share=True)