In [None]:
# get environmental variables
import os, sys
from dotenv import load_dotenv

root_dir = os.path.dirname(os.getcwd())
sys.path.append(root_dir)

load_dotenv()

In [None]:
# import API keys
import openai
from openai import OpenAI

openai.api_key = os.getenv('OPENAI_API_KEY')

In [None]:
from src.utils import load_pickle, save_pickle, save_jsonl, remove_brackets, remove_quotes
import re
from tqdm import tqdm
from datetime import date
import json
import random

In [None]:
# define openai agent
class openai_agent():
    def __init__(self, model, system_prompt, temperature, top_p, num_max_tokens):
        self.model = model
        if self.model == 'gpt-4o-mini':
            self.api_model = 'gpt-4o-mini-2024-07-18'
        elif self.model == 'gpt-4o':
            self.api_model = 'gpt-4o-2024-08-06'
        else:
            self.api_model = ''
        self.temperature = temperature
        self.top_p = top_p
        self.system_prompt = system_prompt
        self.num_max_tokens = num_max_tokens
        self.client = OpenAI()

    def create_message(self, custom_id: str, user_prompt:str):
        """Creates message (dictionary) to run OpenAI API"""
        message = {}
        message['custom_id'] = custom_id
        message['method'] = 'POST'
        message['url'] = "/v1/chat/completions"
        message['body'] = {}
        message['body']['model'] = self.api_model
        message['body']['top_p'] = self.top_p
        message['body']['messages'] = [{"role":"system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}]
        message['body']['max_tokens'] = self.num_max_tokens
        message['body']['temperature'] = self.temperature
        return message
    
    def create_batch(self, data, data_type, prompt):
        batch = []
        for i in tqdm(range(len(data['session_dialogue'])), desc = 'creating data to run batch API (OpenAI)'):
            entire_dialogue, partial_dialogue = data['session_dialogue'][i], data['partial_session_dialogue'][i]
            custom_id = f'session{i}'
            entire_prompt = prompt.replace('{input_text}', entire_dialogue)
            partial_prompt = prompt.replace('{input_text}', partial_dialogue)
            entire_message = self.create_message(custom_id+'_entire', entire_prompt)
            partial_message = self.create_message(custom_id+'_partial', partial_prompt)
            batch.append(entire_message)
            batch.append(partial_message)
        
        batch_save_fname = os.path.join(root_dir, 'data', 'batch_api_tmp.jsonl')
        save_jsonl(batch, batch_save_fname)

        batch_input_file = self.client.files.create(
            file = open(batch_save_fname, 'rb'),
            purpose = "batch"
        )

        cmd_ = f'rm {batch_save_fname}'
        os.system(cmd_)

        today = date.today().strftime("%Y%m%d")
        batch_input_file_id = batch_input_file.id

        uploaded_batch = self.client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
            "description": f"triple_extraction ({data_type}, {today}, {self.api_model})"
            }
        )
        return uploaded_batch

    def retrieve_recent_batches(self, n = 5):
        return self.client.batches.list(limit=n).data

    def retrieve_batch(self, batch_id):
        return self.client.batches.retrieve(batch_id)

    def get_batch_response(self, output_file_id):
        file_response = self.client.files.content(output_file_id)#batch.output_file_id)
        return file_response
    
    def cancel_batch_api(self, batch_id):
        import requests
        url = f"https://api.openai.com/v1/batches/{batch_id}/cancel"

        # Send a POST request to cancel the batch
        headers = {
            "Authorization": f"Bearer {self.client.api_key}",
            "Content-Type": "application/json",
        }

        response = requests.post(url, headers=headers)

        # Print the API response
        print(response.json())

In [None]:
# global variables
SCRIPT = 'friends'
DATA_TYPE = 'normal' # newname, shuffled
SYSTEM_PROMPT = 'You are a helpful assistant.'

TRIAL_IDX = 0
TEMPERATURE = 0.7
TOP_P = 0.95
NUM_MAX_TOKENS = 1000

MODEL = 'gpt-4o'

In [None]:
# local variable
triple_save_dir = os.path.join(root_dir, 'data', 'triple', SCRIPT, DATA_TYPE)
if not os.path.exists(triple_save_dir):
    os.makedirs(triple_save_dir)

data_fname = os.path.join(root_dir, 'data', 'simul-log', SCRIPT, 'processed', DATA_TYPE, f'trial{TRIAL_IDX}.pickle')
prompt_fname = os.path.join(root_dir, 'prompt', 'triple-extraction.v9-2.txt')

In [None]:
# import data
data = load_pickle(data_fname)
prompt = open(prompt_fname, 'r').read()

In [None]:
agent = openai_agent(MODEL, SYSTEM_PROMPT, TEMPERATURE, TOP_P, NUM_MAX_TOKENS)

In [None]:
batch = agent.create_batch(data, DATA_TYPE, prompt)

In [None]:
agent.retrieve_recent_batches(5)

In [None]:
all_response = agent.get_batch_response('file-As84A4AnToHCojDg68y3dC').text

In [None]:
# text로 all_response 저장 후 열기

raw_output_fname = os.path.join(root_dir, 'data', 'batch_response_tmp.txt')
with open(raw_output_fname, 'w') as f:
    f.write(all_response) 
print(f'[Saved] {raw_output_fname}')

In [None]:
with open(raw_output_fname, 'r') as f:
    raw_output = f.readlines()
f.close()

raw_output = [x.strip() for x in raw_output]

cmd_ = f'rm {raw_output_fname}'
os.system(cmd_)

In [None]:
# convert to json
data_json = []
for line in raw_output:
    data_json.append(json.loads(line))

In [None]:
output_dct = dict()
for i in tqdm(range(len(data_json))):
#for i in range(787, 788):
#for i in range(713*2+1, 713*2+2): # 713th data for partial triples
    id_ = data_json[i]['custom_id']
    output_dct[id_] = {}
    #result_ = [x.split(', ') for x in data_json[i]['response']['body']['choices'][0]['message']['content'].strip('[]').replace("(", '').split('), ')]
    #result_[-1][-1] = result_[-1][-1].strip(')')
    raw_result_ = data_json[i]['response']['body']['choices'][0]['message']['content']
    output_dct[id_]['raw'] = raw_result_
    result_ = re.sub(r'{|}', '', data_json[i]['response']['body']['choices'][0]['message']['content'].strip('[]').strip())
    if '\n' in result_:
        result_ = result_.split('\n')
        #result_ = [x.strip().replace('"', '').replace("'", '') for x in result_]
        result_ = [remove_quotes(remove_brackets(x.strip())) for x in result_]
    else:
        if len(result_) > 0:
            new_result_ = []
            tmp_result_ = result_.split(', ')
            for j in range(len(tmp_result_)//3):
                new_result_.append(', '.join(tmp_result_[j*3:(j+1)*3]))
                result_ = new_result_    
    output_dct[id_]['processed'] = result_
#print(len(id_), len(result_))

In [None]:
entire_session_triples_raw, entire_session_triples_processed, partial_session_triples_raw, partial_session_triples_processed = [], [], [], []
for i in tqdm(range(len(output_dct)//2), desc = 'splitting whole result into triples extracted from entire/partial session dialogue'):
    entire_key, partial_key = list(output_dct)[2*i], list(output_dct)[2*i+1]
    assert 'entire' in entire_key and 'partial' in partial_key
    entire_session_triples_raw.append(output_dct[entire_key]['raw'])
    entire_session_triples_processed.append(output_dct[entire_key]['processed'])
    partial_session_triples_raw.append(output_dct[partial_key]['raw'])
    partial_session_triples_processed.append(output_dct[partial_key]['processed'])

In [None]:
# check data manually
assert len(entire_session_triples_processed) == len(partial_session_triples_processed)
print(len(entire_session_triples_processed), len(partial_session_triples_processed))

random_indices = random.sample(range(0, len(entire_session_triples_processed)), 5)
for index in random_indices:
    print('='*40)
    print(f'\n[Index: {index}]')
    print('\n- Entire:')
    print(entire_session_triples_processed[index])
    print('\n- Partial:')
    print(partial_session_triples_processed[index])

In [None]:
entire_result = {}
entire_result['entire_session_triples_raw'] = entire_session_triples_raw
entire_result['entire_session_triples_processed'] = entire_session_triples_processed

partial_result = {}
partial_result['partial_session_triples_raw'] = partial_session_triples_raw
partial_result['partial_session_triples_processed'] = partial_session_triples_processed

print(entire_result.keys())
print(partial_result.keys())

In [None]:
#version=1
#version=2 # 20250507
#version = 3 # 20250529, temperature = 0
#version = 4 # 20250529, temperature = 0.75
version = 5 # 20250604, temp = 0.7, top_p = 0.95

entire_save_fname = os.path.join(triple_save_dir, 'entire-session', f'trial{TRIAL_IDX}', f'{agent.api_model}', f'v{version}.pickle')
if not os.path.exists(os.path.join(triple_save_dir, 'entire-session', f'trial{TRIAL_IDX}', f'{agent.api_model}')):
    os.makedirs(os.path.join(triple_save_dir, 'entire-session', f'trial{TRIAL_IDX}', f'{agent.api_model}'))
save_pickle(entire_result, entire_save_fname)

partial_save_fname = os.path.join(triple_save_dir, 'partial-session', f'trial{TRIAL_IDX}', f'{agent.api_model}', f'v{version}.pickle')
if not os.path.exists(os.path.join(triple_save_dir, 'partial-session', f'trial{TRIAL_IDX}', f'{agent.api_model}')):
    os.makedirs(os.path.join(triple_save_dir, 'partial-session', f'trial{TRIAL_IDX}', f'{agent.api_model}'))
save_pickle(partial_result, partial_save_fname)

# [Saved] /home/edlab/jwyang/research/dialsim-agent/data/triple/normal/trial0.normal.gpt-4o-mini-2024-07-18.pickle
# [Saved] /home/edlab/jwyang/research/dialsim-agent/data/triple/shuffled/trial0.shuffled.gpt-4o-mini-2024-07-18.pickle

In [None]:
# print out a part of result randomly
random.sample(entire_result['entire_session_triples_processed'], 1)