### Pipeline should be like:
1. Core: could predict next action given the current context
2. can load the data from the benchmark
3. connect the data to the core module, to get the run-time action prediction list
4. Use CE metric to evaluate the quality of the predicted action list
5. if good enough, then save the dialogues as the new data for further training

In [1]:
'''
Set the openai GPT-4
'''
from openai import OpenAI
clientGPT4 = OpenAI(api_key="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")
clientGPT3_5 = OpenAI(api_key="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")

'''
Set the AST module for predict the next action: 
For demo, use the model for SGD dataset 
'''
"""
Reference: https://github.com/huggingface/transformers/tree/main/examples/pytorch

Adapted from huggingface Transformers
"""

import logging
import os
import sys
from pathlib import Path
import time

import datasets
import transformers
import transformers.trainer_utils as hf_trainer_utils
import numpy as np
import nltk  # Here to have a nice missing dependency error message early on

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
)

from src.data.data_args import DataArguments
from src.data.dataset_loader import DatasetLoader
from src.data.utils import group_col_name
from src.metrics import create_compute_metric_fct, verify_nltk
from src.model.hf_model_args import HfModelArguments
from src.hf_training.hf_training_args import HfSeq2SeqTrainingArgs

logger = logging.getLogger(__name__)

def train(trainer, train_dataset, training_args):
    logger.info("*** train ***")

    check_point = get_resume_checkpoint(training_args)
    train_result = trainer.train(resume_from_checkpoint=check_point)

    trainer.save_model()  # Saves the tokenizer too for easy upload

    metrics = train_result.metrics
    metrics["train_samples"] = len(train_dataset)
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()


def do_eval(trainer, validation_dataset, max_length, num_beams):
    logger.info("*** Evaluate ***")

    metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")

    metrics["eval_samples"] = len(validation_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

def do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams):
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
        return preds, labels

    def decode(preds, labels, sequence_scores):
        # print("preds length: ", len(preds))
        # print("type of preds: ", type(preds))

        if isinstance(preds, tuple):
            preds = preds[1]
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        # print("decoded_preds lenght: ", len(decoded_preds))
        # print("type of decoded_preds: ", type(decoded_preds))
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        # print("decoded_preds lenght: ", len(decoded_preds))
        # print("type of decoded_preds: ", type(decoded_preds))

        model_path = Path(model_args.model_name_or_path)
        file_name = "pred_mwoz.txt" if training_args.is_mwoz else "preds_test_set.txt"
        if not model_path.exists():
            # model name
            preds_file_path = Path(training_args.output_dir) / file_name
        else:
            preds_file_path = model_path / file_name

        with preds_file_path.open("w") as f:
            for pred, label in zip(decoded_preds, decoded_labels):
                label = label.replace("\n", " ")
                pred = pred.replace("\n", " ")
                f.write(f"{pred}\t{label}" + "\n")

        # print("decoded_preds: ", decoded_preds, "decoded_labels: ", decoded_labels)

        return decoded_preds, decoded_labels, sequence_scores
    
    logger.info("*** Predict ***")

    metrics = {}
    predictions = []
    if group_col_name in test_dataset.column_names:
        group_idx = 0

        while True:
            group_dataset = test_dataset.filter(lambda x: x[group_col_name] == group_idx)
            if group_dataset.num_rows == 0:
                # no groups left
                break
            logger.info("Predicting on test group %d", group_idx)

            predict_results = trainer.predict(
                group_dataset,
                metric_key_prefix=f"predict_group_{group_idx}",
                max_length=max_length,
                num_beams=num_beams,
            )
            metrics.update(predict_results.metrics)
            metrics[f"predict_samples_group_{group_idx}_size"] = len(group_dataset)

            group_idx += 1

            predictions.append(predict_results.predictions)

        for key in ["loss", "rouge1", "rouge2", "rougeL"]:
            metrics[f"overall_predict_{key}"] = round(
                sum([metrics[f"predict_group_{idx}_{key}"] for idx in range(group_idx)]) / group_idx, 4
            )
    else:
        '''
        here
        '''
        
        # print("num_beams: ", num_beams)
        predict_results = trainer.predict(
            test_dataset, 
            metric_key_prefix="test", 
            max_length=max_length, 
            num_beams=num_beams,
            num_return_sequences=4,
            return_dict_in_generate=True, 
            output_scores=True,
        )
        # print("predict_results: ", predict_results)
        # print("predict_results.sequence_scores: ", predict_results.sequence_scores)
        # print("predict_results.predictions: ", predict_results.predictions)
        metrics = predict_results.metrics
        metrics["predict_samples_size"] = len(test_dataset)

    return decode(predict_results.predictions, test_dataset["labels"], predict_results.sequence_scores)


def load_model(model_args, data_args, tokenizer):
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Forcing the generation min lenght, to avoid models preset for summarization tasks that are usually high
    config.min_length = 5

    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model.resize_token_embeddings(len(tokenizer))

    task_specific_params = model.config.task_specific_params
    if task_specific_params is not None:
        model.config.update(task_specific_params.get("summarization_cnn", {}))

    if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
        if isinstance(tokenizer, MBartTokenizer):
            model.config.decoder_start_token_id = tokenizer.lang_code_to_id["en_XX"]
        else:
            model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids("en_XX")

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    if (
        hasattr(model.config, "max_position_embeddings")
        and model.config.max_position_embeddings < data_args.max_source_length
    ):
        if model_args.resize_position_embeddings is None:
            logger.warning(
                "Increasing the model's number of position embedding vectors from"
                f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
            )
            model.resize_position_embeddings(data_args.max_source_length)
        elif model_args.resize_position_embeddings:
            model.resize_position_embeddings(data_args.max_source_length)
        else:
            raise ValueError(
                f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
                f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
                f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
                " model's position encodings by passing `--resize_position_embeddings`."
            )

    return model


def get_resume_checkpoint(training_args):
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint

    last_checkpoint = get_last_checkpoint(training_args)
    if last_checkpoint is not None:
        checkpoint = last_checkpoint

    return checkpoint


def get_last_checkpoint(training_args):
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = hf_trainer_utils.get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming hf_training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    return last_checkpoint


def setup_logging(training_args):
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()


def create_data_collector(model, tokenizer, training_args, data_args):
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    return DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8 if training_args.fp16 else None,
    )


def setup_wandb(training_args):
    if training_args.use_wandb:
        os.environ["WANDB_PROJECT"] = training_args.wandb_project_name
        training_args.run_name = training_args.experiment_name


def get_args():
    parser = HfArgumentParser((HfModelArguments, DataArguments, HfSeq2SeqTrainingArgs))
    model_args, data_args, training_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    name_parts = [training_args.experiment_name]
    name_parts.extend([data_args.text_column, data_args.summary_column])

    name_parts.append(model_args.model_name_or_path)

    training_args.experiment_name = "_".join(name_parts)

    training_args.output_dir = str(Path(training_args.output_dir).joinpath(training_args.experiment_name))

    if data_args.source_prefix is None and model_args.model_name_or_path in [
        "t5-small",
        "t5-base",
        "t5-large",
        "t5-3b",
        "t5-11b",
    ]:
        logger.warning(
            "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
            "`--source_prefix 'summarize: ' `"
        )
    return data_args, model_args, training_args

# def hf_run():
data_args, model_args, training_args = get_args()

setup_wandb(training_args)

setup_logging(training_args)

verify_nltk()

logger.warning(
    "Process rank: %s, device: %s, n_gpu: % distributed hf_training: %s 16-bits hf_training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

set_seed(training_args.seed)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    use_fast=model_args.use_fast_tokenizer,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)

datasets_loader = DatasetLoader(data_args, training_args, tokenizer)
train_dataset, validation_dataset, test_dataset = datasets_loader.load_datasets()

model = load_model(model_args, data_args, tokenizer)

if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
    logger.warning(
        "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
        "`%s`. This will lead to loss being calculated twice and will take up more memory",
        model.__class__.__name__,
    )
metric_fct = create_compute_metric_fct(tokenizer, data_args, training_args, model_args)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    data_collator=create_data_collector(model, tokenizer, training_args, data_args),
    compute_metrics=metric_fct if training_args.predict_with_generate else None,
)

if training_args.do_train:
    train(trainer, train_dataset, training_args)

max_length = (
    training_args.generation_max_length
    if training_args.generation_max_length is not None
    else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams

# if training_args.do_eval:
#     do_eval(trainer, validation_dataset, max_length, num_beams)

# if training_args.do_predict:
#     results_pred, results_label = do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams)
#     # print("results_pred: ", results_pred)
#     # print("results_label: ", results_label)

  from .autonotebook import tqdm as notebook_tqdm






#### only generate the action list, not the full dialogue

In [2]:
num_beams = 4

print("num_beams: ", num_beams)

num_beams:  4


In [3]:
import re

def postprocess_predictions(prediction_str):
    # print("prediction_str: ", prediction_str)
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
        post_str = action_name + " " + "[" + ", ".join(slots) + "]"
        # print("post_str: ", post_str)
        return post_str
    else:
        return prediction_str

def parse_ast_prediction(prediction_str):
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
    else:
        action_name = "MISSING"
        slots = ["MISSING"]

    return action_name, slots

def compute_ast_acc_metrics(predictions, labels, convo_ids, turn_ids):
    # print("predictions: ", predictions)
    # print("labels: ", labels)
    """Adapted from ABCD. """
    action_preds = []
    action_labels = []

    value_preds = []
    value_labels = []

    for pred, label in zip(predictions, labels):

        action_label, values_label = parse_ast_prediction(label)
        values_label.sort()
        # for value in values_label:
        #     action_labels.append(action_label)
        #     value_labels.append(value)
        action_labels.append(action_label)
        value_labels.append(values_label)

        action_pred, values_pred = parse_ast_prediction(pred)
        values_pred.sort()

        if len(values_pred) > len(values_label):
            values_pred = [v for v in values_label if v in values_pred]
        if len(values_pred) < len(values_label):
            values_pred.extend(["MISSING"] * (len(values_label) - len(values_pred)))

        # for value in values_pred:
        #     action_preds.append(action_pred)
        #     value_preds.append(value)
        action_preds.append(action_pred)
        value_preds.append(values_pred)

    # print("action_preds: ", action_preds)
    # print("action_labels: ", action_labels)

    # print("value_preds: ", value_preds)
    # print("value_labels: ", value_labels)

    action_labels_arrary = np.array(action_labels)
    action_preds_arrary = np.array(action_preds)
    # print(f"action_labels_arrary: {action_labels_arrary}")
    # print(f"action_preds_arrary: {action_preds_arrary}")
    action_match = action_labels_arrary == action_preds_arrary
    # print(f"action_match: {action_match}")
    # print()
    action_acc = sum(action_match) / float(len(action_labels))

    value_labels_arrary = np.array(value_labels)
    value_preds_arrary = np.array(value_preds)
    value_match = value_labels_arrary == value_preds_arrary
    value_acc = sum(value_match) / float(len(action_labels))

    joint_match = action_match & value_match
    joint_acc = sum(joint_match) / float(len(action_labels))

    # group by convo_ids
    unique_convo_ids = list(set(convo_ids))
    # print("unique_convo_ids: ", unique_convo_ids)
    conversations = {}
    for uci in unique_convo_ids:
        turns, correctness = [], []
        correctness_action, correctness_value = [], []
        row_id = 0
        for convo_id, turn_count in zip(convo_ids, turn_ids):
            if convo_id == uci:
                turns.append(turn_count)
                correct = False
                correct_action = False
                correct_value = False
                action_right = action_match[row_id]
                value_right = value_match[row_id]
                
                if action_right:
                    correct_action = True
                else:
                    correct_action = False
                
                if value_right:
                    correct_value = True
                else:
                    correct_value = False

                if action_right and value_right:
                    correct = True
                else:
                    correct = False

                correctness.append(correct)
                correctness_action.append(correct_action)
                correctness_value.append(correct_value)
            row_id += 1

        # sort by turn_counts
        ordered = [cor for _, cor in sorted(zip(turns, correctness), key=lambda tc: tc[0])]
        ordered_action = [cor for _, cor in sorted(zip(turns, correctness_action), key=lambda tc: tc[0])]
        ordered_value = [cor for _, cor in sorted(zip(turns, correctness_value), key=lambda tc: tc[0])]
        conversations[uci] = [ordered, ordered_action, ordered_value]

    # print("ordered: ", ordered)
    # print("ordered_action: ", ordered_action)
    # print("ordered_value: ", ordered_value)

    # count how many correct
    turn_score, turn_correct = 0, 0
    turn_score_action, turn_correct_action = 0, 0
    turn_score_value, turn_correct_value = 0, 0
    em_joint, em_action, em_value = [], [], []
    my_scores = []
    for convo_id, itm in conversations.items():
        convo_correctness = itm[0]
        convo_correctness_action = itm[1]
        convo_correctness_value = itm[2]

        # calculate EM
        if sum(convo_correctness) == len(convo_correctness):
            em_joint.append(True)
        else:
            em_joint.append(False)
        if sum(convo_correctness_action) == len(convo_correctness_action):
            em_action.append(True)
        else:
            em_action.append(False)
        if sum(convo_correctness_value) == len(convo_correctness_value):
            em_value.append(True)
        else:
            em_value.append(False)
        
        # print(f"convo_id: {convo_id}, convo_correctness: {convo_correctness}")
        current_score = 0
        convo_length = len(convo_correctness)
        # we use turn_id rather than the true turn_count since turn counts will skip numbers
        # when looping through the conversation due to skipping over customer utterances
        for turn_id in range(convo_length):
            num_remaining = convo_length - turn_id

            num_correct = 0
            num_correct_action = 0
            num_correct_value = 0
            # count up how many were predicted correctly
            tmp_turn_id = turn_id
            while tmp_turn_id < convo_length and convo_correctness[tmp_turn_id]:
                num_correct += 1
                tmp_turn_id += 1
            
            tmp_turn_id = turn_id
            while tmp_turn_id < convo_length and convo_correctness_action[tmp_turn_id]:
                num_correct_action += 1
                tmp_turn_id += 1

            tmp_turn_id = turn_id
            while tmp_turn_id < convo_length and convo_correctness_value[tmp_turn_id]:
                num_correct_value += 1
                tmp_turn_id += 1

            if num_correct > 0:
                turn_correct += 1
            if num_correct_action > 0:
                turn_correct_action += 1
            if num_correct_value > 0:
                turn_correct_value += 1
            # normalize by the number of turns remaining
            turn_score += num_correct / num_remaining
            turn_score_action += num_correct_action / num_remaining
            turn_score_value += num_correct_value / num_remaining
            # current_score += num_correct / num_remaining

        # my_scores.append(current_score / convo_length)

    # normalize by total number of turns possible
    '''
    len(convo_ids): 200, len(turn_ids): 200
    '''
    # print(f"len(convo_ids): {len(convo_ids)}, len(turn_ids): {len(turn_ids)}")
    turn_acc = turn_correct / float(len(convo_ids))
    turn_acc_action = turn_correct_action / float(len(convo_ids))
    turn_acc_value = turn_correct_value / float(len(convo_ids))
    final_score = turn_score / float(len(convo_ids))
    final_score_action = turn_score_action / float(len(convo_ids))
    final_score_value = turn_score_value / float(len(convo_ids))
    
    em_action_score = sum(em_action) / float(len(em_action))
    em_value_score = sum(em_value) / float(len(em_value))
    em_joint_score = sum(em_joint) / float(len(em_joint))

    return {
        "EM_action": round(em_action_score, 4),
        "EM_value": round(em_value_score, 4),
        "EM_joint": round(em_joint_score, 4),
        "turn_acc_joint": round(turn_acc, 4),
        "turn_acc_action": round(turn_acc_action, 4),
        "turn_acc_value": round(turn_acc_value, 4),
        "CE_joint": round(final_score, 4),
        "CE_action": round(final_score_action, 4),
        "CE_value": round(final_score_value, 4)
    }

In [4]:
import warnings
warnings.filterwarnings('ignore')
import json

def call_LLM(dialogue, Action):

    prompt = hint_prompt.replace("[[DIALOG]]", dialogue)

    messages = []
    messages.append({"role": "system", "content": "You are a helpful assistant. You can generate a response to the user's input based on the given previous dialogue and the next action."})
    messages.append({"role": "user", "content": prompt})

    response = clientGPT3_5.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=messages,
        temperature=0.9,
        max_tokens=256,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    # print(response.choices[0].message.content)

    return response.choices[0].message.content

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/train_AST_abcd_forDeploy.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

context_list = []
distinct_dialogue = {}
distinct_dialogue["dialogue"] = []
distinct_dialogue["pred_action_value"] = []
distinct_dialogue["action_value_label"] = []
distinct_dialogue["convo_ids"] = []
distinct_dialogue["turn_ids"] = []
current_conv_id = 0
counter_success_dialogues = 0

if os.path.exists("data/processed/incremental_data.json"):
    # remove the file
    os.remove("data/processed/incremental_data.json")

for dialogue_i in range(len(data)):
    user_input = data[dialogue_i]['input']
    context = user_input
    # save the context to a tmp json file:
    # {"sample_id": 0, "target": "request time [none]", "input": "Context: hi, could you get me a restaurant booking on the 8th please? ", "target_data": "[\"request time\", [\"none\"]]"}
    save_context = {"sample_id": data[dialogue_i]['sample_id'], "convo_id": data[dialogue_i]['convo_id'], "turn_id": data[dialogue_i]['turn_id'], "target": data[dialogue_i]['target'], "input": context, "target_data": data[dialogue_i]['target_data']}
    if os.path.exists("tmp.json"):
        os.remove("tmp.json")
    # print(tmp_sample)
    with open(f"tmp.json", "a") as w:
        json.dump(save_context, w)
        w.write("\n")

    train_dataset, validation_dataset, test_dataset = datasets_loader.load_datasets()
    result_pred, result_label, sequence_scores = do_predict(trainer, test_dataset, tokenizer, training_args, data_args, model_args, max_length, num_beams)

    # Call the LLM model to generate the response
    # choose the prediction that has the highest sequence score
    max_score = -1
    max_score_index = -1
    for i in range(len(sequence_scores)):
        if sequence_scores[i] > max_score:
            max_score = sequence_scores[i]
            max_score_index = i

    action = result_pred[max_score_index]
    print("sequence_scores: ", sequence_scores)

    action = postprocess_predictions(action)

    print("context: ", context)
    print("all the pred: ", result_pred)
    print("agent: ", action)
    print("gold: ", data[dialogue_i]['target'])
    print("-" * 30)
    print()

    if data[dialogue_i]['convo_id'] != current_conv_id:
        if current_conv_id == 0:
            pass
        else:
            # calculate the CE metric
            metrics = compute_ast_acc_metrics(distinct_dialogue["pred_action_value"], distinct_dialogue["action_value_label"], distinct_dialogue["convo_ids"], distinct_dialogue["turn_ids"])
            
            # print("CE_joint: ", metrics["CE_joint"])
            # print("CE_action: ", metrics["CE_action"])
            # print("CE_value: ", metrics["CE_value"])
            if metrics["CE_joint"] > 0.5 and metrics["CE_action"] > 0.5 and metrics["CE_value"] > 0.5:
                print("CE_joint: ", metrics["CE_joint"])
                print("CE_action: ", metrics["CE_action"])
                print("CE_value: ", metrics["CE_value"])
                print("EM action: ", metrics["EM_action"])
                print("EM value: ", metrics["EM_value"])
                print("EM joint: ", metrics["EM_joint"])
                print(distinct_dialogue["pred_action_value"])
                print(distinct_dialogue["action_value_label"])

                counter_success_dialogues += 1

                for i in range(len(distinct_dialogue["dialogue"])):
                    # print(distinct_dialogue["dialogue"][i]["input"])
                    # print(distinct_dialogue["dialogue"][i]["predicted_action"])
                    # print(distinct_dialogue["dialogue"][i]["target"])
                    # print("-" * 30)
                    with open("data/processed/incremental_data.json", "a") as w:
                        json.dump(distinct_dialogue["dialogue"][i], w)
                        w.write("\n")

                if counter_success_dialogues == 2:
                    break

        distinct_dialogue["dialogue"] = []
        distinct_dialogue["pred_action_value"] = []
        distinct_dialogue["action_value_label"] = []
        distinct_dialogue["convo_ids"] = []
        distinct_dialogue["turn_ids"] = []

        save_context['predicted_action'] = action
        distinct_dialogue["dialogue"].append(save_context)
        distinct_dialogue["pred_action_value"].append(action)
        distinct_dialogue["action_value_label"].append(data[dialogue_i]['target'])
        distinct_dialogue["convo_ids"].append(data[dialogue_i]['convo_id'])
        distinct_dialogue["turn_ids"].append(data[dialogue_i]['turn_id'])
        current_conv_id = data[dialogue_i]['convo_id']
    else:
        save_context['predicted_action'] = action
        distinct_dialogue["dialogue"].append(save_context)
        distinct_dialogue["pred_action_value"].append(action)
        distinct_dialogue["action_value_label"].append(data[dialogue_i]['target'])
        distinct_dialogue["convo_ids"].append(data[dialogue_i]['convo_id'])
        distinct_dialogue["turn_ids"].append(data[dialogue_i]['turn_id'])


Generating test split: 1 examples [00:00, 36.07 examples/s]
Running tokenizer on test dataset: 100%|██████████| 1/1 [00:00<00:00, 14.29 examples/s]


sequence_scores:  [0.9948888  0.90496075 0.70858437 0.68808514]
context:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you.
all the pred:  ['pull-up-account [joyce wu]', 'verify-identity [joyce wu, pin_number>, pin_number>]', 'verify-identity [joyce wu, joyce wu, pin_number>]', 'ask-the-oracle [none]']
agent:  pull-up-account [joyce wu]
gold:  pull-up-account [joyce wu]
------------------------------



Generating test split: 1 examples [00:00, 67.64 examples/s]
Running tokenizer on test dataset: 100%|██████████| 1/1 [00:00<00:00, 34.26 examples/s]


sequence_scores:  [0.99388754 0.88257563 0.79155517 0.7662047 ]
context:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you. joyce what would you like to know? how can i get into my account if i forgot my password? i can help you reset your password. what is your username? my username is <username>
all the pred:  ['enter-details [username>]', 'verify-identity [joyce wu, pin_number>, pin_number>]', 'validate-purchase [username>, username>, username>]', 'verify-identity [joyce wu, username>, username>]']
agent:  enter-details [<username>]
gold:  enter-details [<username>]
------------------------------



Generating test split: 1 examples [00:00, 48.41 examples/s]
Running tokenizer on test dataset: 100%|██████████| 1/1 [00:00<00:00, 20.53 examples/s]


sequence_scores:  [0.99112546 0.8131446  0.79994786 0.6504266 ]
context:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you. joyce what would you like to know? how can i get into my account if i forgot my password? i can help you reset your password. what is your username? my username is <username> i am unable to get your old password, but i can create a new password for you. is that okay? that's fine! may i have your pin number? my pin number is <pin_number>
all the pred:  ['make-password [none]', 'verify-identity [joyce wu, pin_number>, pin_number>]', 'enter-details [username>]', 'verify-identity [joyce wu, pin_number>, pin_number>]']
agent:  make-password [none]
gold:  make-password [none]
------------------------------



Generating test split: 1 examples [00:00, 54.76 examples/s]
Running tokenizer on test dataset: 100%|██████████| 1/1 [00:00<00:00, 19.22 examples/s]


sequence_scores:  [0.9992626 0.5327671 0.5264842 0.5078061]
context:  Context: hello. how are you today? i would like to keep my premium subscription fee, so i'd like to go ahead and pay for it now. great. can i have your name please? my name is rodriguez domingo.
all the pred:  ['pull-up-account [rodriguez domingo]', 'select-faq [policy_3]', 'record-reason [previous subscription]', 'record-reason [prime subscription]']
agent:  pull-up-account [rodriguez domingo]
gold:  pull-up-account [rodriguez domingo]
------------------------------

CE_joint:  1.0
CE_action:  1.0
CE_value:  1.0
EM action:  1.0
EM value:  1.0
EM joint:  1.0
['pull-up-account [joyce wu]', 'enter-details [<username>]', 'make-password [none]']
['pull-up-account [joyce wu]', 'enter-details [<username>]', 'make-password [none]']


Generating test split: 1 examples [00:00, 83.68 examples/s]
Running tokenizer on test dataset: 100%|██████████| 1/1 [00:00<00:00, 51.40 examples/s]


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [13]:
a = [-0.0051, -0.0999, -0.3445, -0.3738]

print("a: ", a)
# from log to prob
a = np.exp(a)
print("a: ", a)

a:  [-0.0051, -0.0999, -0.3445, -0.3738]
a:  [0.99491298 0.90492791 0.70857455 0.68811452]


In [None]:
# # Load model directly
# from transformers import (
#     AutoConfig,
#     AutoModel,
#     T5ForConditionalGeneration,
#     AutoModelForSeq2SeqLM,
#     AutoTokenizer,
#     DataCollatorForSeq2Seq,
#     HfArgumentParser,
#     Seq2SeqTrainer,
#     set_seed,
#     MBartTokenizer,
#     MBartTokenizerFast,
# )
# import numpy as np

# tokenizer = AutoTokenizer.from_pretrained(
#     "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction_input_target_t5-small/checkpoint-22900",
#     use_fast=False,
#     revision="main",
#     use_auth_token=None,
# )

# model = T5ForConditionalGeneration.from_pretrained(
#     "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction_input_target_t5-small/checkpoint-22900"
# )

# # model.resize_token_embeddings(len(tokenizer))
# # model.is_encoder_decoder = True

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
# encoder_input_str = "Predict AST: Context: hello, thank you for contacting acmecorp, how may i help you today? hi. i am trying to make a purchase but my credit card keeps declining. i would happy to help you with this. lets try a few things, first can you look at the expiration date on the card to make sure it is still valid? ok"

# input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


# outputs = model.generate(
#     input_ids,
#     num_beams=4,
#     num_return_sequences=4,
#     no_repeat_ngram_size=1,
#     return_dict_in_generate=True, 
#     output_scores=True,
#     max_length=256,
#     min_length=5,
# )

# print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
# print(tokenizer.decode(outputs[0][1], skip_special_tokens=True))
# print(tokenizer.decode(outputs[0][2], skip_special_tokens=True))
# print(tokenizer.decode(outputs[0][3], skip_special_tokens=True))

# # transition_scores = model.compute_transition_scores(
# #     outputs.sequences, outputs.scores, normalize_logits=True
# # )

# # input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
# # generated_tokens = outputs.sequences[:, input_length:]

# # sequence_scores = []
# # for tok, score in zip(generated_tokens[0], transition_scores[0]):
# #     if tok == 1:
# #         break
# #     # | token | token string | logits | probability
# #     # print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
# #     sequence_scores.append(np.exp(score.numpy()))
# # print(np.average(sequence_scores))

# # sequence_scores = []
# # for tok, score in zip(generated_tokens[1], transition_scores[1]):
# #     if tok == 1:
# #         break
# #     # | token | token string | logits | probability
# #     # print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
# #     sequence_scores.append(np.exp(score.numpy()))
# # print(np.average(sequence_scores))

# # sequence_scores = []
# # for tok, score in zip(generated_tokens[2], transition_scores[2]):
# #     if tok == 1:
# #         break
# #     # | token | token string | logits | probability
# #     # print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
# #     sequence_scores.append(np.exp(score.numpy()))
# # print(np.average(sequence_scores))

# # sequence_scores = []
# # for tok, score in zip(generated_tokens[3], transition_scores[3]):
# #     if tok == 1:
# #         break
# #     # | token | token string | logits | probability
# #     # print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
# #     sequence_scores.append(np.exp(score.numpy()))
# # print(np.average(sequence_scores))




search-faq [none]
try-again [none]
search-policy [none]
enter-details [troublemaker]


In [4]:
import json

# Load model directly
from transformers import (
    AutoConfig,
    AutoModel,
    T5ForConditionalGeneration,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
)
import numpy as np

tokenizer = AutoTokenizer.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600",
    use_fast=False,
    revision="main",
    use_auth_token=None,
)

model = T5ForConditionalGeneration.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600"
)

# model.resize_token_embeddings(len(tokenizer))
# model.is_encoder_decoder = True

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/train_AST_abcd_forDeploy.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

for i in range(10):
    print("input: ", data[i]['input'])
    print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']

    input_ids = tokenizer(input_context, return_tensors="pt").input_ids
    outputs = model.generate(
        input_ids,
        num_beams=4,
        num_return_sequences=4,
        no_repeat_ngram_size=1,
        return_dict_in_generate=True, 
        output_scores=True,
        max_length=256,
        min_length=5,
    )
    print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
    print(tokenizer.decode(outputs[0][1], skip_special_tokens=True))
    print(tokenizer.decode(outputs[0][2], skip_special_tokens=True))
    print(tokenizer.decode(outputs[0][3], skip_special_tokens=True))

    print('---------------------------------')
    print()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


input:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you.
target:  pull-up-account [joyce wu]




AttributeError: 'BeamSearchEncoderDecoderOutput' object has no attribute 'sequence_scores'

#### directly load the model (OK)

In [2]:
import json

# Load model directly
from transformers import (
    AutoConfig,
    AutoModel,
    T5ForConditionalGeneration,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
    BeamSearchScorer,
    LogitsProcessorList,
    MinLengthLogitsProcessor
)
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600",
    use_fast=False,
    # revision="main",
    # use_auth_token=None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600"
)
model.to(device)

# model.resize_token_embeddings(len(tokenizer))
# model.is_encoder_decoder = True

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/train_AST_abcd_forDeploy.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

for i in range(10):
    print("input: ", data[i]['input'])
    print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt").input_ids
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 4
    # define decoder start token ids
    input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
    input_ids = input_ids * model.config.decoder_start_token_id

    model_kwargs = {
        "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
    }

    beam_scorer = BeamSearchScorer(
        batch_size=1,
        num_beams=num_beams,
        device=model.device,
        num_beam_hyps_to_keep=num_beams,
    )

    logits_processor = LogitsProcessorList(
        [
            MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ]
    )

    outputs = model.beam_search(
        input_ids, 
        beam_scorer, 
        logits_processor=logits_processor, 
        output_scores=True, 
        return_dict_in_generate = True,
        **model_kwargs
    )

    # print("sequence_scores: ", outputs.sequence_scores)
    # print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][1], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][2], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][3], skip_special_tokens=True))

    # print('---------------------------------')
    # print()
    # print("outputs: ", outputs
    # print("sequence_scores: ", outputs.sequences_scores)
    scores = outputs.sequences_scores.cpu().numpy()
    for i in range(len(scores)):
        scores[i] = np.exp(scores[i])
    # print("scores: ", scores)
    # normalize the scores
    scores = scores / np.sum(scores)
    print("scores: ", scores)
    print(tokenizer.decode(outputs.sequences[0], skip_special_tokens=True))
    print(tokenizer.decode(outputs.sequences[1], skip_special_tokens=True))
    print(tokenizer.decode(outputs.sequences[2], skip_special_tokens=True))
    print(tokenizer.decode(outputs.sequences[3], skip_special_tokens=True))
    # output = tokenizer.batch_decode(outputs[0][0], skip_special_tokens=True)
    # print(output)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a custom generation loop instead.


input:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you.
target:  pull-up-account [joyce wu]




scores:  [0.3017998  0.2745201  0.21494927 0.20873083]
pull-up-account [joyce wu]
verify-identity [joyce wu, pin_number>, pin_number>]
verify-identity [joyce wu, joyce wu, pin_number>]
ask-the-oracle [none]
input:  Context: hello! how may i help you? i just wanted to get some more information about my order. i can't access my account because i forgot my password. okay.may i have your full name and account id? my name is joyce wu and my account id is <pin_number> thank you. joyce what would you like to know? how can i get into my account if i forgot my password? i can help you reset your password. what is your username? my username is <username>
target:  enter-details [<username>]
scores:  [0.2894068  0.25699425 0.2304903  0.22310862]
enter-details [username>]
verify-identity [joyce wu, pin_number>, pin_number>]
validate-purchase [username>, username>, username>]
verify-identity [joyce wu, username>, username>]
input:  Context: hello! how may i help you? i just wanted to get some more i

#### directly load the model + chainedPrior

In [3]:
import json

# Load model directly
from transformers import (
    AutoConfig,
    AutoModel,
    T5ForConditionalGeneration,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
    BeamSearchScorer,
    LogitsProcessorList,
    MinLengthLogitsProcessor
)
import numpy as np
import torch
from aalpy.utils import load_automaton_from_file, save_automaton_to_file, visualize_automaton, generate_random_dfa
from aalpy.automata import Dfa
from aalpy.SULs import AutomatonSUL
from aalpy.oracles import RandomWalkEqOracle
from aalpy.learning_algs import run_Lstar, run_KV

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load an automaton
automaton = load_automaton_from_file('/research/d5/gds/xywen22/project/llm_framework/chainPrior/learned_mdp2.dot', automaton_type='mdp')
# print(automaton)
# visualize the automaton
# visualize_automaton(automaton)
automaton = str(automaton)
print(automaton)

automaton_splits = automaton.split('\n')
# print(automaton_splits)
automaton_states = automaton_splits[1:33]
# ['s0 [label="init"];', 's1 [label="pull-up-account"];']
automaton_transitions = automaton_splits[33:-4]
# ['s0 -> s0  [label="init:1.0"];', 's0 -> s1  [label="action:0.03"];']

state_mapping = {}
for state in automaton_states:
    state_name = state.split(' ')[0]
    state_label = state.split('[label="')[1].split('"];')[0]
    state_mapping[state_name] = state_label

transition_mapping = {}
for transition in automaton_transitions:
    transition_split = transition.split('->')
    source_state = transition_split[0].strip()
    target_state = transition_split[1].strip().split(' ')[0]
    transition_label = transition_split[1].split('[label="')[1].split('"];')[0]
    transition_action = transition_label.split(':')[0]
    transition_freq = float(transition_label.split(':')[1])
    transition_mapping[(state_mapping[source_state], state_mapping[target_state])] = (transition_action, transition_freq)

# from s0 to s31, if some pair of states are not in the transition_mapping, then the frequency is 0
for i in range(32):
    for j in range(32):
        if (state_mapping[f's{i}'], state_mapping[f's{j}']) not in transition_mapping:
            transition_mapping[(state_mapping[f's{i}'], state_mapping[f's{j}'])] = ('unknown', 0.0)

print("state_mapping: ", state_mapping)
print("transition_mapping: ", transition_mapping)


tokenizer = AutoTokenizer.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600",
    use_fast=False,
    # revision="main",
    # use_auth_token=None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWOAction10P_input_target_t5-small/checkpoint-4600"
)
model.to(device)

# model.resize_token_embeddings(len(tokenizer))
# model.is_encoder_decoder = True

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/train_AST_abcd_forDeploy.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

# split the data into convo_ids
convos = {}
for i in range(len(data)):
    convo_id = data[i]['convo_id']
    if convo_id not in convos:
        convos[convo_id] = []
    convos[convo_id].append(data[i])

counter = 0
for convo_id, convo_data in convos.items():
    if counter == 10:
        break
    print("convo_id: ", convo_id)
    previous_actions = ['init']
    for i in range(len(convo_data)):
        print("input: ", convo_data[i]['input'])
        print("target: ", convo_data[i]['target'])

        input_context = "Predict AST: " + convo_data[i]['input']
        encoder_input_ids = tokenizer(input_context, return_tensors="pt").input_ids
        encoder_input_ids = encoder_input_ids.to(device)

        # lets run beam search using 3 beams
        num_beams = 4
        # define decoder start token ids
        input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        input_ids = input_ids * model.config.decoder_start_token_id

        model_kwargs = {
            "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
        }

        beam_scorer = BeamSearchScorer(
            batch_size=1,
            num_beams=num_beams,
            device=model.device,
            num_beam_hyps_to_keep=num_beams,
        )

        logits_processor = LogitsProcessorList(
            [
                MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
            ]
        )

        outputs = model.beam_search(
            input_ids, 
            beam_scorer, 
            logits_processor=logits_processor, 
            output_scores=True, 
            return_dict_in_generate = True,
            **model_kwargs
        )

        # print("sequence_scores: ", outputs.sequence_scores)
        # print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
        # print(tokenizer.decode(outputs[0][1], skip_special_tokens=True))
        # print(tokenizer.decode(outputs[0][2], skip_special_tokens=True))
        # print(tokenizer.decode(outputs[0][3], skip_special_tokens=True))

        # print('---------------------------------')
        # print()
        # print("outputs: ", outputs
        # print("sequence_scores: ", outputs.sequences_scores)
        scores = outputs.sequences_scores.cpu().numpy()
        for j in range(len(scores)):
            scores[j] = np.exp(scores[j])
        # print("scores: ", scores)
        # normalize the scores
        scores = scores / np.sum(scores)
        print("scores: ", scores)
        action1 = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(' ')[0].strip()
        action2 = tokenizer.decode(outputs.sequences[1], skip_special_tokens=True).split(' ')[0].strip()
        action3 = tokenizer.decode(outputs.sequences[2], skip_special_tokens=True).split(' ')[0].strip()
        action4 = tokenizer.decode(outputs.sequences[3], skip_special_tokens=True).split(' ')[0].strip()
        # output = tokenizer.batch_decode(outputs[0][0], skip_special_tokens=True)
        # print(output)
        print(f"actions: {action1}, {action2}, {action3}, {action4}")

        rate1 = transition_mapping[(previous_actions[-1], action1)][1]
        rate2 = transition_mapping[(previous_actions[-1], action2)][1]
        rate3 = transition_mapping[(previous_actions[-1], action3)][1]
        rate4 = transition_mapping[(previous_actions[-1], action4)][1]

        print(f"pairs: {previous_actions[-1]} -> {action1}, {previous_actions[-1]} -> {action2}, {previous_actions[-1]} -> {action3}, {previous_actions[-1]} -> {action4}")
        print(f"rates: {rate1}, {rate2}, {rate3}, {rate4}")

        previous_actions.append(convo_data[i]['target'].split(' ')[0].strip())
        print("-" * 30)
        print()

    counter += 1



    

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


digraph learnedModel {
s0 [label="init"];
s1 [label="pull-up-account"];
s2 [label="enter-details"];
s3 [label="verify-identity"];
s4 [label="make-password"];
s5 [label="search-timing"];
s6 [label="search-policy"];
s7 [label="validate-purchase"];
s8 [label="search-faq"];
s9 [label="membership"];
s10 [label="search-boots"];
s11 [label="try-again"];
s12 [label="ask-the-oracle"];
s13 [label="update-order"];
s14 [label="promo-code"];
s15 [label="update-account"];
s16 [label="search-membership"];
s17 [label="make-purchase"];
s18 [label="offer-refund"];
s19 [label="notify-team"];
s20 [label="record-reason"];
s21 [label="search-jeans"];
s22 [label="shipping-status"];
s23 [label="search-shirt"];
s24 [label="instructions"];
s25 [label="search-jacket"];
s26 [label="log-out-in"];
s27 [label="select-faq"];
s28 [label="subscription-status"];
s29 [label="send-link"];
s30 [label="search-pricing"];
s31 [label="end"];
s0 -> s0  [label="init:1.0"];
s0 -> s1  [label="action:0.03"];
s0 -> s2  [label="actio

In [8]:
import random

random.seed(0)

PUNCTUATIONS = ['.', ',', '!', '?', ';', ':']
DATASETS = ['cr', 'sst2', 'subj', 'pc', 'trec']
NUM_AUGS = [1, 2, 4, 8]
PUNC_RATIO = 0.3

# Insert punction words into a given sentence with the given ratio "punc_ratio"
def insert_punctuation_marks(sentence, punc_ratio=PUNC_RATIO):
	words = sentence.split(' ')
	new_line = []
	q = random.randint(1, int(punc_ratio * len(words) + 1))
	qs = random.sample(range(0, len(words)), q)

	for j, word in enumerate(words):
		if j in qs:
			new_line.append(PUNCTUATIONS[random.randint(0, len(PUNCTUATIONS)-1)])
			new_line.append(word)
		else:
			new_line.append(word)
	new_line = ' '.join(new_line)
	return new_line


def main(train_orig):
    for aug in NUM_AUGS:
        data_aug = []
        for line in train_orig:
            sentence = line
            for i in range(aug):
                sentence_aug = insert_punctuation_marks(sentence)
                data_aug.append(sentence_aug)
            data_aug.append(line)
        
        print("aug: ", aug)
        print("data_aug: ", data_aug)


if __name__ == "__main__":
    data = [
        "The pluger gets stuck & you have to manually extract the mess.",
        "The Hitachi is made in Malasyia, and looked cheap compared with the Makita, which is made in the USA.",
        "i must have heard this about a dozen times over the span of 2 weeks , when t-zones never worked .",
        "This model appears to be especially good.",
        "in any case , navigation by artist / album is ok however i miss an ability to navigate by folders .",
        "it's fast, feels great in your hand and looks great too.",
        "now that i have all my music on it there has n't been any more problems .",
        "i have stored around 60 cd 's ( at 160kbps ) on this and have barely touched the available memory .",
        "I do n't want to be a sensationalist but this computer is incredible.",
        "But it wouldn't let me turn the AV protection back on and it wouldn't stop with the alerts."
    ]
    main(data)

aug:  1
data_aug:  ['? The pluger gets stuck ? & you ! have to ? manually extract the mess.', 'The pluger gets stuck & you have to manually extract the mess.', 'The Hitachi is made in Malasyia, , and looked cheap compared with the Makita, which is made ! in the , USA.', 'The Hitachi is made in Malasyia, and looked cheap compared with the Makita, which is made in the USA.', 'i : must have . heard : this about a dozen ! times ? over the span of 2 weeks , when ; t-zones never . worked .', 'i must have heard this about a dozen times over the span of 2 weeks , when t-zones never worked .', 'This model ; appears : to be especially good.', 'This model appears to be especially good.', 'in any case , navigation by artist / album is ok however i miss an ? ability to ; navigate by folders .', 'in any case , navigation by artist / album is ok however i miss an ability to navigate by folders .', ". it's fast, feels great in your hand and : looks great ? too.", "it's fast, feels great in your hand a

In [10]:
import json

data = []
with open('./data/processed/validation_AST_multiwoz.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

dialogue_id = {}
for i in range(len(data)):
    if data[i]['convo_id'] not in dialogue_id:
        dialogue_id[data[i]['convo_id']] = 0
    dialogue_id[data[i]['convo_id']] += 1

counting_dic = {}
print("dialogue_id: ", dialogue_id)
for key, value in dialogue_id.items():
    if value not in counting_dic:
        counting_dic[value] = 0
    counting_dic[value] += 1
print("counting_dic: ", counting_dic)
# order the counting dic
counting_dic = dict(sorted(counting_dic.items(), key=lambda item: item[0]))
print("counting_dic: ", counting_dic)
# sum the counting dic
sum_counting_dic = 0
for key, value in counting_dic.items():
    sum_counting_dic += value
print("sum_counting_dic: ", sum_counting_dic)

top_p = 0
for key, value in counting_dic.items():
    print(f"dialogues with {key} turns: {value} ({value/sum_counting_dic:.2%})")
    top_p += value / sum_counting_dic
    print(f"the first {key} turns with top_p: {top_p:.2%}")


print("max dialogue_id: ", max(dialogue_id, key=dialogue_id.get), dialogue_id[max(dialogue_id, key=dialogue_id.get)])

dialogue_id:  {'PMUL0698.json': 4, 'PMUL3233.json': 9, 'SNG01627.json': 2, 'MUL1719.json': 4, 'MUL0242.json': 8, 'PMUL1072.json': 5, 'PMUL3048.json': 7, 'PMUL1100.json': 7, 'PMUL3979.json': 9, 'MUL1409.json': 7, 'PMUL4828.json': 10, 'SNG0329.json': 2, 'PMUL3314.json': 6, 'MUL1768.json': 4, 'MUL0293.json': 6, 'PMUL0420.json': 13, 'PMUL0858.json': 7, 'MUL1367.json': 9, 'MUL1271.json': 6, 'PMUL0928.json': 5, 'MUL1589.json': 6, 'PMUL3200.json': 4, 'MUL0398.json': 8, 'SNG01735.json': 3, 'PMUL4290.json': 10, 'SNG0551.json': 2, 'MUL2384.json': 6, 'SNG01993.json': 4, 'PMUL2235.json': 4, 'PMUL4075.json': 3, 'PMUL0724.json': 5, 'MUL2160.json': 7, 'PMUL1402.json': 6, 'PMUL1152.json': 5, 'PMUL1121.json': 5, 'SNG02071.json': 1, 'PMUL3215.json': 4, 'PMUL4833.json': 7, 'MUL0344.json': 10, 'MUL2418.json': 4, 'PMUL1181.json': 5, 'MUL1604.json': 4, 'PMUL0287.json': 5, 'MUL2064.json': 12, 'PMUL4581.json': 8, 'MUL1888.json': 5, 'MUL1603.json': 3, 'PMUL1591.json': 5, 'MUL2393.json': 9, 'MUL0300.json': 6, '