# Gemini Pro 1.5 Inference

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_frames/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:v2_s_dense_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        return pd.read_csv(data_file)

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        video_id = row['Video UID']
        question = row['Question']
        category = row['Category']
        answer = str(row.get('Answer_open', ''))

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
          '''
            question_str = (
                f"These are frames from a video that I want to upload. "
                f"Use the visual cues to answer the question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Note: All actions mentioned refer to the person recording the video."
            )


            (
            '''
           question_str =    f"""You are allowed to hallucinate!
We are doing a scientific experiment to check biases in text questions.
The user will provide you a question in text form and you need to try your best to answer.
Not answering is not an option.
If you make mistakes is not an issue just try your best!
Keep your answers very short and to the point.
Remember you don't need to be correct, in case you feel like you cannot answer just do a random guess or use common sense knowledge.
Don't excuse yourself or don't clarify that it's I guess, I know already!

"""
f"{question}")

            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    #nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    frames_dir = f'/x20/users/cp/cplizzari/uniform_sampling_temporal_v6/_{sampling_rate}/{uid}_{question}_{sampling_rate}_frames'



    image_paths = gfile.ListDir(frames_dir)

    if not gfile.Exists(frames_dir) or len(image_paths) == 0:
        print(f"Frames directory does not exist: {frames_dir}")
        return None


    # Initialize the prompt list
    prompt = [
        model_client.ContentChunk(
            value=gfile.Open(os.path.join(frames_dir, image_path), 'rb').read(),
            mimetype='image/jpg',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
        for image_path in sorted(image_paths)
    ]

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(Config.DATA_FILE, Config.QA_TYPE, Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


Running bulk inference for sampling rate: 1.5


Processing QA Items:  22%|██▏       | 119/550 [01:33<03:18,  2.17it/s]

Frames directory does not exist: /x20/users/cp/cplizzari/uniform_sampling_temporal_sampling_rate_v6/_1.5/nan_nan_1.5_frames
Saved 350 entries to /x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling/results_1.5.json

Processing QA Items:  63%|██████▎   | 346/550 [03:14<02:20,  1.45it/s]


Saved 400 entries to /x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling/results_1.5.json

Processing QA Items: 100%|██████████| 550/550 [05:14<00:00,  1.75it/s]


Frames directory does not exist: /x20/users/cp/cplizzari/uniform_sampling_temporal_sampling_rate_v6/_1.5/055f3cf1-1133-4260-b5a0-31e7ca1726a1_396.3912530165856_441.4351463667155.mp4_In which order does the person perform the following actions: pour water in the pan, rinsing the knife, stirring meat, eating a piece of bread. _1.5_frames
Frames directory does not exist: /x20/users/cp/cplizzari/uniform_sampling_temporal_sampling_rate_v6/_1.5/10341975-6612-4137-b0c2-703847ad4dba_364.778727881237_465.1785665187631_How many pans does the person interact with?_1.5_frames
Frames directory does not exist: /x20/users/cp/cplizzari/uniform_sampling_temporal_sampling_rate_v6/_1.5/10341975-6612-4137-b0c2-703847ad4dba_364.778727881237_465.1785665187631_How many times does the person open the tap?_1.5_frames
Final results saved to /x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling/results_1.5.json





# EgoTaskQA inference

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR =  '/cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen:///mbns/iz/home/courier/alessiot/chiara_cvpr:/lmroot:blade:aip-serving-alessiot-gemini_flash_s_2m'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1]
    MAX_WORKERS = 64  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        # Load the JSON file
        df = json.load(gfile.GFile(data_file, 'r'))
        return df

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations[index]
        video_id = row['interval']
        question = row['question']
        category = row['reasoning_type']
        answer = row['answer'].strip()

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
               f"""Question: {question}

We are doing a scientific experiment to check biases in questions.
The user has provided you a question in text form and one or more frames from a video you need to try your best to answer.
Not answering is not an option.
If you make mistakes is not an issue just try your best!
Keep your answers very short and to the point.
Remember to try to ground your answer on the provided frame(s). You should try to be correct, but in case you feel like you cannot answer just do your best or use common sense knowledge.
Don't excuse yourself or don't clarify that it's I guess, I know already!

""")
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries, list_values):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    # List of image file paths
    if f'{uid}_{question}' not in list_values:
      return None
    # List of image file paths
    frames_dir = f'/cns/lu-d/home/alessiot/ttl=30d/single_frame/'+uid

    image_paths = sorted(gfile.ListDir(frames_dir))

    if len(image_paths) == 0:
      return None
    # Initialize the prompt list
    prompt = []

    # Loop over the image file paths and create ContentChunk objects
    for image_path in image_paths:
      prompt.append(
        model_client.ContentChunk(
          value=gfile.Open(frames_dir + '/' + image_path, 'rb').read(),
          mimetype='image/jpg',
          substream_name='',
          metadata=model_client.Metadata(
            role=roles.ROLE_USER
          ),
        )
      )


    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    with gfile.Open("/x20/users/cp/cplizzari/EgoTaskQA/output_v_values_500.txt", "r") as file:
      list_values = [line.strip() for line in file]
    print(len(list_values))

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries, list_values): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset('/x20/users/cp/cplizzari/EgoTaskQA/data/qa/direct/test_qas.json', 'OpenQA')
        data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Adjust batch_size as needed

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )

if __name__ == "__main__":
    main()


Running bulk inference for sampling rate: 1
501


Processing QA Items:   6%|▌         | 518/8783 [01:08<41:52,  3.29it/s]

Saved 50 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json
Saved 100 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:   8%|▊         | 683/8783 [01:43<29:39,  4.55it/s]


Saved 150 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  10%|▉         | 866/8783 [02:18<42:45,  3.09it/s]


Saved 200 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  12%|█▏        | 1053/8783 [02:51<19:04,  6.76it/s]


Saved 250 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  14%|█▍        | 1211/8783 [03:25<24:07,  5.23it/s]


Saved 300 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  16%|█▌        | 1395/8783 [03:59<37:19,  3.30it/s]


Saved 350 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  18%|█▊        | 1571/8783 [04:34<19:56,  6.03it/s]


Saved 400 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items:  99%|█████████▉| 8721/8783 [05:01<00:00, 3475.56it/s]


Saved 450 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items: 100%|█████████▉| 8770/8783 [05:34<00:00, 75.01it/s]


Saved 500 entries to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json

Processing QA Items: 100%|██████████| 8783/8783 [05:43<00:00, 25.54it/s]


Final results saved to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/results_1.json





# EgoSchema Inference

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/EgoSchema_results_closeQA'
    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:goldfish_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'CloseQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


NUM_TO_LETTER = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}

def load_correct_answers(file_path):
    """
    Load correct answers from a JSON file and map the values from numbers to letters.
    """
    correct_answers = json.load(gfile.Open(file_path, 'r'))
    return {k: v for k, v in correct_answers.items()}

class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)
        self.correct_answers = load_correct_answers('/x20/users/cp/cplizzari/EgoSchema_annotations/subset_answers.json')
        self.filter_and_add_correct_answers()

    def load_annotations(self, data_file):
        # Load the JSON file
        df = json.load(gfile.GFile(data_file, 'r'))
        return df

    def filter_and_add_correct_answers(self):
        # Filter annotations to keep only those with a corresponding correct answer
        filtered_annotations = []
        for item in self.annotations:
            video_id = item['q_uid']
            if video_id in self.correct_answers:
                # Add correct answer to the annotation
                item['correct_answer'] = self.correct_answers[video_id]
                filtered_annotations.append(item)
        self.annotations = filtered_annotations

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type  # CloseQA, OpenQA, Mixed
        self.choice_indices = ['A', 'B', 'C', 'D', 'E']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations[index]
        video_id = row['q_uid']
        question = row['question']
        option_0 = row['option 0']
        option_1 = row['option 1']
        option_2 = row['option 2']
        option_3 = row['option 3']
        option_4 = row['option 4']
        answer = row['correct_answer']
        print(answer)


        qa_type = self.qa_type
        if qa_type == 'Mixed':  # randomly choose a QA type
            qa_type = random.choices(['CloseQA', 'OpenQA'], weights=[self.CloseQA_weight, self.openqa_weight], k=1)[0]
        if qa_type == 'OpenQA':
            question_str = f"{question}"
            answer_str = answer
        elif qa_type == 'CloseQA':

            choices = [option_0, option_1, option_2, option_3, option_4]
            answer = choices[answer]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices = [f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))]  # ["(A) xx", "(B) xx", "(C) xx", "(D) xx"]
            choices_str = ' '.join(choices)  # (A) xx (B) xx (C) xx (D) xx
            example_question = "What is 2 + 2?"
            example_choices = ["A. 3", "B. 4", "C. 5", "D. 6"]
            example_choices_str = ", ".join(example_choices)
            example_answer = "[B]"

            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                "Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                "Here is an example to illustrate the format: "
                f"Example Question: {example_question} Example Choices: {example_choices_str}. "
                f"Example Answer: {example_answer}."
                 f"You need to answer the question in any case and not demand additional context information. "
                f"Use you commonsense knowledge to be able to answer the question."
            )
            answer_str = choices[answer_index]  # (A/B/C/D) xx
            print('correct answer: ', answer_str)

        else:
            raise NotImplementedError

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
        }



class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch




def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    #nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None
    image_paths = gfile.ListDir('/x20/users/cp/cplizzari/frames_32/'+uid)
    if len(image_paths) == 0:
      return None

    # Initialize the prompt list
    prompt = [
        model_client.ContentChunk(
            value=gfile.Open(os.path.join(frames_dir, image_path), 'rb').read(),
            mimetype='image/jpg',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
        for image_path in sorted(image_paths)
    ]

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": 'all'
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result is not None:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset('/x20/users/cp/cplizzari/EgoSchema_annotations/questions.json', 'CloseQA')
        data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Adjust batch_size as needed

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


# Commonsense - Gemini

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_commonsense/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:v2_s_dense_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        return pd.read_csv(data_file)

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        video_id = row['Video UID']
        question = row['Question']
        category = row['Category']
        answer = str(row.get('Answer_open', ''))

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"Answer the following question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Use you commonsense knowledge to be able to answer the question."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    #nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None


    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(Config.DATA_FILE, Config.QA_TYPE, Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


Running bulk inference for sampling rate: 1.5


Processing QA Items:   9%|▊         | 48/550 [00:07<01:16,  6.60it/s]

Error generating stream for 03f0258c-8bd2-4545-8440-aa4a8afcfd31_659.0838040176359_665.5749393156977: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 0ece1a2a-2da0-4b28-ac8c-bf9f04c0d17c_1490.0790667735569_1528.1639409597753: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Retr

Processing QA Items:  18%|█▊        | 98/550 [01:13<02:02,  3.70it/s]


Error generating stream for 00c450a5-49d0-4647-a45c-bf214cea169d_1098.6108431421267_1117.7239668578732: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 12cc8ad3-5cee-457f-8657-1db8ed812558_-0.10553267045454545_34.46792267045454: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota)

Processing QA Items:  27%|██▋       | 148/550 [02:11<01:37,  4.13it/s]


Error generating stream for 176e0cc4-f5f3-48a4-9032-7c9c04e003f9_1198.4239571244834_1263.0419453028603: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 07b1c874-9dc1-42bc-87ff-dffa9bef14fb_1614.6537332841376_1665.3687700491955: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); R

Processing QA Items:  36%|███▌      | 198/550 [03:09<05:02,  1.16it/s]


Error generating stream for 0928af96-a3fd-4930-9bc1-2fa43d01bb53_0.3564398571421623_64.67324454285783: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 0928af96-a3fd-4930-9bc1-2fa43d01bb53_0.3564398571421623_64.67324454285783: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Ret

Processing QA Items:  45%|████▌     | 249/550 [04:09<19:09,  3.82s/it]


Error generating stream for 105d3303-8e2d-4c20-96ff-e9a8ff325109_464.2746674856418_561.7851980684015: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 01f5ee9f-09e9-4d3d-b682-aa64b6a57858_867.4916071698378_988.9360245020739: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Retry

Processing QA Items:  54%|█████▍    | 298/550 [04:16<00:59,  4.26it/s]


Error generating stream for 0a4dded4-0ee3-4c9d-b13e-f9f2fa85dfa3_71.22006531893004_109.61192468106997: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 2112c43b-7dd4-4331-be3b-94ec951b682d_1736.165753432497_1844.8491643008363: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Ret

Processing QA Items:  63%|██████▎   | 349/550 [05:13<00:30,  6.53it/s]


Error generating stream for 0a4dded4-0ee3-4c9d-b13e-f9f2fa85dfa3_71.22006531893004_109.61192468106997: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota) [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 0b38da1c-2b62-4c05-a68e-e8f2aafa47a8_1012.8982398930917_1044.9694501069082: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota) [extensible_stubs::OVERLOADED_

Processing QA Items:  73%|███████▎  | 399/550 [06:18<00:32,  4.70it/s]


Error generating stream for 06546c45-e0f8-4d75-a184-39e19552c6b7_156.02273494082317_177.16150505917682: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:2] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 115774b6-534d-444f-b7aa-d1b834eb0ee7_66.51224605407714_72.79878034592284: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota) [ex

Processing QA Items:  82%|████████▏ | 449/550 [07:16<00:20,  5.00it/s]


Error generating stream for 18ef9181-3da3-4217-8a98-1405e87db8fc_921.8039535064169_1024.3653275602499: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota) [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 002d2729-df71-438d-8396-5895b349e8fd_2035.4266694854234_2047.877591623903: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota) [extensible_stubs::OVERLOADED_T

Processing QA Items:  91%|█████████ | 500/550 [08:14<00:32,  1.55it/s]


Error generating stream for 10341975-6612-4137-b0c2-703847ad4dba_364.778727881237_465.1785665187631: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 07884569-3860-4a20-8c85-278eefbd678e_1148.4864739947864_1167.0297379638075: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Retr

Processing QA Items: 100%|██████████| 550/550 [09:02<00:00,  1.01it/s]


Error generating stream for 0be670d2-3216-4261-ab73-f9941b69e04c_1281.435662760334_1311.883297239666: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); RetryingStub: [attempts:3] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_REQUEST (2)]. Retrying in 30 seconds...
Error generating stream for 0be670d2-3216-4261-ab73-f9941b69e04c_987.6359027603339_1006.250337239666: A retryable error could not be retried due to too many retries per Extensible Stubs request (see go/xs-retries-per-request). (old status: generic::resource_exhausted: User cplizzari is out of quota for model evergreen://blade:gdm-aip-agent-generate-service-prod/lmroot:v2_s_dense_shared, retry in a few seconds, see go/lm-proxy-quota); Retry




# Commonsense - EgoSchema

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/EgoSchema_results_closeQA_commonsense_private'
    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:goldfish_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'CloseQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


NUM_TO_LETTER = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}

def load_correct_answers(file_path):
    """
    Load correct answers from a JSON file and map the values from numbers to letters.
    """
    correct_answers = json.load(gfile.Open(file_path, 'r'))
    return {k: v for k, v in correct_answers.items()}

class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)
        self.correct_answers = load_correct_answers('/x20/users/cp/cplizzari/EgoSchema_annotations/subset_answers.json')
        self.filter_and_add_correct_answers()

    def load_annotations(self, data_file):
        # Load the JSON file
        df = json.load(gfile.GFile(data_file, 'r'))
        return df

    def filter_and_add_correct_answers(self):
        # Filter annotations to keep only those with a corresponding correct answer
        filtered_annotations = []
        for item in self.annotations:
            video_id = item['q_uid']
            if video_id in self.correct_answers:
                # Add correct answer to the annotation
                item['correct_answer'] = self.correct_answers[video_id]
                filtered_annotations.append(item)
        self.annotations = filtered_annotations

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type  # CloseQA, OpenQA, Mixed
        self.choice_indices = ['A', 'B', 'C', 'D', 'E']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations[index]
        video_id = row['q_uid']
        question = row['question']
        option_0 = row['option 0']
        option_1 = row['option 1']
        option_2 = row['option 2']
        option_3 = row['option 3']
        option_4 = row['option 4']
        answer = row['correct_answer']
        print(answer)


        qa_type = self.qa_type
        if qa_type == 'Mixed':  # randomly choose a QA type
            qa_type = random.choices(['CloseQA', 'OpenQA'], weights=[self.CloseQA_weight, self.openqa_weight], k=1)[0]
        if qa_type == 'OpenQA':
            question_str = f"{question}"
            answer_str = answer
        elif qa_type == 'CloseQA':

            choices = [option_0, option_1, option_2, option_3, option_4]
            answer = choices[answer]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices = [f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))]  # ["(A) xx", "(B) xx", "(C) xx", "(D) xx"]
            choices_str = ' '.join(choices)  # (A) xx (B) xx (C) xx (D) xx
            example_question = "What is 2 + 2?"
            example_choices = ["A. 3", "B. 4", "C. 5", "D. 6"]
            example_choices_str = ", ".join(example_choices)
            example_answer = "[B]"

            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                "Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                "Here is an example to illustrate the format: "
                f"Example Question: {example_question} Example Choices: {example_choices_str}. "
                f"Example Answer: {example_answer}."
                 f"You need to answer the question in any case and not demand additional context information. "
                f"Use you commonsense knowledge to be able to answer the question."
            )
            answer_str = choices[answer_index]  # (A/B/C/D) xx
            print('correct answer: ', answer_str)

        else:
            raise NotImplementedError

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
        }



class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch




def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    #nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None
    image_paths = gfile.ListDir('/x20/users/cp/cplizzari/frames_32/'+uid)
    if len(image_paths) == 0:
      return None

    # Initialize the prompt list
    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": 'all'
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result is not None:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset('/x20/users/cp/cplizzari/EgoSchema_annotations/questions.json', 'CloseQA')
        data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Adjust batch_size as needed

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


Error generating stream for 026a2f15-c454-4c28-80e0-24c85d7f4ecf: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User is calling agent evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority too quickly.); RetryingStub: [attempts:2] [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 47f4c828-f238-459f-91c3-6b221db54c5b: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: throttling::THROTTLED_CLIENT: Request throttled at the client by AdaptiveThrottler. 525002922 { 3 { 1: "wiz-magi-servo" } }) [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Running bulk inference for sampling rate: 1.5
3
correct answer:  (E) C is cleaning dishes.
4
correct answer:  (B) To clean th

Processing QA Items: 100%|██████████| 500/500 [01:06<00:00,  7.52it/s]

Error generating stream for 057f8774-15c2-4e2e-b9fd-75f26d4b3b83: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User is calling agent evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority too quickly.) [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 049249dc-bdad-48c4-bdc0-511814c5781c: A load-shedding retryable throttled error could not be retried due to Extensible Stubs retrying limits (see go/stubs-retries). (old status: generic::resource_exhausted: User is calling agent evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority too quickly.) [extensible_stubs::OVERLOADED_TOO_MANY_RETRIES_PER_STUB (3)]. Retrying in 30 seconds...
Error generating stream for 05ad5736-88f5-42bb-ac9f-689e199c50de: A load-shedding retryable throttled error could not be retried due to Extensib




# Commonsense - EgoTaskQA

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_single_frame/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen:///mbns/iz/home/courier/alessiot/chiara_cvpr:/lmroot:blade:aip-serving-alessiot-gemini_flash_s_2m'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 64  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        # Load the JSON file
        df = json.load(gfile.GFile(data_file, 'r'))
        return df

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations[index]
        video_id = row['interval']
        question = row['question']
        category = row['reasoning_type']
        answer = row['answer'].strip()

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
               f"""You are allowed to hallucinate!
We are doing a scientific experiment to check biases in text questions.
The user will provide you a question in text form and you need to try your best to answer.
Not answering is not an option.
If you make mistakes is not an issue just try your best!
Keep your answers very short and to the point.
Remember you don't need to be correct, in case you feel like you cannot answer just do a random guess or use common sense knowledge.
Don't excuse yourself or don't clarify that it's I guess, I know already!

"""
f"{question}")
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries, list_values):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    # List of image file paths
    if f'{uid}_{question}' not in list_values:
      return None
    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    with gfile.Open("/x20/users/cp/cplizzari/EgoTaskQA/output_v_values_500.txt", "r") as file:
      list_values = [line.strip() for line in file]
    print(len(list_values))

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries, list_values): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset('/x20/users/cp/cplizzari/EgoTaskQA/data/qa/direct/test_qas.json', 'OpenQA')
        data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Adjust batch_size as needed

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )

if __name__ == "__main__":
    main()


Running bulk inference for sampling rate: 1.5
501


Processing QA Items: 100%|██████████| 8783/8783 [00:00<00:00, 131985.86it/s]

Final results saved to /cns/lu-d/home/alessiot/ttl=30d/uniform_egotaskQA_commonsense_flash_hallucinate/results_1.5.json





# Commonsense - EgoThink

In [None]:
import os
import glob
import pandas as pd
from google3.pyglib import gfile

class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/EgoThink_commonsense/'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen:///mbns/iz/home/courier/alessiot/chiara_cvpr/:lmroot:blade:aip-serving-alessiot-gemini_mpp_2m'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'

class BaseDataset:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.annotations = self.load_annotations(data_dir)

    def load_annotations(self, data_dir):
        # Search for all .parquet files in subdirectories
        parquet_files = []

        # List the directories in the provided data_dir
        for subdir in gfile.ListDir(data_dir):
            subdir_path = os.path.join(data_dir, subdir)

            # Check if the subdir_path is indeed a directory
            if gfile.IsDirectory(subdir_path):
                # List files inside the subdirectory
                for file in gfile.ListDir(subdir_path):
                    # Check if the file ends with '.parquet'
                    if file.endswith('.parquet'):
                        parquet_files.append(os.path.join(subdir_path, file))

        # Read each parquet file and concatenate the data
        data_frames = []
        for file in parquet_files:
            try:
                # Read the parquet file
                df = pd.read_parquet(file)

                # Extract the directory name (subdir) as 'category'
                category = os.path.basename(os.path.dirname(file))

                # Add 'category' column to the DataFrame
                df['category'] = category

                # Append the DataFrame to the list
                data_frames.append(df)
            except Exception as e:
                print(f"Error reading {file}: {e}")

        # Concatenate all the DataFrames into one
        full_df = pd.concat(data_frames, ignore_index=True)
        print(full_df.keys())
        print(f"Total rows after concatenation: {len(full_df)}")
        return full_df


    def __len__(self):
        return len(self.annotations)

class QADataset(BaseDataset):
    def __init__(self, data_dir, qa_type, CloseQA_weight=50):
        super().__init__(data_dir)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        question = row['question']
        category = row['category']
        answer = row['answer']

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"Answer the following question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Use you commonsense knowledge to be able to answer the question."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }

def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries, e):
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']
    print(question_answer)
    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                break  # Exit loop if successful
            except Exception as e:
                print(f"Error generating stream for: {e}. Retrying in 30 seconds...")
                time.sleep(30)
    except Exception as e:
        print(f"Error processing: {e}")
        return None

    return {
        "V": e,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=4):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                model_response = existing_data
                existing_entries = {(entry["V"], entry["Q"]) for entry in existing_data}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries, idx): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 100 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")


def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(data_dir='/x20/users/cp/cplizzari/EgoThink/EgoThink/', qa_type=Config.QA_TYPE, CloseQA_weight=Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )

if __name__ == "__main__":
    main()


ValueError: No objects to concatenate

In [None]:
import os
import glob
import pandas as pd
from google3.pyglib import gfile

class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/EgoThink_commonsense/'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:goldfish_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.5]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'

class BaseDataset:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.annotations = self.load_annotations(data_dir)

    def load_annotations(self, data_dir):
        # Search for all .parquet files in subdirectories
        parquet_files = []

        # List the directories in the provided data_dir
        for subdir in gfile.ListDir(data_dir):
            subdir_path = os.path.join(data_dir, subdir)

            # Check if the subdir_path is indeed a directory
            if gfile.IsDirectory(subdir_path):
                # List files inside the subdirectory
                for file in gfile.ListDir(subdir_path):
                    # Check if the file ends with '.parquet'
                    if file.endswith('.parquet'):
                        parquet_files.append(os.path.join(subdir_path, file))

        # Read each parquet file and concatenate the data
        data_frames = []
        for file in parquet_files:
            try:
                # Read the parquet file
                df = pd.read_parquet(file)

                # Extract the directory name (subdir) as 'category'
                category = os.path.basename(os.path.dirname(file))

                # Add 'category' column to the DataFrame
                df['category'] = category

                # Append the DataFrame to the list
                data_frames.append(df)
            except Exception as e:
                print(f"Error reading {file}: {e}")

        # Concatenate all the DataFrames into one
        full_df = pd.concat(data_frames, ignore_index=True)
        print(full_df.keys())
        print(f"Total rows after concatenation: {len(full_df)}")
        return full_df


    def __len__(self):
        return len(self.annotations)

class QADataset(BaseDataset):
    def __init__(self, data_dir, qa_type, CloseQA_weight=50):
        super().__init__(data_dir)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        question = row['question']
        category = row['category']
        answer = row['answer']

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"Answer the following question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Use you commonsense knowledge to be able to answer the question."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }

def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries, e):
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']
    print(question_answer)
    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                break  # Exit loop if successful
            except Exception as e:
                print(f"Error generating stream for: {e}. Retrying in 30 seconds...")
                time.sleep(30)
    except Exception as e:
        print(f"Error processing: {e}")
        return None

    return {
        "V": e,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=4):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                model_response = existing_data
                existing_entries = {(entry["V"], entry["Q"]) for entry in existing_data}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries, idx): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 100 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")


def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(data_dir='/x20/users/cp/cplizzari/EgoThink/EgoThink/', qa_type=Config.QA_TYPE, CloseQA_weight=Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )

if __name__ == "__main__":
    main()


ValueError: No objects to concatenate

# Only Text - Gemini


In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    # MODEL_URL = 'evergreen2:///mbns/iz/home/courier/alessiot/chiara_cvpr/lmroot:v2_s_dense_shared'
    # MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:goldfish_shared'
    MODEL_URL = 'evergreen:///mbns/iz/home/courier/alessiot/chiara_cvpr:/lmroot:blade:aip-serving-alessiot-gemini_mpp_2m'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [1.0]
    MAX_WORKERS = 1  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        return pd.read_csv(data_file)

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        video_id = row['Video UID']
        question = row['Question']
        category = row['Category']
        answer = str(row.get('Answer_open', ''))

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"These are frames from a video that I want to upload. "
                f"Use the visual cues to answer the question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        #print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    prompt = []

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                if text == '':
                  return None
                break  # Exit loop if successful

            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(10)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=1):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                # Filter out entries with empty answers
                model_response = [entry for entry in existing_data if entry["A"]!=""]
                existing_entries = {(entry["V"], entry["Q"]) for entry in model_response}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 50 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(Config.DATA_FILE, Config.QA_TYPE, Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


Running bulk inference for sampling rate: 1.0


FileError: couldn't open file '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/results_1.0.json' with mode 'r'. error: '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/results_1.0.json: No read permission for user alessiot/alessiot, request_info:  on file/symlink: /users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/results_1.0.json [PERMISSION_DENIED]'., status=/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/results_1.0.json: No read permission for user alessiot/alessiot, request_info:  on file/symlink: /users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_text_only/results_1.0.json [PERMISSION_DENIED]

# Evaluate results

In [None]:
import json
import ast
from IPython import display
import re
import argparse
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
import os

# Client configuration for bulk processing
client = model_client.ModelClient(
    model_url='evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:goldfish_shared',
    default_config=model_client.make_generation_config(
        seed=0,
        formatting_options=model_client.FormattingOptions(enable_formatting=True),
        token_generation=model_client.make_token_generation_config(
            sampling_config=model_client.make_sampling_config(
                temperature=0,
                nucleus_top_p=0.95,
            ),
            length=8192,
        ),
    ),
)

def parse_args():
    parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
    parser.add_argument("--pred_paths", required=True, help="Comma-separated list of paths to prediction files.")
    parser.add_argument("--cvrr_dataset_path", required=True, help="Folder path to CVRR-ES dataset.")
    parser.add_argument("--output_dir", required=True, help="The path to save prediction json files.")
    parser.add_argument("--api_key", required=True, help="OpenAI API key.")
    return parser.parse_args()

def generate_response(single_dict):
    """
    Generates a response for a single QA pair using the model client.
    """
    response = model_client.Content()
    question = single_dict['q']
    answer = single_dict['a']
    pred = single_dict['pred']
    video = single_dict['v']

    prompt = [
        model_client.ContentChunk(
            value=f'''[
            {{
                "role": "system",
                "content":
                    "You are an intelligent chatbot designed for evaluating the correctness of AI assistant predictions for question-answer pairs. "
                    "Your task is to compare the predicted answer with the ground-truth answer and determine if the predicted answer is correct or not. Here's how you can accomplish the task:"
                    "------"
                    "##INSTRUCTIONS: "
                    "- Focus on the correctness and accuracy of the predicted answer with the ground-truth.\n"
                    "- Consider uncertain predictions, such as 'it is impossible to answer the question from the video', as incorrect, unless the ground truth answer also says that.\n"
            }},
            {{
                "role": "user",
                "content":
                    "Please evaluate the following video-based question-answer pair:\n\n"
                    f"Question: {question}\n"
                    f"Ground truth correct Answer: {answer}\n"
                    f"Predicted Answer: {pred}\n\n"
                    "Provide your evaluation as a correct/incorrect prediction along with the score where the score is an integer value between 0 (fully wrong) and 5 (fully correct). The middle score provides the percentage of correctness."
                    "Please generate the response in the form of a Python dictionary string with keys 'pred', 'score' and 'reason', where value of 'pred' is a string of 'correct' or 'incorrect', value of 'score' is in INTEGER, not STRING and value of 'reason' should provide the reason behind the decision."
                    "Only provide the Python dictionary string."
            }}
        ]''',
            mimetype='text/plain',
            substream_name='',
            metadata=model_client.Metadata(
                role=roles.ROLE_USER
            ),
        )
    ]

    completion = ''
    try:
      while True:
        try:
          for content in client.generate_stream(prompt):
              completion += content.as_text()
              response += content
              display.clear_output()
              #display.display(response)
          break
        except:
            time.sleep(30)

      match = re.search(r'\{.*?\}', completion)
      if match:
          result = match.group(0)
          response_dict = ast.literal_eval(result)
          return response_dict
    except Exception as stream_error:
        print(f"Error in generating stream: {stream_error}")
        return None

def annotate_bulk(prediction_set, output_dir, pred_path):
    """
    Evaluates question-answer pairs using bulk inference and multithreading.
    """
    # Determine the output path by removing the .json extension and using the same directory
    base_filename = os.path.basename(pred_path).replace(".json", "")
    output_path = os.path.join(output_dir, os.path.dirname(pred_path).replace("/benchmark/", "/eval_results/"), base_filename)
    print(f"Output path: {output_path}")

    # Ensure the output directory exists
    if not gfile.Exists(output_path):
        gfile.MakeDirs(output_path)

    save_path = f"{output_path}/eval_results.json"
    result_qa_pair = []

    with ThreadPoolExecutor(max_workers=10) as executor:
        responses = list(tqdm(executor.map(generate_response, prediction_set), total=len(prediction_set)))

    for e, response_dict in enumerate(responses):
        if response_dict:
            response_prediction = response_dict["pred"]
            if response_prediction == 'correct':
                final_prediction = {'pred': 'correct', 'score': 5, 'reason': response_dict['reason']}
            else:
                final_prediction = {'pred': 'incorrect', 'score': response_dict['score'], 'reason': response_dict['reason']}

            result_qa_pair.append([final_prediction, prediction_set[e]])

    with gfile.GFile(save_path, "w") as f:
        json.dump(result_qa_pair, f)
def annotate_bulk(prediction_set, output_dir, pred_path):
    """
    Evaluates question-answer pairs using bulk inference and multithreading, skipping already processed pairs.
    """
    # Determine the output path by removing the .json extension and using the same directory
    base_filename = os.path.basename(pred_path).replace(".json", "")
    output_path = os.path.join(output_dir, os.path.dirname(pred_path).replace("/benchmark/", "/eval_results/"), base_filename)
    print(f"Output path: {output_path}")

    # Ensure the output directory exists
    if not gfile.Exists(output_path):
        gfile.MakeDirs(output_path)

    save_path = f"{output_path}/eval_results.json"
    result_qa_pair = []

    # Load existing results if available
    processed_questions = set()
    if gfile.Exists(save_path):
        with gfile.GFile(save_path, "r") as f:
            existing_results = json.load(f)
            for result, original_qa in existing_results:
              # Check that "pred" is not empty before processing
              if result.get("pred")!="":
                  # Use a tuple of (video, question) as a unique identifier to avoid duplicates
                  processed_questions.add((original_qa['v'], original_qa['q']))
            result_qa_pair.extend(existing_results)

    # Filter out entries that are already processed based on the (video, question) tuple
    to_process = [qa for qa in prediction_set if (qa['v'], qa['q']) not in processed_questions]
    print(f"Evaluating {len(to_process)} new entries (skipping {len(processed_questions)} already processed).")

    # Proceed with inference only for new entries
    with ThreadPoolExecutor(max_workers=5) as executor:
        responses = list(tqdm(executor.map(generate_response, to_process), total=len(to_process)))

    for e, response_dict in enumerate(responses):
        if response_dict:
            response_prediction = response_dict["pred"]
            if response_prediction == 'correct':
                final_prediction = {'pred': 'correct', 'score': 5, 'reason': response_dict['reason']}
            else:
                final_prediction = {'pred': 'incorrect', 'score': response_dict['score'], 'reason': response_dict['reason']}

            result_qa_pair.append([final_prediction, to_process[e]])

    # Save the combined results (old + new)
    with gfile.GFile(save_path, "w") as f:
        json.dump(result_qa_pair, f)
    print(f"Results saved to {save_path}")

def main(pred_paths, cvrr_dataset_path, output_dir, api_key):
    for pred_path in pred_paths:
        prediction_set = {}
        my_list = []

        with gfile.GFile(pred_path, "r") as f:
            json_pred = json.load(f)
        for dict_pred in json_pred:
            question = dict_pred['Q']
            pred = dict_pred['A']
            answer = dict_pred['C']
            cat = dict_pred['M']
            video = dict_pred['V']
            qa_set = {"v": video, "q": question, "a": answer, "pred": pred, "cat": cat}
            my_list.append(qa_set)

        prediction_set['eval_results'] = my_list
        sum_qa_pairs = len(my_list)
        print(f"Total number of QA pairs for {pred_path}: {sum_qa_pairs}")

        annotate_bulk(my_list, output_dir, pred_path)
        print(f"Evaluation for {pred_path} completed!")

# Replace these with your actual values
pred_paths = [
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_shuffled/results_0.1.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_shuffled/results_1.0.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_shuffled/results_0.5.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_random/results_0.1.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_random/results_1.0.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_random/results_0.5.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/gemini_flash_frames_prompt2_fixed/results_4.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/gemini_flash_frames_prompt2_fixed/results_8.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/gemini_flash_frames_prompt2_fixed/results_16.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/results_internlm_frames/results_32.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/results_internlm_frames/results_8.json',
    '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_pro_sampling/results_1.0.json'

]
cvrr_dataset_path = '.'  # Specify the path to your CVRR dataset
output_dir = '/x20/users/cp/cplizzari/eval_results/benchmark_v7'  # Base directory for saving results
api_key = 'sk-your_api_key_here'  # Replace with your actual API key

main(pred_paths, cvrr_dataset_path, output_dir, api_key)

# Evaluate - Human Eval


In [None]:
import json
import ast
from IPython import display
import re
import argparse
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
import os

# Client configuration for bulk processing
client = model_client.ModelClient(
    model_url='evergreen://blade:gdm-aip-fastpath-agent-generate-service-prod/lmroot:goldfish_shared',
    default_config=model_client.make_generation_config(
        seed=0,
        formatting_options=model_client.FormattingOptions(enable_formatting=True),
        token_generation=model_client.make_token_generation_config(
            sampling_config=model_client.make_sampling_config(
                temperature=0,
                nucleus_top_p=0.95,
            ),
            length=8192,
        ),
    ),
)

def parse_args():
    parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
    parser.add_argument("--pred_paths", required=True, help="Comma-separated list of paths to prediction files.")
    parser.add_argument("--cvrr_dataset_path", required=True, help="Folder path to CVRR-ES dataset.")
    parser.add_argument("--output_dir", required=True, help="The path to save prediction json files.")
    parser.add_argument("--api_key", required=True, help="OpenAI API key.")
    return parser.parse_args()

def generate_response(single_dict):
    """
    Generates a response for a single QA pair using the model client.
    """
    response = model_client.Content()
    question = single_dict['q']
    answer = single_dict['a']
    pred = single_dict['pred']
    video = single_dict['v']

    prompt = [
        model_client.ContentChunk(
            value=f'''[
            {{
                "role": "system",
                "content":
                    "You are an intelligent chatbot designed for evaluating the correctness of AI assistant predictions for question-answer pairs. "
                    "Your task is to compare the predicted answer with the ground-truth answer and determine if the predicted answer is correct or not. Here's how you can accomplish the task:"
                    "------"
                    "##INSTRUCTIONS: "
                    "- Focus on the correctness and accuracy of the predicted answer with the ground-truth.\n"
                    "- Consider uncertain predictions, such as 'it is impossible to answer the question from the video', as incorrect, unless the ground truth answer also says that.\n"
            }},
            {{
                "role": "user",
                "content":
                    "Please evaluate the following video-based question-answer pair:\n\n"
                    f"Question: {question}\n"
                    f"Ground truth correct Answer: {answer}\n"
                    f"Predicted Answer: {pred}\n\n"
                    "Provide your evaluation as a correct/incorrect prediction along with the score where the score is an integer value between 0 (fully wrong) and 5 (fully correct). The middle score provides the percentage of correctness."
                    "Please generate the response in the form of a Python dictionary string with keys 'pred', 'score' and 'reason', where value of 'pred' is a string of 'correct' or 'incorrect', value of 'score' is in INTEGER, not STRING and value of 'reason' should provide the reason behind the decision."
                    "Only provide the Python dictionary string."
            }}
        ]''',
            mimetype='text/plain',
            substream_name='',
            metadata=model_client.Metadata(
                role=roles.ROLE_USER
            ),
        )
    ]

    completion = ''
    try:
      while True:
        try:
          for content in client.generate_stream(prompt):
              completion += content.as_text()
              response += content
              #display.clear_output()
              #display.display(response)
          break
        except:
            time.sleep(30)

      match = re.search(r'\{.*?\}', completion)
      if match:
          result = match.group(0)
          response_dict = ast.literal_eval(result)
          return response_dict

    except Exception as stream_error:
        print(f"Error in generating stream: {stream_error}")
        return None


def annotate_bulk(prediction_set, output_dir, pred_path):
    """
    Evaluates question-answer pairs using bulk inference and multithreading, skipping already processed pairs.
    """
    # Determine the output path by removing the .json extension and using the same directory
    base_filename = os.path.basename(pred_path).replace(".csv", "")
    output_path = os.path.join(output_dir, os.path.dirname(pred_path).replace("/benchmark/", "/eval_results/"), base_filename)
    print(f"Output path: {output_path}")

    # Ensure the output directory exists
    if not gfile.Exists(output_path):
        gfile.MakeDirs(output_path)

    save_path = f"{output_path}/eval_results.json"
    result_qa_pair = []

    # Load existing results if available
    processed_questions = set()
    if gfile.Exists(save_path):
        with gfile.GFile(save_path, "r") as f:
            existing_results = json.load(f)
            for result, original_qa in existing_results:
              # Check that "pred" is not empty before processing
              if result.get("pred")!="":
                  # Use a tuple of (video, question) as a unique identifier to avoid duplicates
                  processed_questions.add((original_qa['v'], original_qa['q']))
            result_qa_pair.extend(existing_results)

    # Filter out entries that are already processed based on the (video, question) tuple
    to_process = [qa for qa in prediction_set if (qa['v'], qa['q']) not in processed_questions]
    print(f"Evaluating {len(to_process)} new entries (skipping {len(processed_questions)} already processed).")

    # Proceed with inference only for new entries
    with ThreadPoolExecutor(max_workers=5) as executor:
        for e, response_dict in enumerate(tqdm(executor.map(generate_response, to_process), total=len(to_process))):
            if response_dict:
                response_prediction = response_dict["pred"]
                if response_prediction == 'correct':
                    final_prediction = {'pred': 'correct', 'score': 5, 'reason': response_dict['reason']}
                else:
                    final_prediction = {'pred': 'incorrect', 'score': response_dict['score'], 'reason': response_dict['reason']}

                # Append the processed result immediately
                result_qa_pair.append([final_prediction, to_process[e]])
                print(len(result_qa_pair))
                # Save the updated results after every response
                if e % 50 == 0:
                  with gfile.GFile(save_path, "w") as f:
                      json.dump(result_qa_pair, f)
                  print(f"Progress saved to {save_path} at index {e + 1}")
    # Save the combined results (old + new)
    with gfile.GFile(save_path, "w") as f:
        json.dump(result_qa_pair, f)
    print(f"Results saved to {save_path}")
    return result_qa_pair
import csv
def main(pred_paths, cvrr_dataset_path, output_dir, api_key):
    prediction_set = {}
    my_list = []

    # Read the CSV file
    with gfile.Open(pred_paths, mode="r") as csv_file:
        csv_reader = csv.DictReader(csv_file)

        for row in csv_reader:
            # Extract required columns
            video = row['Video UID']
            question = row['Question']
            answer = row['Answer_open']
            pred = row['Response']
            cat = row['Category']

            # Build QA set dictionary
            qa_set = {"v": video, "q": question, "a": answer, "pred": pred, "cat": cat}
            my_list.append(qa_set)

    prediction_set['eval_results'] = my_list
    sum_qa_pairs = len(my_list)
    print(f"Total number of QA pairs from {pred_paths}: {sum_qa_pairs}")

    # Example annotation function call
    result_qa_pair = annotate_bulk(my_list, output_dir, pred_paths)
    print(f"Evaluation for {pred_paths} completed!")
    return result_qa_pair
# Replace these with your actual values
pred_paths ='/x20/users/cp/cplizzari/benchmark/benchmark_v7/human_eval_questions.csv'


cvrr_dataset_path = '.'  # Specify the path to your CVRR dataset
output_dir = '/x20/users/cp/cplizzari/eval_results/benchmark_v7'  # Base directory for saving results
api_key = 'sk-your_api_key_here'  # Replace with your actual API key

result_qa_pair = main(pred_paths, cvrr_dataset_path, output_dir, api_key)

Total number of QA pairs from /x20/users/cp/cplizzari/benchmark/benchmark_v7/human_eval_questions.csv: 549
Output path: /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions
Evaluating 548 new entries (skipping 1 already processed).


  0%|          | 0/548 [00:00<?, ?it/s]

2


  0%|          | 1/548 [00:03<36:01,  3.95s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 1


  0%|          | 2/548 [00:07<32:17,  3.55s/it]

3
4
5
6
7
8


  1%|▏         | 8/548 [00:15<15:06,  1.68s/it]

9
10
11
12
13
14


  3%|▎         | 14/548 [00:15<07:30,  1.18it/s]

15
16
17
18
19
20


  4%|▎         | 20/548 [00:16<04:40,  1.88it/s]

21


  4%|▍         | 21/548 [00:17<05:03,  1.73it/s]

22


  4%|▍         | 22/548 [00:18<04:53,  1.79it/s]

23


  4%|▍         | 23/548 [00:25<13:48,  1.58s/it]

24


  4%|▍         | 24/548 [00:53<53:02,  6.07s/it]

25


  5%|▍         | 26/548 [00:54<33:56,  3.90s/it]

26
27
28


  5%|▌         | 28/548 [00:58<28:49,  3.33s/it]

29
30
31


  6%|▌         | 31/548 [01:01<18:30,  2.15s/it]

32
33
34


  6%|▌         | 34/548 [01:01<11:35,  1.35s/it]

35
36
37
38


  7%|▋         | 38/548 [01:04<09:39,  1.14s/it]

39


  7%|▋         | 39/548 [01:07<10:56,  1.29s/it]

40


  7%|▋         | 40/548 [01:08<10:42,  1.26s/it]

41
42
43
44
45
46


  8%|▊         | 46/548 [01:10<06:19,  1.32it/s]

47


  9%|▊         | 47/548 [01:39<34:18,  4.11s/it]

48


  9%|▉         | 48/548 [01:39<29:36,  3.55s/it]

49


  9%|▉         | 49/548 [01:39<25:16,  3.04s/it]

50
51
52


  9%|▉         | 51/548 [01:44<22:41,  2.74s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 51
53
54
55
56
57
58
59
60


 11%|█         | 60/548 [01:44<07:19,  1.11it/s]

61
62
63


 11%|█▏        | 63/548 [01:47<07:29,  1.08it/s]

64
65


 12%|█▏        | 65/548 [01:48<06:47,  1.18it/s]

66
67
68
69


 13%|█▎        | 69/548 [01:49<05:14,  1.52it/s]

70


 13%|█▎        | 70/548 [02:50<58:24,  7.33s/it]

71


 13%|█▎        | 71/548 [02:51<51:32,  6.48s/it]

72
73


 13%|█▎        | 73/548 [02:55<41:01,  5.18s/it]

74
75
76


 14%|█▍        | 76/548 [02:55<25:52,  3.29s/it]

77


 14%|█▍        | 77/548 [02:56<22:19,  2.84s/it]

78


 14%|█▍        | 78/548 [02:56<19:14,  2.46s/it]

79


 14%|█▍        | 79/548 [02:57<15:49,  2.02s/it]

80
81


 15%|█▍        | 81/548 [02:59<12:56,  1.66s/it]

82


 15%|█▍        | 82/548 [03:01<14:02,  1.81s/it]

83


 15%|█▌        | 83/548 [03:03<14:13,  1.84s/it]

84
85


 16%|█▌        | 85/548 [03:03<09:11,  1.19s/it]

86
87


 16%|█▌        | 87/548 [03:04<06:34,  1.17it/s]

88


 16%|█▌        | 88/548 [03:04<05:32,  1.38it/s]

89


 16%|█▌        | 89/548 [03:05<05:10,  1.48it/s]

90


 17%|█▋        | 91/548 [03:11<11:55,  1.56s/it]

91
92
93


 17%|█▋        | 93/548 [03:39<51:06,  6.74s/it]

94
95
96


 18%|█▊        | 96/548 [03:44<32:29,  4.31s/it]

97


 18%|█▊        | 97/548 [03:45<28:12,  3.75s/it]

98
99
100
101
102


 20%|█▉        | 108/548 [03:49<07:34,  1.03s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 101
103
104
105
106
107
108
109
110


 20%|██        | 110/548 [03:52<08:20,  1.14s/it]

111


 20%|██        | 111/548 [03:57<11:01,  1.51s/it]

112
113
114


 21%|██        | 114/548 [03:59<09:27,  1.31s/it]

115
116


 21%|██        | 116/548 [05:05<1:04:02,  8.90s/it]

117
118


 22%|██▏       | 118/548 [05:06<49:02,  6.84s/it]  

119
120


 22%|██▏       | 120/548 [05:08<37:42,  5.29s/it]

121
122
123


 22%|██▏       | 123/548 [05:11<26:01,  3.67s/it]

124


 23%|██▎       | 124/548 [05:13<23:42,  3.36s/it]

125


 23%|██▎       | 125/548 [05:13<20:08,  2.86s/it]

126
127
128
129


 24%|██▎       | 129/548 [05:22<17:44,  2.54s/it]

130
131
132
133
134


 24%|██▍       | 134/548 [05:25<11:32,  1.67s/it]

135


 25%|██▍       | 135/548 [05:26<10:20,  1.50s/it]

136
137
138


 25%|██▌       | 138/548 [05:26<07:08,  1.05s/it]

139


 25%|██▌       | 139/548 [06:21<57:41,  8.46s/it]

140


 26%|██▌       | 140/548 [07:26<2:01:35, 17.88s/it]

141
142
143
144
145
146
147
148
149
150
151
152


 28%|██▊       | 151/548 [07:32<35:22,  5.35s/it]  

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 151
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180


 33%|███▎      | 180/548 [07:36<09:25,  1.54s/it]

181
182


 33%|███▎      | 182/548 [07:36<08:48,  1.44s/it]

183
184


 34%|███▎      | 184/548 [08:05<15:34,  2.57s/it]

185
186


 34%|███▍      | 186/548 [08:10<15:35,  2.58s/it]

187


 35%|███▍      | 191/548 [08:11<10:26,  1.76s/it]

188
189
190
191


 35%|███▌      | 192/548 [08:12<10:02,  1.69s/it]

192


 35%|███▌      | 193/548 [08:12<09:01,  1.53s/it]

193


 35%|███▌      | 194/548 [08:12<07:54,  1.34s/it]

194
195
196


 36%|███▌      | 197/548 [08:13<05:17,  1.11it/s]

197


 36%|███▌      | 198/548 [08:14<05:38,  1.03it/s]

198


 36%|███▋      | 199/548 [08:16<05:50,  1.00s/it]

199
200
201


 37%|███▋      | 201/548 [08:23<11:08,  1.93s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 201
202
203
204


 37%|███▋      | 205/548 [08:23<05:47,  1.01s/it]

205
206


 38%|███▊      | 207/548 [09:20<46:05,  8.11s/it]

207
208
209


 38%|███▊      | 210/548 [09:20<29:35,  5.25s/it]

210


 39%|███▊      | 211/548 [09:21<26:07,  4.65s/it]

211


 39%|███▊      | 212/548 [09:27<26:46,  4.78s/it]

212
213
214
215
216
217
218
219
220


 40%|████      | 221/548 [09:27<08:11,  1.50s/it]

221
222


 41%|████      | 223/548 [09:28<07:21,  1.36s/it]

223
224


 41%|████▏     | 227/548 [09:29<04:58,  1.08it/s]

225
226
227
228


 42%|████▏     | 229/548 [09:30<04:10,  1.27it/s]

229


 42%|████▏     | 230/548 [10:03<29:28,  5.56s/it]

230
231


 42%|████▏     | 232/548 [10:04<21:43,  4.13s/it]

232
233


 43%|████▎     | 234/548 [10:05<15:32,  2.97s/it]

234


 43%|████▎     | 235/548 [10:07<14:42,  2.82s/it]

235
236
237
238


 44%|████▎     | 239/548 [10:08<07:36,  1.48s/it]

239


 44%|████▍     | 240/548 [10:10<07:53,  1.54s/it]

240


 44%|████▍     | 241/548 [10:10<07:01,  1.37s/it]

241
242
243
244


 45%|████▌     | 247/548 [10:11<03:09,  1.59it/s]

245
246
247
248


 45%|████▌     | 249/548 [10:12<02:59,  1.67it/s]

249


 46%|████▌     | 250/548 [10:14<03:49,  1.30it/s]

250
251


 46%|████▌     | 251/548 [10:21<09:35,  1.94s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 251
252
253


 46%|████▋     | 254/548 [11:16<45:28,  9.28s/it]

254


 47%|████▋     | 255/548 [11:18<38:40,  7.92s/it]

255
256
257
258


 47%|████▋     | 259/548 [11:19<19:39,  4.08s/it]

259


 47%|████▋     | 260/548 [11:19<16:48,  3.50s/it]

260


 48%|████▊     | 261/548 [11:19<14:16,  2.99s/it]

261


 48%|████▊     | 262/548 [11:21<13:04,  2.74s/it]

262
263
264


 48%|████▊     | 265/548 [11:22<07:45,  1.64s/it]

265
266
267


 49%|████▉     | 269/548 [11:25<05:00,  1.08s/it]

268
269


 49%|████▉     | 271/548 [11:25<03:36,  1.28it/s]

270
271


 50%|████▉     | 272/548 [11:25<03:00,  1.53it/s]

272


 50%|████▉     | 273/548 [11:27<04:04,  1.12it/s]

273


 50%|█████     | 274/548 [11:28<03:46,  1.21it/s]

274
275
276


 51%|█████     | 277/548 [11:59<26:56,  5.97s/it]

277


 51%|█████     | 278/548 [11:59<21:39,  4.81s/it]

278


 51%|█████     | 279/548 [12:00<17:43,  3.95s/it]

279
280
281


 51%|█████▏    | 282/548 [12:01<09:19,  2.10s/it]

282


 52%|█████▏    | 283/548 [12:02<08:39,  1.96s/it]

283
284
285


 52%|█████▏    | 286/548 [12:03<05:29,  1.26s/it]

286
287


 53%|█████▎    | 288/548 [12:04<04:19,  1.00it/s]

288
289


 53%|█████▎    | 291/548 [12:05<02:51,  1.50it/s]

290


 53%|█████▎    | 292/548 [12:05<02:37,  1.62it/s]

291


 53%|█████▎    | 293/548 [12:06<02:38,  1.61it/s]

292
293


 54%|█████▍    | 295/548 [12:06<01:56,  2.17it/s]

294


 54%|█████▍    | 296/548 [12:08<02:57,  1.42it/s]

295


 55%|█████▍    | 299/548 [12:08<01:40,  2.47it/s]

296
297
298


 55%|█████▍    | 300/548 [13:10<56:50, 13.75s/it]

299
300


 55%|█████▍    | 301/548 [13:18<50:20, 12.23s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 301
301
302
303
304
305
306
307
308
309
310
311
312
313


 58%|█████▊    | 316/548 [13:19<07:53,  2.04s/it]

314
315
316


 58%|█████▊    | 318/548 [13:20<06:45,  1.76s/it]

317
318
319


 59%|█████▊    | 321/548 [14:22<26:42,  7.06s/it]

320
321


 59%|█████▉    | 323/548 [14:23<21:17,  5.68s/it]

322
323


 59%|█████▉    | 326/548 [14:23<14:04,  3.80s/it]

324
325


 60%|█████▉    | 328/548 [14:25<10:25,  2.84s/it]

326
327


 60%|██████    | 329/548 [14:26<08:26,  2.31s/it]

328
329
330


 61%|██████    | 332/548 [14:28<05:33,  1.54s/it]

331


 61%|██████    | 333/548 [14:28<04:41,  1.31s/it]

332
333
334
335


 61%|██████▏   | 337/548 [14:29<02:35,  1.36it/s]

336


 62%|██████▏   | 338/548 [14:30<02:40,  1.31it/s]

337
338
339
340


 62%|██████▏   | 342/548 [14:31<01:49,  1.87it/s]

341


 63%|██████▎   | 343/548 [14:32<01:57,  1.74it/s]

342


 63%|██████▎   | 344/548 [15:03<19:27,  5.72s/it]

343
344
345


 63%|██████▎   | 347/548 [15:05<12:01,  3.59s/it]

346


 64%|██████▎   | 348/548 [15:05<10:12,  3.06s/it]

347
348
349
350


 64%|██████▍   | 351/548 [15:16<10:44,  3.27s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 351
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365


 67%|██████▋   | 367/548 [15:18<02:39,  1.13it/s]

366


 67%|██████▋   | 368/548 [16:16<13:32,  4.52s/it]

367


 67%|██████▋   | 369/548 [16:16<12:29,  4.19s/it]

368


 68%|██████▊   | 370/548 [16:17<11:12,  3.78s/it]

369


 68%|██████▊   | 371/548 [16:18<10:11,  3.46s/it]

370
371
372
373
374
375


 69%|██████▉   | 380/548 [16:20<03:38,  1.30s/it]

376
377
378
379


 70%|██████▉   | 381/548 [16:20<03:14,  1.16s/it]

380
381


 70%|██████▉   | 383/548 [16:21<02:44,  1.00it/s]

382


 70%|███████   | 384/548 [16:26<04:11,  1.53s/it]

383
384
385
386
387
388


 71%|███████   | 390/548 [18:28<32:07, 12.20s/it]

389
390
391
392
393
394
395
396
397
398


 73%|███████▎  | 401/548 [18:38<13:20,  5.44s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 401
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413


 76%|███████▌  | 417/548 [19:21<08:32,  3.91s/it]

414
415
416


 77%|███████▋  | 420/548 [19:22<07:23,  3.47s/it]

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438


 81%|████████  | 442/548 [20:14<05:01,  2.84s/it]

439
440
441
442
443
444
445
446
447
448


 82%|████████▏ | 451/548 [20:30<04:08,  2.56s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 451
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475


 87%|████████▋ | 479/548 [21:13<02:18,  2.01s/it]

476
477
478
479


 88%|████████▊ | 483/548 [21:14<02:00,  1.86s/it]

480


 88%|████████▊ | 484/548 [21:16<01:57,  1.84s/it]

481


 89%|████████▊ | 485/548 [21:16<01:51,  1.77s/it]

482


 89%|████████▉ | 490/548 [21:17<01:14,  1.28s/it]

483
484
485
486
487


 90%|████████▉ | 491/548 [21:18<01:10,  1.24s/it]

488
489


 90%|████████▉ | 493/548 [21:18<00:58,  1.06s/it]

490


 90%|█████████ | 494/548 [21:19<00:51,  1.04it/s]

491
492


 91%|█████████ | 496/548 [21:19<00:43,  1.21it/s]

493


 91%|█████████ | 497/548 [21:21<00:45,  1.13it/s]

494


 91%|█████████ | 498/548 [21:21<00:42,  1.17it/s]

495
496
497
498


 91%|█████████▏| 501/548 [21:36<02:03,  2.62s/it]

Progress saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json at index 501
499
500
501


 92%|█████████▏| 505/548 [22:24<04:56,  6.90s/it]

502


 92%|█████████▏| 506/548 [22:25<04:16,  6.11s/it]

503
504
505


 93%|█████████▎| 509/548 [22:29<02:44,  4.22s/it]

506
507
508


 93%|█████████▎| 512/548 [22:29<01:39,  2.78s/it]

509
510


 94%|█████████▍| 514/548 [22:30<01:13,  2.17s/it]

511
512


 94%|█████████▍| 516/548 [22:30<00:55,  1.72s/it]

513


 94%|█████████▍| 517/548 [22:31<00:46,  1.49s/it]

514
515


 95%|█████████▍| 519/548 [22:31<00:30,  1.06s/it]

516
517


 95%|█████████▌| 521/548 [22:33<00:26,  1.01it/s]

518


 95%|█████████▌| 522/548 [22:36<00:38,  1.49s/it]

519
520


 96%|█████████▌| 524/548 [22:38<00:28,  1.20s/it]

521
522


 96%|█████████▌| 526/548 [23:05<01:54,  5.22s/it]

523
524
525


 97%|█████████▋| 529/548 [23:07<01:04,  3.38s/it]

526
527


 97%|█████████▋| 531/548 [23:09<00:47,  2.79s/it]

528


 97%|█████████▋| 532/548 [23:12<00:43,  2.70s/it]

529
530


 97%|█████████▋| 534/548 [23:12<00:26,  1.86s/it]

531
532
533
534


 98%|█████████▊| 538/548 [23:14<00:12,  1.24s/it]

535
536
537
538
539


 99%|█████████▉| 543/548 [23:15<00:03,  1.32it/s]

540


 99%|█████████▉| 544/548 [23:16<00:02,  1.37it/s]

541


 99%|█████████▉| 545/548 [23:16<00:01,  1.52it/s]

542
543
544


100%|██████████| 548/548 [23:17<00:00,  2.55s/it]

545





Results saved to /x20/users/cp/cplizzari/eval_results/benchmark_v7/human_eval_questions/eval_results.json
Evaluation for /x20/users/cp/cplizzari/benchmark/benchmark_v7/human_eval_questions.csv completed!


# Random Sampling - Gemini Flash

In [None]:
import os
import json
import time
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_random/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:v2_s_dense_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [0.1, 0.5, 1.0]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        return pd.read_csv(data_file)

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        video_id = row['Video UID']
        question = row['Question']
        category = row['Category']
        answer = str(row.get('Answer_open', ''))

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"These are frames from a video that I want to upload. "
                f"Use the visual cues to answer the question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    frames_dir = f'/x20/users/cp/cplizzari/random_sampling_temporal_sampling_rate_v6/_{sampling_rate}/{uid}_{question}_{sampling_rate}_frames'

    if not gfile.Exists(frames_dir):
        print(f"Frames directory does not exist: {frames_dir}")
        return None

    image_paths = gfile.ListDir(frames_dir)

    # Initialize the prompt list
    prompt = [
        model_client.ContentChunk(
            value=gfile.Open(os.path.join(frames_dir, image_path), 'rb').read(),
            mimetype='image/png',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
        for image_path in sorted(image_paths)
    ]

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                break  # Exit loop if successful
            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(30)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=4):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                model_response = existing_data
                existing_entries = {(entry["V"], entry["Q"]) for entry in existing_data}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 100 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")

def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(Config.DATA_FILE, Config.QA_TYPE, Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()


# Shuffled frames - Gemini Flash


In [None]:
import os
import json
import time
import random
import random
import ast
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from google3.pyglib import gfile
from google3.learning.deepmind.evergreen.model_access.client.python import model_client
import google3.learning.gemini.format.python.roles as roles
from IPython import display


class Config:
    # Directories and File Paths
    OUTPUT_DIR = '/x20/users/cp/cplizzari/benchmark/benchmark_v7/uniform_gemini_flash_sampling_random/'
    DATA_FILE = '/x20/users/cp/cplizzari/selected_wip_v7.csv'

    # Model Configuration
    MODEL_URL = 'evergreen2://blade:gdm-aip-agent-generate-service-prod-high-priority/lmroot:v2_s_dense_shared'
    TEMPERATURE = 0
    TOP_P = 0.95
    MAX_LENGTH = 8192

    # QA Configuration
    QA_TYPE = 'OpenQA'  # Options: 'OpenQA', 'CloseQA', 'Mixed'
    CLOSE_QA_WEIGHT = 50  # Used only if QA_TYPE is 'Mixed'

    # Inference Configuration
    SAMPLING_RATES = [0.1, 0.5, 1.0]
    MAX_WORKERS = 8  # Number of threads for multithreading
    BATCH_SIZE = 1
    SHUFFLE_DATA = False

    # Output File Naming
    RESULT_FILE_TEMPLATE = 'results_{sampling_rate}.json'


class DataLoader:
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.start = 0
        return self

    def __next__(self):
        if self.start >= len(self.indices):
            raise StopIteration
        end = min(self.start + self.batch_size, len(self.indices))
        batch_indices = self.indices[self.start:end]
        batch = [self.dataset[idx] for idx in batch_indices]
        self.start = end
        return batch


class BaseDataset:
    def __init__(self, data_file):
        self.annotations = self.load_annotations(data_file)

    def load_annotations(self, data_file):
        return pd.read_csv(data_file)

    def __len__(self):
        return len(self.annotations)


class QADataset(BaseDataset):
    def __init__(self, data_file, qa_type, CloseQA_weight=50):
        super().__init__(data_file)
        self.qa_type = qa_type
        self.choice_indices = ['A', 'B', 'C', 'D']
        self.CloseQA_weight = CloseQA_weight
        self.openqa_weight = 100 - CloseQA_weight

    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        video_id = row['Video UID']
        question = row['Question']
        category = row['Category']
        answer = str(row.get('Answer_open', ''))

        qa_type = self.qa_type
        if self.qa_type == 'Mixed':
            qa_type = random.choices(
                ['CloseQA', 'OpenQA'],
                weights=[self.CloseQA_weight, self.openqa_weight],
                k=1
            )[0]

        if qa_type == 'OpenQA':
            question_str = (
                f"These are frames from a video that I want to upload. "
                f"Use the visual cues to answer the question: {question}. "
                f"You need to answer the question in any case and not demand additional context information. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = answer
        elif qa_type == 'CloseQA':
            try:
                wrong_answers = ast.literal_eval(row['Answer_closed'])
            except (ValueError, SyntaxError, KeyError):
                wrong_answers = ['A', 'B', 'C']
                wrong_files.append(video_id)

            choices = [answer] + wrong_answers[:3]
            random.shuffle(choices)
            answer_index = choices.index(answer)
            choices_str = ' '.join([f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))])
            question_str = (
                f"Question: {question} Choices: {choices_str}. "
                f"Please answer by returning only the letter that corresponds to the correct answer, in the form [LETTER]. "
                f"Note: All actions mentioned refer to the person recording the video."
            )
            answer_str = choices[answer_index]
        else:
            raise NotImplementedError(f"QA type '{qa_type}' is not implemented.")

        return {
            'video_id': video_id,
            'question_answer': question_str,
            'question': question,
            'answer': answer_str,
            'task': qa_type,
            'category': category
        }


def initialize_client():
    return model_client.ModelClient(
        model_url=Config.MODEL_URL,
        default_config=model_client.make_generation_config(
            seed=0,
            formatting_options=model_client.FormattingOptions(enable_formatting=True),
            token_generation=model_client.make_token_generation_config(
                sampling_config=model_client.make_sampling_config(
                    temperature=Config.TEMPERATURE,
                    nucleus_top_p=Config.TOP_P,
                ),
                length=Config.MAX_LENGTH,
            ),
        ),
    )


def process_qa_item(batch, sampling_rate, client, existing_entries):
    uid = batch['video_id']
    question = batch['question']
    question_answer = batch['question_answer']
    category = batch['category']
    answer = batch['answer']

    # Skip processing if entry already exists
    if (uid, question) in existing_entries:
        print(f"Skipping existing entry for video ID: {uid} and question: {question}")
        return None

    frames_dir = f'/x20/users/cp/cplizzari/random_sampling_temporal_sampling_rate_v6/_{sampling_rate}/{uid}_{question}_{sampling_rate}_frames'

    if not gfile.Exists(frames_dir):
        print(f"Frames directory does not exist: {frames_dir}")
        return None

    image_paths = gfile.ListDir(frames_dir)

    # Initialize the prompt list with shuffled image paths
    prompt = [
        model_client.ContentChunk(
            value=gfile.Open(os.path.join(frames_dir, image_path), 'rb').read(),
            mimetype='image/png',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
        for image_path in random.sample(image_paths, len(image_paths))  # Shuffle image paths
    ]

    prompt.append(
        model_client.ContentChunk(
            value=question_answer,
            mimetype='text/plain',
            metadata=model_client.Metadata(role=roles.ROLE_USER)
        )
    )

    text = ''
    try:
        while True:
            try:
                for content in client.generate_stream(prompt):
                    text += content.as_text()
                break  # Exit loop if successful
            except Exception as e:
                print(f"Error generating stream for {uid}: {e}. Retrying in 30 seconds...")
                time.sleep(30)
    except Exception as e:
        print(f"Error processing {uid}: {e}")
        return None

    return {
        "V": uid,
        "Q": question,
        "QA": question_answer,
        "A": text,
        "C": answer,
        "M": category
    }


def perform_bulk_inference(data_loader, output_file_path, sampling_rate, client, max_workers=4):
    """
    Perform bulk inference using multithreading.
    """
    model_response = []
    existing_entries = set()

    # Load existing responses if the output file exists
    if gfile.Exists(output_file_path):
        with gfile.GFile(output_file_path, 'r') as fi:
            try:
                existing_data = json.load(fi)
                model_response = existing_data
                existing_entries = {(entry["V"], entry["Q"]) for entry in existing_data}
            except json.JSONDecodeError:
                print(f"JSON decode error for file {output_file_path}. Starting with an empty response.")
                model_response = []
                existing_entries = set()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Prepare all batches
        batches = list(data_loader)
        total_batches = len(batches)
        futures = {
            executor.submit(process_qa_item, batch[0], sampling_rate, client, existing_entries): idx
            for idx, batch in enumerate(batches)
        }

        for future in tqdm(as_completed(futures), total=total_batches, desc="Processing QA Items"):
            result = future.result()
            if result:
                model_response.append(result)
                existing_entries.add((result["V"], result["Q"]))

                # Periodically save to prevent data loss
                if len(model_response) % 100 == 0:
                    with gfile.GFile(output_file_path, 'w') as fi:
                        json.dump(model_response, fi)
                        print(f"Saved {len(model_response)} entries to {output_file_path}")

    # Final save after all processing
    with gfile.GFile(output_file_path, 'w') as fi:
        json.dump(model_response, fi)
        print(f"Final results saved to {output_file_path}")



def main():
    # Initialize client
    client = initialize_client()

    # Ensure the output directory exists
    if not gfile.Exists(Config.OUTPUT_DIR):
        gfile.MakeDirs(Config.OUTPUT_DIR)

    for sampling_rate in Config.SAMPLING_RATES:
        # Set the output file path based on the current sampling rate
        output_file_path = os.path.join(
            Config.OUTPUT_DIR,
            Config.RESULT_FILE_TEMPLATE.format(sampling_rate=sampling_rate)
        )

        # Initialize dataset and data loader with the current sampling rate
        dataset = QADataset(Config.DATA_FILE, Config.QA_TYPE, Config.CLOSE_QA_WEIGHT)
        data_loader = DataLoader(
            dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=Config.SHUFFLE_DATA
        )

        # Run bulk inference
        print(f"Running bulk inference for sampling rate: {sampling_rate}")
        perform_bulk_inference(
            data_loader=data_loader,
            output_file_path=output_file_path,
            sampling_rate=sampling_rate,
            client=client,
            max_workers=Config.MAX_WORKERS
        )


if __name__ == "__main__":
    main()
