In [1]:
import requests
from PIL import Image

import torch
from transformers import (
  AutoProcessor,
  LlavaOnevisionForConditionalGeneration
)


In [2]:
model_id = "llava-hf/llava-onevision-qwen2-7b-ov-chat-hf"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype="auto",
    low_cpu_mem_usage=True,
).to(0)

processor = AutoProcessor.from_pretrained(
  model_id,
)


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What are these?"},
            {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
# raw_image = Image.open(requests.get(image_file, stream=True).raw)
raw_image = Image.open("examples/sample.png")
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(
    0, torch.float16
)

output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))


Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


user 
What are these?assistant
The image you've provided is a collage of various photographs. From left to right, top to bottom:

1. A golden retriever dog looking relaxed.
2. A beaver swimming in water among rocks.
3. A red panda resting on a rock.
4. A geyser erupting with steam and water shooting into the air.
5. A colorful bird, possibly a macaw, perched on a branch.
6. A waterfall cascading over rocks.
7. A hot air balloon with a pilot inside, floating in the sky.
8. A white arctic fox walking in snow.

Each photograph captures a different aspect of wildlife, nature, and human activity. The collage as a whole showcases the diversity of the natural world.


In [5]:
# Monkey patcher from Qwen2VL
from typing import List, Optional, Tuple, Union, Literal
import torch
from transformers.models.llava_onevision import LlavaOnevisionForConditionalGeneration

# Calculate dynamic threshold for attention map
def calculate_dynamic_threshold(visual_token_attn_score, percentile=95):
    hist = torch.histc(visual_token_attn_score, bins=100)
    cdf = torch.cumsum(hist, dim=0) / torch.sum(hist)
    threshold = torch.argmax((cdf > percentile / 100).float()).item() / 100
    return threshold


def get_mean_attn_score(output_ids) -> torch.Tensor:
    r"""
    get the mean attention weights of the prefilling and full attention
    Args:
        output_ids: the output ids of the model
    Returns:
        mean_attn: the mean attention weights of the prefilling and full attention, shape: (L, L)
    """
    output_attn = output_ids.attentions
    pref_len = output_attn[0][0].shape[3]
    full_len = output_attn[-1][0].shape[3]
    prefill_attn = output_attn[0]
    assert prefill_attn[0].shape[0] == 1, "batch size should be 1"
    full_attn = []

    for l, layer in enumerate(prefill_attn):
        layer = layer.cpu().squeeze(0).float()
        layer = torch.nn.functional.pad(layer, (0, full_len - pref_len, 0, full_len - pref_len))
        for i in range(full_len - pref_len):
            cur_attn = output_attn[i + 1][l].cpu().squeeze(0)[:, 0, :].float()
            layer[:, pref_len + i, :pref_len + i + 1] = cur_attn
        full_attn.append(layer)
    mean_attn = torch.stack(full_attn).mean(dim=(0, 1))
    return mean_attn


def get_visual_token_mean_attn_score(mean_attn, inputs, vision_start_token_id, vision_end_token_id) -> Tuple[torch.Tensor, ...]:
    r"""
    Get the attention weights of the visual tokens
    Args:
        mean_attn: the mean attention weights of the prefilling and full attention, shape: (L, L)
        inputs: the inputs of the model
    Returns:
        visual_token_attn_weights: the tuple of the attention weights of the visual tokens, each element shape: (V, V)
    """
    assert inputs["input_ids"].shape[0] == 1, "batch size should be 1"
    pref_len = len(inputs['input_ids'][0])
    vision_start_token_indices = torch.where(
        inputs["input_ids"][0] == vision_start_token_id
    )[0]
    vision_end_token_indices = torch.where(
        inputs["input_ids"][0] == vision_end_token_id
    )[0]
    # assert len(vision_start_token_indices) == len(vision_end_token_indices), "vision start and end token idx should be the same"
    # print(vision_start_token_indices)
    # print(vision_end_token_indices)
    # iterate over multiple images
    visual_token_attn_weights = tuple(
        torch.mean(mean_attn[pref_len:, s + 1 : e], dim=0)
        for s, e in zip(
            vision_start_token_indices, vision_end_token_indices, strict=True
        )
    )
    return visual_token_attn_weights


def get_visual_token_weight(
    visual_token_attn_score,
    threshold,
    keep_weight,
    weighting_type: Literal["linear", "exp", "uniform"] | str = "linear",
    lowest_weight=0.0,
):
    sorted_indices = torch.argsort(visual_token_attn_score, descending=True)
    num_tokens_to_keep = int(len(visual_token_attn_score) * threshold)
    weight_vision_token = torch.zeros_like(visual_token_attn_score, dtype=torch.float)
    weight_vision_token[sorted_indices[:num_tokens_to_keep]] = keep_weight
    if weighting_type == "linear":
        weight_vision_token[sorted_indices[num_tokens_to_keep:]] = torch.linspace(
            lowest_weight, 1.0, len(visual_token_attn_score) - num_tokens_to_keep
        )
    elif weighting_type == "exp":
        weight_vision_token[sorted_indices[num_tokens_to_keep:]] = torch.exp(
            torch.linspace(0, -3, len(sorted_indices) - num_tokens_to_keep)
        )
    elif weighting_type == "uniform":
        weight_vision_token[sorted_indices[num_tokens_to_keep:]] = lowest_weight
    else:
        raise ValueError(f"Invalid weighting type: {weighting_type}")
    return weight_vision_token


In [None]:
from transformers.models.llava_onevision.modeling_llava_onevision import (
    LlavaOnevisionCausalLMOutputWithPast,
)


def patch_forward(
    self,
    input_ids: torch.LongTensor = None,
    pixel_values: torch.FloatTensor = None,
    image_sizes: Optional[torch.LongTensor] = None,
    pixel_values_videos: torch.FloatTensor = None,
    image_sizes_videos: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    vision_feature_layer: Optional[Union[int, List[int]]] = None,
    vision_feature_select_strategy: Optional[str] = None,
    vision_aspect_ratio: Optional[str] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    logits_to_keep: Union[int, torch.Tensor] = 0,
    **lm_kwargs,
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
    r"""
    Args:
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        logits_to_keep (`int` or `torch.Tensor`, *optional*):
            If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
            `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
            token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
            If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
            This is useful when using packed tensor format (single dimension for batch and sequence length).


    Returns:
        [`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.

    Example:

    ```python
    >>> from PIL import Image
    >>> import requests
    >>> import torch
    >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration

    >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
    >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")

    >>> conversation = [
    ...     {
    ...       "role": "user",
    ...       "content": [
    ...           {"type": "text", "text": "What is shown in this image?"},
    ...           {"type": "image"},
    ...         ],
    ...     },
    ... ]
    >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
    >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)

    >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
    >>> processor.batch_decode(output, skip_special_tokens=True)[0]
    "user\n\nWhat is shown in this image?\nassistant\ncat"
    ```"""
    output_attentions = (
        output_attentions
        if output_attentions is not None
        else self.config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else self.config.output_hidden_states
    )
    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )
    vision_feature_layer = (
        vision_feature_layer
        if vision_feature_layer is not None
        else self.config.vision_feature_layer
    )
    vision_feature_select_strategy = (
        vision_feature_select_strategy
        if vision_feature_select_strategy is not None
        else self.config.vision_feature_select_strategy
    )
    vision_aspect_ratio = (
        vision_aspect_ratio
        if vision_aspect_ratio is not None
        else self.config.vision_aspect_ratio
    )

    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    if (
        pixel_values is not None or pixel_values_videos is not None
    ) and inputs_embeds is not None:
        raise ValueError(
            "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
            "and must specify either one"
        )

    if inputs_embeds is None:
        inputs_embeds = self.get_input_embeddings()(input_ids)

    # Images are processed with Anyres
    if pixel_values is not None:
        image_features = self.get_image_features(
            pixel_values,
            image_sizes,
            vision_feature_layer=vision_feature_layer,
            vision_feature_select_strategy=vision_feature_select_strategy,
        )
        image_features, feature_lens = self.pack_image_features(
            image_features,
            image_sizes,
            image_newline=self.image_newline,
            vision_aspect_ratio=vision_aspect_ratio,
        )

        special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
        special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
            inputs_embeds.device
        )
        if (
            not is_torchdynamo_compiling()
            and inputs_embeds[special_image_mask].numel() != image_features.numel()
        ):
            n_image_tokens = (input_ids == self.config.image_token_index).sum()
            n_image_features = image_features.shape[0]
            raise ValueError(
                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
            )
        image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
        if self.embed_weight is not None:
            image_features *= self.embed_weight[:, None]
        inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

    # Video are simply embedded and further pooled to decrease seq len
    if pixel_values_videos is not None:
        video_features = self.get_video_features(
            pixel_values_videos,
            vision_feature_layer=vision_feature_layer,
            vision_feature_select_strategy=vision_feature_select_strategy,
        )
        image_newline = (
            self.image_newline[None, None, :]
            .repeat(video_features.shape[0], 1, 1)
            .to(video_features.device)
        )
        video_features = torch.cat((video_features, image_newline), dim=1)
        video_features = video_features.flatten(0, 1)

        special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
        special_video_mask = special_video_mask.expand_as(inputs_embeds).to(
            inputs_embeds.device
        )
        if (
            not is_torchdynamo_compiling()
            and inputs_embeds[special_image_mask].numel() != video_features.numel()
        ):
            n_video_tokens = (input_ids == self.config.video_token_index).sum()
            n_video_features = video_features.shape[0]
            raise ValueError(
                f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
            )
        video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)

    outputs = self.language_model(
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
        logits_to_keep=logits_to_keep,
        **lm_kwargs,
    )

    logits = outputs[0]

    loss = None
    if labels is not None:
        # Shift so that tokens < n predict n
        if attention_mask is not None:
            # we use the input attention mask to shift the logits and labels, because it is 2D.
            # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
            shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
                logits.device
            )
            shift_logits = logits[..., :-1, :][
                shift_attention_mask.to(logits.device) != 0
            ].contiguous()
            shift_labels = labels[..., 1:][
                shift_attention_mask.to(labels.device) != 0
            ].contiguous()
        else:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1).to(shift_logits.device),
        )

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    return LlavaOnevisionCausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
        image_hidden_states=image_features if pixel_values is not None else None,
        video_hidden_states=video_features if pixel_values_videos is not None else None,
    )
