In [None]:
import os, sys
from dotenv import load_dotenv
root_dir = os.path.dirname(os.path.abspath(os.getcwd()))
sys.path.append(root_dir)

load_dotenv()

In [None]:
import pickle
from tqdm import tqdm
import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request

import random
random.seed(42)

In [None]:
from src.utils import load_pickle, save_pickle

For Claude message batches, refer to https://docs.anthropic.com/ko/api/messages-batch-examples

In [None]:
# global variables
SCRIPT = 'friends'
DATA_TYPE = 'normal'
TRIAL_IDX = 3
TEMPERATURE = 0
TOP_P = 1
NUM_MAX_TOKENS = 1000
SYSTEM_PROMPT = 'You are a helpful assistant'
MODEL = 'claude-opus-4'# claude-opus-4-20250514

In [None]:
# load data
prompt_fname = os.path.join(root_dir, 'prompt', 'triple-extraction.v9-2.txt') # 20250824
prompt = open(prompt_fname, 'r').read().strip()

data_fname = os.path.join(root_dir, 'data', 'simul-log', SCRIPT, 'processed', DATA_TYPE, f'trial{TRIAL_IDX}.pickle')
data = load_pickle(data_fname)

In [None]:
class anthropic_bot():
    def __init__(self, model, api_key):
        if model == 'claude-sonnet-4':
            self.api_model = 'claude-sonnet-4-20250514'
        elif model == 'claude-sonnet-3.7':
            self.api_model = 'claude-3.7-sonnet-20250219'
        elif model == 'claude-sonnet-3.5':
            self.api_model = 'claude-3.5-sonnet-20241022'
        elif model == 'claude-opus-4':
            self.api_model = 'claude-opus-4-20250514'
        else:
            pass # anthropic models other than claude-sonnet 3.5, 3.7, 4 are not implemented yet
        self.client = anthropic.Anthropic(api_key = api_key)
        
    def create_batch_message(self, data, max_tokens, temperature, system_prompt, top_p, data_type):
        requests = []
        if data_type == 'partial':
            for i in tqdm(range(len(data['session_dialogue'])), desc = 'Extracting partial dialogues'):
                custom_id = f'session_{i+1}_partial'
                partial_dialogue = data['partial_session_dialogue'][i]
                user_prompt = prompt.replace('{input_text}', partial_dialogue)
                requests.append(
                    Request(
                            custom_id=custom_id, 
                            params=MessageCreateParamsNonStreaming(
                                model=self.api_model,
                                max_tokens = max_tokens,
                                temperature = temperature,
                                system = system_prompt,
                                top_p = top_p,
                                messages = [{'role': 'user', 
                                    'content': user_prompt}]
                                )
                        )
                )
            message_batch = self.client.messages.batches.create(requests = requests)
            return message_batch
        elif data_type == 'entire':
            requests = []
            for i in tqdm(range(len(data['session_dialogue'])), desc = 'Extracting entire dialogues'):
                custom_id = f'session_{i+1}_entire'
                entire_dialogue = data['session_dialogue'][i]
                user_prompt = prompt.replace('{input_text}', entire_dialogue)

                requests.append(
                    Request(
                            custom_id=custom_id, 
                            params=MessageCreateParamsNonStreaming(
                                model=self.api_model,
                                max_tokens = max_tokens,
                                temperature = temperature,
                                top_p = top_p,
                                system = system_prompt,
                                messages = [{'role': 'user', 
                                    'content': user_prompt}]
                                )
                        )
                )
            message_batch = self.client.messages.batches.create(requests = requests)
            return message_batch
    def retrieve_batch_status(self, batch_id):
        return self.client.messages.batches.retrieve(batch_id)

In [None]:
bot = anthropic_bot(MODEL, os.getenv('ANTHROPIC_API_KEY'))

In [None]:
partial_batch = bot.create_batch_message(data, NUM_MAX_TOKENS, TEMPERATURE, SYSTEM_PROMPT, TOP_P, 'partial')
print(partial_batch)

In [None]:
#entire_batch = bot.create_batch_message(data, NUM_MAX_TOKENS, TEMPERATURE, SYSTEM_PROMPT, TOP_P, 'entire')
#print(entire_batch)

In [None]:
batch_ids = dict()
#batch_ids['partial'] = 'msgbatch_01LkRRRcj2s9nYTjfJ35CYFy' # trial1
#batch_ids['partial'] = 'msgbatch_016oQ9rQHD1RbJ1Z2UdatVU5' # trial2
batch_ids['partial'] = 'msgbatch_01RyKA5NVrVBqeyhdcXGMgWk' # trial3

#batch_ids['entire'] = 'msgbatch_01RU2ftxkLnd7jpLXZRH3XUd'

In [None]:
bot.retrieve_batch_status(batch_ids['partial'])

In [None]:
#bot.retrieve_batch_status(entire_batch.id)
#bot.retrieve_batch_status('msgbatch_01X65FrrQPScLxEomduRVbsw') # entire
bot.retrieve_batch_status(batch_ids['entire'])

---

After all batches are completed

Parse raw data

In [None]:
all_result = {}
# triples from partial-session dialogues

#full_batch_id = batch_ids['partial'] # partial, entire
full_batch_id = batch_ids['partial'] # partial, entire

for result in bot.client.messages.batches.results(
    full_batch_id,
):
    print(result.custom_id, result.result.message.content[0].text)

all_partial_session_triples_raw = []

for result in bot.client.messages.batches.results(full_batch_id,):
    all_partial_session_triples_raw.append(result.result.message.content[0].text)

all_result['partial_session_triples_raw'] = all_partial_session_triples_raw

assert len(all_result['partial_session_triples_raw']) == len(data['date'])
assert len(all_result['partial_session_triples_raw']) == len(data['partial_session_dialogue'])

all_partial_session_triples_processed = []
all_partial_session_quadruples_processed = []

for i in tqdm(range(len(all_partial_session_triples_raw)), desc = 'parsing raw triple text from each session partial dialolgue'):
#for i in range(1):
    current_session_triples = []
    current_session_quadruples = []
    raw_text = all_partial_session_triples_raw[i]
    #print(raw_text)
    #print("============================")
    if '[' in raw_text and ']' in raw_text:
        #print(raw_text.split('[')[1].split(']')[0])
        #print("============================")
        #print(raw_text.split('[')[1].split(']')[0].split('}'))
        triple_candidates = raw_text.split('[')[1].split(']')[0].split('}')
        for j in range(len(triple_candidates)):
            triple_candidate = triple_candidates[j]
            if '{' in triple_candidate:
                triple_candidate = triple_candidate.split('{')[1]
                #print(triple_candidate)
                triple_candidate = triple_candidate.split(', ')
                
                if len(triple_candidate) == 3: # triple should consist of 3 elements
                    #print('\ntriple_candidate:', triple_candidate)
                    #current_triple = dict.fromkeys(['head','relation','tail'])
                    current_quadruple = dict.fromkeys(['head','relation','tail', 'start_date'])
                    for k in range(len(triple_candidate)):
                        triple_component = triple_candidate[k]
                        if ':' not in triple_component:
                            triple_component = triple_component.strip().replace('"', '').replace("'", '')
                        if ':' in triple_component:
                            triple_component = triple_component.split(':')[-1].strip().replace('"', '').replace("'", '')
                        #if k != 1:
                        if k == 0:
                            triple_component = triple_component.capitalize() # if this is not a relationship
                        #print(f'triple_component: {triple_component}')
                        if k == 0:
                            current_quadruple['head'] = triple_component
                        if k == 1:
                            current_quadruple['relation'] = triple_component
                        if k == 2:
                            current_quadruple['tail'] = triple_component
                    current_quadruple['start_date'] = data['date'][i]
                    current_triple = f'{current_quadruple["head"]}, {current_quadruple["relation"]}, {current_quadruple["tail"]}'

                    #print(f'current_triple: {current_triple}')
                    #print(f'current_quadruple: {current_quadruple}') 
                    
                    current_session_triples.append(current_triple)
                    current_session_quadruples.append(current_quadruple)  
    else:
        pass
    all_partial_session_triples_processed.append(current_session_triples)
    all_partial_session_quadruples_processed.append(current_session_quadruples)
print(len(all_partial_session_triples_processed), len(all_partial_session_quadruples_processed), len(data['date']))

display(random.sample(all_partial_session_triples_processed, 2))
display(random.sample(all_partial_session_quadruples_processed, 2))

#  'partial_session_triples_processed', <- batch api 결과 가지고 만들기 // 'partial_session_quadruples_processed', <- batch api 결과 triple 각각에 해당하는 timestamp 붙여서 만들기
all_result['partial_session_triples_processed'] = all_partial_session_triples_processed
all_result['partial_session_quadruples_processed'] = all_partial_session_quadruples_processed

display(random.sample(all_result['partial_session_triples_processed'], 2))
display(random.sample(all_result['partial_session_quadruples_processed'], 2))

In [None]:
# triples from entire-session dialogues

#full_batch_id = batch_ids['partial'] # partial, entire
full_batch_id = batch_ids['entire'] # partial, entire

for result in bot.client.messages.batches.results(
    full_batch_id,
):
    print(result.custom_id, result.result.message.content[0].text)

all_entire_session_triples_raw = []

for result in bot.client.messages.batches.results(full_batch_id,):
    all_entire_session_triples_raw.append(result.result.message.content[0].text)

all_result['entire_session_triples_raw'] = all_entire_session_triples_raw

assert len(all_result['entire_session_triples_raw']) == len(data['date'])
assert len(all_result['entire_session_triples_raw']) == len(data['session_dialogue'])

all_entire_session_triples_processed = []
all_entire_session_quadruples_processed = []

for i in tqdm(range(len(all_entire_session_triples_raw)), desc = 'parsing raw triple text from each session entire dialolgue'):
#for i in range(1):
    current_session_triples = []
    current_session_quadruples = []
    raw_text = all_entire_session_triples_raw[i]
    #print(raw_text)
    #print("============================")
    if '[' in raw_text and ']' in raw_text:
        #print(raw_text.split('[')[1].split(']')[0])
        #print("============================")
        #print(raw_text.split('[')[1].split(']')[0].split('}'))
        triple_candidates = raw_text.split('[')[1].split(']')[0].split('}')
        for j in range(len(triple_candidates)):
            triple_candidate = triple_candidates[j]
            if '{' in triple_candidate:
                triple_candidate = triple_candidate.split('{')[1]
                #print(triple_candidate)
                triple_candidate = triple_candidate.split(', ')
                
                if len(triple_candidate) == 3: # triple should consist of 3 elements
                    #print('\ntriple_candidate:', triple_candidate)
                    #current_triple = dict.fromkeys(['head','relation','tail'])
                    current_quadruple = dict.fromkeys(['head','relation','tail', 'start_date'])
                    for k in range(len(triple_candidate)):
                        triple_component = triple_candidate[k]
                        if ':' not in triple_component:
                            triple_component = triple_component.strip().replace('"', '').replace("'", '')
                        if ':' in triple_component:
                            triple_component = triple_component.split(':')[-1].strip().replace('"', '').replace("'", '')
                        #if k != 1:
                        if k == 0:
                            triple_component = triple_component.capitalize() # if this is not a relationship
                        #print(f'triple_component: {triple_component}')
                        if k == 0:
                            current_quadruple['head'] = triple_component
                        if k == 1:
                            current_quadruple['relation'] = triple_component
                        if k == 2:
                            current_quadruple['tail'] = triple_component
                    current_quadruple['start_date'] = data['date'][i]
                    current_triple = f'{current_quadruple["head"]}, {current_quadruple["relation"]}, {current_quadruple["tail"]}'

                    #print(f'current_triple: {current_triple}')
                    #print(f'current_quadruple: {current_quadruple}') 
                    
                    current_session_triples.append(current_triple)
                    current_session_quadruples.append(current_quadruple)  
    else:
        pass
    all_entire_session_triples_processed.append(current_session_triples)
    all_entire_session_quadruples_processed.append(current_session_quadruples)
print(len(all_entire_session_triples_processed), len(all_entire_session_quadruples_processed), len(data['date']))

display(random.sample(all_entire_session_triples_processed, 2))
display(random.sample(all_entire_session_quadruples_processed, 2))

#  'entire_session_triples_processed', <- batch api 결과 가지고 만들기 // 'entire_session_quadruples_processed', <- batch api 결과 triple 각각에 해당하는 timestamp 붙여서 만들기
all_result['entire_session_triples_processed'] = all_entire_session_triples_processed
all_result['entire_session_quadruples_processed'] = all_entire_session_quadruples_processed

display(random.sample(all_result['entire_session_triples_processed'], 2))
display(random.sample(all_result['entire_session_quadruples_processed'], 2))

In [None]:
# claude_data['entire_session_quadruples_processed'] 
for i in tqdm(range(len(all_result['entire_session_quadruples_processed'])), desc = 'capitalizing head of each quadruple'):
    current_session_quadruples = all_result['entire_session_quadruples_processed'][i]
    for j in range(len(current_session_quadruples)):
        quadruple = current_session_quadruples[j]
        quadruple['head'] = quadruple['head'].capitalize()
        quadruple['tail'] = quadruple['tail'].capitalize()
        current_session_quadruples[j] = quadruple
    all_result['entire_session_quadruples_processed'][i] = current_session_quadruples

In [None]:
random.sample(all_result['entire_session_quadruples_processed'], 2)

In [None]:
# {'entire_session_triples_processed}',  <- entire_session_quadruples_processed 가지고 만들기

all_entire_session_triples_processed = []

for i in tqdm(range(len(all_result['entire_session_quadruples_processed'])), desc = 'creating processed triples out of processed quadruples'):
    current_session_quadruples = all_result['entire_session_quadruples_processed'][i]
    current_session_triples = []
    idx_to_del = []
    for j in range(len(current_session_quadruples)):
        quadruple = current_session_quadruples[j]
        if len(quadruple) == 4:
            #triple = dict.fromkeys(['head', 'relation', 'tail'])
            #triple['head'] = quadruple['head']
            #triple['relation'] = quadruple['relation']
            #triple['tail'] = quadruple['tail']
            #triple = f'{quadruple["head"]}, {quadruple["relation"]}, {quadruple["tail"]}'
            triple = []
            triple.append(quadruple['head'].capitalize())
            triple.append(quadruple['relation'])
            triple.append(quadruple['tail'].capitalize())
            current_session_triples.append(triple)
        else:
            # error in quadruple
            idx_to_del.append(j)
    idx_to_del = idx_to_del[::-1]
    for idx in idx_to_del:
        assert len(current_session_quadruples[idx]) != 4
        del current_session_quadruples[idx]
    all_result['entire_session_quadruples_processed'][i] = current_session_quadruples
    all_entire_session_triples_processed.append(current_session_triples)

all_result['entire_session_triples_processed'] = all_entire_session_triples_processed

print(len(all_result['entire_session_triples_processed'])) # 788

In [None]:
# assert that every session has equal number of triple and quadruple
for i in tqdm(range(len(data['session_dialogue'])), desc = 'asserting that num_triples == num_quadruples in all sessions'):
    assert len(all_result['partial_session_triples_raw']) == len(all_result['partial_session_quadruples_processed'])
    assert len(all_result['partial_session_triples_processed']) == len(all_result['partial_session_quadruples_processed'])

In [None]:
script = 'friends'

#model_name = 'claude-3.7-sonnet-20250219'
#model_name = 'claude-sonnet-4-20250514'
model_name = 'claude-opus-4-20250514'

#version = 2 # 이 모델로 뽑은 몇 번째 트리플인지
version = 1

"""
## save triples from entire-session dialogues
entire_result = {}
entire_result['entire_session_triples_raw'] = all_result['entire_session_triples_raw']
entire_result['entire_session_triples_processed'] = all_result['entire_session_triples_processed']
if not os.path.exists(f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/entire-session/trial{TRIAL_IDX}/{model_name}'):
    os.makedirs(f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/entire-session/trial{TRIAL_IDX}/{model_name}')
save_pickle(entire_result, f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/entire-session/trial{TRIAL_IDX}/{model_name}/v{version}.pickle')
"""

## save triples from partial-session dialogues
partial_result = {}
partial_result['partial_session_triples_raw'] = all_result['partial_session_triples_raw']
partial_result['partial_session_triples_processed'] = all_result['partial_session_triples_processed']
if not os.path.exists(f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/partial-session/trial{TRIAL_IDX}/{model_name}'):
    os.makedirs(f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/partial-session/trial{TRIAL_IDX}/{model_name}')
save_pickle(partial_result, f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/partial-session/trial{TRIAL_IDX}/{model_name}/v{version}.pickle')
# priint saved filename
print(f'/home/edlab/jwyang/research/dialsim-agent/data/triple/{script}/normal/partial-session/trial{TRIAL_IDX}/{model_name}/v{version}.pickle')