## VARCO-VISION 14B

* From NCSOFT
* Ranked #35 in OpenCompass multimodal academic leaderboard (https://rank.opencompass.org.cn/leaderboard-multimodal)

In [None]:
import torch
import requests
from PIL import Image
from transformers import LlavaOnevisionForConditionalGeneration, AutoProcessor

# Add the desired directory to the Python path
import sys
import os
sys.path.append(os.path.abspath('/data/students/earl/llava-dissector/VARCO-VISION-14B-HF'))

import urllib
from io import BytesIO
from PIL import Image

from typing import Optional, Union, List, Tuple

# Sub-class LlavaOnevisionForConditionalGeneration here
class CustomLlavaOnevisionForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
    def __init__(self, config, threshold=-1.5):
        super().__init__(config)
        self.threshold = threshold

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        image_sizes: Optional[torch.LongTensor] = None,
        pixel_values_videos: torch.FloatTensor = None,
        image_sizes_videos: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[int] = None,
        vision_feature_select_strategy: Optional[str] = None,
        vision_aspect_ratio: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
    ):
        # Call the parent class's forward method
        outputs = super().forward(
            input_ids=input_ids,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            pixel_values_videos=pixel_values_videos,
            image_sizes_videos=image_sizes_videos,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            vision_feature_layer=vision_feature_layer,
            vision_feature_select_strategy=vision_feature_select_strategy,
            vision_aspect_ratio=vision_aspect_ratio,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep
        )    

        # Modify the hidden_states before logits are computed
        hidden_states = outputs.hidden_states[-1]  # Get the last hidden state
        modified_hidden_states = self.modify_hidden_states(hidden_states)

        # Recompute logits using the modified hidden states
        # The key thing here is realizing that language_model hides the lm_head as its attribute
        logits = self.language_model.lm_head(modified_hidden_states)

        # Return the modified outputs
        outputs.logits = logits
        return outputs

    def modify_hidden_states(self, hidden_states):
        if self.threshold is not None:
            threshold = self.threshold
            hidden_states = torch.relu(hidden_states - threshold) + threshold
        return hidden_states

    
model_name = "/data/students/earl/llava-dissector/VARCO-VISION-14B-HF"
model = CustomLlavaOnevisionForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype="float16",
        device_map="auto",
        attn_implementation="flash_attention_2"
    )
processor = AutoProcessor.from_pretrained(model_name)
device = "cuda:5" #model.device


## Extract the object and visualize

In [None]:
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
cls = "baseball player"
model.threshold = -10


url = "https://farm3.staticflickr.com/2402/2480652763_e6b62303ee_z.jpg"
img = urllib.request.urlopen(url=url, timeout=5).read()
img = Image.open(BytesIO(img)).convert("RGB")
text = f'Give the normalized bounding box coordinates in the format [x1, y1, x2, y2] of all instances of {cls} in the image.'
#text = f'Bounding box coordinates of all instances of {cls} in the image. Do not include coordinates for any other objects or text in the output. Do not output bounding boxes for all other objects.'

conversation = [
    #{
    #    "role": "system",
    #    "content": [
    #        {"type": "text", "text": "You are a helpful assistant that extracts bounding box coordinates of objects in images."},
    #        {"type": "text", "text": "You will be given an image and a class name, and you should output the bounding box coordinates of instances of that class in the image."},
    #        {"type": "text", "text": "The output should be in the format: <bbox> x1, y1, x2, y2 </bbox> for each bounding box."},
    #    ],
    {    
        "role": "user",
        "content": [
            #{"type": "text", "text": f"<gro>\nBounding box coordinates of instances of {cls} in the image. Do not include coordinates for any other objects or text in the output. Do not output bounding boxes for all other objects."},
            {"type": "text", "text": {text}},
            {"type": "image"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

EOS_TOKEN = "<|im_end|>"
#image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
#raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=img, text=prompt, return_tensors='pt').to(device, torch.float16)

output = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
output = processor.decode(output[0][inputs.input_ids.shape[1]:])
if output.endswith(EOS_TOKEN):
    output = output[: -len(EOS_TOKEN)]

output = output.strip()
print(output)

## Visualize using cv2

# Extract the bounding box coordinates from the output
# Make sure the list is a list of floats
# Example output: "Bounding box coordinates: [[x1, y1, x2, y2], [x1, y1, x2, y2]]"
import re
import matplotlib.pyplot as plt

#pattern = r'<bbox>\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*</bbox>'
pattern = r'(?:[\[\(]\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*[\]\)])|(?:<box>\s*([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s*</box>)'
matches = re.findall(pattern, output)
bounding_boxes = [[float(coord) for coord in match] for match in matches]
print("Bounding boxes:", bounding_boxes)

# Visualize using cv2
import cv2
import numpy as np
img_cv = np.array(img)
for box in bounding_boxes:
    x1, y1, x2, y2 = box
    # convert these normalized coordinates to pixel values
    h, w, _ = img_cv.shape
    x1 = int(x1 * w)
    y1 = int(y1 * h)
    x2 = int(x2 * w)
    y2 = int(y2 * h)
    # Draw the bounding box and label on the image
    cv2.rectangle(img_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.putText(img_cv, cls, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

plt.figure(figsize=(10, 10))
plt.imshow(img_cv)
plt.axis('off')
plt.show() 
