In [1]:
# 240630

from glob import glob
import pickle
import os
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import product
from anthropic import AnthropicVertex
import json
from torch.utils.data import Dataset, DataLoader
from format import *

REGION = 'us-east5'
PROJECT_ID = "phonic-impact-421908"
ENDPOINT = f"https://{REGION}-aiplatform.googleapis.com"

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT_ID
client = AnthropicVertex(region=REGION, project_id=PROJECT_ID)

In [2]:
class ModelResponder:
    def __init__(self, model_path, ques_path_list, context_list, prompt_func_list=None, path=None):
        self.batch_size = 100
        self.model_path = model_path
        if path is not None:
            self.path = path
        else:
            self.path = os.path.basename(model_path)

        if not prompt_func_list:
            default_prompt_func = lambda inst,ques: f"{ques}"
            prompt_func_list = [default_prompt_func]
        self.prompt_func_list = prompt_func_list

        indexed_ques_paths = list(enumerate(ques_path_list))
        indexed_prompt_funcs = list(enumerate(self.prompt_func_list))
        indexed_contexts = list(enumerate(context_list))
        self.combinations = list(product(indexed_ques_paths, indexed_prompt_funcs, indexed_contexts))

    def process_and_update(self, ques_dict, context, prompt_func, pbar):
        max_retries = 50
        retries = 0
        ques=ques_dict['input']
        while retries < max_retries:
            try:
                responses = client.messages.create(
                    messages=[{"role":"user","content": ques}],
                    system=context,
                    model=self.model_path,
                    temperature=0,
                    max_tokens=1024
                )
                output = responses.to_dict()['content'][0]['text']
                ques_dict['response'] = output
                pbar.update(1)
                return ques_dict
            except Exception as e:
                current = f"{ques_dict['year']}-{ques_dict['session']}-{ques_dict['question_number']}"
                if "Quota" or "quota" in str(e):
                    time.sleep(60)
                else:
                    retries += 1
                    print(e)
                    print(f"Retrying '{current}'... {retries}/{max_retries}")
                    time.sleep(10)
                if retries == max_retries:
                    print(f"Failed to generate response for '{current}', skipping...")
                    ques_dict['response'] = None
                    return ques_dict

    def process_files(self):
        for (ques_idx, ques_path), (prompt_idx, prompt_func), (context_idx, inst) in self.combinations:
            with open(ques_path) as f:
                exam = json.load(f)
            filename=f'output/{self.path} [f{prompt_idx+1}_p{context_idx+1}_q{ques_idx+1}].json'
            dataset = ExamDataset(exam, inst, prompt_func)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
            with tqdm(total=len(dataloader),desc=filename, leave=True, position=1) as pbar:
                for i, batch in enumerate(dataloader):
                    with ThreadPoolExecutor() as executor:
                        with tqdm(total=len(batch), desc=f"Batch {i}", leave=True, position=0) as batch_pbar:
                            results = list(executor.map(lambda ques: self.process_and_update(ques, inst, prompt_func, batch_pbar), batch))
                    if os.path.exists(filename):
                        with open(filename, 'r') as file:
                            resp = json.load(file)
                    else:
                        resp = []
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
                    for result in results:
                        resp.append(result)
                    with open(filename, 'w', encoding='utf-8') as f:
                        json.dump(resp, f, indent=4, ensure_ascii=False)
                    pbar.update(1)

In [3]:
model_path = "claude-3-5-sonnet@20240620"
run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0: 100%|██████████| 100/100 [00:14<00:00,  6.78it/s]        | 0/21 [00:00<?, ?it/s]
Batch 1: 100%|██████████| 100/100 [00:09<00:00, 10.97it/s]        | 1/21 [00:14<04:54, 14.75s/it]
Batch 2: 100%|██████████| 100/100 [00:09<00:00, 11.00it/s]        | 2/21 [00:23<03:37, 11.44s/it]
Batch 3: 100%|██████████| 100/100 [00:09<00:00, 10.65it/s]        | 3/21 [00:32<03:06, 10.37s/it]
Batch 4: 100%|██████████| 100/100 [00:10<00:00,  9.85it/s]        | 4/21 [00:42<02:49,  9.99s/it]
Batch 5: 100%|██████████| 100/100 [00:08<00:00, 11.38it/s]▍       | 5/21 [00:52<02:40, 10.05s/it]
Batch 6: 100%|██████████| 100/100 [00:10<00:00,  9.95it/s]▊       | 6/21 [01:01<02:24,  9.63s/it]
Batch 7: 100%|██████████| 100/100 [00:08<00:00, 11.55it/s]█▎      | 7/21 [01:11<02:16,  9.77s/it]
Batch 8: 100%|██████████| 100/100 [00:09<00:00, 11.02it/s]█▊      | 8/21 [01:20<02:02,  9.42s/it]
Batch 9: 100%|██████████| 100/100 [00:09<00:00, 10.64it/s]██▎     | 9/21 [01:29<01:51,  9.32s/it]
Batch 10: 100%|██████████| 1

In [4]:
model_path = "claude-3-opus@20240229"
run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0:   0%|          | 0/100 [00:00<?, ?it/s] 0%|          | 0/21 [00:00<?, ?it/s]

Batch 0: 100%|██████████| 100/100 [01:31<00:00,  1.10it/s]
Batch 1: 100%|██████████| 100/100 [01:36<00:00,  1.03it/s]    | 1/21 [01:31<30:22, 91.14s/it]
Batch 2: 100%|██████████| 100/100 [01:32<00:00,  1.08it/s]    | 2/21 [03:07<29:53, 94.37s/it]
Batch 3: 100%|██████████| 100/100 [02:26<00:00,  1.47s/it]    | 3/21 [04:40<28:02, 93.49s/it]
Batch 4: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it]    | 4/21 [07:07<32:28, 114.61s/it]
Batch 5: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it]    | 5/21 [08:48<29:19, 109.95s/it]
Batch 6: 100%|██████████| 100/100 [02:38<00:00,  1.59s/it]    | 6/21 [11:30<31:50, 127.36s/it]
Batch 7: 100%|██████████| 100/100 [01:38<00:00,  1.02it/s]    | 7/21 [14:08<32:06, 137.60s/it]
Batch 8: 100%|██████████| 100/100 [02:45<00:00,  1.66s/it]    | 8/21 [15:46<27:05, 125.01s/it]
Batch 9: 100%|██████████| 100/100 [02:30<00:00,  1.50s/it]    | 9/21 [18:32<27:32, 137.72s/it]
Batch 10: 100%|██████████| 100/100 [01:39<00:00,  1.01it/s]   | 10/21 [21:02<25:57, 141.5

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-47'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-28'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-32'... 1/50


Batch 0:  57%|█████▋    | 57/100 [01:19<00:19,  2.26it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-33'... 1/50


Batch 0:  64%|██████▍   | 64/100 [01:20<00:12,  2.92it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-59'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-57'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-1-61'... 1/50


Batch 0: 100%|██████████| 100/100 [02:38<00:00,  1.59s/it]
Batch 1:  50%|█████     | 50/100 [00:13<00:13,  3.81it/s]     | 1/21 [02:38<52:57, 158.85s/it]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-2-47'... 1/50


Batch 1: 100%|██████████| 100/100 [01:33<00:00,  1.07it/s]
Batch 2:  63%|██████▎   | 63/100 [01:20<00:20,  1.77it/s]     | 2/21 [04:12<38:03, 120.20s/it]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-3-67'... 1/50


Batch 2:  66%|██████▌   | 66/100 [01:21<00:17,  1.95it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-3-68'... 1/50


Batch 2: 100%|██████████| 100/100 [02:30<00:00,  1.50s/it]
Batch 3:  19%|█▉        | 19/100 [00:07<00:20,  4.00it/s]     | 3/21 [06:42<40:11, 133.96s/it]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-40'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-46'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-50'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-49'... 1/50


Batch 3:  35%|███▌      | 35/100 [00:12<00:14,  4.54it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-71'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-70'... 1/50


Batch 3:  43%|████▎     | 43/100 [00:14<00:17,  3.31it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2019-4-81'... 1/50


Batch 3: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s]
Batch 4:  61%|██████    | 61/100 [01:19<00:10,  3.68it/s]     | 4/21 [08:17<33:36, 118.60s/it]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2020-2-3'... 1/50


Batch 4:  75%|███████▌  | 75/100 [01:24<00:10,  2.35it/s]

Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2020-2-22'... 1/50
Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}
Retrying '2020-2-28'... 1/50


Batch 4: 100%|██████████| 100/100 [02:32<00:00,  1.52s/it]
Batch 5: 100%|██████████| 100/100 [01:29<00:00,  1.12it/s]    | 5/21 [10:49<34:51, 130.70s/it]
Batch 6: 100%|██████████| 100/100 [02:28<00:00,  1.48s/it]    | 6/21 [12:19<29:11, 116.75s/it]
Batch 7: 100%|██████████| 100/100 [01:30<00:00,  1.10it/s]    | 7/21 [14:47<29:39, 127.10s/it]
Batch 8: 100%|██████████| 100/100 [02:29<00:00,  1.50s/it]    | 8/21 [16:18<25:01, 115.54s/it]
Batch 9: 100%|██████████| 100/100 [01:30<00:00,  1.10it/s]    | 9/21 [18:48<25:15, 126.25s/it]
Batch 10: 100%|██████████| 100/100 [02:28<00:00,  1.48s/it]   | 10/21 [20:18<21:07, 115.23s/it]
Batch 11: 100%|██████████| 100/100 [02:33<00:00,  1.53s/it]   | 11/21 [22:47<20:53, 125.37s/it]
Batch 12: 100%|██████████| 100/100 [01:27<00:00,  1.14it/s]   | 12/21 [25:20<20:04, 133.85s/it]
Batch 13: 100%|██████████| 100/100 [02:30<00:00,  1.51s/it]   | 13/21 [26:47<15:58, 119.84s/it]
Batch 14: 100%|██████████| 100/100 [01:26<00:00,  1.16it/s]   | 14/21 [29:18<15:04

In [3]:
model_path = "claude-3-sonnet@20240229"
run = ModelResponder(model_path, exam_list, inst_list)
run.process_files()

Batch 0:   0%|          | 0/100 [00:00<?, ?it/s]