### Pipeline should be like:
1. Core: could predict next action given the current context
2. can load the data from the benchmark
3. connect the data to the core module, to get the run-time action prediction list
4. Use CE metric to evaluate the quality of the predicted action list
5. if good enough, then save the dialogues as the new data for further training

In [None]:
'''
Set the openai GPT-4
'''
from openai import OpenAI
clientGPT4 = OpenAI(api_key="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")
clientGPT3_5 = OpenAI(api_key="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")

'''
Set the AST module for predict the next action: 
For demo, use the model for SGD dataset 
'''
"""
Reference: https://github.com/huggingface/transformers/tree/main/examples/pytorch

Adapted from huggingface Transformers
"""

import logging
import os
import sys
from pathlib import Path
import time

import datasets
import transformers
import transformers.trainer_utils as hf_trainer_utils
import numpy as np
import nltk  # Here to have a nice missing dependency error message early on

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
)

from src.data.data_args import DataArguments
from src.data.dataset_loader import DatasetLoader
from src.data.utils import group_col_name
from src.metrics import create_compute_metric_fct, verify_nltk
from src.model.hf_model_args import HfModelArguments
from src.hf_training.hf_training_args import HfSeq2SeqTrainingArgs
# set cuda
os.environ["CUDA_VISIBLE_DEVICES"]="1"

logger = logging.getLogger(__name__)

def train(trainer, train_dataset, training_args):
    logger.info("*** train ***")

    check_point = get_resume_checkpoint(training_args)
    train_result = trainer.train(resume_from_checkpoint=check_point)

    trainer.save_model()  # Saves the tokenizer too for easy upload

    metrics = train_result.metrics
    metrics["train_samples"] = len(train_dataset)
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()


def do_eval(trainer, validation_dataset, max_length, num_beams):
    logger.info("*** Evaluate ***")

    metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")

    metrics["eval_samples"] = len(validation_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

def do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams):
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
        return preds, labels

    def decode(preds, labels):
        if isinstance(preds, tuple):
            preds = preds[0]
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        model_path = Path(model_args.model_name_or_path)
        file_name = "pred_mwoz.txt" if training_args.is_mwoz else "preds_test_set.txt"
        if not model_path.exists():
            # model name
            preds_file_path = Path(training_args.output_dir) / file_name
        else:
            preds_file_path = model_path / file_name

        with preds_file_path.open("w") as f:
            for pred, label in zip(decoded_preds, decoded_labels):
                label = label.replace("\n", " ")
                pred = pred.replace("\n", " ")
                f.write(f"{pred}\t{label}" + "\n")

        return decoded_preds, decoded_labels
    
    logger.info("*** Predict ***")

    metrics = {}
    predictions = []
    if group_col_name in test_dataset.column_names:
        group_idx = 0

        while True:
            group_dataset = test_dataset.filter(lambda x: x[group_col_name] == group_idx)
            if group_dataset.num_rows == 0:
                # no groups left
                break
            logger.info("Predicting on test group %d", group_idx)

            predict_results = trainer.predict(
                group_dataset,
                metric_key_prefix=f"predict_group_{group_idx}",
                max_length=max_length,
                num_beams=num_beams
            )
            metrics.update(predict_results.metrics)
            metrics[f"predict_samples_group_{group_idx}_size"] = len(group_dataset)

            group_idx += 1

            predictions.append(predict_results.predictions)

        for key in ["loss", "rouge1", "rouge2", "rougeL"]:
            metrics[f"overall_predict_{key}"] = round(
                sum([metrics[f"predict_group_{idx}_{key}"] for idx in range(group_idx)]) / group_idx, 4
            )
    else:
        '''
        here
        '''
        # print("test_dataset.column_names: ", test_dataset.column_names)
        # print("test_dataset: ", test_dataset)
        # print("test_dataset[:2]: ", test_dataset[:2])
        # sample_test_dataset = test_dataset.filter(lambda x: x["sample_id"] in [0, 1, 3])
        # print("sample_test_dataset[\"sample_id\"]: ", sample_test_dataset["sample_id"])
        # sample_test_dataset["sample_id"] = [0, 1, 2]
        # print("sample_test_dataset[\"sample_id\"]: ", sample_test_dataset["sample_id"])
        # sample_test_dataset["input_ids"] = [sample_test_dataset["input_ids"][0], sample_test_dataset["input_ids"][1], [1,2,3,4,5]]
        # print("sample_test_dataset[\"input_ids\"]: ", sample_test_dataset["input_ids"])

        # print("sample_test_dataset: ", sample_test_dataset)
        # print(test_dataset["sample_id"])
        # print(test_dataset["input_ids"])
        # print(test_dataset["labels"])
        
        predict_results = trainer.predict(
            test_dataset, 
            metric_key_prefix="test", 
            max_length=max_length, 
            num_beams=num_beams,
            return_dict_in_generate=True, 
            num_return_sequences=num_beams,
            output_scores=True,
        )
        print("predict_results: ", predict_results)
        # print("predict_results.predictions: ", predict_results.predictions)
        metrics = predict_results.metrics
        metrics["predict_samples_size"] = len(test_dataset)

    # trainer.log(metrics)
    # trainer.log_metrics("test", metrics)
    # trainer.save_metrics("test", metrics)

    return decode(predict_results.predictions, test_dataset["labels"])


def load_model(model_args, data_args, tokenizer):
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Forcing the generation min lenght, to avoid models preset for summarization tasks that are usually high
    config.min_length = 5

    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model.resize_token_embeddings(len(tokenizer))

    task_specific_params = model.config.task_specific_params
    if task_specific_params is not None:
        model.config.update(task_specific_params.get("summarization_cnn", {}))

    if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
        if isinstance(tokenizer, MBartTokenizer):
            model.config.decoder_start_token_id = tokenizer.lang_code_to_id["en_XX"]
        else:
            model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids("en_XX")

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    if (
        hasattr(model.config, "max_position_embeddings")
        and model.config.max_position_embeddings < data_args.max_source_length
    ):
        if model_args.resize_position_embeddings is None:
            logger.warning(
                "Increasing the model's number of position embedding vectors from"
                f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
            )
            model.resize_position_embeddings(data_args.max_source_length)
        elif model_args.resize_position_embeddings:
            model.resize_position_embeddings(data_args.max_source_length)
        else:
            raise ValueError(
                f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
                f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
                f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
                " model's position encodings by passing `--resize_position_embeddings`."
            )

    return model


def get_resume_checkpoint(training_args):
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint

    last_checkpoint = get_last_checkpoint(training_args)
    if last_checkpoint is not None:
        checkpoint = last_checkpoint

    return checkpoint


def get_last_checkpoint(training_args):
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = hf_trainer_utils.get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming hf_training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    return last_checkpoint


def setup_logging(training_args):
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()


def create_data_collector(model, tokenizer, training_args, data_args):
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    return DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8 if training_args.fp16 else None,
    )


def setup_wandb(training_args):
    if training_args.use_wandb:
        os.environ["WANDB_PROJECT"] = training_args.wandb_project_name
        training_args.run_name = training_args.experiment_name


def get_args():
    parser = HfArgumentParser((HfModelArguments, DataArguments, HfSeq2SeqTrainingArgs))
    model_args, data_args, training_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    name_parts = [training_args.experiment_name]
    name_parts.extend([data_args.text_column, data_args.summary_column])

    name_parts.append(model_args.model_name_or_path)

    training_args.experiment_name = "_".join(name_parts)

    training_args.output_dir = str(Path(training_args.output_dir).joinpath(training_args.experiment_name))

    if data_args.source_prefix is None and model_args.model_name_or_path in [
        "t5-small",
        "t5-base",
        "t5-large",
        "t5-3b",
        "t5-11b",
    ]:
        logger.warning(
            "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
            "`--source_prefix 'summarize: ' `"
        )
    return data_args, model_args, training_args

# def hf_run():
data_args, model_args, training_args = get_args()




setup_wandb(training_args)

setup_logging(training_args)

verify_nltk()

logger.warning(
    "Process rank: %s, device: %s, n_gpu: % distributed hf_training: %s 16-bits hf_training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

set_seed(training_args.seed)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    use_fast=model_args.use_fast_tokenizer,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)

datasets_loader = DatasetLoader(data_args, training_args, tokenizer)
train_dataset, validation_dataset, test_dataset = datasets_loader.load_datasets()

model = load_model(model_args, data_args, tokenizer)

if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
    logger.warning(
        "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
        "`%s`. This will lead to loss being calculated twice and will take up more memory",
        model.__class__.__name__,
    )
metric_fct = create_compute_metric_fct(tokenizer, data_args, training_args, model_args)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    data_collator=create_data_collector(model, tokenizer, training_args, data_args),
    compute_metrics=metric_fct if training_args.predict_with_generate else None,
)

if training_args.do_train:
    train(trainer, train_dataset, training_args)

max_length = (
    training_args.generation_max_length
    if training_args.generation_max_length is not None
    else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
# if training_args.do_eval:
#     do_eval(trainer, validation_dataset, max_length, num_beams)

# if training_args.do_predict:
#     results_pred, results_label = do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams)
    # print("results_pred: ", results_pred)
    # print("results_label: ", results_label)

In [None]:
import re

def postprocess_predictions(prediction_str):
    # print("prediction_str: ", prediction_str)
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
        post_str = action_name + " " + "[" + ", ".join(slots) + "]"
        # print("post_str: ", post_str)
        return post_str
    else:
        return prediction_str

def parse_ast_prediction(prediction_str):
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
    else:
        action_name = "MISSING"
        slots = ["MISSING"]

    return action_name, slots

def compute_ast_acc_metrics(predictions, labels, convo_ids, turn_ids, sequence_scores=None, num_beams=None):
    # print("len(predictions): ", len(predictions))
    # print("len(labels): ", len(labels))
    # print("len(sequence_scores): ", len(sequence_scores))
    from aalpy.utils import load_automaton_from_file

    # load an automaton
    automaton = load_automaton_from_file("./chainPrior/learned_mdp_8000.dot", automaton_type='mdp')
    # print(automaton)
    # visualize the automaton
    # visualize_automaton(automaton)
    automaton = str(automaton)
    # print(automaton)

    automaton_splits = automaton.split('\n')
    # print(automaton_splits)
    automaton_states = automaton_splits[1:33]
    # ['s0 [label="init"];', 's1 [label="pull-up-account"];']
    automaton_transitions = automaton_splits[33:-4]
    # ['s0 -> s0  [label="init:1.0"];', 's0 -> s1  [label="action:0.03"];']

    state_mapping = {}
    for state in automaton_states:
        state_name = state.split(' ')[0]
        state_label = state.split('[label="')[1].split('"];')[0]
        state_mapping[state_name] = state_label

    '''
    state_mapping: {'s0': 'init', 's1': 'pull-up-account', 's2': 'enter-details', 's3': 'verify-identity', 's4': 'make-password', 's5': 'search-timing', 's6': 'search-policy', 's7': 'validate-purchase', 's8': 'search-faq', 's9': 'membership', 's10': 'search-boots', 's11': 'try-again', 's12': 'ask-the-oracle', 's13': 'update-order', 's14': 'promo-code', 's15': 'update-account', 's16': 'search-membership', 's17': 'make-purchase', 's18': 'offer-refund', 's19': 'notify-team', 's20': 'record-reason', 's21': 'search-jeans', 's22': 'shipping-status', 's23': 'search-shirt', 's24': 'instructions', 's25': 'search-jacket', 's26': 'log-out-in', 's27': 'select-faq', 's28': 'subscription-status', 's29': 'send-link', 's30': 'search-pricing', 's31': 'end'}
    '''
    # print(f"state_mapping: {state_mapping}")

    transition_mapping = {}
    for transition in automaton_transitions:
        transition_split = transition.split('->')
        source_state = transition_split[0].strip()
        target_state = transition_split[1].strip().split(' ')[0]
        transition_label = transition_split[1].split('[label="')[1].split('"];')[0]
        transition_action = transition_label.split(':')[0]
        transition_freq = np.log(float(transition_label.split(':')[1])) if float(transition_label.split(':')[1]) > 0 else -10000
        transition_mapping[(state_mapping[source_state], state_mapping[target_state])] = (transition_action, transition_freq)

    # from s0 to s31, if some pair of states are not in the transition_mapping, then the frequency is 0
    for i in range(32):
        for j in range(32):
            if (state_mapping[f's{i}'], state_mapping[f's{j}']) not in transition_mapping:
                transition_mapping[(state_mapping[f's{i}'], state_mapping[f's{j}'])] = ('unknown', -10000)

    # group predictions every 4
    new_predictions = []
    for i in range(0, len(predictions), num_beams):
        new_predictions.append(predictions[i:i+num_beams])
        # new_predictions.append(predictions[i])
    
    new_sequence_scores = []
    for i in range(0, len(sequence_scores), num_beams):
        # print(f"sequence_scores[i:i+4]: {sequence_scores[i:i+4]}")
        # do no use normalization
        new_sequence_scores.append(sequence_scores[i:i+num_beams])

    
    previous_actions = ['init']
    current_convo_id = 999999
    new_new_predictions = []
    not_in_counter = 0
    for new_pred, label1, new_sequence_score, convo_id1, turn_id1 in zip(new_predictions, labels, new_sequence_scores, convo_ids, turn_ids):
        if convo_id1 != current_convo_id:
            previous_actions = ['init']
            current_convo_id = convo_id1
        
        actions = []
        for pred in new_pred:
            actions.append(pred.split(' ')[0].strip())

        rates = []
        for i in range(len(actions)):
            try:
                rate = transition_mapping[(previous_actions[-1], actions[i])][1]
            except:
                rate = -10000
            rates.append(rate)
        rates = np.array(rates)

        '''
        the way to merge the two modules for post processng, v1
        '''
        # print(f"new_sequence_score: {new_sequence_score}, rates: {rates}")
        # merge_scores = 0.999*np.array(new_sequence_score) + 0.001*np.array(rates)
        # else:
        exp_new_sequence_score = [np.exp(score) for score in new_sequence_score]
        exp_rates = [np.exp(rate) for rate in rates]
        norm_exp_new_sequence_score = exp_new_sequence_score / np.sum(exp_new_sequence_score)
        norm_exp_rates = exp_rates / np.sum(exp_rates)
        log_norm_exp_new_sequence_score = [np.log(score) for score in norm_exp_new_sequence_score]
        log_norm_exp_rates = [np.log(rate) for rate in norm_exp_rates]

        # print("log_norm_exp_new_sequence_score: ", log_norm_exp_new_sequence_score)
        # print("log_norm_exp_rates: ", log_norm_exp_rates)
        # print()

        # merge_scores = 0.9*np.array(log_norm_exp_new_sequence_score) + 0.1*np.array(log_norm_exp_rates)
        merge_scores = new_sequence_score
        # merge_scores = rates

        '''
        the way to merge the two modules for post processng, v2
        '''
        # check if the first action predicted by policy model is in the possible subsequent actions of the previous action, if not, check the second action, and so on
        # if none of the actions predicted by the policy model is in the possible subsequent actions of the previous action, then use most probable action from the prior
        # merge_scores = []
        all_possible_actions = ['search-faq', 'search-timing', 'pull-up-account', 'verify-identity', 'membership', 'ask-the-oracle', 'search-shirt', 'search-policy', 'select-faq', 'send-link', 'enter-details', 'log-out-in', 'promo-code', 'notify-team', 'make-purchase', 'validate-purchase', 'update-order', 'subscription-status', 'make-password', 'try-again', 'shipping-status', 'record-reason', 'update-account', 'instructions', 'search-boots', 'search-jeans', 'search-membership', 'search-jacket', 'search-pricing', 'offer-refund']

        current_action_probs = {}
        for possible_action in all_possible_actions:
            if transition_mapping[(previous_actions[-1], possible_action)][0] != 'unknown':
                current_action_probs[possible_action] = transition_mapping[(previous_actions[-1], possible_action)][1]


        # if new_pred[0].split(' ')[0].strip() not in current_action_probs:
        #     not_in_counter += 1
        #     if new_pred[1].split(' ')[0].strip() in current_action_probs:
        #         new_new_predictions.append(new_pred[1])
        #         print(f"2: wrong: {new_pred[0].split(' ')[0].strip()}, correct: {new_pred[1].split(' ')[0].strip()}, label: {label1.split(' ')[0].strip()}")
        #     else:
        #         max_index = np.argmax(list(current_action_probs.values()))
        #         new_new_predictions.append(list(current_action_probs.keys())[max_index] + ' ' + new_pred[0].split(' ')[1].strip())
        #         print(f"chain: wrong: {new_pred[i].split(' ')[0].strip()}, correct: {list(current_action_probs.keys())[max_index]}, label: {label1.split(' ')[0].strip()}")
        # else:
        max_index = np.argmax(merge_scores)
        new_new_predictions.append(new_pred[max_index])

        previous_actions.append(label1.split(' ')[0].strip())

    print(f"not_in_counter: {not_in_counter}, total number of predictions: {len(new_predictions)}")

    """Adapted from ABCD. """
    # print("predictions:", predictions)
    # print("labels:", labels)
    action_preds = []
    action_labels = []

    value_preds = []
    value_labels = []
    
    # print("len(new_new_predictions): ", len(new_new_predictions))
    # print("len(labels): ", len(labels))
    for pred, label in zip(new_new_predictions, labels):
        action_label, values_label = parse_ast_prediction(label)
        values_label.sort()
        # for value in values_label:
        #     action_labels.append(action_label)
        #     value_labels.append(value)
        action_labels.append(action_label)
        value_labels.append(values_label)

        # print("pred str: ", pred)
        action_pred, values_pred = parse_ast_prediction(pred)
        # print(f"parsed results: {action_pred}, {values_pred}")
        values_pred.sort()

        if len(values_pred) > len(values_label):
            values_pred = [v for v in values_label if v in values_pred]
        if len(values_pred) < len(values_label):
            values_pred.extend(["MISSING"] * (len(values_label) - len(values_pred)))

        # for value in values_pred:
        #     action_preds.append(action_pred)
        #     value_preds.append(value)
        action_preds.append(action_pred)
        value_preds.append(values_pred)

    # print("action_preds: ", action_preds)
    # print("action_labels: ", action_labels)
    # print("value_labels: ", value_labels)
    # print("convo_ids: ", convo_ids)
    # print("turn_ids: ", turn_ids)

    action_labels_arrary = np.array(action_labels, dtype=object)
    action_preds_arrary = np.array(action_preds, dtype=object)
    # print(f"action_labels_arrary: {action_labels_arrary}")
    # print(f"action_preds_arrary: {action_preds_arrary}")
    action_match = action_labels_arrary == action_preds_arrary

    # plot confusion matrix
    # print(f"action_labels_arrary: {action_labels_arrary}")
    # print(f"action_preds_arrary: {action_preds_arrary}")
    '''
    acc_actions: {'search-faq': [206, 242], 'search-timing': [26, 30], 'pull-up-account': [675, 709], 'verify-identity': [329, 373], 'membership': [114, 136], 'ask-the-oracle': [147, 171], 'search-shirt': [32, 34], 'search-policy': [31, 44], 'select-faq': [168, 171], 'send-link': [52, 57], 'enter-details': [185, 211], 'log-out-in': [85, 92], 'promo-code': [49, 53], 'notify-team': [70, 75], 'make-purchase': [56, 59], 'validate-purchase': [211, 233], 'update-order': [120, 142], 'subscription-status': [36, 53], 'make-password': [40, 42], 'try-again': [44, 60], 'shipping-status': [61, 86], 'record-reason': [137, 172], 'update-account': [77, 92], 'instructions': [40, 49], 'search-boots': [23, 25], 'search-jeans': [24, 26], 'search-membership': [25, 35], 'search-jacket': [22, 27], 'search-pricing': [24, 36], 'offer-refund': [59, 73]}
    '''
    possibleActions = ['search-faq', 'search-timing', 'pull-up-account', 'verify-identity', 'membership', 'ask-the-oracle', 'search-shirt', 'search-policy', 'select-faq', 'send-link', 'enter-details', 'log-out-in', 'promo-code', 'notify-team', 'make-purchase', 'validate-purchase', 'update-order', 'subscription-status', 'make-password', 'try-again', 'shipping-status', 'record-reason', 'update-account', 'instructions', 'search-boots', 'search-jeans', 'search-membership', 'search-jacket', 'search-pricing', 'offer-refund', 'MISSING']


    action_acc = sum(action_match) / float(len(action_labels))

    value_labels_arrary = np.array(value_labels, dtype=object)
    value_preds_arrary = np.array(value_preds, dtype=object)
    # print(f"value_labels_arrary: {value_labels_arrary}")
    # print(f"value_preds_arrary: {value_preds_arrary}")
    value_match = value_labels_arrary == value_preds_arrary
    # print(f"value_match: {value_match}")
    value_acc = sum(value_match) / float(len(action_labels))

    joint_match = action_match & value_match
    joint_acc = sum(joint_match) / float(len(action_labels))

    # group by convo_ids
    unique_convo_ids = list(set(convo_ids))
    # print(f"unique_convo_ids: {unique_convo_ids}")
    conversations = {}
    for uci in unique_convo_ids:
        turns, correctness = [], []
        correctness_action, correctness_value = [], []
        row_id = 0
        for convo_id, turn_count in zip(convo_ids, turn_ids):
            if convo_id == uci:
                turns.append(turn_count)
                correct = False
                correct_action = False
                correct_value = False
                action_right = action_match[row_id]
                value_right = value_match[row_id]
                # print(f"action_right: {action_right}, value_right: {value_right}")
                
                if action_right:
                    correct_action = True
                else:
                    correct_action = False
                
                if value_right:
                    correct_value = True
                else:
                    correct_value = False

                if action_right and value_right:
                    correct = True
                else:
                    correct = False

                correctness.append(correct)
                correctness_action.append(correct_action)
                correctness_value.append(correct_value)
            row_id += 1

        # sort by turn_counts
        ordered = [cor for _, cor in sorted(zip(turns, correctness), key=lambda tc: tc[0])]
        ordered_action = [cor for _, cor in sorted(zip(turns, correctness_action), key=lambda tc: tc[0])]
        ordered_value = [cor for _, cor in sorted(zip(turns, correctness_value), key=lambda tc: tc[0])]
        conversations[uci] = [ordered, ordered_action, ordered_value]

    # count how many correct
    turn_score, turn_correct = 0, 0
    turn_score_action, turn_correct_action = 0, 0
    turn_score_value, turn_correct_value = 0, 0
    em_joint, em_action, em_value = [], [], []
    my_scores = []
    for convo_id, itm in conversations.items():
        # print(f"convo_id: {convo_id}")
        convo_correctness = itm[0]
        convo_correctness_action = itm[1]
        convo_correctness_value = itm[2]

        # calculate EM
        if sum(convo_correctness) == len(convo_correctness):
            em_joint.append(True)
        else:
            em_joint.append(False)
        if sum(convo_correctness_action) == len(convo_correctness_action):
            em_action.append(True)
        else:
            em_action.append(False)
        if sum(convo_correctness_value) == len(convo_correctness_value):
            em_value.append(True)
        else:
            em_value.append(False)
        
        # print(f"convo_id: {convo_id}, convo_correctness: {convo_correctness}")
        current_score = 0
        convo_length = len(convo_correctness)
        # we use turn_id rather than the true turn_count since turn counts will skip numbers
        # when looping through the conversation due to skipping over customer utterances

        snipet_lens = [1,2,3]

        # for joint correctness
        snipet_lens_joint  = snipet_lens
        snipet_numbers_joint = [0] * len(snipet_lens_joint)
        snipet_correct_joint = [0] * len(snipet_lens_joint)
        # for each dialogue, compute the rate of each length of snipet that is correct, using the sliding window of the length
        for snipet_i in range(len(snipet_lens_joint)):
            # print("convo_length: ", convo_length)
            if snipet_lens_joint[snipet_i] > convo_length:
                continue
            # print(f"snipet_i: {snipet_i}")
            snipet_len = snipet_lens_joint[snipet_i]
            for turn_id in range(convo_length - snipet_len + 1):
                snipet_numbers_joint[snipet_i] += 1
                if sum(convo_correctness[turn_id:turn_id+snipet_len]) == snipet_len:
                    snipet_correct_joint[snipet_i] += 1

        average_counter = 0
        for snipet_i in range(len(snipet_lens_joint)):
            # print(f"snipet_correct_joint: {snipet_correct_joint[snipet_i]}, snipet_numbers_joint: {snipet_numbers_joint[snipet_i]}")
            if snipet_numbers_joint[snipet_i] == 0:
                continue
            snipet_correct_joint[snipet_i] = snipet_correct_joint[snipet_i] / snipet_numbers_joint[snipet_i]
            average_counter += 1
        
        # print(f"snipet_correct: {snipet_correct_joint}")
        # print("average_counter: ", average_counter)
        average_for_dialogue = 0
        for snipet_i in range(len(snipet_lens_joint)):
            average_for_dialogue += snipet_correct_joint[snipet_i]
        average_for_dialogue = average_for_dialogue / len(snipet_lens)
        # average_for_dialogue = average_for_dialogue / average_counter
        # print(f"average_for_dialogue: {average_for_dialogue}")

        turn_score += average_for_dialogue

        # for action correctness
        snipet_lens_action  = snipet_lens
        snipet_numbers_action = [0] * len(snipet_lens_action)
        snipet_correct_action = [0] * len(snipet_lens_action)
        # for each dialogue, compute the rate of each length of snipet that is correct, using the sliding window of the length
        for snipet_i in range(len(snipet_lens_action)):
            # print("convo_length: ", convo_length)
            if snipet_lens_action[snipet_i] > convo_length:
                continue
            # print(f"snipet_i: {snipet_i}")
            snipet_len = snipet_lens_action[snipet_i]
            for turn_id in range(convo_length - snipet_len + 1):
                snipet_numbers_action[snipet_i] += 1
                if sum(convo_correctness_action[turn_id:turn_id+snipet_len]) == snipet_len:
                    snipet_correct_action[snipet_i] += 1

        average_counter = 0
        for snipet_i in range(len(snipet_lens_action)):
            if snipet_numbers_action[snipet_i] == 0:
                continue
            snipet_correct_action[snipet_i] = snipet_correct_action[snipet_i] / snipet_numbers_action[snipet_i]
            average_counter += 1

        # print(f"snipet_correct: {snipet_correct_action}")
        average_for_dialogue = 0
        for snipet_i in range(len(snipet_lens_action)):
            average_for_dialogue += snipet_correct_action[snipet_i]
        average_for_dialogue = average_for_dialogue / len(snipet_lens)
        # average_for_dialogue = average_for_dialogue / average_counter
        # print(f"average_for_dialogue: {average_for_dialogue}")

        turn_score_action += average_for_dialogue

        # for value correctness
        snipet_lens_value  = snipet_lens
        snipet_numbers_value = [0] * len(snipet_lens_value)
        snipet_correct_value = [0] * len(snipet_lens_value)
        # for each dialogue, compute the rate of each length of snipet that is correct, using the sliding window of the length
        for snipet_i in range(len(snipet_lens_value)):
            # print("convo_length: ", convo_length)
            if snipet_lens_value[snipet_i] > convo_length:
                continue
            # print(f"snipet_i: {snipet_i}")
            snipet_len = snipet_lens_value[snipet_i]
            for turn_id in range(convo_length - snipet_len + 1):
                snipet_numbers_value[snipet_i] += 1
                if sum(convo_correctness_value[turn_id:turn_id+snipet_len]) == snipet_len:
                    snipet_correct_value[snipet_i] += 1

        average_counter = 0
        for snipet_i in range(len(snipet_lens_value)):
            if snipet_numbers_value[snipet_i] == 0:
                continue
            snipet_correct_value[snipet_i] = snipet_correct_value[snipet_i] / snipet_numbers_value[snipet_i]
            average_counter += 1

        # print(f"snipet_correct: {snipet_correct_value}")
        average_for_dialogue = 0
        for snipet_i in range(len(snipet_lens_value)):
            average_for_dialogue += snipet_correct_value[snipet_i]
        average_for_dialogue = average_for_dialogue / len(snipet_lens)
        # average_for_dialogue = average_for_dialogue / average_counter
        # print(f"average_for_dialogue: {average_for_dialogue}")

        turn_score_value += average_for_dialogue

    # normalize by total number of turns possible
    '''
    len(convo_ids): 200, len(turn_ids): 200
    '''
    # print(f"len(convo_ids): {len(convo_ids)}, len(turn_ids): {len(turn_ids)}")
    turn_acc = turn_correct / float(len(conversations))
    turn_acc_action = turn_correct_action / float(len(conversations))
    turn_acc_value = turn_correct_value / float(len(conversations))
    final_score = turn_score / float(len(conversations))
    final_score_action = turn_score_action / float(len(conversations))
    final_score_value = turn_score_value / float(len(conversations))
    
    em_action_score = sum(em_action) / float(len(em_action))
    em_value_score = sum(em_value) / float(len(em_value))
    em_joint_score = sum(em_joint) / float(len(em_joint))

    return {
        "EM_action": round(em_action_score, 4),
        "EM_value": round(em_value_score, 4),
        "EM_joint": round(em_joint_score, 4),
        # "turn_acc_joint": round(turn_acc, 4),
        # "turn_acc_action": round(turn_acc_action, 4),
        # "turn_acc_value": round(turn_acc_value, 4),
        "CE_joint": round(final_score, 4),
        "CE_action": round(final_score_action, 4),
        "CE_value": round(final_score_value, 4)
    }

#### can call LLM for generating response

In [None]:
hint_prompt = """
The following are conversations between a user and an assistant. Indicated by the dialog acts, the assistant can help the user with checking in or providing information of temperature, time, price, location, and so on.
The response should be coherent, engaging, diverse, informative, and overall good and should be in line with the next action.
The response should be concise and to the point and not exceed 30 words.

Available Dialog acts:
confirm appointment_date, confirm appointment_time, confirm check_in_date, confirm date, confirm destination, confirm event_name, confirm location, confirm number_of_rooms, confirm number_of_seats, confirm pickup_location, confirm pickup_time, confirm property_name, confirm restaurant_name, confirm time, confirm visit_date, inform address, inform category, inform date, inform free_entry, inform from_station, inform genre, inform good_for_kids, inform humidity, inform inbound_arrival_time, inform outbound_arrival_time, inform phone_number, inform price, inform rating, inform ride_fare, inform time, inform to_station, inform wind, inform_count count, offer address, offer airlines, offer album, offer appointment_date, offer appointment_time, offer artist, offer attraction_name, offer car_name, offer category, offer location, offer date, offer event_name, offer inbound_departure_time, offer movie_name, offer outbound_departure_time, offer pickup_location, offer precipitation, offer price, offer property_name, offer restaurant_name, offer rating, offer temperature, offer time, offer title, offer type, offer_intent intent, request appointment_date, request appointment_time, request area, request category, request check_in_date, request date, request departure_date, request destination, request event_name, request genre, request location, request number_of_beds, request number_of_seats, request pickup_time, request restaurant_name, request return_date, request time, request type, request visit_date, notify_failure, notify_success, req_more, goodbye.

Conversation: 
User: i want to make a restaurant reservation for 2 people at half past 11 in the morning.  
Assistant(request restaurant_name [none], request location [none]): what city do you want to dine in? do you have a preferred restaurant? 
User: please find restaurants in san jose. can you try sino?  
Assistant(confirm restaurant_name [sino]. confirm location [san jose]. confirm time [11:30 am]. confirm number_of_seats [2]. confirm date [today]): confirming: i will reserve a table for 2 people at sino in san jose. the reservation time is 11:30 am today. 
User: yes, thanks. what's their phone number?  
Assistant(inform phone_number [408-247-8880], notify_success [none]): your reservation has been made. their phone number is 408-247-8880. 
User: what's their address? do they have vegetarian options on their menu?  
Assistant(inform address [377 santana row #1000]): the street address is 377 santana row #1000. they have good vegetarian options. 
User: thanks very much.  
Assistant(req_more [none]): is there anything else i can help you with? 
User: no, that's all. thanks.

Conversation:
User: i'm looking for something like a performing arts venue to visit, can you help me?  
Assistant(request location [none]): where would you like me to search? 
User: search in paris, france please.  
Assistant(offer attraction_name [arab world institute], offer category [performing arts venue]): there's the arab world institute, which is a performing arts venue. 
User: can i bring children there, also is this place free to go to?  
Assistant(inform good_for_kids [false], inform free_entry [false]): i don't think children would enjoy this venue, and it is not free. 
User: okay, cool. that's all i need for now, bye.

Conversation:
User: hi there, could you find some pop music events for me? 
Assistant: where are you interested in? 
User: something in philly is fine.  
Assistant(offer event_name [conan gray], offer date [march 1st], offer time [8:30 pm]): would you like conan gray event at the fillmore philadelphia occurring on march 1st at 8:30 pm? 
User: sure, could you find something else on march 4th in sd instead?  
Assistant(offer event_name [tuxedo], offer date [march 4th], offer time [7 pm], inform_count count [1]): there is 1 event called tuxedo at the observatory north park at 7 pm on march 4th. 
User: sure, that sounds good!  
Assistant(offer_intent intent [buyeventtickets]): would you like me to purchase the tickets? 
User: perhaps later!  
Assistant(req_more [none]): could i further assist you? 
User: could you get me a bus ride there too? 
Assistant: sure, do you intend to depart on march 4th? 
User: no, i will leave on march 3rd instead. 
Assistant: and where will you be travelling from? 
User: i will travel from lax.  
Assistant(inform_count count [10]): there are 10 available buses, would you like one which departs at 7:40 am with 0 transfers costing $28? 
User: that sounds great! could you get me a roundtrip flight there too?  
Assistant(request return_date [none]): sure, when will you be returning? 
User: i will return next thursday.  
Assistant(request departure_date [march 3rd]): and will you depart on march 3rd or march 4th. 
User: preferably march 3rd, please.  
Assistant(offer airlines [alaska airlines], offer outbound_departure_time [6:55 am], offer inbound_departure_time [9:55 am], offer price [$116]): how about an alaska airlines flight departing at 6:55 am and returning at 9:55 am costing $116? 
User: are there any alternatives?  
Assistant(offer airlines [american airlines], offer outbound_departure_time [8:55 pm], offer inbound_departure_time [7:10 pm], offer price [$127]): how about an american airlines flight departing at 8:55 pm and returning at 7:10 pm costing $127? 
User: which airport will i depart from? 
Assistant: you will depart from lax international airport. 
User: sure, could you get me hotels in that area too?  
Assistant(offer rating [3]): how about a 3 star hotel called bahia resort hotel? 
User: are there any alternatives?  
Assistant(offer rating [2]): how about a 2 star hotel called best western mission bay? 
User: that sounds great too!  
Assistant(offer_intent intent [reservehotel]): would you like me to make a reservation? 
User: no thanks, that is not needed.

Conversation: 
[[DIALOG]]
"""

In [None]:
import warnings
warnings.filterwarnings('ignore')
import json

def call_LLM(dialogue, Action):

    prompt = hint_prompt.replace("[[DIALOG]]", dialogue)

    messages = []
    messages.append({"role": "system", "content": "You are a helpful assistant. You can generate a response to the user's input based on the given previous dialogue and the next action."})
    messages.append({"role": "user", "content": prompt})

    response = clientGPT3_5.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=messages,
        temperature=0.9,
        max_tokens=256,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    # print(response.choices[0].message.content)

    return response.choices[0].message.content

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/test_AST_abcd_woaction_flow_all.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

context_list = []
for dialogue_i in range(len(data)):
    if dialogue_i == 10:
        break
    user_input = data[dialogue_i]['input']
    context = user_input
    # save the context to a tmp json file:
    # {"sample_id": 0, "target": "request time [none]", "input": "Context: hi, could you get me a restaurant booking on the 8th please? ", "target_data": "[\"request time\", [\"none\"]]"}
    save_context = {"sample_id": 0, "convo_id": data[dialogue_i]['convo_id'], "turn_id": data[dialogue_i]['turn_id'], "target": data[dialogue_i]['target'], "input": context, "target_data": data[dialogue_i]['target_data']}
    if os.path.exists("tmp.json"):
        os.remove("tmp.json")
    # print(tmp_sample)
    with open(f"tmp.json", "a") as w:
        json.dump(save_context, w)
        w.write("\n")

    train_dataset, validation_dataset, test_dataset = datasets_loader.load_datasets()
    print(f"num_beams: {num_beams}")
    result_pred, result_label = do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams)

    # Call the LLM model to generate the response
    print("all predictions: ", result_pred)
    action = result_pred[-1]

    # build the context for the next turn for calling the LLM model
    dialog_with_hint = ""
    for each in context_list:
        dialog_with_hint += "User: " + each["user"] + "\n" + "Assistant(" + each["action"] + "): " + each["agent"] + "\n"
    dialog_with_hint += "User: " + user_input + "\n" + "Assistant(" + action + "): "

    response = call_LLM(dialog_with_hint, action)
    print("User: ", user_input)
    print("** Next Action **: ", action)
    print("Agent: ", response)
    context += str(action) + ". "
    if response[-1] not in [".", "?", "!"]:
        response += "."
    context += str(response) + " "
    context_list.append({"user": user_input, "action": action, "agent": response})
    print("*" *100)
    print("\n")

print("context: ", context)
print("context_list: ", context_list)

#### only generate the action list, not the full dialogue

In [None]:
# import warnings
# warnings.filterwarnings('ignore')
# import json

# data = []
# with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/test_AST_abcd_woaction_flow_all.json', 'r') as file:
#     for line in file:
#         json_data = json.loads(line)
#         data.append(json_data)

# context_list = []
# distinct_dialogue = {}
# distinct_dialogue["dialogue"] = []
# distinct_dialogue["pred_action_value"] = []
# distinct_dialogue["action_value_label"] = []
# distinct_dialogue["convo_ids"] = []
# distinct_dialogue["turn_ids"] = []
# current_conv_id = 0
# counter_success_dialogues = 0

# if os.path.exists("data/updating/incremental_data.json"):
#     # remove the file
#     os.remove("data/updating/incremental_data.json")

# for dialogue_i in range(len(data)):
#     user_input = data[dialogue_i]['input']
#     context = user_input
#     # save the context to a tmp json file:
#     # {"sample_id": 0, "target": "request time [none]", "input": "Context: hi, could you get me a restaurant booking on the 8th please? ", "target_data": "[\"request time\", [\"none\"]]"}
#     save_context = {"sample_id": data[dialogue_i]['sample_id'], "convo_id": data[dialogue_i]['convo_id'], "turn_id": data[dialogue_i]['turn_id'], "target": data[dialogue_i]['target'], "input": context, "target_data": data[dialogue_i]['target_data']}
#     if os.path.exists("tmp.json"):
#         os.remove("tmp.json")
#     # print(tmp_sample)
#     with open(f"tmp.json", "a") as w:
#         json.dump(save_context, w)
#         w.write("\n")

#     train_dataset, validation_dataset, test_dataset = datasets_loader.load_datasets()
#     result_pred, result_label = do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams)

#     # Call the LLM model to generate the response
#     action = result_pred[-1]

#     action = postprocess_predictions(action)

#     print("context: ", context)
#     print("agent: ", action)
#     print("gold: ", data[dialogue_i]['target'])
#     print("-" * 30)
#     print()

#     if data[dialogue_i]['convo_id'] != current_conv_id:
#         if current_conv_id == 0:
#             pass
#         else:
#             # calculate the CE metric
#             metrics = compute_ast_acc_metrics(distinct_dialogue["pred_action_value"], distinct_dialogue["action_value_label"], distinct_dialogue["convo_ids"], distinct_dialogue["turn_ids"])
            
#             # print("CE_joint: ", metrics["CE_joint"])
#             # print("CE_action: ", metrics["CE_action"])
#             # print("CE_value: ", metrics["CE_value"])
#             if metrics["CE_joint"] > 0.5 and metrics["CE_action"] > 0.5 and metrics["CE_value"] > 0.5:
#                 print("CE_joint: ", metrics["CE_joint"])
#                 print("CE_action: ", metrics["CE_action"])
#                 print("CE_value: ", metrics["CE_value"])
#                 print("EM action: ", metrics["EM_action"])
#                 print("EM value: ", metrics["EM_value"])
#                 print("EM joint: ", metrics["EM_joint"])
#                 print(distinct_dialogue["pred_action_value"])
#                 print(distinct_dialogue["action_value_label"])

#                 counter_success_dialogues += 1

#                 for i in range(len(distinct_dialogue["dialogue"])):
#                     # print(distinct_dialogue["dialogue"][i]["input"])
#                     # print(distinct_dialogue["dialogue"][i]["predicted_action"])
#                     # print(distinct_dialogue["dialogue"][i]["target"])
#                     # print("-" * 30)
#                     with open("data/updating/incremental_data.json", "a") as w:
#                         json.dump(distinct_dialogue["dialogue"][i], w)
#                         w.write("\n")

#                 if counter_success_dialogues == 2:
#                     break

#         distinct_dialogue["dialogue"] = []
#         distinct_dialogue["pred_action_value"] = []
#         distinct_dialogue["action_value_label"] = []
#         distinct_dialogue["convo_ids"] = []
#         distinct_dialogue["turn_ids"] = []

#         save_context['predicted_action'] = action
#         distinct_dialogue["dialogue"].append(save_context)
#         distinct_dialogue["pred_action_value"].append(action)
#         distinct_dialogue["action_value_label"].append(data[dialogue_i]['target'])
#         distinct_dialogue["convo_ids"].append(data[dialogue_i]['convo_id'])
#         distinct_dialogue["turn_ids"].append(data[dialogue_i]['turn_id'])
#         current_conv_id = data[dialogue_i]['convo_id']
#     else:
#         save_context['predicted_action'] = action
#         distinct_dialogue["dialogue"].append(save_context)
#         distinct_dialogue["pred_action_value"].append(action)
#         distinct_dialogue["action_value_label"].append(data[dialogue_i]['target'])
#         distinct_dialogue["convo_ids"].append(data[dialogue_i]['convo_id'])
#         distinct_dialogue["turn_ids"].append(data[dialogue_i]['turn_id'])


In [None]:
import json 
from tqdm import tqdm

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/raw/abcd_v1.1.json', 'r') as file:
    for line in tqdm(file):
        json_data = json.loads(line)
        data.append(json_data)
        break

print(f"data: {data[0]}")

In [None]:
pull_up_account = [
    "account has been pulled up for albert sanders.",
    "account has been pulled up for sanya afzal."
]

enter-details = [
    "details of <username> have been entered.",
    "details of 14 have been entered."
]

verify-identity= [
    "identity verification in progress ..."
]

make-password = [
    "a password has been generated."
]

search-timing = [
    "system action: search timing"
]

search-policy = [
    "system action: search policy"
]

validate-purchase = [
    "purchase validation in progress ..."
]

search-faq = [
    "searching the faq pages ..."
]

membership = [
    "membership level of bronze has been noted."
    "membership level of guest has been noted."
]

search-boots = [
    "system action: search boots",

]

try-again = [
    "agent is looking for solutions ..."
]

ask-the-oracle = [
    "querying the system for an answer ..."

]

update-order = [
    "order has been updated with change address."
]

promo-code = [
    "a promo code has been created."
]

update-account = [
    "account has been updated with renew subscription.",
    "account has been updated with change time."
]

search-membership = [
    "system action: search membership"
]

make-purchase = [
    "a purchase of <name> was made.",
    "a purchase of calvin klein jacket was made."
]

offer-refund = [
    "a refund has been made for the amount of $<amount>.",
    "a refund has been made for the amount of $1<amount>.",
]

notify-team = [
    "the website team has been notified.",
]

record-reason = [
    "a reason of paypal has been recorded.",
    "a reason of competitor has been recorded."
    "a reason of spouse has been recorded."
]

search-jeans = [
    "system action: search jeans"
]

shipping-status = [
    "shipping status of delivered has been noted.",
    "shipping status of order received has been noted."
]

search-shirt = [
    "system action: search shirt"
]

instructions = [
    "agent is looking for solutions ..."
]

search-jacket = [
    "system action: search jacket"
]

log-out-in = [
    "agent is looking for solutions ..."
]

select-faq = [
    "faq answer related to boots (how1) was selected."
    "faq answer related to shirt (other2) was selected."
]

subscription-status = [
    "querying the system for subscription status ..."
]

send-link = [
    "a link will be sent."
]

search-pricing = [
    "system action: search pricing"
]

# summarize to a dict
action_description = {
    "pull-up-account": "account has been pulled up for <name>.",
    "enter-details": "details of <username> have been entered.",
    "verify-identity": "identity verification in progress ...",
    "make-password": "a password has been generated.",
    "search-timing": "system action: search timing, I need to ask a certain question about timing.",
    "search-policy": "system action: search policy, what kind of policy does the customer want to know?",
    "validate-purchase": "purchase validation in progress ...",
    "search-faq": "Answers can be found in the faq pages, searching the faq pages ...",
    "membership": "membership level of <level> has been noted.",
    "search-boots": "system action: search boots, click the boots toggle switch",
    "try-again": "agent is looking for solutions ...",
    "ask-the-oracle": "querying the system for an answer ...",
    "update-order": "order has been updated with <change>.",
    "promo-code": "a promo code has been created.",
    "update-account": "account has been updated with <change>.",
    "search-membership": "system action: search membership, I need to know the membership level of the customer.",
    "make-purchase": "a purchase of <item> was made.",
    "offer-refund": "a refund has been made for the amount of $<amount>.",
    "notify-team": "the website team has been notified.",
    "record-reason": "a reason of <reason> has been recorded.",
    "search-jeans": "system action: search jeans, click the jeans toggle switch",
    "shipping-status": "shipping status of <status> has been noted.",
    "search-shirt": "system action: search shirt, click the shirt toggle switch",
    "instructions": "agent is looking for solutions ..., I will give you some instructions.",
    "search-jacket": "system action: search jacket, click the jecket toggle switch",
    "log-out-in": "agent is looking for solutions ..., instruct the customer to log out of their account and log back in.",
    "select-faq": "faq answer related to <faq> was selected.",
    "subscription-status": "querying the system for subscription status ...",
    "send-link": "a link will be sent.",
    "search-pricing": "system action: search pricing, price of something."
}

In [13]:
import sacrebleu
from rouge import Rouge
from bert_score import score
import json

def calculate_bert_score(dialogues):
    references = []
    candidates = []
    for dialogue in dialogues:
        references.append(dialogue["label_utterance"])
        candidates.append(dialogue["pred_utterance"])

    P, R, F1 = score(candidates, references, lang='en', verbose=True, rescale_with_baseline=True)
    # mean P, R, F1
    P = P.mean().item()
    R = R.mean().item()
    F1 = F1.mean().item()
    return P, R, F1

def calculate_BLEU_Score(dialogues):
    references = []
    candidates = []
    for dialogue in dialogues:
        references.append(dialogue["label_utterance"])
        candidates.append(dialogue["pred_utterance"])

    return sacrebleu.corpus_bleu(candidates, [references]).score


def calculate_rouge_scores(dialogues):
    hypothesis = []
    reference = []
    for dialogue in dialogues:
        hypothesis.append(dialogue["pred_utterance"])
        reference.append(dialogue["label_utterance"])
    
    rouge = Rouge()

    scores = rouge.get_scores(hypothesis, reference, avg=True)
    return scores 

with open("dialogues/abcdASTWOActionFlowAll_wo_chainedPrior.json", "r") as r:
    dialogues = [json.loads(line) for line in r]

dialog_BLEU = calculate_BLEU_Score(dialogues)
dialog_ROUGE = calculate_rouge_scores(dialogues)
dialog_BERT = calculate_bert_score(dialogues)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


100%|██████████| 18/18 [00:01<00:00, 12.37it/s]


computing greedy matching.


100%|██████████| 57/57 [00:00<00:00, 98.33it/s]

done in 2.05 seconds, 1757.35 sentences/sec





In [14]:
print("dialog_BLEU: ", dialog_BLEU)
print("dialog_ROUGE: ", dialog_ROUGE)
print("dialog_BERT: ", dialog_BERT)

dialog_BLEU:  34.97804356890443
dialog_ROUGE:  {'rouge-1': {'r': 0.6210459256190967, 'p': 0.5383475930749909, 'f': 0.5649412970641914}, 'rouge-2': {'r': 0.5387776203532834, 'p': 0.45980727537952437, 'f': 0.48421015696692354}, 'rouge-l': {'r': 0.6186531625556021, 'p': 0.5364894791571523, 'f': 0.5629253138646001}}
dialog_BERT:  (0.36442309617996216, 0.39747434854507446, 0.38023123145103455)


In [None]:
'''
dialog_BLEU:  0.6122821503865296
dialog_ROUGE:  {'rouge-1': {'r': 0.16851587301587298, 'p': 0.06309809007293622, 'f': 0.08786207706084463}, 'rouge-2': {'r': 0.026555555555555554, 'p': 0.007242070477196334, 'f': 0.010955777178062023}, 'rouge-l': {'r': 0.15362698412698403, 'p': 0.05770404941573766, 'f': 0.0802271028245546}}
dialog_BERT:  (tensor([0.8295, 0.8478, 0.8036, 0.8254, 0.8398, 0.8411, 0.8241, 0.8294, 0.8504,
        0.8602, 0.8308, 0.8472, 0.8366, 0.8490, 0.8224, 0.8406, 0.8175, 0.8025,
        0.8282, 0.8406, 0.8321, 0.8075, 0.8360, 0.8153, 0.8258, 0.8248, 0.8314,
        0.8650, 0.7890, 0.8166, 0.8175, 0.7992, 0.8154, 0.8327, 0.8505, 0.8288,
        0.8278, 0.8261, 0.8228, 0.8768, 0.8196, 0.7798, 0.8394, 0.8461, 0.8167,
        0.8704, 0.8328, 0.8117, 0.8078, 0.8129, 0.7982, 0.8124, 0.8299, 0.8071,
        0.9225, 0.8163, 0.7616, 0.8110, 0.8356, 0.8555, 0.8156, 0.8235, 0.8138,
        0.8602, 0.8062, 0.8387, 0.8261, 0.8409, 0.8401, 0.8041, 0.8427, 0.8343,
        0.8233, 0.8115, 0.8258, 0.8349, 0.8291, 0.8485, 0.8065, 0.8300, 0.8416,
        0.8158, 0.7933, 0.8107, 0.8177, 0.8150, 0.8092, 0.8438, 0.8679, 0.8318,
        0.8566, 0.8124, 0.8295, 0.7854, 0.7797, 0.8130, 0.8132, 0.8932, 0.8113,
        0.8561]), tensor([0.8157, 0.8582, 0.8030, 0.8380, 0.8399, 0.8211, 0.8161, 0.8343, 0.8365,
        0.8502, 0.8002, 0.8354, 0.8205, 0.8386, 0.8148, 0.8523, 0.8122, 0.8135,
        0.8307, 0.9475, 0.8064, 0.8259, 0.8750, 0.8178, 0.8477, 0.8514, 0.8245,
        0.8460, 0.8227, 0.8205, 0.8097, 0.8196, 0.8245, 0.8939, 0.8335, 0.8164,
        0.8303, 0.8418, 0.8207, 0.9143, 0.8164, 0.8234, 0.8236, 0.9356, 0.7739,
        0.8624, 0.8177, 0.8105, 0.8313, 0.8314, 0.8091, 0.8372, 0.8194, 0.8097,
        0.8331, 0.8304, 0.7880, 0.8103, 0.8311, 0.8981, 0.8121, 0.8180, 0.8212,
        0.8515, 0.7684, 0.8373, 0.8215, 0.8380, 0.8511, 0.8066, 0.8427, 0.8283,
        0.8206, 0.8208, 0.7957, 0.8676, 0.8237, 0.8243, 0.8302, 0.8302, 0.8730,
        0.8048, 0.8252, 0.8516, 0.8159, 0.8255, 0.8146, 0.8305, 0.8900, 0.8042,
        0.8890, 0.8106, 0.8354, 0.8033, 0.8086, 0.8095, 0.8145, 0.8976, 0.8179,
        0.8120]), tensor([0.8225, 0.8530, 0.8033, 0.8317, 0.8398, 0.8310, 0.8201, 0.8318, 0.8434,
        0.8552, 0.8152, 0.8412, 0.8285, 0.8438, 0.8186, 0.8464, 0.8149, 0.8080,
        0.8295, 0.8909, 0.8190, 0.8166, 0.8550, 0.8166, 0.8366, 0.8379, 0.8279,
        0.8554, 0.8055, 0.8185, 0.8136, 0.8093, 0.8199, 0.8622, 0.8419, 0.8226,
        0.8290, 0.8339, 0.8217, 0.8952, 0.8180, 0.8010, 0.8314, 0.8886, 0.7947,
        0.8664, 0.8252, 0.8111, 0.8194, 0.8220, 0.8036, 0.8246, 0.8246, 0.8084,
        0.8755, 0.8233, 0.7746, 0.8106, 0.8333, 0.8762, 0.8138, 0.8207, 0.8175,
        0.8558, 0.7869, 0.8380, 0.8237, 0.8394, 0.8456, 0.8054, 0.8427, 0.8313,
        0.8220, 0.8161, 0.8105, 0.8509, 0.8264, 0.8362, 0.8182, 0.8301, 0.8570,
        0.8102, 0.8089, 0.8307, 0.8168, 0.8202, 0.8119, 0.8371, 0.8788, 0.8177,
        0.8725, 0.8115, 0.8324, 0.7943, 0.7939, 0.8112, 0.8138, 0.8954, 0.8146,
        0.8334]))
'''