## Gemini demo

This notebook assumes you:

- ran the script in `process_tvqa` -- you need the reconstructed clips and candidate lists (a.k.a. multiple_choice related files)
- have access to a Vertex AI Workbench instance (where this notebook was tested)
- have enabled the Gemini API
- store your files in a Cloud Storage bucket in Google Cloud Platform

In [None]:
import os
import re
import json
import datetime
from google import genai
from pydantic import BaseModel
from tqdm.notebook import tqdm
from google.cloud import storage
from IPython.display import HTML, Markdown, display
from google.genai.types import GenerateContentConfig, Part

In [None]:
def list_files_in_folder(bucket_name, folder_path):
    """Lists all files in a folder within a Google Cloud Storage bucket."""

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=folder_path)  # Use prefix to filter

    file_list = []  # Initialize an empty list
    for blob in blobs:
        if not blob.name.endswith('/'):  # Exclude "folders" (objects ending in /)
            file_list.append(blob.name)

    return file_list

def read_json_from_gcs(bucket_name, blob_name):
    """Reads a JSON file from Google Cloud Storage."""

    try:
        storage_client = storage.Client()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_name)

        # Download the file content as bytes
        file_content = blob.download_as_bytes()

        # Decode the bytes and parse the JSON
        data = json.loads(file_content)

        return data

    except Exception as e:
        print(f"Error reading file: {e}")
        return None


def write_json_to_gcs(bucket_name, blob_name, data):
    """Writes JSON data to a Google Cloud Storage blob."""
    try:
        storage_client = storage.Client()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_name)
        blob.upload_from_string(data, content_type='application/json')
    except Exception as e:
        print(f"Error writing JSON to GCS: {e}")

def blob_exists(bucket_name, blob_name):
    """Checks if a blob exists in a Google Cloud Storage bucket."""

    try:
        storage_client = storage.Client()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_name)
        return blob.exists()

    except Exception as e:
        print(f"Error checking blob existence: {e}")
        return False

def extract_clip_id(blob_name):
    """Extracts the clip ID from a blob name."""

    try:
        # Updated regex to handle different show names
        match = re.search(r"[a-z]+_s\d+e\d+_seg\d+_clip_\d+", blob_name)
        if match:
            return match.group(0)
        else:
            return None  # Clip ID not found

    except Exception as e:
        print(f"Error extracting clip ID: {e}")
        return None        

In [None]:
PROJECT_ID = "YOUR_PROJECT_ID" 
LOCATION = "YOUR_LOCATION"

client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)

---

In [None]:
MODEL_ID = "gemini-2.0-flash"  

bucket_name = "YOUR_BUCKET_NAME"
folder_path = "YOUR_DATA_FOLDER/" 

files = list_files_in_folder(bucket_name, folder_path)

In [None]:
file_dict = {}

for file in files:
    try:
        show_id = file.split('/')[1]
        clip_id = file.split('/')[2].split('.')[0]
    except:
        print(file)
  
    if '.DS_Store' in str(file) or '.ipynb_checkpoints' in str(file):
        continue
        
    if '.annotated_video.mp4' in file:
        file_dict.setdefault(show_id, {}).setdefault(clip_id, {}).update(
            {'video_path': file}
        ) 
    elif '.audio.mp3' in file:
        file_dict.setdefault(show_id, {}).setdefault(clip_id, {}).update(
            {'audio_path': file}
        )        
    elif '.multiple_choice.json' in file:
        file_dict.setdefault(show_id, {}).setdefault(clip_id, {}).update(
            {'multiple_choice_path': file}
        )   
    elif '.multiple_choice/' in file:
        file_dict.setdefault(show_id, {})\
                 .setdefault(clip_id, {})\
                 .setdefault('multiple_choice_image_paths', [])\
                 .append(file)
    elif '.frames/' in file:
        file_dict.setdefault(show_id, {})\
                 .setdefault(clip_id, {})\
                 .setdefault('frame_paths', [])\
                 .append(file)        
    elif '.subtitles.json' in file:
        file_dict.setdefault(show_id, {}).setdefault(clip_id, {}).update(
            {'subtitle_path': file}
        )  
     
    else:
        print(file)
        

In [None]:
class ConversationalRoles(BaseModel):
    line_index: int
    reply_to: int
    speaker: str
    addressees: list[str]
    side_participants: list[str]

class ClipRoles(BaseModel):
    clip_roles: list[ConversationalRoles]    

In [None]:
system_instruction = """You are a video analysis assistant.  Your task is to analyze the conversations in a video clip and its associated subtitles. For each dialogue line, you will:

*   determine what previous line it is replying to
*   determine the speaker, addressees, and side-participants

Here's how to determine the reply-to relationship between utterances to resolve conversational threads:

*   The reply-to structure gives us information about floor-claiming and topical change within the clip.
*   The character is saying this line because they want to respond to that previous line. What previous line is this current line replying to?
*   If the speaker of the last line is the same, you can treat it as continuation and put the index of last line as the reply-to.
*   When there is a noticeable change in topic and distribution of other participants' attention, and no previous line **triggers** this current line: write the current line index, indicating the current line replies to itself.

Here's how to determine each role:

*   **Speaker:**  The character who is speaking the line.  Infer this from lip movements, body language, and the context of the dialogue. If a character finishes one line and immediately starts another (very short pause), assume it's the same speaker, UNLESS there's a clear visual indication of a scene or speaker change (e.g., a camera cut to a different person starting to speak).
*   **Addressee(s):** The character(s) the speaker is *directly* addressing. Use these cues:
    *   **Eye Contact:** The most important cue. Who is the speaker looking at?
    *   **Body Orientation:** Is the speaker's body turned towards a particular person or group?
    *   **Dialogue Context:** Does the line contain a name, pronoun ("you"), or clearly refer to a specific individual or group?  ("Hey, John..." or "You all need to...")
    *   **Reactions:** If a character reacts immediately and strongly to a line (e.g., nods, responds verbally, shows surprise), they are likely an addressee.
    *   If the speaker seems to be talking to everyone present, list all characters who appear to be paying attention.
    *   If the speaker is talking to a crowd of unidentifiable characters, write "crowd".
    *   If the speaker is talking to themselves, or no one in particular, write "none".
*   **Side-Participant(s):**  Any character(s) visible in the scene *during the line's timeframe* who are *not* the speaker or addressees. They are present, and their presence is known to other participants. They can potentially join the conversation at any time.
    *   If it is not possible to confidently determine if someone is a side-participant, write "unknown".
    *   If there are no side-participants, write "none".


**Input:**

You will receive a list of subtitle entries.  Each entry will be a dictionary with the following keys:
*   `"line_index"`: (int) The index of the current entry (subtitle line).
*   `"start_time"`: (float) The start time of the subtitle line in seconds.
*   `"end_time"`: (float) The end time of the subtitle line in seconds.
*   `"text"`: (string) The text of the dialogue line.

You will also receive a list of potential participants for you to assign roles from. You must pick from this list.

With all this information, analyze the video segment corresponding to the `start_time` and `end_time` of each subtitle entry.

**Output:**

Provide your output in JSON format, mirroring the structure of the input.  For *each* subtitle entry, add the following keys:
*   `"line_index"`: (int) The line being analyzed.
*   `"reply_to"`: (int) The line index that this current line replies to, could be the same as the current line index or any previous line index.
*   `"speaker"`: (string) The name of the speaker. If you cannot determine the speaker, use "unknown".
*   `"addressees"`: (list of strings) A list of the names of the addressee(s).  This can be an empty list (`[]`) if there are no direct addressees, or `["none"]` if the speaker is speaking generally but to no one in particular.
*   `"side_participants"`: (list of strings) A list of the names of the side-participant(s). This can be an empty list (`[]`), `["none"]`, or `["unknown"]`.

This corresponds to the data model format:

```python
class ConversationalRoles(BaseModel):
    reply_to: int  # Index of the line being replied to
    speaker: str   # Speaker of the line, or "unknown"
    addressees: list[str]  # List of names, ["crowd"], ["none"], or ["unknown"]
    side_participants: list[str] # List of names, ["crowd"], ["none"], or ["unknown"]
    
class ClipRoles(BaseModel):
    clip_roles: list[ConversationalRoles]        
"""

In [None]:
prefix = 'PRED_gemini'

for show_id in tqdm(file_dict):
    for clip_id, path_dict in tqdm(file_dict[show_id].items(), total=len(file_dict[show_id])):
        pred_blob_name = f"{prefix}/{show_id}/{clip_id}.{MODEL_ID}.json"
        
        if blob_exists(bucket_name, pred_blob_name):
            continue

        prompt = """Analyze this video. Pay attention to the bounding boxes and character captions.
        
Subtitles:
"""            
        subtitle_blob_name = path_dict['subtitle_path']
        subtitle_data = read_json_from_gcs(bucket_name, subtitle_blob_name)
        subtitle_list = [{"line_index": line_idx, "start_time": item[0], "end_time": item[1], "text": item[2]} for line_idx, item in enumerate(subtitle_data)]
        for subtitle in subtitle_list:
            prompt += json.dumps(subtitle) + '\n'        

        multiple_choice_blob_name = path_dict['multiple_choice_path']
        multiple_choice_data = read_json_from_gcs(bucket_name, multiple_choice_blob_name)
        
        prompt += f"\n Select only from the following participants:"
        
        mc_folder_prefix = f"{show_id}/{clip_id}.multiple_choice/"
        mc_images = path_dict['multiple_choice_image_paths']

        candidate_parts = []
        for participant in multiple_choice_data:
            candidate_img = None
            participant_safe_char = participant.replace(" ", "_")
            # Look for a file whose name contains _{participant_safe_char}.jpg.
            for img_filename in mc_images:
                if f"_{participant_safe_char}.jpg" in img_filename:
                    candidate_img = f"gs://{bucket_name}/{img_filename}"
                    break
            # Only add an image part if a candidate image was found.
            candidate_parts.append(participant)            
            if candidate_img:
                candidate_parts.append(
                    Part.from_uri(file_uri=candidate_img, mime_type="image/jpeg")
                )     

        try:
            video_part = Part.from_uri(
                file_uri= f"gs://{bucket_name}/" + path_dict['video_path'],
                mime_type="video/mp4",
            )     
            
            response = client.models.generate_content(
                model=MODEL_ID,
                contents=[
                    prompt,                
                    candidate_parts,
                    video_part,
                ],
                config=GenerateContentConfig(
                    system_instruction=system_instruction,
                    response_mime_type="application/json",
                    response_schema=ClipRoles,        
                ),
            )       
            pred = response.text
            write_json_to_gcs(bucket_name, pred_blob_name, pred)
        except Exception as e:
            print(e)
            continue        