In [None]:
# 240626

from openai import OpenAI
import os
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import product
import json
from util_trans import *

client = OpenAI(api_key="")

In [2]:
class ModelResponder:
    def __init__(self, model_path, ques_path_list, context_list, prompt_func_list=None, path=None):
        self.batch_size = 100
        self.model_path = model_path
        if path is not None:
            self.path = path
        else:
            self.path = os.path.basename(model_path)

        if not prompt_func_list:
            default_prompt_func = lambda inst,ques: f"{ques}"
            prompt_func_list = [default_prompt_func]
        self.prompt_func_list = prompt_func_list

        indexed_ques_paths = list(enumerate(ques_path_list))
        indexed_prompt_funcs = list(enumerate(self.prompt_func_list))
        indexed_contexts = list(enumerate(context_list))
        self.combinations = list(product(indexed_ques_paths, indexed_prompt_funcs, indexed_contexts))

    def process_and_update(self, ques_dict, inst, prompt_func, batch_pbar):
        max_retries = 5
        retries = 0
        while retries < max_retries:
            try:
                response = client.chat.completions.create(
                    model=self.model_path,
                    messages=[
                        {"role": "system", "content": inst},
                        {"role": "user", "content": prompt_func(inst, ques_dict)},
                    ],
                    temperature=0.0000000000000000000000000000000000000000001,
                    top_p=0.0000000000000000000000000000000000000000001,
                    seed=100,
                    )
                batch_pbar.update(1)
                output = response.choices[0].message.content
                result = output_adjust(output)
                return result
            except Exception as e:
                retries += 1
                print(e)
                print(f"Retrying... ({retries}/{max_retries})")
                time.sleep(60)
                
    def process_files(self):
        for (ques_idx, ques_path), (prompt_idx, prompt_func), (context_idx, inst) in self.combinations:
            with open(ques_path) as f:
                exam = json.load(f)
            filename=f'output/{self.path} [f{prompt_idx+1}_p{context_idx+1}_q{ques_idx+1}].json'
            dataset = ExamDataset(exam, inst, prompt_func)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
            with tqdm(total=len(dataloader),desc=filename, leave=True, position=0) as pbar:
                for i, batch in enumerate(dataloader):
                    with ThreadPoolExecutor() as executor:
                        with tqdm(total=len(batch), desc=f"Batch {i}", leave=True, position=0) as batch_pbar:
                            results = list(executor.map(lambda ques: self.process_and_update(ques, inst, prompt_func, batch_pbar), batch))
                    if os.path.exists(filename):
                        with open(filename, 'r') as file:
                            resp = json.load(file)
                    else:
                        resp = []
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
                    for result in results:
                        resp.append(result)
                    with open(filename, 'w', encoding='utf-8') as f:
                        json.dump(resp, f, indent=4, ensure_ascii=False)
                    pbar.update(1)

In [3]:
model_path="gpt-4o-2024-05-13"
run = ModelResponder(model_path, ques_path_list, inst_list)
run.process_files()

Batch 0:   0%|          | 0/100 [00:00<?, ?it/s]         | 0/21 [00:00<?, ?it/s]

Batch 0: 100%|██████████| 100/100 [00:14<00:00,  6.80it/s]
Batch 1: 100%|██████████| 100/100 [00:14<00:00,  6.67it/s] 1/21 [00:14<04:54, 14.71s/it]
Batch 2: 100%|██████████| 100/100 [00:20<00:00,  4.94it/s] 2/21 [00:29<04:42, 14.88s/it]
Batch 3: 100%|██████████| 100/100 [00:19<00:00,  5.24it/s] 3/21 [00:49<05:12, 17.33s/it]
Batch 4: 100%|██████████| 100/100 [00:16<00:00,  6.19it/s] 4/21 [01:09<05:06, 18.04s/it]
Batch 5: 100%|██████████| 100/100 [00:18<00:00,  5.44it/s] 5/21 [01:25<04:37, 17.37s/it]
Batch 6: 100%|██████████| 100/100 [00:21<00:00,  4.56it/s] 6/21 [01:43<04:25, 17.72s/it]
Batch 7: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s] 7/21 [02:05<04:27, 19.10s/it]
Batch 8: 100%|██████████| 100/100 [00:19<00:00,  5.24it/s] 8/21 [02:23<04:03, 18.71s/it]
Batch 9: 100%|██████████| 100/100 [00:21<00:00,  4.75it/s] 9/21 [02:42<03:45, 18.83s/it]
Batch 10: 100%|██████████| 100/100 [00:15<00:00,  6.40it/s]10/21 [03:03<03:34, 19.53s/it]
Batch 11: 100%|██████████| 100/100 [00:18<00:00,  