In [14]:
import sys
import os
import numpy as np

SCRIPT_DIR = os.path.dirname(os.path.abspath("..."))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from training.agent_trainer import END_KEY, RESPONSE_KEY_NL, CHAT_START_KEY
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

In [3]:

test_data = load_from_disk("/opt/home/bo_ling/dataset/modeling_data_v1.hf")['test']
test_data

Dataset({
    features: ['text'],
    num_rows: 16958
})

In [5]:
local_output_dir="/opt/home/agent_modeling_data_v1_checkpoint-60000"
#model, tokenizer = load_model_tokenizer_for_generate(local_output_dir)
import torch
tokenizer = AutoTokenizer.from_pretrained(local_output_dir, padding_side="left")
device = "cpu" # "cuda:3" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(local_output_dir, trust_remote_code=True).to(device)

In [6]:
eos_token_id = tokenizer.encode(END_KEY)
print(END_KEY, eos_token_id)

<|endofsentence|> [50400]


In [7]:
def generate_agent_response(
    texts: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    do_sample: bool = True,
    max_new_tokens: int = 256,
    top_p: float = 0.92,
    top_k: int = 0,
    **kwargs,
) -> str:
    #texts = texts.replace(END_KEY, "")
    input_ids = tokenizer(texts, return_tensors="pt").input_ids.to(device)

    response_key_token_id = tokenizer.encode(RESPONSE_KEY_NL)[0]
    end_key_token_id = tokenizer.encode(END_KEY)[0]
    gen_tokens = model.generate(
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        # Ensure generation stops once it generates "### End"
        eos_token_id=end_key_token_id,
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        **kwargs,
    )[0].cpu()

    # The response will be set to this variable if we can identify it.
    question_size = len(input_ids[0])
    decoded = tokenizer.decode(gen_tokens[question_size:]).strip()

    return decoded

In [47]:
def print_generated_chats(chat_str):
    chats = chat_str.replace("\n\n###\n", END_KEY).split(END_KEY)
    curr_text = chats[0] + END_KEY
    print(chats[0])
    for chat in chats[1:]:
        print(chat)
        if chat.startswith(RESPONSE_KEY_NL):
            generated = generate_agent_response(curr_text, model, tokenizer).strip(END_KEY)
            if generated.replace("\n", "").startswith(RESPONSE_KEY_NL.replace("\n", "")):
                print("\n")
                print("*"*40 + "START OUTPUT GENERATED" + "*"*40)
                print(generated)
                print("*"*40 + "END OUTPUT GENERATED" + "*"*40)
                print("\n")
        curr_text += chat + END_KEY

In [48]:
i = 0
chat_str = test_data[i]['text']
print_generated_chats(chat_str)
    

Specific Information: The restaurant name is Salt & Straw (Capitol Hill) and the order value is $29.36. The estimated delivery time is 18:38 and latest arrival time is 18:45. Order status is 3.

agent: Hi Kaye, welcome to Chat Support team. Thanks for being a member! I will be happy to assist you today. 

agent: I'm sorry to hear that your order is taking time longer than expected. I understand that your delivery time keeps changing. Allow me a moment while I am checking your order details. 


****************************************START OUTPUT GENERATED****************************************
agent: I can understand your concern, Kaye that your order is taking too long to arrive. I am sorry for the inconvenience caused to you. Please be with me, while I am assisting with your concern. 
****************************************END OUTPUT GENERATED****************************************



Kaye: Thanks Pammy 

agent: I have reviewed the status of your order and can see that your delive

In [49]:
i = 1
chat_str = test_data[i]['text']
print_generated_chats(chat_str)

Specific Information: The restaurant name is Chuy's (4544 McKinney Ave.) and the order value is $18.46. The estimated delivery time is 21:55 and latest arrival time is 21:50. Order status is 5.

agent: Hi Krystal. Welcome to Chat Support team. I will be assisting you with your concern. 

agent: "Sorry to hear about the inconvenience" 


****************************************START OUTPUT GENERATED****************************************
agent: Krystal, what I understand so far is that your order is getting delayed for delivery, right? 
****************************************END OUTPUT GENERATED****************************************



agent: I have checked your order details and I can confirm that a delivery person named Eric is currently en route to pick up your order. 


****************************************START OUTPUT GENERATED****************************************
agent: Ive looked into your order and it looks like a delivery person is on their way to pick up your order. 