In [None]:
!pip install flash_attn

In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForImageTextToText
from transformers.image_utils import load_image
from transformers.modeling_outputs import BaseModelOutput

from torch.profiler import profile, ProfilerActivity, record_function
from torch.utils.benchmark import Timer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load images
image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
image2 = load_image("https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg")


# Initialize processor and model
model_id =  "HuggingFaceTB/SmolVLM-256M-Instruct"   # "HuggingFaceTB/SmolVLM-Instruct"; "HuggingFaceTB/SmolVLM-500M-Instruct"; "HuggingFaceTB/SmolVLM-256M-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)


retain = .5
K = int(processor.image_seq_len * retain)
processor.image_seq_len = K

def prune_visual_tokens_hook(module, inputs, outputs):
    # idx = torch.randperm(outputs.shape[1])[:K]
    # idx = torch.sort(idx)
    # pruned = outputs[:, idx]

    # sorted and different per batch
    idx = torch.sort(torch.argsort(torch.rand(outputs.shape[0], outputs.shape[1], device=outputs.device), dim=-1)[:,:K], dim=-1).values
    pruned = torch.gather(input=outputs, dim=1, index=idx.unsqueeze(-1).expand(-1, -1, outputs.shape[-1]))

    return pruned

vision_encoder = model.model.connector

if retain < 1:
  print("Attaching Hook")
  handle = vision_encoder.register_forward_hook(prune_visual_tokens_hook)

# Create input messages
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "image"},
            {"type": "text", "text": "Can you describe the two images?"}
        ]
    },
]

# Prepare inputs
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image1, image2], return_tensors="pt")
inputs = inputs.to(DEVICE)

# with profile(
#     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True, with_stack=False
# ) as prof:
#     # Generate outputs
#     generated_ids = model.generate(**inputs, max_new_tokens=500)
#     generated_texts = processor.batch_decode(
#         generated_ids,
#         skip_special_tokens=True,
#     )
# print(
#     prof.key_averages().table(
#         sort_by="self_cuda_memory_usage",
#         row_limit=20
#     )
# )

# generated_ids = model.generate(**inputs, max_new_tokens=500)
# generated_texts = processor.batch_decode(
#     generated_ids,
#     skip_special_tokens=True
# )
# print(generated_texts[0])

def generate_outputs():
  generated_ids = model.generate(**inputs, max_new_tokens=500)
  generated_texts = processor.batch_decode(
      generated_ids,
      skip_special_tokens=True
  )

# generated_ids = model.generate(**inputs, max_new_tokens=500)
# torch.cuda.reset_peak_memory_stats()
# generated_texts = processor.batch_decode(
#       generated_ids,
#       skip_special_tokens=True
#   )
# print("Peak allocated:", torch.cuda.max_memory_allocated() / 1024**3, "GB\n")

timer = Timer(
    stmt='generate_outputs()',
    setup='from __main__ import generate_outputs',
    num_threads=1
)

print(timer.timeit(100))


