### Pipeline should be like:
1. Core: could predict next action given the current context
2. can load the data from the benchmark
3. connect the data to the core module, to get the run-time action prediction list
4. Use CE metric to evaluate the quality of the predicted action list
5. if good enough, then save the dialogues as the new data for further training

In [41]:
import re


def parse_ast_prediction(prediction_str):
    match = re.match(r"(.*)\[(.*)]", prediction_str)
    if match:
        # action w/ value
        action_name = match.group(1).strip()
        slot_str = match.group(2)
        slot_str = slot_str.replace(";", ",")
        slots = [s.strip() for s in slot_str.split(",")]
        for i in range(len(slots)):
            if slots[i].endswith(">") and not slots[i].startswith("<"):
                # add "<" to the beginning of the slot
                slots[i] = "<" + slots[i]
            if slots[i].startswith("<") and not slots[i].endswith(">"):
                # add ">" to the end of the slot
                slots[i] = slots[i] + ">"
    else:
        action_name = "MISSING"
        slots = ["MISSING"]

    return action_name, slots


def parse_context_action(context):
    # Regular expression pattern to match actions (XX-XX, XX-XX-XX or XX-XX-XX-XX)
    action_pattern = r"\b(\w+-\w+(?:-\w+){0,2})(?=\s*\[)"

    # Find all matches in the context string
    actions = re.findall(action_pattern, context)
    return actions

#### directly load the model (OK)

In [42]:
import json

# Load model directly
from transformers import (
    AutoConfig,
    AutoModel,
    T5ForConditionalGeneration,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    set_seed,
    MBartTokenizer,
    MBartTokenizerFast,
    BeamSearchScorer,
    LogitsProcessorList,
    MinLengthLogitsProcessor
)
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWMostPActionChainAll_input_target_t5-small/",
    use_fast=False,
    # revision="main",
    # use_auth_token=None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    "/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/results/abcdASTWMostPActionChainAll_input_target_t5-small/"
)
model.to(device)

# model.resize_token_embeddings(len(tokenizer))
# model.is_encoder_decoder = True

data = []
with open('/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/data/processed/dev_AST_abcd_wmostpaction_chain_1p.json', 'r') as file:
    for line in file:
        json_data = json.loads(line)
        data.append(json_data)

for i in range(10):
    print("input: ", data[i]['input'])
    print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt").input_ids
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 4
    # define decoder start token ids
    input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
    input_ids = input_ids * model.config.decoder_start_token_id

    model_kwargs = {
        "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
    }

    beam_scorer = BeamSearchScorer(
        batch_size=1,
        num_beams=num_beams,
        device=model.device,
        num_beam_hyps_to_keep=num_beams,
    )

    logits_processor = LogitsProcessorList(
        [
            MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ]
    )

    outputs = model.beam_search(
        input_ids, 
        beam_scorer, 
        logits_processor=logits_processor, 
        output_scores=True, 
        return_dict_in_generate = True,
        **model_kwargs
    )

    # print("sequence_scores: ", outputs.sequence_scores)
    # print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][1], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][2], skip_special_tokens=True))
    # print(tokenizer.decode(outputs[0][3], skip_special_tokens=True))

    # print('---------------------------------')
    # print()
    # print("outputs: ", outputs
    # print("sequence_scores: ", outputs.sequences_scores)
    scores = outputs.sequences_scores.cpu().numpy()
    # for i in range(len(scores)):
    #     scores[i] = np.exp(scores[i])
    # print("scores: ", scores)
    # # normalize the scores
    # scores = scores / np.sum(scores)
    print("scores: ", scores)

    for k in range(3):
        pred_str = tokenizer.decode(outputs.sequences[k], skip_special_tokens=True)
        print("pred str: ", pred_str)
        action_name, slots = parse_ast_prediction(pred_str)
        print(f"parsed results: {action_name}, {slots}")

    actions = parse_context_action(input_context)
    print(f"context parsed results: {actions}")
    # output = tokenizer.batch_decode(outputs[0][0], skip_special_tokens=True)
    # print(output)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


input:  Context: hello. how may i help you today? i was trying to see my history but i can't remember my username. i can definitely help you with that. what is your name and email address? david williams. <email>. Possible Actions: [pull-up-account, enter-details, verify-identity, make-password, search-timing, search-policy, validate-purchase, search-faq, membership, search-boots, try-again, ask-the-oracle, update-order, promo-code, update-account, search-membership, make-purchase, offer-refund, notify-team, record-reason, search-jeans, shipping-status, search-shirt, instructions, search-jacket, log-out-in, select-faq, subscription-status, send-link, search-pricing]
target:  pull-up-account [david williams]




scores:  [-4.8543201e-04 -2.4581455e-01 -6.3823980e-01 -6.7502719e-01]
pred str:  pull-up-account [david williams]
parsed results: pull-up-account, ['david williams']
pred str:  verify-identity [david williams, n/a, n/a]
parsed results: verify-identity, ['david williams', 'n/a', 'n/a']
pred str:  verify-identity [david williams, n/a, n/a [none]
parsed results: verify-identity [david williams, n/a, n/a, ['none']
context parsed results: []
input:  Context: hello. how may i help you today? i was trying to see my history but i can't remember my username. i can definitely help you with that. what is your name and email address? david williams. <email>. pull-up-account ['david williams'] can i also have your zip code and phone number please? <zip_code>. <phone>. thank you so much for verifying that. one moment while i pull this up. Possible Actions: [verify-identity, validate-purchase, record-reason, pull-up-account]
target:  verify-identity [david williams, <zip_code>, <zip_code>]
scores:  

In [40]:
from aalpy.utils import load_automaton_from_file
import random

# load an automaton
chained_proir = load_automaton_from_file("/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/chainPrior/mdp_guildline_half.dot", automaton_type='mc')

data_index = [random.randint(0, len(data)) for i in range(10)]

for i in data_index:
    print('=' * 100)
    print("input: ", data[i]['input'])
    print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt")
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 50
    # define decoder start token ids
    outputs = model.generate(**encoder_input_ids, max_new_tokens=256, return_dict_in_generate=True, output_scores=True, do_sample=False, num_beams=num_beams, num_return_sequences=num_beams,)
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=False
    )

    log_prob = transition_scores.cpu().numpy()
    # print("log_prob: ", log_prob[0])
    log_prob = np.sum(log_prob, axis=1)
    print(f'scores: {log_prob}')

    policy_model_action = []
    policy_model_slot = []
    policy_model_prob = []

    print('+' * 20, 'policy model', '+' * 20)
    for k in range(num_beams):
        pred_str = tokenizer.decode(outputs.sequences[k], skip_special_tokens=True)
        action_name, slots = parse_ast_prediction(pred_str)
        # print(f"policy model top parsed: {action_name}, {slots}; score: {log_prob[k]}")

        policy_model_action.append(action_name)
        policy_model_slot.append(slots)
        if action_name != 'MISSING':
            policy_model_prob.append(log_prob[k])
        else:
            policy_model_prob.append(-10000)

        if k == 0:
            print("policy model top pred str: ", pred_str)
            print(f"policy model top parsed: {action_name}, {slots}; score: {log_prob[k]}")


    print('+' * 20, 'chained prior', '+' * 20)
    prev_actions = parse_context_action(input_context)
    prev_actions = ['init'] + prev_actions
    print(f"context parsed: {prev_actions}")

    init_state = chained_proir.states[0]
    curr_state = init_state

    prev_actions_with_prob = [(curr_state.output, 1)]
    prev_prob = 0
    idx = 0
    while curr_state.output != prev_actions[-1] and idx <= len(prev_actions) - 1:
        print(f"curr_state.output: {curr_state.output}")
        idx += 1
        matched = False
        for prosible_state, prob in curr_state.transitions:
            # print(f"prosible_state: {prosible_state.output}, prob: {prob}")
            if prosible_state.output == prev_actions[idx]:
                # print(f"prev_actions[idx]: {prev_actions[idx]}, curr_state.output: {curr_state.output}")
                prev_actions_with_prob.append((prosible_state.output, np.log(prob)))
                prev_prob += np.log(prob)
                curr_state = prosible_state
                matched = True
        
        if matched == False:
            break
    print(f"context parsed prob: {prev_actions_with_prob}, total previous prob: {prev_prob}")
    
    chained_prior_pred = []
    chained_prior_prob = []
    for state in chained_proir.states:
        # print(f"state: {state.output}")
        if state.output == prev_actions[-1]:
            for prosible_state, prob in state.transitions:
                chained_prior_pred.append(prosible_state.output)
                chained_prior_prob.append(np.log(prob) + prev_prob)
    print(f"possible next: {chained_prior_pred}, prob: {chained_prior_prob}")

    for i, action in enumerate(policy_model_action):
        # print(f"policy model action: {action}, prob: {policy_model_prob[i]}")
        if action == 'MISSING':
            defualt_next = np.argmax(np.array(chained_prior_prob))
            policy_model_action[i] = chained_prior_pred[defualt_next]
            policy_model_prob[i] = chained_prior_prob[defualt_next] * 1000
        else:
            for j, c_action in enumerate(chained_prior_pred):
                if action == c_action:
                    # print(f"policy model prob: {policy_model_prob[i]}, chained prior prob: {chained_prior_prob[j]}")
                    policy_model_prob[i] = policy_model_prob[i] + chained_prior_prob[j] * 100
    
    for i, action in enumerate(chained_prior_pred):
        if action not in policy_model_action:
            policy_model_action.append(action)
            policy_model_slot.append(['MISSING'])
            policy_model_prob.append(chained_prior_prob[i] * 1000)

    print('+' * 20, 'chained prior + policy model', '+' * 20)
    # for i in range(len(policy_model_action)):
    #     print(f"MERGED top parsed: {policy_model_action[i]}, {policy_model_slot[i]}; score: {policy_model_prob[i]}")
    top_idx = np.argmax(np.array(policy_model_prob))
    print("MERGED top pred str: ", f"{policy_model_action[top_idx]} {policy_model_slot[top_idx]}")
    print(f"MERGED top parsed: {policy_model_action[top_idx]}, {policy_model_slot[top_idx]}; score: {policy_model_prob[top_idx]}")
    print()

input:  Context: hi, how may i help you? got a promo code from you 5 days ago. and today i'm putting it in and it isn't working. saying it's invalid. okay, can i have your full name? i'm norman bouchard. pull-up-account ['norman bouchard'] had a code for this guess t shirt. ask-the-oracle ['none'] it appears that there is no error due to the system. may i ask what your membership level is? says i'm a gold member. membership ['gold'] great, as a gold member we value you very highly. we will generate a new code for your right away. Possible Actions: [enter-details, ask-the-oracle, offer-refund, search-faq, membership]
target:  promo-code [none]
scores: [-758.4534  -434.4535  -681.27234 -594.493   -535.6451  -444.94394
 -568.61646 -646.576   -480.01202 -625.2027  -657.003   -552.1121
 -549.9955  -414.09598 -661.4824  -522.9814  -424.09296 -505.14584
 -581.68677 -600.9825  -591.17584 -592.8374  -574.1764  -537.5526
 -542.9673  -590.82306 -530.23737 -567.1177  -577.81824 -590.5803
 -531.300

KeyboardInterrupt: 

#### V2:

In [43]:
from aalpy.utils import load_automaton_from_file
import random

# load an automaton
chained_proir = load_automaton_from_file("/research/d5/gds/xywen22/project/llm_framework/AST_abcd_part/chainPrior/mdp_guildline_half.dot", automaton_type='mc')

data_index = [random.randint(0, len(data)) for i in range(10)]

for i in data_index:
    print('=' * 100)
    print("input: ", data[i]['input'])
    print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt")
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 50
    # define decoder start token ids
    outputs = model.generate(**encoder_input_ids, max_new_tokens=256, return_dict_in_generate=True, output_scores=True, do_sample=False, num_beams=num_beams, num_return_sequences=num_beams,)
    transition_scores = model.compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=False
    )

    # log_prob = transition_scores.cpu().numpy()
    # print("log_prob: ", log_prob[0])
    # log_prob = np.sum(log_prob, axis=1)

    output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
    length_penalty = model.generation_config.length_penalty
    reconstructed_scores = transition_scores.sum(axis=1).cpu().numpy() / (output_length ** length_penalty)
    log_prob = reconstructed_scores

    print(f'scores: {log_prob}')

    policy_model_action = []
    policy_model_slot = []
    policy_model_prob = []
    policy_max_idx = np.argmax(log_prob)

    print('+' * 20, 'policy model', '+' * 20)
    for k in range(num_beams):
        pred_str = tokenizer.decode(outputs.sequences[k], skip_special_tokens=True)
        action_name, slots = parse_ast_prediction(pred_str)
        # print(f"policy model top parsed: {action_name}, {slots}; score: {log_prob[k]}")

        policy_model_action.append(action_name)
        policy_model_slot.append(slots)
        if action_name != 'MISSING':
            policy_model_prob.append(log_prob[k])
        else:
            policy_model_prob.append(-10000)

        if k == 0:
            print("policy model top pred str: ", pred_str)
            print(f"policy model top parsed: {action_name}, {slots}; score: {log_prob[k]}")


    print('+' * 20, 'chained prior', '+' * 20)
    prev_actions = parse_context_action(input_context)
    prev_actions = ['init'] + prev_actions
    print(f"context parsed: {prev_actions}")

    init_state = chained_proir.states[0]
    curr_state = init_state

    prev_actions_with_prob = [(curr_state.output, 1)]
    prev_prob = 0
    idx = 0
    while curr_state.output != prev_actions[-1] and idx <= len(prev_actions) - 1:
        print(f"curr_state.output: {curr_state.output}")
        idx += 1
        matched = False
        for prosible_state, prob in curr_state.transitions:
            # print(f"prosible_state: {prosible_state.output}, prob: {prob}")
            if prosible_state.output == prev_actions[idx]:
                # print(f"prev_actions[idx]: {prev_actions[idx]}, curr_state.output: {curr_state.output}")
                prev_actions_with_prob.append((prosible_state.output, np.log(prob)))
                prev_prob += np.log(prob)
                curr_state = prosible_state
                matched = True
        
        if matched == False:
            break
    print(f"context parsed prob: {prev_actions_with_prob}, total previous prob: {prev_prob}")
    
    chained_prior_pred = []
    chained_prior_prob = []
    for state in chained_proir.states:
        # print(f"state: {state.output}")
        if state.output == prev_actions[-1]:
            for prosible_state, prob in state.transitions:
                chained_prior_pred.append(prosible_state.output)
                chained_prior_prob.append(np.log(prob) + prev_prob)
    print(f"possible next: {chained_prior_pred}, prob: {chained_prior_prob}")

    for i, action in enumerate(policy_model_action):
        # print(f"policy model action: {action}, prob: {policy_model_prob[i]}")
        if action == 'MISSING':
            defualt_next = np.argmax(np.array(chained_prior_prob))
            policy_model_action[i] = chained_prior_pred[defualt_next]
            policy_model_prob[i] = chained_prior_prob[defualt_next] * 1000
        else:
            for j, c_action in enumerate(chained_prior_pred):
                if action == c_action:
                    # print(f"policy model prob: {policy_model_prob[i]}, chained prior prob: {chained_prior_prob[j]}")
                    policy_model_prob[i] = policy_model_prob[i] + chained_prior_prob[j] * 100
    
    for i, action in enumerate(chained_prior_pred):
        if action not in policy_model_action:
            policy_model_action.append(action)
            policy_model_slot.append(['MISSING'])
            policy_model_prob.append(chained_prior_prob[i] * 1000)

    print('+' * 20, 'chained prior + policy model', '+' * 20)
    # for i in range(len(policy_model_action)):
    #     print(f"MERGED top parsed: {policy_model_action[i]}, {policy_model_slot[i]}; score: {policy_model_prob[i]}")
    top_idx = np.argmax(np.array(policy_model_prob))
    print("MERGED top pred str: ", f"{policy_model_action[top_idx]} {policy_model_slot[top_idx]}")
    print(f"MERGED top parsed: {policy_model_action[top_idx]}, {policy_model_slot[top_idx]}; score: {policy_model_prob[top_idx]}")
    print()

input:  Context: thank you for contacting acmebrands! how can i help you today? i'm wondering about a weird fee that got added on to my order. that i'm also checking on. wasn't disclosed when i bought it this extra bit. i can take a look at that for you. could you provide me with your name and account id? rodriguez domingo. <account_id>. pull-up-account ['rodriguez domingo'] i've pulled up your account, mr. domingo. which order id was this for? <order_id>. a michael kors shirt. Possible Actions: [verify-identity, validate-purchase, record-reason, pull-up-account]
target:  verify-identity [rodriguez domingo, <account_id>, <account_id>]




scores: [-31.46646322 -36.88648546 -30.28423517 -33.19927572 -29.0894498
 -31.26869846 -30.45401476 -30.84985352 -32.54336825 -33.34528266
 -29.67690158 -30.17221347 -29.24696181 -31.44116821 -28.65219998
 -29.24461263 -29.81454536 -30.13176541 -31.69626194 -28.0696039
 -28.15845917 -28.79051883 -27.61348199 -29.53698176 -29.33618707
 -29.26768663 -26.15303819 -31.44916992 -28.70319824 -24.67834195
 -29.25456814 -26.88825989 -28.76027018 -29.3990153  -25.31437683
 -28.94166528 -25.58459778 -28.70853407 -26.20375784 -28.50030092
 -26.00545432 -26.13295201 -22.38342014 -24.38519734 -23.48244324
 -26.06386497 -23.81065704 -24.30553367 -24.958648   -23.45784505]
++++++++++++++++++++ policy model ++++++++++++++++++++
policy model top pred str:  verify-identity [rodriguez domingo, account_id>, account_id>]
policy model top parsed: verify-identity, ['rodriguez domingo', '<account_id>', '<account_id>']; score: -31.466463216145833
++++++++++++++++++++ chained prior ++++++++++++++++++++
context 

Token indices sequence length is longer than the specified maximum sequence length for this model (546 > 512). Running this sequence through the model will result in indexing errors


++++++++++++++++++++ chained prior ++++++++++++++++++++
context parsed: ['init', 'pull-up-account', 'validate-purchase']
curr_state.output: init
curr_state.output: pull-up-account
context parsed prob: [('init', 1), ('pull-up-account', -0.3042973319756445), ('validate-purchase', -1.4282186111358397)], total previous prob: -1.7325159431114843
possible next: ['record-reason', 'ask-the-oracle', 'pull-up-account', 'make-password', 'verify-identity', 'membership', 'offer-refund', 'n/a', 'membership-privileges', 'shipping-status', 'enter-details', 'subscription-status', 'validate-purchase', 'update-order', 'search-faq', 'promo-code', 'notify-team', 'make-purchase', 'try-again', 'ask-the-oracle', 'update-order', 'record-reason', 'enter-details', 'validate-purchase', 'notify-team', 'shipping-status', 'membership'], prob: [-2.8527188643719614, -4.10548183286733, -8.264364916227002, -9.36297720489511, -6.798027847433574, -4.269227004088348, -6.41853822572867, -4.575485462113065, -3.47687317344495

KeyboardInterrupt: 

In [None]:
from aalpy.utils import load_automaton_from_file

# load an automaton
chained_proir = load_automaton_from_file('/research/d1/gds/jyzhong/computation_models/mdp_guildline_half.dot', automaton_type='mc')

preds = []
labels = []


for i in tqdm(range(len(data))):
    # print('=' * 100)
    # print("input: ", data[i]['input'])
    # print("target: ", data[i]['target'])

    input_context = "Predict AST: " + data[i]['input']
    encoder_input_ids = tokenizer(input_context, return_tensors="pt")
    encoder_input_ids = encoder_input_ids.to(device)

    # lets run beam search using 3 beams
    num_beams = 20
    # define decoder start token ids
    outputs = model.generate(**encoder_input_ids, max_new_tokens=100, return_dict_in_generate=True, output_scores=True, do_sample=False, num_beams=num_beams, num_return_sequences=num_beams,)
    # print(outputs.scores)
    try:
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=False
        )
        output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
        length_penalty = model.generation_config.length_penalty
        reconstructed_scores = transition_scores.sum(axis=1).cpu().numpy() / (output_length**length_penalty)
    except:
        reconstructed_scores = np.zeros(num_beams)
        print(f"cannot compute {i}, set prob to 0")

    log_prob = reconstructed_scores
    # log_prob = np.sum(log_prob, axis=1)
    # log_prob = normalize_log_probs(log_prob)
    # print(f'scores: {log_prob}')

    policy_model_action = []
    policy_model_slot = []
    policy_model_prob = []
    policy_max_idx = np.argmax(log_prob)

    # print('+' * 20, 'policy model', '+' * 20)
    for k in range(num_beams):
        pred_str = tokenizer.decode(outputs.sequences[k], skip_special_tokens=True)
        action_name, slots = parse_ast_prediction(pred_str)

        policy_model_action.append(action_name)
        policy_model_slot.append(slots)
        if action_name != 'MISSING':
            policy_model_prob.appe