In [None]:
!pip install qwen-vl-utils

In [None]:
import torch
import os
import pandas as pd
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

model_name = "Qwen/Qwen2-VL-2B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
).to(device)

processor = AutoProcessor.from_pretrained(model_name)

In [None]:
image_folder = "/content/drive/MyDrive" 
csv_file_path = "/content/drive/MyDrive/SkyScript_train_unfiltered_5M.csv"
output_csv_file = "/content/drive/MyDrive/updated_annotations.csv"

df_filtered = pd.read_csv(output_csv_file)

In [None]:
prompt_text = "Generate a caption for the satellite image."

save_every = 100  

for idx, row in df_filtered.iterrows():
    # print("Idx: ", row)
    if pd.notna(row['generated_captions']) and str(row['generated_captions']).strip().lower() not in ["", "nan"]:
        print(f"ℹ️ Line {idx} are skipping: There is a caption.")
        continue

    img_file = row["filepath"]  
    img_path = os.path.join(image_folder, img_file)  

    if os.path.exists(img_path):
        image = Image.open(img_path).convert("RGB")

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path},
                    {"type": "text", "text": prompt_text},
                ],
            }
        ]

        text_input = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = processor(
            text=[text_input],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=128)
            generated_ids_trimmed = [
                out_ids[len(in_ids):]
                for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            caption = processor.batch_decode(
                generated_ids_trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )[0]

        df_filtered.at[idx, "generated_captions"] = caption
        print(f"✅ Processed {idx}: {img_file} → {caption}")

        if idx % save_every == 0 and idx > 0:
            df_filtered.to_csv(output_csv_file, index=False)
            print(f"💾 {output_csv_file} file was saved (Step {idx}).")

df_filtered.to_csv(output_csv_file, index=False)
print("✅ All captions were generated and updated csv file was saved.")