In [None]:
#crop analysis

import torch
import random
import numpy as np
import difflib
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torchvision.transforms import ToTensor, ToPILImage
from IPython.display import display
from datasets import load_dataset



model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    # Parameters for pixel constraint, adjust as needed:
    min_pixels=256 * 28 * 28,
    max_pixels=512 * 28 * 28
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")



def get_prediction(image, k):
    """
    Runs model inference given an image.
    """
    prompt = "List all the entities in the image"

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt}
        ]
    }]
    
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    image_inputs, video_inputs = process_vision_info(messages)
    
    inputs = processor(
        text=[text],
        images=image,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    print(inputs)
    print(inputs.input_ids.shape)

    
    for i in range(k, 483):
        inputs.attention_mask[0][i] = 0

    print(inputs.attention_mask)
    
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )
    return output_text[0]
   

results = []  
dataset   = load_dataset("yyyyifan/VLQA", split="test")

#29

K = 483
for idx in range(K):
    image = dataset[0]["image"]

    print("Sample Image:")
    display(image)

    k = idx
    baseline_pred = get_prediction(image, k)

    print("Baseline prediction:")
    print(baseline_pred)
  

    display_images = [image]
    display_titles = ["Original"]


