<a href="https://colab.research.google.com/github/hassineElghazel/video-llava-prompt-study/blob/main/Extension/gemini_1_5pro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##General Setup

Imports and persistant storage in google drive.

In [None]:
!pip install av

Collecting av
  Downloading av-14.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.6 kB)
Downloading av-14.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.3/35.3 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: av
Successfully installed av-14.4.0


In [None]:
# --- standard library ---
import os
import gc
import json
from pathlib import Path
from io import BytesIO

# --- third-party ---
import numpy as np
import pandas as pd
from tqdm import tqdm
import av
import torch
from PIL import Image
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted, BadRequest, InternalServerError

# --- Colab-specific ---
from google.colab import drive, userdata


In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
CUT_VIDEO_OUT = "/content/drive/MyDrive/model_2.0/model/cut_videos_6tol"

In [None]:
annotations_path = f'/content/drive/MyDrive/model_2.0/model/GroundTruth.xlsx'
annotations_df = pd.read_excel(annotations_path)

annotations_df.head()

Unnamed: 0,idx,annotation_uid,clip_uid,sentence,predicted_times,exact_times,video_uid,ground truth
0,0,93891b44-c00f-43d9-885b-92fdce39128c,75d3fc52-3776-47d4-b7fd-8074d30b06d1,Where did I put the chopsticks.,"(0.0, 3.75)","(3.319, 4.619)",413fe086-1745-4573-b75b-e7d26ff72df9,beside the stove
1,1,00bb3571-8b35-40ba-8163-a1a0fd20b886,60e7e14d-cbed-46d1-924d-6ce451ea7d7c,What cable did I remove?,"(78.75, 86.25)","(84.32902, 84.71152)",03e90bbc-7d6b-423c-84d9-b5be3eff11c5,usb drive
2,2,d8fb0c3d-39b1-4bb2-b800-fc56f12ab120,e1c79556-e8af-4e26-bc4c-633100277239,Did I close the refrigerator?,"(0.0, 3.75)","(0.65536, 2.00097)",4ce119de-0f42-4bd1-b387-9e19643fdddc,yes
3,3,e8b1352a-6e0c-4f86-8a52-afe93743abb6,efc190a8-45de-4ce5-b480-b722403bcec1,Where was the scissors before I picked it up?,"(93.75, 97.5)","(87.0, 89.306)",ff6d3d52-dda5-46dd-8515-b9b772933030,on the plate rack
4,4,2bc5aa4c-3114-497e-9d13-e229277fe11b,cc6270fd-3c0d-4dda-bcb4-52cefc0224d7,In what location did I open the door?,"(3.75, 7.5)","(6.2829, 8.428)",432cb803-6be5-47bc-8443-6bb5b9051667,in the bathroom


Define model.

In [None]:
# ✅ Set and configure API key
api_key = userdata.get('GEMINI_API_KEY')
genai.configure(api_key=api_key)
model = genai.GenerativeModel("gemini-1.5-pro")

Define model usage functions.

In [None]:
import re

def build_general_prompt(sentence: str) -> str:
    q = sentence.strip().lower()

    prompt = (
        "### SYSTEM:\n"
        "You are a helpful assistant that answers questions based on visual summaries of video segments.\n"
        "Use only the information visible in the video.\n"
        "Be accurate and respond in **5 words or fewer**.\n\n"

        "### CONTEXT:\n"
        "<video>\n"
        "(Visual summary of the video is shown here)\n\n"

        f"### USER:\nQuestion: {sentence.strip()}\n"
    )

    # -------------------------------
    # 🔍 Task Classification Logic
    # -------------------------------

    # 1. Counting tasks
    if re.search(r'\bhow many\b', q):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "This is a counting task. Count only the visible items.\n"
            "Return the number.\n"
        )

    # 2. Color detection
    elif re.search(r'\bwhat (color|colour)\b', q):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Identify the visible color of the mentioned object.\n"
        )

    # 3. Location questions (debiased)
    elif re.search(r'\bwhere (is|was|did|do)\b|\bin what location\b', q):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Identify the object's visible spatial location based only on the video summary.\n"
            "Use only location descriptions directly inferable from the visual context.\n"
        )

    # 4. Action-object identification
    elif re.search(r'\bwhat\b', q) and re.search(
        r'\b(remove|pick|grab|cut|wash|tie|wipe|press|use|carry|put|insert|drop|place)\b', q
    ):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Identify the object involved in the action based on what is clearly visible in the video.\n"
        )

    # 5. Tool/machine recognition
    elif re.search(r'\bwhat\b', q) and re.search(r'\b(tool|machine|device|object|item|equipment)\b', q):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Determine what object, tool, or machine was used. Base your answer only on visible cues.\n"
        )

    # 6. Human interaction
    elif re.search(r'\bwho did (i|he|she|we|they) (talk|interact)\b', q):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Name the person involved in the interaction.\n"
        )

    # 7. Yes/no — fallback only
    elif (
        re.search(r'\bdid\b.*\b(i|we|they|he|she)\b', q)
        and not re.search(r'\bwhat\b', q)
        and not re.search(r'\bwho\b', q)
    ):
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "This is a yes/no question. Respond with 'yes' or 'no' based on visible evidence only.\n"
        )

    # 8. Fallback
    else:
        prompt += (
            "\n### TASK INSTRUCTION:\n"
            "Answer briefly and accurately using only the information visible in the video summary.\n"
        )

    prompt += "\nAnswer:\n"
    return prompt

In [None]:
def make_grid_image(frames, grid_size=(2, 4)):
    assert len(frames) == 8, "Exactly 8 frames are expected"
    frame_width, frame_height = frames[0].size
    grid_width = grid_size[1] * frame_width
    grid_height = grid_size[0] * frame_height

    grid_image = Image.new('RGB', (grid_width, grid_height))

    for idx, frame in enumerate(frames):
        row = idx // grid_size[1]
        col = idx % grid_size[1]
        grid_image.paste(frame, (col * frame_width, row * frame_height))

    return grid_image


def read_video_frames(container, indices):
    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i > indices[-1]:
            break
        if i in indices:
            frames.append(frame.to_ndarray(format='rgb24'))

    while len(frames) < 8:
        frames.append(frames[-1])

    return frames


def frame_to_pil(frame):
    return Image.fromarray(frame)


def prepare_gemini_inputs(row, video_segments_dir):
    clip_uid = row['clip_uid']
    annotation_uid = row['annotation_uid']
    sentence = row['sentence']
    idx = int(row['idx'])
    video_segment_path = os.path.join(video_segments_dir, f"{idx}_{clip_uid}_{annotation_uid}.mp4")

    if not os.path.exists(video_segment_path):
        print(f"Video file {video_segment_path} not found.")
        return None, None

    try:
        container = av.open(video_segment_path)
        if not container.streams.video:
            print(f"No video stream in {video_segment_path}.")
            return None, None

        total_frames = container.streams.video[0].frames
        indices = np.linspace(0, total_frames - 1, num=8, dtype=int)
        frames = read_video_frames(container, indices)
        pil_frames = [frame_to_pil(f) for f in frames]
        grid_image = make_grid_image(pil_frames)  # ✅ create one image from 8

        prompt = build_general_prompt(sentence)
        return [grid_image], prompt  # ✅ Return one image inside a list

    except Exception as e:
        print(f"Error processing video {video_segment_path}: {e}")
        return None, None


def process_row_with_gemini(row):
    images, prompt = prepare_gemini_inputs(row, CUT_VIDEO_OUT)
    if images is None or not all(isinstance(img, Image.Image) for img in images):
        print(f"⚠️ Skipping idx {row['idx']} due to invalid image(s)")
        return None

    try:
        response = model.generate_content(
            [prompt] + images,
            generation_config={"max_output_tokens": 3000}
        )

        if not response.parts:
          print(f"⚠️ No content returned for idx {row['idx']}. Prompt:\n{prompt}\n")
          return None

        answer = response.text.strip()
        print(answer)
        times = row['predicted_times'].strip('()').split(",")

        return {
          'idx' : row['idx'],
            'annotation_uid' : row['annotation_uid'],
            'clip_uid' : row['clip_uid'],
            'sentence' : row['sentence'],
            'predicted_times' : row['predicted_times'],
            'exact_times' : row['exact_times'],
            'video_uid' : row['video_uid'],
            'ground truth' : row['ground truth'],
            'answer': answer.split("Answer:")[-1].strip() if "Answer:" in answer else answer.strip(),
        }

    except Exception as e:
        print(f"❌ Error generating response for idx {row['idx']}: {e}")
        return None

Feed prompts and video into model.

In [None]:
CHECKPOINT_PATH = "/content/drive/MyDrive/model_2.0/model/gemini_video_qa_ans_1.5_pro_3000.csv"
CHECKPOINT_EVERY = 20

answers = []

# Load checkpoint if it exists
if os.path.exists(CHECKPOINT_PATH):
    answers_df = pd.read_csv(CHECKPOINT_PATH)
    processed_keys = set(answers_df.apply(lambda row: f"{row['clip_uid']}::{row['annotation_uid']}", axis=1))
    answers = answers_df.to_dict(orient="records")
    print(f"🔁 Resuming from checkpoint: {len(answers)} examples already processed.")
else:
    answers_df = pd.DataFrame()
    processed_indices = set()

# Loop with checkpointing
for i, row in annotations_df.iterrows():
    key = f"{row['clip_uid']}::{row['annotation_uid']}"
    '''
    if key in processed_keys:
        continue  # ✅ Skip already processed
        '''

    result = process_row_with_gemini(row)
    if result:
        answers.append(result)

    # Save checkpoint
    if len(answers) % CHECKPOINT_EVERY == 0:
        answers_df = pd.DataFrame(answers)
        answers_df.to_csv(CHECKPOINT_PATH, index=False)
        print(f"💾 Checkpoint saved with {len(answers)} entries.")

    gc.collect()

# Final save
answers_df = pd.DataFrame(answers)
answers_df.to_csv(CHECKPOINT_PATH, index=False)
print("✅ Final save completed. Total entries:", len(answers))

Being used in the kitchen.
Black cable from monitor.
Yes.
On the kitchen counter.
Bathroom doorway.
Hanging on the door.
Next to the workbench.
White adhesive foam strip.
Zero cans on the window.
On the stovetop.
On the workbench.
On the wooden board.
On the table.
Kitchen sink.
On the gray couch.
No video stream in /content/drive/MyDrive/model_2.0/model/cut_videos_6tol/17_679cfee6-7da1-4701-b75a-9e34abb9400a_e3015a5a-3e3e-47f5-a6b9-b77d3648621e.mp4.
⚠️ Skipping idx 17 due to invalid image(s)
Folded in hand.
Inside the toolbox drawer.
Kitchen cabinet.
On the rug.
Bottom right, wooden cabinet.
💾 Checkpoint saved with 20 entries.
No video stream in /content/drive/MyDrive/model_2.0/model/cut_videos_6tol/26_3341648c-88b4-433a-87ac-1fcc9619a4dc_8e2fed00-ea98-4cc4-b8e7-faecffe28d90.mp4.
⚠️ Skipping idx 26 due to invalid image(s)
💾 Checkpoint saved with 20 entries.
Nobody.  No interaction shown.
One rope.
A person in a black hoodie.
Outside the window.
On the kitchen counter.
Shop vacuum.
On 