### Pipeline should be like:
1. Core: could predict next action given the current context
2. can load the data from the benchmark
3. connect the data to the core module, to get the run-time action prediction list
4. Use CE metric to evaluate the quality of the predicted action list
5. if good enough, then save the dialogues as the new data for further training

#### only generate the action list, not the full dialogue

In [1]:
import re
import numpy as np


def parse_ast_prediction(prediction_str):
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
    else:
        action_name = "MISSING"
        slots = ["MISSING"]

    return action_name, slots


def parse_context_action(context):
    # Regular expression pattern to match actions (XX-XX, XX-XX-XX or XX-XX-XX-XX)
    action_pattern = r"\b(\w+-\w+(?:-\w+){0,2})(?=\s*\[)"

    # Find all matches in the context string
    actions = re.findall(action_pattern, context)
    return actions

def normalize_log_probs(log_probs):
    """
    Normalizes an array of log probabilities using the Log-Sum-Exp trick
    to prevent numerical underflow.

    Parameters:
    - log_probs: A numpy array containing log probabilities.

    Returns:
    - An array of normalized log probabilities.
    """

    # Find the maximum log probability to avoid underflow issues
    max_log_prob = np.max(log_probs)
    
    # Subtract the max log probability and exponentiate the result
    probs_stable = np.exp(log_probs - max_log_prob) + 1e-9
    
    # Sum the stabilized probabilities
    prob_total_stable = np.sum(probs_stable) + 1e-9
    
    # Normalize the stabilized probabilities
    normalized_probs_stable = probs_stable / prob_total_stable
    
    # Convert back to log probabilities adding the subtracted max_log_prob
    normalized_log_probs_stable = np.log(normalized_probs_stable) + max_log_prob
    
    return normalized_log_probs_stable


In [2]:
from src.metrics import *
from types import SimpleNamespace


data_args = SimpleNamespace(
    test_file='/research/d1/gds/jyzhong/computation_models/LLMFramework/data/processed/test_AST_abcd_waction_full.json',
    max_predict_samples=1000,
)

training_args = SimpleNamespace(
    use_bert_score=False,
    use_ast_metrics=True
)


def create_compute_metric_fct( data_args, training_args):
    def decode(preds, labels):
        # if isinstance(preds, tuple):
        #     preds = preds[0]
        # if data_args.ignore_pad_token_for_loss:
        #     # Replace -100 in the labels as we can't decode them.
        #     preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        # decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        # if data_args.ignore_pad_token_for_loss:
        #     # Replace -100 in the labels as we can't decode them.
        #     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        # decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(preds, labels)

        # model_path = Path(model_args.model_name_or_path)
        # file_name = "pred_mwoz.txt" if training_args.is_mwoz else "preds_test_set.txt"
        # if not model_path.exists():
        #     # model name
        #     preds_file_path = Path(training_args.output_dir) / file_name
        # else:
        #     preds_file_path = model_path / file_name

        # with preds_file_path.open("w") as f:
        #     for pred, label in zip(decoded_preds, decoded_labels):
        #         label = label.replace("\n", " ")
        #         pred = pred.replace("\n", " ")
        #         f.write(f"{pred}\t{label}" + "\n")

        return decoded_preds, decoded_labels

    def parse_predictions(eval_preds):
        preds, labels = eval_preds
        decoded_predictions, decoded_labels = decode(preds, labels)
        return decoded_predictions, decoded_labels

    def compute_em_and_ce(eval_preds):
        predictions, labels = parse_predictions(eval_preds)
        predictions = [parse_workflow_string(w) for w in predictions]
        labels = [parse_workflow_string(w) for w in labels]
        return compute_metrics(labels, predictions, use_bert_score=training_args.use_bert_score)

    def compute_cds_metrics(eval_preds):
        predictions, labels = parse_predictions(eval_preds)
        # print("data_args.test_file", data_args.test_file)
        # print("data_args.max_predict_samples", data_args.max_predict_samples)
        convo_ids, turn_ids = load_raw_test_dataset(data_args.test_file, data_args.max_predict_samples)
        return compute_cds_em_and_ce(predictions, labels, convo_ids, turn_ids)
    
    def compute_ast_metrics(eval_preds, sequence_scores=None):
        predictions, labels = parse_predictions(eval_preds)
        is_eval = True if len(labels) == 3684 else False
        conv_ids, turn_ids = load_raw_test_dataset(data_args.test_file, data_args.max_predict_samples)
        print("predictions:", len(predictions))
        print("labels:", len(labels))
        print("conv_ids:", len(conv_ids))
        print("turn_ids:", len(turn_ids))
        # print("sequence_scores:", len(sequence_scores))
        '''
        predictions: 200
        labels: 50
        conv_ids: 50
        turn_ids: 50
        sequence_scores: (200,)
        '''
        return compute_ast_acc_metrics(predictions, labels, conv_ids, turn_ids, sequence_scores)

    def no_metrics(eval_preds):
        # Evaluation will be done during post hf_training
        preds, labels = eval_preds
        decode(preds, labels)
        return {}

    return compute_ast_metrics

  from .autonotebook import tqdm as notebook_tqdm


#### directly load the model (OK)

In [3]:
import json

# Load model directly
from transformers import (
    AutoConfig,
    AutoModel,
    T5ForConditionalGeneration,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
    BeamSearchScorer,
    LogitsProcessorList,
    MinLengthLogitsProcessor
)
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(
    "/research/d1/gds/jyzhong/computation_models/LLMFramework/results/abcdASTWActionAll_input_target_t5-small/checkpoint-45700",
    use_fast=False,
    # revision="main",
    # use_auth_token=None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    "/research/d1/gds/jyzhong/computation_models/LLMFramework/results/abcdASTWActionAll_input_target_t5-small/checkpoint-45700"
)
model.to(device)

# model.resize_token_embeddings(len(tokenizer))
# model.is_encoder_decoder = True

data = []
with open('/research/d1/gds/jyzhong/computation_models/LLMFramework/data/processed/test_AST_abcd_waction_full.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
from aalpy.utils import load_automaton_from_file

# load an automaton
chained_proir = load_automaton_from_file('/research/d1/gds/jyzhong/computation_models/mdp_guildline_half.dot', automaton_type='mc')

pred_pm = []
preds = []
labels = []


# for i in tqdm(range(len(data))):
for i in tqdm(range(1000)):
    # print('=' * 100)
    # print("input: ", data[i]['input'])
    # print("target: ", data[i]['target'])
    labels.append(data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt")
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 35
    # define decoder start token ids
    outputs = model.generate(**encoder_input_ids, max_new_tokens=100, return_dict_in_generate=True, output_scores=True, do_sample=False, num_beams=num_beams, num_return_sequences=num_beams,)
    # print(outputs.scores)
    try:
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=False
        )
        output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
        length_penalty = model.generation_config.length_penalty
        reconstructed_scores = transition_scores.sum(axis=1).cpu().numpy() / (output_length**length_penalty)
    except:
        reconstructed_scores = np.zeros(num_beams)
        print(f"cannot compute {i}, set prob to 0")

    log_prob = reconstructed_scores
    # log_prob = np.sum(log_prob, axis=1)
    # log_prob = normalize_log_probs(log_prob)
    # print(f'scores: {log_prob}')

    policy_model_action = []
    policy_model_slot = []
    policy_model_prob = []
    policy_max_idx = np.argmax(log_prob)

    # print('+' * 20, 'policy model', '+' * 20)
    for k in range(num_beams):
        pred_str = tokenizer.decode(outputs.sequences[k], skip_special_tokens=True)
        action_name, slots = parse_ast_prediction(pred_str)

        policy_model_action.append(action_name)
        policy_model_slot.append(slots)
        if action_name != 'MISSING':
            policy_model_prob.append(log_prob[k])
        else:
            policy_model_prob.append(-1000)

        if k == policy_max_idx:
            pred_pm.append(tokenizer.decode(outputs.sequences[policy_max_idx], skip_special_tokens=True))
            # print("policy model top pred str: ", tokenizer.decode(outputs.sequences[policy_max_idx], skip_special_tokens=True))
            # print(f"policy model top parsed: {action_name}, {slots}; score: {log_prob[policy_max_idx]}")


    # print('+' * 20, 'chained prior', '+' * 20)
    prev_actions = parse_context_action(input_context)
    prev_actions = ['init'] + prev_actions

    init_state = chained_proir.states[0]
    curr_state = init_state

    prev_actions_with_prob = [(curr_state.output, 1)]
    prev_prob = 0
    idx = 0
    while curr_state.output != prev_actions[-1] and idx <= len(prev_actions) - 1:
        
        idx += 1
        matched = False
        for prosible_state, prob in curr_state.transitions:
            if prosible_state.output == prev_actions[idx]:
                prev_actions_with_prob.append((prosible_state.output, np.log(prob)))
                prev_prob += np.log(prob)
                curr_state = prosible_state
                matched = True
        
        if matched == False:
            break
    # print(f"context parsed prob: {prev_actions_with_prob}, total previous prob: {prev_prob}")
    
    chained_prior_pred = []
    chained_prior_prob = []
    for state in chained_proir.states:
        if state.output == prev_actions[-1]:
            for prosible_state, prob in state.transitions:
                chained_prior_pred.append(prosible_state.output)
                chained_prior_prob.append(np.log(prob) + prev_prob)
    chained_prior_prob = normalize_log_probs(chained_prior_prob) * 5
    # print(f"possible next: {chained_prior_pred}, prob: {chained_prior_prob}")


    alpha = 0.6
    for i, action in enumerate(policy_model_action):
        for j, c_action in enumerate(chained_prior_pred):
            if action == c_action:
                policy_model_prob[i] = policy_model_prob[i] * (1 - alpha) + chained_prior_prob[j] * alpha

    # print('+' * 20, 'chained prior + policy model', '+' * 20)
    top_idx = np.argmax(np.array(policy_model_prob))
    # print("MERGED top pred str: ", f"")
    # print(f"MERGED top parsed: {policy_model_action[top_idx]}, {policy_model_slot[top_idx]}; score: {policy_model_prob[top_idx]}")

    preds.append(tokenizer.decode(outputs.sequences[top_idx], skip_special_tokens=True))

print(preds[:5])
print(labels[:5])

  5%|▍         | 48/1000 [00:33<09:00,  1.76it/s]

cannot compute 47, set prob to 0


 30%|███       | 305/1000 [02:52<05:05,  2.28it/s]

cannot compute 304, set prob to 0


 95%|█████████▌| 952/1000 [08:30<00:17,  2.77it/s]

cannot compute 951, set prob to 0


100%|██████████| 1000/1000 [08:54<00:00,  1.87it/s]

['pull-up-account [chloe zhang]', 'search-timing [none]', 'pull-up-account [albert sanders], account_id>, account_id>]', 'verify-identity [albert sanders, account_id>, account_id>] "account_id>]', "enter-details [account_id>], account_id>, account_id>] 'i've been charged twice for it. how can i help?silver?silver?silver?silver?silver?silver?silver?silver?silver?soulver?soulver?soulver?"]
['search-faq [none]', 'search-timing [none]', 'pull-up-account [albert sanders]', 'verify-identity [albert sanders, <account_id>, <account_id>]', 'membership [silver]']





# Chained + PM

In [5]:
metric = create_compute_metric_fct(data_args, training_args)
results = metric((preds, labels))

predictions: 1000
labels: 1000
conv_ids: 1000
turn_ids: 1000
['search-faq [none]', 'search-timing [none]', 'pull-up-account [albert sanders]', 'verify-identity [albert sanders, <account_id>, <account_id>]', 'membership [silver]', 'ask-the-oracle [none]', 'search-faq [none]', 'search-shirt [none]', 'search-faq [none]', 'search-policy [none]']
['pull-up-account [chloe zhang]', 'search-timing [none]', 'pull-up-account [albert sanders], account_id>, account_id>]', 'verify-identity [albert sanders, account_id>, account_id>] "account_id>]', "enter-details [account_id>], account_id>, account_id>] 'i've been charged twice for it.\nhow can i help?silver?silver?silver?silver?silver?silver?silver?silver?silver?soulver?soulver?soulver?", 'offer-refund [20) - see if account_id> account_id> account_id> account_id> account_id> account_id> account_id> account_id> account_id>]', 'search-faq [none]', 'purchase-tommy hilifiger shirt [tommy hilfiger hilfiger hilfiger hilfiger hilfiger hilfiger hilfiger hi

In [6]:
print(results)

{'EM_action': 0.2143, 'EM_value': 0.1071, 'EM_joint': 0.0786, 'turn_acc_joint': 0.34, 'turn_acc_action': 0.569, 'turn_acc_value': 0.407, 'CE_joint': 0.241, 'CE_action': 0.407, 'CE_value': 0.2962}


# PM Only

In [7]:
metric = create_compute_metric_fct(data_args, training_args)
results = metric((pred_pm, labels))
results

predictions: 1000
labels: 1000
conv_ids: 1000
turn_ids: 1000
['search-faq [none]', 'search-timing [none]', 'pull-up-account [albert sanders]', 'verify-identity [albert sanders, <account_id>, <account_id>]', 'membership [silver]', 'ask-the-oracle [none]', 'search-faq [none]', 'search-shirt [none]', 'search-faq [none]', 'search-policy [none]']
['verify-identity [chloe zhang, michael kors kors kors kors kors kors kors kors kors kors kors kors kors korkors kors kors kors kors kors kors kors kors kors kors kors kors kors', 'verify-identity [chloe zhang, search-timingsnormanuelnuelnuelnuelnuelnuelnuelnueltuelnuelnuelnuelnuelnuelnuelnuelnuelnuelnuelnuel', 'verify-identity [albert sanders, account_id>, account_id>]: "i recently signed up for a subscription but it seems like you guys charged me twice for it.\nlet\'s fix that.\nlet\'s try to fix that.\nmay i have your full name and order id?i have your full account id: account_id>]', 'verify-identity [albert sanders, account_id>, account_id>] "a

{'EM_action': 0.0036,
 'EM_value': 0.0321,
 'EM_joint': 0.0036,
 'turn_acc_joint': 0.107,
 'turn_acc_action': 0.229,
 'turn_acc_value': 0.241,
 'CE_joint': 0.0722,
 'CE_action': 0.1361,
 'CE_value': 0.1357}