In [None]:
# 240626

import os
from tqdm import tqdm
import torch
from vllm import LLM, SamplingParams
from itertools import product
import gc
from format import *
from torch.utils.data import Dataset, DataLoader
import json

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [None]:
class ModelResponder:
    def __init__(self, model_path, ques_path_list, prompt_func_list, inst_list, path=None, quant=False, llama3=False, qwen=False, yi=False):
        self.sampling_params = SamplingParams(temperature=0)
        self.batch_size = 200
        self.device_num = torch.cuda.device_count()
        self.model = LLM(model=model_path, tensor_parallel_size=self.device_num, disable_custom_all_reduce=True)
        if path is not None:
            self.path = path
        else:
            self.path = os.path.basename(model_path)

        indexed_ques_paths = list(enumerate(ques_path_list))
        indexed_prompt_funcs = list(enumerate(prompt_func_list))
        indexed_insts = list(enumerate(inst_list))

        self.combinations = list(product(indexed_ques_paths, indexed_prompt_funcs, indexed_insts))

    def process_files(self):
        for (ques_idx, ques_path), (prompt_idx, prompt_func), (inst_idx, inst) in self.combinations:
            with open(ques_path) as f:
                exam = json.load(f)
            filename=f'output/{self.path} [f{prompt_idx+1}_p{context_idx+1}_q{ques_idx+1}].json'
            dataset = ExamDataset(exam, inst, prompt_func)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
            with tqdm(total=len(dataloader),desc=filename, leave=True, position=0) as pbar:
                for i, batch in enumerate(dataloader):
                    results=[]
                    attempt = 0  # Track the number of attempts for this question
                    while attempt < 2:  # Allow up to 2 attempts
                        try:
                            outputs = self.model.generate(batch['input'],sampling_params=self.sampling_params)
                            output_text = []
                            for output in outputs:
                                generated_text = output.outputs[0].text
                                output_text.append(generated_text)
                            batch['response'] = output_text
                            ques_dict = [{key: batch[key][i].tolist() if isinstance(batch[key][i], torch.Tensor) else batch[key][i] for key in batch} for i in range(len(batch['input']))]

                            results.append(ques_dict)
                            break  # Break the loop if successful
                        except RuntimeError as e:
                            if "CUDA out of memory" in str(e):
                                print(f"Attempt {attempt + 1}: CUDA OOM on {filename}. Trying again after clearing cache.")
                                torch.cuda.empty_cache()  # Try to free some memory
                                attempt += 1
                                if attempt == 2:
                                    print(f"Skipping {filename} after repeated OOM errors.")
                            else:
                                raise  # Re-raise exception if it's not a CUDA OOM error

                    if os.path.exists(filename):
                        with open(filename, 'r') as file:
                            resp = json.load(file)
                    else:
                        resp = []
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
                    resp += results
                    with open(filename, 'w', encoding='utf-8') as f:
                        json.dump(resp, f, indent=4, ensure_ascii=False)
                    pbar.update(1)
                    
    def delete(self):
        del self.model
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
inst = "\
You are taking the Pharmacist Licensing Examination. \
Read the question carefully and use your judgment to select one of the options beginning with '①, ②, ③, ④, ⑤' as an answer. \
Submit only one symbol from '①, ②, ③, ④, ⑤' without any explanation, as this is an OMR exam.\
"

inst_list = [inst]

ques_path_list = [
    r'/home/hwjang/project/LLM/240625/exams_final.json',
    ]

In [None]:
# def main():
#     model_path = "meta-llama/Meta-Llama-3-70B-Instruct"
#     prompt_func_list = [prompt_llama3]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "meta-llama/Llama-2-70b-chat-hf"
#     prompt_func_list = [prompt_llama2]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "/home/hwjang/project/LLM/data/SOLAR-0-70b-16bit"
#     prompt_func_list = [prompt_solar]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()


# def main():
#     model_path = "Qwen/Qwen2-72B-Instruct"
#     prompt_func_list = [prompt_qwen2]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list, qwen=True)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "CohereForAI/c4ai-command-r-plus"
#     prompt_func_list = [prompt_command]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "CohereForAI/c4ai-command-r-v01"
#     prompt_func_list = [prompt_command]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "01-ai/Yi-1.5-34B-Chat"
#     prompt_func_list = [prompt_yi]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list, yi=True)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "moreh/MoMo-72B-lora-1.8.7-DPO"
#     prompt_func_list = [prompt_momo]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "meta-llama/Llama-2-7b-chat-hf"
#     prompt_func_list = [prompt_llama2]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

# def main():
#     model_path = "upstage/SOLAR-10.7B-Instruct-v1.0"
#     prompt_func_list = [prompt_solar]
#     run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
#     run.process_files()
#     run.delete()

def main():
    model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
    prompt_func_list = [prompt_llama3]
    run = ModelResponder(model_path, ques_path_list, prompt_func_list, inst_list)
    run.process_files()
    run.delete()

if __name__ == '__main__':
    main()