In [None]:
from tqdm import tqdm
import json
from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
import torch
from PIL import Image
from glob import glob
import os
import cv2
import pandas as pd
import gc

In [None]:
root_dir = "/kaggle/input/msvd-and-msrvtt/frames"
video_ids = sorted(glob(os.path.join(root_dir, "*")))

In [None]:
def load_video_frames(video_folder):
    frame_paths = sorted(glob(os.path.join(video_folder, "*.jpg")))
    frames = []

    for path in frame_paths:
        img_bgr = cv2.imread(path)
        img_pil = Image.fromarray(img_bgr)
        frames.append(img_pil)
        
    return frames

In [None]:
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", torch_dtype="float16", device_map='auto')
processor = LlavaOnevisionProcessor.from_pretrained(
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", use_fast=False
)

processor.tokenizer.padding_side = "left"

In [None]:
# prompt = '''
# Describe all visible objects, people, actions, and their positions in the image. 
# Be concise, factual, and avoid subjective terms. 
# Also, describe the main action performed by the primary subject in the scene.
# Limit the entire description to at most 3 sentences.
# '''
prompt = """
Describe the image in 3 short sentences, under 64 tokens, focusing on scene, objects, and main action.
"""

conversation_image = [
      {
          "role": "user",
          "content": [
              {"type": "text", "text": prompt},
              {"type": "image"},
              ],
      },
]
prompt_image = processor.apply_chat_template(conversation_image, add_generation_prompt=True)
generate_kwargs = {
    "max_new_tokens": 64,
    "do_sample": False,        
    "num_beams": 5,            
    "early_stopping": True  
}

In [None]:
batch_size = 4
all_captions = []

index = 400
for batch_start in tqdm(range(index, len(video_ids), batch_size)):
    if batch_start == 1000:
        break
    if batch_start + batch_size >= len(video_ids):
        batch_video_ids = video_ids[batch_start:]
    else: 
        batch_video_ids = video_ids[batch_start: batch_start + batch_size]
        
    median_frames = []
    texts = []
    for video_id in batch_video_ids:
        images = load_video_frames(video_id)

        median_idx = len(images) // 2
        median_frame = images[median_idx]

        median_frames.append(median_frame)
        texts.append(prompt_image) 
        del images 
        

    inputs = processor(
        images=median_frames,
        text=texts,
        padding=True,
        return_tensors="pt"
    ).to(model.device, torch.float16)
    
    outputs = model.generate(
        **inputs,
        pad_token_id=processor.tokenizer.eos_token_id,
        **generate_kwargs
    )
    generated_texts = processor.batch_decode(outputs, skip_special_tokens=True)
    
    for video_id, caption in zip(batch_video_ids, generated_texts):
        vid_id = video_id.split("/")[-1]
        parts = caption.split("assistant", 1)
        all_captions.append((vid_id, parts[1].strip()))
        
    del median_frames, texts, inputs, outputs, generated_texts
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
file_path = "/kaggle/working/image_descriptions.csv"
df_new = pd.DataFrame(all_captions, columns=['id', 'description'])
if os.path.exists(file_path):
    df_existing = pd.read_csv(file_path)
    df_new = pd.concat([df_existing, df_new], ignore_index=True)
df_new.to_csv(file_path, index=False)