In [1]:
# 240630

from glob import glob
import pickle
import os
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import product
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig
import vertexai.preview.generative_models as generative_models
import json
from torch.utils.data import Dataset, DataLoader
from format import *

vertexai.init(project='phonic-impact-421908', location="asia-northeast3")

In [2]:
class ModelResponder:
    def __init__(self, model_path, ques_path_list, context_list, prompt_func_list=None, path=None):
        self.batch_size = 350
        self.model_path = model_path
        if path is not None:
            self.path = path
        else:
            self.path = os.path.basename(model_path)

        if not prompt_func_list:
            default_prompt_func = lambda inst,ques: f"{ques}"
            prompt_func_list = [default_prompt_func]
        self.prompt_func_list = prompt_func_list

        indexed_ques_paths = list(enumerate(ques_path_list))
        indexed_prompt_funcs = list(enumerate(self.prompt_func_list))
        indexed_contexts = list(enumerate(context_list))

        self.combinations = list(product(indexed_ques_paths, indexed_prompt_funcs, indexed_contexts))
        
        self.generation_config = GenerationConfig(
            temperature=0,
            candidate_count=1,
        )
        self.safety_settings = {
            generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        }
    def process_and_update(self, ques_dict, context, prompt_func, pbar):
        max_retries = 5
        retries = 0
        ques=ques_dict['input']
        self.model = GenerativeModel(
            self.model_path,
            # system_instruction=[context]
            )
        while retries < max_retries:
            try:
                responses = self.model.generate_content(
                    contents=[context,ques],
                    generation_config=self.generation_config,
                    safety_settings=self.safety_settings,
                )
                pbar.update(1)
                output = responses.text
                ques_dict['response'] = output
                return ques_dict
            except Exception as e:
                current = f"{ques_dict['year']}-{ques_dict['session']}-{ques_dict['question_number']}"
                if "Quota" or "quota" in str(e):
                    time.sleep(30)
                elif "candidate is likely blocked" in str(e):
                    retries += 1
                    print(f"Candidate blocked '{current}'")
                else:
                    retries += 1
                    print(e)
                    print(f"Retrying '{current}'... {retries}/{max_retries}")
                    time.sleep(10)
                if retries == max_retries:
                    print(f"Failed to generate response for '{current}', skipping...")
                    return None
                
    def process_files(self):
        for (ques_idx, ques_path), (prompt_idx, prompt_func), (context_idx, inst) in self.combinations:
            with open(ques_path) as f:
                exam = json.load(f)
            filename=f'output/{self.path} [f{prompt_idx+1}_p{context_idx+1}_q{ques_idx+1}].json'
            dataset = ExamDataset(exam, inst, prompt_func)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
            with tqdm(total=len(dataloader),desc=filename, leave=True, position=0) as pbar:
                for i, batch in enumerate(dataloader):
                    with ThreadPoolExecutor() as executor:
                        with tqdm(total=len(batch), desc=f"Batch {i}", leave=True, position=0) as batch_pbar:
                            results = list(executor.map(lambda ques: self.process_and_update(ques, inst, prompt_func, batch_pbar), batch))
                    if os.path.exists(filename):
                        with open(filename, 'r') as file:
                            resp = json.load(file)
                    else:
                        resp = []
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
                    for result in results:
                        resp.append(result)
                    with open(filename, 'w', encoding='utf-8') as f:
                        json.dump(resp, f, indent=4, ensure_ascii=False)
                    pbar.update(1)

In [3]:
model_path = 'gemini-1.0-pro-001' # 2024-02-15
run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0:  17%|█▋        | 61/350 [00:07<00:11, 24.56it/s] | 0/6 [00:00<?, ?it/s]