In [1]:
import datasets
from dataset import DialogDataset
import transformers
import torch
from models import SegmentationModel
from utils import check_token_stats, test_prediction, extract_sample

In [3]:
DATA_NAME = 'multi_woz_v22'
MODEL_TYPE = 'gpt2'
TASK_TYPE = 'classification'
LR = 1e-05
device = torch.device("cuda")

label_name = 'label_cls' if TASK_TYPE == 'classification' else 'label_reg'
PATH = f'saved/{DATA_NAME}_{MODEL_TYPE}_{TASK_TYPE}_{LR}.pt'

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
check_token_stats(datasets.load_dataset(path=DATA_NAME), tokenizer)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/dongi/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

tokens num -- min: 30, max: 786, mean: 202.56462585034015, std: 86.53795045423138


In [6]:
test_data = datasets.load_dataset(path=DATA_NAME, split='test')
test_dataset =  DialogDataset(test_data, tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/dongi/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


In [7]:
model = SegmentationModel(model_name='gpt2', model_type=MODEL_TYPE, task_type=TASK_TYPE).to(device)
model.eval()
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [8]:
test_logit, test_pred, test_real, test_recall, test_precision, test_f1 = test_prediction(model, test_loader, label_name, device)
#print(f"recall: {test_recall}, precision: {test_precision}, f1-score: {test_f1}")

recall: 0.6694649971542401, precision: 0.6998363825673063, f1-score: 0.6843138680823213


In [9]:
NUM = 0
extract_sample(NUM, test_logit, test_dataset, tokenizer, TASK_TYPE='classification')

[Predicted Sample] ...
i  need  train  reservations  from  nor wich  to  cam bridge 
        %Prob% : 0.5412
        %Logit% : 0.1651839166879654
 i  have  133  trains  matching  your  request  is  there  a  specific  day  and  time  you  would  like  to  travel 
        %Prob% : 0.7399
        %Logit% : 1.0456702709197998
 id  like  to  leave  on  m onday  and  arrive  by  1800 
        %Prob% : 0.7006
        %Logit% : 0.8503118753433228
 there  are  12  trains  for  the  day  and  time  you  request  would  you  like  to  book  it 
        %Prob% : 0.5714
        %Logit% : 0.2875508964061737
 now 
        %Prob% : 0.8752
        %Logit% : 1.9477922916412354
 before  booking  i  would  also  like  to  know  the  travel  time  price  and  departure  time 
        %Prob% : 0.5631
        %Logit% : 0.2536555826663971
 please 
        %Prob% : 0.8865
        %Logit% : 2.055598258972168
 there  are  12  trains  meeting  your  needs  with  the  first  leaving  at  05 16  and  the  last  on

  for each_sample_id, each_sample_logit in zip(torch.tensor(sample['input_ids']), torch.tensor(sample_logits.squeeze())):


In [10]:
print("[Truth Sample] ...")
if test_dataset.data.info.builder_name == 'multi_woz_v22':
    print(*test_data[NUM]['turns']['utterance'], sep='\n\n')
elif test_dataset.data.info.builder_name == 'daily_dialog':
    print(*test_data[NUM]['dialog'], sep='\n\n')

[Truth Sample] ...
I need train reservations from norwich to cambridge

I have 133 trains matching your request. Is there a specific day and time you would like to travel?

I'd like to leave on Monday and arrive by 18:00.

There are 12 trains for the day and time you request. Would you like to book it now?

Before booking, I would also like to know the travel time, price, and departure time please.

There are 12 trains meeting your needs with the first leaving at 05:16 and the last one leaving at 16:16. Do you want to book one of these?

No hold off on booking for now. Can you help me find an attraction called cineworld cinema?

Yes it is a cinema located in the south part of town what information would you like on it?

Yes, that was all I needed. Thank you very much!

Thank you for using our system.
