In [1]:
from my_funcs import (
    default_transformation, read_json_from_data_dir,
    normalize, embed_query_retrieve_examples, iterate_nearest_dialogs,
    make_two_type_msg, get_python_chat_prompt, 
    parse_python_completion, update_dialogue_state, 
    compute_acc, calculate_token_f1, evaluate,
    DataOntologyNormalizer, Ontology,
    copy, defaultdict, random,
    openai, tiktoken, SentenceTransformer
)

import os
import json
import numpy as np
# from rank_bm25 import BM25Okapi
from refpydst.prompt_formats.python.completion_parser import *
from refpydst.prompt_formats.python.completion_parser import parse_python_modified
from refpydst.evaluate_metrics import evaluate
from vllm import LLM, SamplingParams
import torch
from typing import Tuple, List
from sklearn.metrics.pairwise import cosine_similarity

from my_openai_key import get_openai_key
openai.api_key = get_openai_key()

  from tqdm.autonotebook import tqdm, trange


In [5]:
def embed_query_retrieve_examples(embedder, example_pool, query_data, emb_keys, emb_values, label_to_idx, num_retrieved_examples=10):
    # Embed Query based on its turn and previous (predicted) slot values 
    with torch.no_grad():
        query_string = default_transformation(query_data)
        query_emb = embedder.encode(query_string, convert_to_numpy=True).reshape(1, -1)
    
    exmple_generator = (
            (example, score)
            for turn_label, score in iterate_nearest_dialogs(query_emb, emb_keys, emb_values, k=5)
                for example in example_pool
                    if example['ID'] == turn_label.split('_')[0] and example['turn_id'] == int(turn_label.split('_')[-1])
        )

    all_considered_examples: List[Tuple[Turn, float]] = \
        [turn_and_score for _, turn_and_score in zip(range(100), exmple_generator)]
    all_embeddings = np.asarray([
        emb_values[label_to_idx[f"{turn['ID']}_turn_{turn['turn_id']}"]]
        for turn, score in all_considered_examples
    ])
    if len(all_considered_examples) == 0:
        raise ValueError("No examples found in the retriever index.")

    result: List[int] = []
    example_scores = np.asarray([1 - 0.5*(score**2) for turn, score in all_considered_examples])
    assert np.all(np.diff(example_scores) <= 0)  # verifies they are decreasing as expected
    while len(result) < num_retrieved_examples:
        best_idx: int = np.argmax(example_scores).item()
        example_scores[best_idx] = -np.inf
        result.append(best_idx)
        best_emb = all_embeddings[best_idx]
        discount = 0.2 * cosine_similarity(best_emb[None, :], all_embeddings).squeeze(0)
        example_scores = example_scores - discount

    retrieved_exampels = [all_considered_examples[i][0] for i in result][::-1]
    retrieved_exampels = [e for e in retrieved_exampels if e['ID'] != query_data['ID']][-10:]
    return retrieved_exampels

def make_dict(query_data):
    li = []    
    for generated_example in query_data['generated']:
        tmp = {}
        tmp['dialog'] = {}

        if '[context]' in generated_example:
            contxt_end = generated_example.index('[context]') + len('[context]')
        else:
            contxt_end = 0
    
        if '[utterance_sys]' in generated_example:
            sys_start = generated_example.index('[utterance_sys]')
            sys_end = sys_start + len('[utterance_sys]')
        else:
            sys_start = contxt_end+1
            sys_end = sys_start
    
        if '[utterance_usr]' in generated_example:
            usr_start = generated_example.index('[utterance_usr]')
            usr_end = usr_start + len('[utterance_usr]')
        else:
            usr_start = sys_end+1
            usr_end = usr_start
        
        if '[belief_state]' in generated_example:
            belief_start = generated_example.index('[belief_state]')
            belief_end = belief_start + len('[belief_state]')
        else:
            belief_start = None
            belief_end = None

        tmp['last_slot_values'] = dict(eval('{'+generated_example[contxt_end:sys_start]+'}'))        
        tmp['dialog']['sys'] = [generated_example[sys_end:usr_start]]
        tmp['dialog']['usr'] = [generated_example[usr_end:belief_start]]

        tmp['slot_values'] = {}
        bs = generated_example[belief_end:].strip() if belief_end is not None else ''
        for slot_val in bs.split(','):
            s_v_list = slot_val.split(':')
            if len(s_v_list) <= 1:
                continue
            if len(s_v_list) > 2:
                s_v_list[1] = ':'.join(s_v_list[1:])
            slot, val = s_v_list[0].strip(), s_v_list[1].strip()
            tmp['slot_values'][slot] = val
        li.append(tmp)
    return li

query_data_path = 'jun/inference/mw24_20p_dev_wrong_67.json'

with open(query_data_path, 'r') as f:
    query_dataset = json.load(f)
new_data = []
for query_data in query_dataset:
    new_data.append(make_dict(query_data))

SyntaxError: invalid syntax. Perhaps you forgot a comma? (<string>, line 1)

In [None]:
sample_pool_path = 'data/mw21_100p_train.json'
query_data_path = 'jun/inference/mw24_20p_dev_wrong_67.json'
engine = "/data1/home/haesungpyun/models/Meta-Llama-3-70B-Instruct-GPTQ"
quantization = 'gptq'

with open(query_data_path, 'r') as f:
    query_dataset = json.load(f)
with open(sample_pool_path, 'r') as f:
    sample_pool = json.load(f)

# Register all dialogues from the train dataset to example pool and Get all the unique dialogue ids in example pool
example_pool = []
selected_dialog_id_from_split = set()
for dataset in [sample_pool]:
    example_pool += dataset
    selected_dialog_id_from_split.update([dial['ID'] for dial in dataset])

# Load the all train data index
retriever_full_path = '/home/haesungpyun/my_refpydst/outputs/runs/retriever/mw21_100p_train/referred_states'
search_index_filename = os.path.join(retriever_full_path, "train_index.npy")
search_embeddings = np.load(search_index_filename, allow_pickle=True).item()    # {'MUL0720.json_turn_10': np.array([0.1, 0.2, ...]), ...}

# Keep only embeddings for the selected dialogues in split version
emb_dict = {k: v for k, v in search_embeddings.items() if k.split('_')[0] in selected_dialog_id_from_split}
emb_keys = list(emb_dict.keys())
emb_dim = emb_dict[emb_keys[0]].shape[-1]

# Convert embeddings to array and Normalize them
emb_values = np.zeros((len(emb_keys), emb_dim))
for i, k in enumerate(emb_keys):
    emb_values[i] = emb_dict[k]
emb_values = normalize(emb_values)

# Create a label to index mapping  {'MUL0720.json_turn_10': 1, ...} 
label_to_idx = {k: i for i, k in enumerate(emb_keys)}

# Load the model for embed query
embedder = SentenceTransformer(retriever_full_path)

# Tokenizer
# encoding = tiktoken.encoding_for_model('gpt-3.5-turbo')


model = LLM(model="/data1/home/haesungpyun/models/Meta-Llama-3-70B-Instruct-GPTQ", quantization='gptq', enforce_eager=False)
tokenizer = model.get_tokenizer()

terminators =  [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")    
]

stop_sequences = ['--', '\n', ';', '#']

In [None]:
# Query에 대해 retrieve & generate & evaluate
total_log = []
retrieving_samples = True
num_retrieved_examples = 100
n_smapled_examples = 10

random.seed(42)
# Randomly select a query data and Retrieve examples from example pool (train data)
# query_data = query_dataset[random.randint(0, len(dev_dataset))]

prev_data = None
for query_data in query_dataset:
    query_data['pred_prior_context'] = prev_data['pred_slot_values'] if prev_data else {}
    modified_data = copy.deepcopy(query_data)
    modified_data['last_slot_values'] = query_data.get('pred_prior_context', {})

    retrieved_examples = embed_query_retrieve_examples(
        embedder, example_pool, modified_data, 
        emb_keys, emb_values, label_to_idx, num_retrieved_examples=num_retrieved_examples)
    msg_chat, gold_python = get_python_chat_prompt(query_data, retrieved_examples) 

    generated_examples = make_dict(query_data)
    gen_msg_chat, _ = get_python_chat_prompt(query_data, generated_examples)   
    # msg_chat_usr_last, msg_one_prompt = make_two_type_msg(msg_chat)
    # raise ValueError    
    log = defaultdict(dict)
    log['ID-turn-id']= f"{query_data['ID']}-{query_data['turn_id']}"
    log['last_slot_values'] = query_data['last_slot_values'] 
    log['turn_slot_values'] = query_data['turn_slot_values']
    log['slot_values']= query_data['slot_values']
    log['gold-python'] = gold_python

    examples_list = []
    for idx, example in enumerate(retrieved_examples):
        tmp = {}
        tmp['ID_turn-id'] = f"{example['ID']}-{example['turn_id']}"
        tmp['last_slot_values'], tmp['turn_slot_values'], tmp['slot_values'] = \
            example['last_slot_values'], example['turn_slot_values'], example['slot_values']
        examples_list.append(tmp)
    log["retrieve_example"] = examples_list

    samplig_params = SamplingParams(
        n=1, best_of=1, max_tokens=120, 
        temperature=0, stop=stop_sequences,
        stop_token_ids=terminators)
    msg_chat_ids = tokenizer.apply_chat_template(
        msg_chat, add_generation_prompt=True, return_tensors='pt')
    gen_msg_chat = tokenizer.apply_chat_template(
        gen_msg_chat, add_generation_prompt=True, return_tensors='pt')
    
    prompts = [tokenizer.batch_decode(ids, skip_special_tokens=False)[0] for ids in [msg_chat_ids, gen_msg_chat]]
    result = model.generate(prompts, sampling_params=samplig_params)
    
    completions = [{output.outputs[0].text: 1} for output in result]
    best_completion = [comp.strip().replace('agent.state.', '') for dic in completions for comp,_ in dic.items()]

    predicted_prior_context = query_data.get('pred_prior_context', query_data['last_slot_values'])
    batch_pred_s_v = [parse_python_completion(comp, predicted_prior_context) for comp in best_completion]
    batch_pred_turn_s_v = [parse_state_change(comp, predicted_prior_context) for comp in best_completion]

################################################################################
    pred_prev = prev_data.get('pred_slot_values', {}) if prev_data else {}
    pred = update_dialogue_state(pred_prev, batch_pred_turn_s_v[0])

    query_data['og_completion'] = best_completion[0]
    query_data['og_pred_turn_slot_values'] = batch_pred_turn_s_v[0]
    query_data['og_pred_slot_values'] = batch_pred_s_v[0]
    
    log['og_completion'] = best_completion[0]
    log['og_pred_turn_slot_values'] = batch_pred_turn_s_v[0]
    log['og_pred_slot_values'] = batch_pred_s_v[0]
    
    this_jga, this_acc, this_f1 = evaluate(pred, query_data['slot_values'])
    delta_jga, _, _ = evaluate(batch_pred_turn_s_v[0], query_data['turn_slot_values'])
    log['og_full_jga'] = this_jga
    log['og_delta_jga'] = delta_jga

###############################################################################
    pred_prev = prev_data.get('pred_slot_values', {}) if prev_data else {}
    pred = update_dialogue_state(pred_prev, batch_pred_turn_s_v[1])

    query_data['completion'] = best_completion[1]
    query_data['pred_turn_slot_values'] = batch_pred_turn_s_v[1]
    query_data['pred_slot_values'] = batch_pred_s_v[1]
    
    log['completion'] = best_completion[1]
    log['pred_turn_slot_values'] = batch_pred_turn_s_v[1]
    log['pred_slot_values'] = batch_pred_s_v[1]
    
    this_jga, this_acc, this_f1 = evaluate(pred, query_data['slot_values'])
    delta_jga, _, _ = evaluate(batch_pred_turn_s_v[1], query_data['turn_slot_values'])
    log['full_jga'] = this_jga
    log['delta_jga'] = delta_jga

    total_log.append(log)

    prev_data = query_data


In [None]:
# Query에 대해 retrieve & generate & evaluate
total_log = []
retrieving_samples = True
num_retrieved_examples = 100
n_smapled_examples = 10

random.seed(42)
# Randomly select a query data and Retrieve examples from example pool (train data)
# query_data = query_dataset[random.randint(0, len(dev_dataset))]

prev_data = None
for query_data in query_dataset:
    query_data['pred_prior_context'] = prev_data['pred_slot_values'] if prev_data else {}
    modified_data = copy.deepcopy(query_data)
    modified_data['last_slot_values'] = query_data.get('pred_prior_context', {})

    retrieved_examples = embed_query_retrieve_examples(
        embedder, example_pool, modified_data, 
        emb_keys, emb_values, label_to_idx, num_retrieved_examples=num_retrieved_examples)
    msg_chat, gold_python = get_python_chat_prompt(query_data, retrieved_examples) 

    generated_examples = make_dict(query_data)
    gen_msg_chat, _ = get_python_chat_prompt(query_data, generated_examples)   
    # msg_chat_usr_last, msg_one_prompt = make_two_type_msg(msg_chat)
    # raise ValueError    
    log = defaultdict(dict)
    log['ID-turn-id']= f"{query_data['ID']}-{query_data['turn_id']}"
    log['last_slot_values'] = query_data['last_slot_values'] 
    log['turn_slot_values'] = query_data['turn_slot_values']
    log['slot_values']= query_data['slot_values']
    log['gold-python'] = gold_python

    examples_list = []
    for idx, example in enumerate(retrieved_examples):
        tmp = {}
        tmp['ID_turn-id'] = f"{example['ID']}-{example['turn_id']}"
        tmp['last_slot_values'], tmp['turn_slot_values'], tmp['slot_values'] = \
            example['last_slot_values'], example['turn_slot_values'], example['slot_values']
        examples_list.append(tmp)
    log["retrieve_example"] = examples_list

    samplig_params = SamplingParams(
        n=1, best_of=1, max_tokens=120, 
        temperature=0, stop=stop_sequences,
        stop_token_ids=terminators)
    msg_chat_ids = tokenizer.apply_chat_template(
        msg_chat, add_generation_prompt=True, return_tensors='pt')
    gen_msg_chat = tokenizer.apply_chat_template(
        gen_msg_chat, add_generation_prompt=True, return_tensors='pt')
    
    prompts = [tokenizer.batch_decode(ids, skip_special_tokens=False)[0] for ids in [msg_chat_ids, gen_msg_chat]]
    result = model.generate(prompts, sampling_params=samplig_params)
    
    completions = [{output.outputs[0].text: 1} for output in result]
    best_completion = [comp.strip().replace('agent.state.', '') for dic in completions for comp,_ in dic.items()]

    predicted_prior_context = query_data.get('pred_prior_context', query_data['last_slot_values'])
    batch_pred_s_v = [parse_python_completion(comp, predicted_prior_context) for comp in best_completion]
    batch_pred_turn_s_v = [parse_state_change(comp, predicted_prior_context) for comp in best_completion]

################################################################################
    pred_prev = prev_data.get('pred_slot_values', {}) if prev_data else {}
    pred = update_dialogue_state(pred_prev, batch_pred_turn_s_v[0])

    query_data['og_completion'] = best_completion[0]
    query_data['og_pred_turn_slot_values'] = batch_pred_turn_s_v[0]
    query_data['og_pred_slot_values'] = batch_pred_s_v[0]
    
    log['og_completion'] = best_completion[0]
    log['og_pred_turn_slot_values'] = batch_pred_turn_s_v[0]
    log['og_pred_slot_values'] = batch_pred_s_v[0]
    
    this_jga, this_acc, this_f1 = evaluate(pred, query_data['slot_values'])
    delta_jga, _, _ = evaluate(batch_pred_turn_s_v[0], query_data['turn_slot_values'])
    log['og_full_jga'] = this_jga
    log['og_delta_jga'] = delta_jga

###############################################################################
    pred_prev = prev_data.get('pred_slot_values', {}) if prev_data else {}
    pred = update_dialogue_state(pred_prev, batch_pred_turn_s_v[1])

    query_data['completion'] = best_completion[1]
    query_data['pred_turn_slot_values'] = batch_pred_turn_s_v[1]
    query_data['pred_slot_values'] = batch_pred_s_v[1]
    
    log['completion'] = best_completion[1]
    log['pred_turn_slot_values'] = batch_pred_turn_s_v[1]
    log['pred_slot_values'] = batch_pred_s_v[1]
    
    this_jga, this_acc, this_f1 = evaluate(pred, query_data['slot_values'])
    delta_jga, _, _ = evaluate(batch_pred_turn_s_v[1], query_data['turn_slot_values'])
    log['full_jga'] = this_jga
    log['delta_jga'] = delta_jga

    total_log.append(log)

    prev_data = query_data
