In [1]:
# 240630

from glob import glob
import pickle
import os
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import product
import vertexai
from vertexai.language_models import ChatModel, InputOutputTextPair
import vertexai.preview.generative_models as generative_models
import json
from torch.utils.data import Dataset, DataLoader
from format import *

vertexai.init(project='phonic-impact-421908', location="asia-northeast3")

In [2]:
class ModelResponder:
    def __init__(self, model_path, ques_path_list, context_list, prompt_func_list=None, path=None):
        self.batch_size = 350
        self.model_path = model_path
        if path is not None:
            self.path = path
        else:
            self.path = os.path.basename(model_path)

        if not prompt_func_list:
            default_prompt_func = lambda inst,ques: f"{ques}"
            prompt_func_list = [default_prompt_func]
        self.prompt_func_list = prompt_func_list

        indexed_ques_paths = list(enumerate(ques_path_list))
        indexed_prompt_funcs = list(enumerate(self.prompt_func_list))
        indexed_contexts = list(enumerate(context_list))

        self.combinations = list(product(indexed_ques_paths, indexed_prompt_funcs, indexed_contexts))
        
        self.safety_settings = {
            generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
            generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        }
    def process_and_update(self, ques_dict, context, prompt_func, pbar):
        max_retries = 5
        retries = 0
        ques=ques_dict['input']
        self.chat_model = ChatModel.from_pretrained(self.model_path)
        while retries < max_retries:
            try:
                chat = self.chat_model.start_chat(context=context)
                responses = chat.send_message(ques, temperature=0)
                pbar.update(1)
                output = responses.text
                ques_dict['response'] = output
                return ques_dict
            except Exception as e:
                current = f"{ques_dict['year']}-{ques_dict['session']}-{ques_dict['question_number']}"
                if "Quota" or "quota" in str(e):
                    time.sleep(30)
                elif "candidate is likely blocked" in str(e):
                    retries += 1
                    print(f"Candidate blocked '{current}'")
                else:
                    retries += 1
                    print(e)
                    print(f"Retrying '{current}'... {retries}/{max_retries}")
                    time.sleep(10)
                if retries == max_retries:
                    print(f"Failed to generate response for '{current}', skipping...")
                    ques_dict['response'] = None
                    return ques_dict
            except RetryError as e:
                time.sleep(60)
    def process_files(self):
        for (ques_idx, ques_path), (prompt_idx, prompt_func), (context_idx, inst) in self.combinations:
            with open(ques_path) as f:
                exam = json.load(f)
            filename=f'output/{self.path} [f{prompt_idx+1}_p{context_idx+1}_q{ques_idx+1}].json'
            dataset = ExamDataset(exam, inst, prompt_func)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
            with tqdm(total=len(dataloader),desc=filename, leave=True, position=1) as pbar:
                for i, batch in enumerate(dataloader):
                    with ThreadPoolExecutor() as executor:
                        with tqdm(total=len(batch), desc=f"Batch {i}", leave=True, position=0) as batch_pbar:
                            results = list(executor.map(lambda ques: self.process_and_update(ques, inst, prompt_func, batch_pbar), batch))
                    if os.path.exists(filename):
                        with open(filename, 'r') as file:
                            resp = json.load(file)
                    else:
                        resp = []
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
                    for result in results:
                        resp.append(result)
                    with open(filename, 'w', encoding='utf-8') as f:
                        json.dump(resp, f, indent=4, ensure_ascii=False)
                    pbar.update(1)

In [3]:
model_path = 'chat-bison@001' # 2023-07-10

run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0:   0%|          | 0/350 [00:00<?, ?it/s]      | 0/6 [00:00<?, ?it/s]

Batch 0: 100%|██████████| 350/350 [00:32<00:00, 10.92it/s]
Batch 1: 100%|██████████| 350/350 [00:48<00:00,  7.17it/s]6 [00:32<02:40, 32.06s/it]
Batch 2: 100%|██████████| 350/350 [00:26<00:00, 13.13it/s]6 [01:20<02:47, 41.92s/it]
Batch 3: 100%|██████████| 350/350 [00:49<00:00,  7.00it/s]6 [01:47<01:44, 34.96s/it]
Batch 4: 100%|██████████| 350/350 [00:25<00:00, 13.88it/s]6 [02:37<01:21, 40.90s/it]
Batch 5: 100%|██████████| 350/350 [00:24<00:00, 14.33it/s]6 [03:02<00:35, 35.24s/it]
output/chat-bison@001 [f1_p1_q1].json: 100%|██████████| 6/6 [03:27<00:00, 34.53s/it]
Batch 0: 100%|██████████| 350/350 [01:09<00:00,  5.03it/s]6 [00:00<?, ?it/s]
Batch 1: 100%|██████████| 350/350 [00:23<00:00, 14.66it/s]6 [01:09<05:47, 69.55s/it]
Batch 2: 100%|██████████| 350/350 [00:25<00:00, 13.79it/s]6 [01:33<02:50, 42.68s/it]
Batch 3: 100%|██████████| 350/350 [00:57<00:00,  6.14it/s]6 [01:58<01:44, 34.78s/it]
Batch 4: 100%|██████████| 350/350 [00:23<00:00, 14.76it/s]6 [02:55<01:27, 43.57s/it]
Batch 5: 100%|

In [3]:
model_path = 'chat-bison@002' # 2023-12-06

run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0:   0%|          | 0/350 [00:00<?, ?it/s]      | 0/6 [00:00<?, ?it/s]

Batch 0: 100%|██████████| 350/350 [00:25<00:00, 13.56it/s]
Batch 1: 100%|██████████| 350/350 [01:37<00:00,  3.60it/s]6 [00:25<02:09, 25.82s/it]
Batch 2: 100%|██████████| 350/350 [00:21<00:00, 15.98it/s]6 [02:03<04:31, 67.89s/it]
Batch 3: 100%|██████████| 350/350 [01:56<00:00,  3.00it/s]6 [02:25<02:20, 46.90s/it]
Batch 4: 100%|██████████| 350/350 [00:20<00:00, 17.42it/s]6 [04:21<02:28, 74.39s/it]
Batch 5: 100%|██████████| 350/350 [00:54<00:00,  6.36it/s]6 [04:41<00:54, 54.82s/it]
output/chat-bison@002 [f1_p1_q1].json: 100%|██████████| 6/6 [05:36<00:00, 56.13s/it]
Batch 0: 100%|██████████| 350/350 [00:19<00:00, 17.70it/s]6 [00:00<?, ?it/s]
Batch 1: 100%|██████████| 350/350 [00:21<00:00, 16.05it/s]6 [00:19<01:38, 19.78s/it]
Batch 2: 100%|██████████| 350/350 [00:20<00:00, 17.03it/s]6 [00:41<01:23, 20.98s/it]
Batch 3: 100%|██████████| 350/350 [01:18<00:00,  4.45it/s]6 [01:02<01:02, 20.80s/it]
Batch 4: 100%|██████████| 350/350 [00:20<00:00, 17.00it/s]6 [02:20<01:27, 43.66s/it]
Batch 5: 100%|