In [9]:
from pathlib import Path

import torch
from transformers import (
    TrainingArguments,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
)
from datasets import load_dataset, Split, Dataset

from trainer.curriculum_trainer import CurriculumTrainer

In [14]:

from data.dataset.data_augmentations import flatten_conversation, random_mask_beliefs
from gpu import get_device
from utils import print_stage

In [11]:

from data.dataset.tokenize import tokenizer, preprocess_func

In [12]:
from pathlib import Path
from datasets import Dataset, load_dataset, Split
from data.dataset.data_augmentations import flatten_conversation

data_dir = Path("resources/bart/")

data_files = {
    Split.TRAIN: str((data_dir / "train.history_belief").absolute()),
    Split.VALIDATION: str((data_dir / "val.history_belief").absolute()),
    Split.TEST: str((data_dir / "test.history_belief").absolute()),
}

dataset = load_dataset(
    "data/dataset/multiwoz_dataset.py", data_files=data_files
)
dataset = dataset.map(
    flatten_conversation,
    batched=True,
    remove_columns=dataset["train"].column_names,
)



  0%|          | 0/3 [00:00<?, ?it/s]



In [4]:

masked_beliefs_final_dev = load_dataset(
    "json",
    data_files="resources/tokens/masked_beliefs_final_dev_token.json",
).map(preprocess_func, batched=True)["train"]

masked_beliefs_final_test = load_dataset(
    "json",
    data_files="resources/tokens/masked_beliefs_final_test_token.json",
).map(preprocess_func, batched=True)["train"]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



In [5]:
sample = masked_beliefs_final_dev["input_ids"][0]

In [6]:
tokenizer.decode(sample)

'<s><s> <|context|> <|user|> i need to book a hotel in the east that has 4 stars. <|endofcontext|> <|previousbelief|> attraction area not mentioned, attraction name not mentioned, attraction type not mentioned, hospital department not mentioned, hotel area not mentioned, hotel book day not mentioned, hotel book people not mentioned, hotel book stay not mentioned, hotel internet not mentioned, hotel name not mentioned, hotel parking not mentioned, hotel pricerange not mentioned, hotel stars not mentioned, hotel type not mentioned, restaurant area not mentioned, restaurant book day not mentioned, restaurant book people not mentioned, restaurant book time not mentioned, restaurant food not mentioned, restaurant name not mentioned, restaurant pricerange not mentioned, taxi arriveby not mentioned, taxi departure not mentioned, taxi destination not mentioned, taxi leaveat not mentioned, train arriveby not mentioned, train book people not mentioned, train day not mentioned, train departure no

In [20]:

data_dir = Path("resources/bart/")

data_files = {
    Split.TRAIN: str((data_dir / "train.history_belief").absolute()),
    Split.VALIDATION: str((data_dir / "val.history_belief").absolute()),
    Split.TEST: str((data_dir / "test.history_belief").absolute()),
}

dataset = load_dataset(
    "data/dataset/multiwoz_dataset.py", data_files=data_files
)
print_stage("Flattening Conversation")
dataset = dataset.map(
    flatten_conversation,
    batched=True,
    remove_columns=dataset["train"].column_names,
)



  0%|          | 0/3 [00:00<?, ?it/s]





In [21]:

model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base"
)#.to(device)
model.resize_token_embeddings(len(tokenizer))


Embedding(50273, 768)

In [22]:
from datasets import Dataset

In [30]:
sample_dataset = Dataset.from_dict(masked_beliefs_final_test[:10])
masked_beliefs_final_test, dataset["test"]

(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 7372
 }),
 Dataset({
     features: ['conversation_id', 'turn', 'turn_number'],
     num_rows: 7372
 }))

In [31]:
dataset["test"]["turn"][0]

'<s> <|context|> <|user|> i would like a taxi from saint john s college to pizza hut fen ditton . <|endofcontext|> <|belief|> taxi leaveat not mentioned , taxi destination pizza hut fenditton , taxi departure saint johns college , taxi arriveby not mentioned <|endofbelief|> </s>'

In [32]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [33]:
from transformers import default_data_collator
from torch.utils.data import DataLoader

In [38]:
from tqdm.auto import tqdm

from postprocessing import postprocessing

In [41]:
import json
from utils.Constants import SLOT_VALS

slot_template = {slot:"" for slot in SLOT_VALS}
def get_slot_map(slot_triplet_str_list):
    slot_map = slot_template.copy()
    for slot_triplet_str in slot_triplet_str_list:
        slot_triplets = slot_triplet_str.split()
        key = slot_triplets[0] + " " + slot_triplets[1]
        val = slot_triplets[2]
        if key not in SLOT_VALS:
            continue
        slot_map[key] = val
    return slot_map

def get_unique_slot_map(preds, targets):
    unique_slots = set()
    pred_map = {}
    target_map = {}
    
    for pred_str in preds:
        triplet = pred_str.split()
        key = triplet[0] + " " + triplet[1]
        val = triplet[2]
        pred_map[key] = val
        unique_slots.add(key)
    
    for target_str in targets:
        triplet = target_str.split()
        key = triplet[0] + " " + triplet[1]
        val = triplet[2]
        target_map[key] = val
        unique_slots.add(key)
    
    return unique_slots.copy(), pred_map.copy(), target_map.copy()
        

In [42]:
from utils.dst import ignore_none, default_cleaning, IGNORE_TURNS_TYPE2

def evaluate_dst(results):
    num_turns = 0
    joint_acc = 0
    slot_acc = 0
    r_slot_acc = 0

    num_slots = len(SLOT_VALS)
    num_r_slots = 0

    clean_tokens = ['<s>', '</s>']

    for dial in tqdm(results.keys()):
        dialogue_pred = results[dial]['generated_turn_belief']
        dialogue_target = results[dial]['target_turn_belief']

        for turn_id, (turn_target, turn_pred) in enumerate(zip(dialogue_target, dialogue_pred)):

            # clean
            for bs in turn_pred:
                if bs in clean_tokens + ['', ' '] or bs.split()[-1] == 'none':
                    turn_pred.remove(bs)

            new_turn_pred = []
            for bs in turn_pred:
                for tok in clean_tokens:
                    bs = bs.replace(tok, '').strip()
                    new_turn_pred.append(bs)
            turn_pred = new_turn_pred

            turn_pred, turn_target = ignore_none(turn_pred, turn_target)
            turn_pred, turn_target = default_cleaning(turn_pred, turn_target)

            join_flag = False

            # calculate joint accuracy
            if set(turn_target) == set(turn_pred):
                joint_acc += 1
                join_flag = True

            pred_slot_map = get_slot_map(turn_pred)
            target_slot_map = get_slot_map(turn_target)

            # calculate slot accuracy
            for slot_key in SLOT_VALS:
                if target_slot_map[slot_key] == pred_slot_map[slot_key]:
                    slot_acc += 1

            # calculate relative slot accuracy
            unique_slots, unique_pred_map, unique_target_map = get_unique_slot_map(turn_pred, turn_target)
            for slot_key in unique_slots:
                if slot_key not in unique_target_map.keys(): continue
                if slot_key not in unique_pred_map.keys(): continue
                if unique_target_map[slot_key] == unique_pred_map[slot_key]:
                    r_slot_acc += 1
            num_r_slots += len(unique_slots)

            num_turns += 1

    print('joint accuracy: {}'.format(joint_acc / num_turns))
    print('slot accuracy: {}'.format(slot_acc / (num_slots * num_turns)))
    print('relative slot accuracy: {}'.format(r_slot_acc / num_r_slots))

In [19]:
from datasets import set_caching_enabled
set_caching_enabled(False)
masked_beliefs_final_test = dataset["test"].map(
    lambda d: random_mask_beliefs(d, 1), remove_columns="turn"
)
# masked_beliefs_final = masked_beliefs_final.map(
#     tokenization,
#     batched=True,
#     remove_columns=masked_beliefs_final["train"].column_names,
# )

  set_caching_enabled(False)


  0%|          | 0/7372 [00:00<?, ?ex/s]

In [20]:
sample = masked_beliefs_final_test["masked"][0]

In [34]:
from data.dataset.multiwoz_dataset import HistoryBelief, parse_raw_belief
HistoryBelief(sample).text, sample

('<s> <|context|> <|user|> i would like a taxi from saint john s college to pizza hut fen ditton . <|endofcontext|> <|previousbelief|> attraction area not mentioned , attraction name not mentioned , attraction type not mentioned , hospital department not mentioned , hotel area not mentioned , hotel book day not mentioned , hotel book people not mentioned , hotel book stay not mentioned , hotel internet not mentioned , hotel name not mentioned , hotel parking not mentioned , hotel pricerange not mentioned , hotel stars not mentioned , hotel type not mentioned , restaurant area not mentioned , restaurant book day not mentioned , restaurant book people not mentioned , restaurant book time not mentioned , restaurant food not mentioned , restaurant name not mentioned , restaurant pricerange not mentioned , taxi arriveby not mentioned , taxi departure not mentioned , taxi destination not mentioned , taxi leaveat not mentioned , train arriveby not mentioned , train book people not mentioned ,

In [53]:
from data.dataset.tokenize import tokenizer
encoding = tokenizer(HistoryBelief(sample).text, return_tensors="pt")

In [54]:

from transformers import (
    default_data_collator,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
)

In [55]:
model = BartForConditionalGeneration.from_pretrained(
        # "facebook/bart-base"
        # "checkpoints/bart_finetune_cur/final/checkpoint-14195"
        # "checkpoints/bart_finetune/final/checkpoint-28390"
        "checkpoints/bart_finetune_cur/course_4/checkpoint-14195"
)

In [56]:

with torch.no_grad():
    output = model(**encoding)

In [57]:
generated_ids = output.logits.argmax(-1)
prediction_texts = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=False
)

In [70]:
prediction_texts[0][:prediction_texts[0].index("</s>")] + "</s>"

'<s><s> <|context|> <|user|> i would like a taxi from saint john s college to pizza hut fen ditton. <|endofcontext|> <|belief|> taxraction type not mentioned <|endofbelief|>  train type saint mentioned, train area not mentioned, attraction department not mentioned, attraction name not mentioned, hotel parking day not mentioned, hotel parking stay none mentioned, hotel parking stay none mentioned, attraction type not mentioned, hotel type pizza mentioned, hotel type not mentioned, hotel pricerange not mentioned, hotel stars not mentioned, hotel type not mentioned, attraction book not mentioned <|endofbelief|>  attraction book day not mentioned, attraction book people none mentioned, attraction book time not mentioned, hotel book not mentioned, restaurant pr pizza mentioned, restaurant areaicerange not mentioned, restaurant departureby not mentioned <|endofbelief|>  attraction departure saint mentioned <|endofbelief|>  attraction arrive pizza mentioned, restaurant arriveat not mentioned,

In [71]:
giu = HistoryBelief(prediction_texts[0][:prediction_texts[0].index("</s>")] + "</s>")

In [72]:
giu.belief

OrderedDict([('attraction area', 'not mentioned'),
             ('attraction name', 'not mentioned'),
             ('attraction type', 'not mentioned'),
             ('hospital department', 'not mentioned'),
             ('hotel area', 'not mentioned'),
             ('hotel book day', 'not mentioned'),
             ('hotel book people', 'not mentioned'),
             ('hotel book stay', 'not mentioned'),
             ('hotel internet', 'not mentioned'),
             ('hotel name', 'not mentioned'),
             ('hotel parking', 'stay none mentioned'),
             ('hotel pricerange', 'not mentioned'),
             ('hotel stars', 'not mentioned'),
             ('hotel type', 'not mentioned'),
             ('restaurant area', 'not mentioned'),
             ('restaurant book day', 'not mentioned'),
             ('restaurant book people', 'not mentioned'),
             ('restaurant book time', 'not mentioned'),
             ('restaurant food', 'not mentioned'),
             ('restaurant

In [75]:
from tqdm.auto import tqdm

previous_belief_text = ""
output_pred = []
for idx in tqdm(range(len(masked_beliefs_final_test))):
    masked_text = masked_beliefs_final_test["masked"][idx]
    history_belief = HistoryBelief(masked_text)
    if masked_beliefs_final_test["turn_number"][idx] != 1:
        # update previous belief
        history_belief.prev_belief = parse_raw_belief(previous_belief_text)
    
    with torch.no_grad():
        output = model(**encoding)
        
    generated_ids = output.logits.argmax(-1)
    prediction_texts = tokenizer.batch_decode(
        generated_ids, skip_special_tokens=False
    )
    clean_pred_text = prediction_texts[0][:prediction_texts[0].index("</s>")] + "</s>"
    pred_hb = HistoryBelief(clean_pred_text)
    previous_belief_text = pred_hb.belief_text
    output_pred.append(pred_hb.text)

  0%|          | 0/7372 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [48]:

dev_loader = DataLoader(
            dataset=masked_beliefs_final_dev,
            batch_size=4,
            collate_fn=default_data_collator,
            shuffle=False,
        )

model.eval()
dev_results = {}
predictions = []
turn_id = 0
for batch in tqdm(dev_loader):
    
    with torch.no_grad():
        output = model(**batch)
    generated_ids = output.logits.argmax(-1)
    prediction_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
    
    for prediction_text in prediction_texts:
        gold_text = dataset["validation"]["turn"][turn_id]
        
        dialogue_id = dataset["validation"]["conversation_id"][turn_id]
        if dialogue_id not in dev_results.keys():
            dev_results[dialogue_id] = {
                'generated_turn_belief': [],
                'target_turn_belief': [],
            }
        
        dev_results[dialogue_id]["generated_turn_belief"] += [postprocessing(prediction_text)]
        dev_results[dialogue_id]["target_turn_belief"] += [postprocessing(gold_text)]

        turn_id += 1

  0%|          | 0/3 [00:01<?, ?it/s]

In [50]:
evaluate_dst(dev_results), dev_results

  0%|          | 0/2 [00:00<?, ?it/s]

joint accuracy: 0.0
slot accuracy: 0.8032258064516129
relative slot accuracy: 0.0


(None,
 {1: {'generated_turn_belief': [[], [], [], [], [], [], [], [], []],
   'target_turn_belief': [['hotel area east ', 'hotel stars 4 '],
    ['hotel area east ',
     'hotel parking yes ',
     'hotel stars 4 ',
     'hotel internet yes '],
    ['hotel name wartworth ',
     'hotel area east ',
     'hotel parking yes ',
     'hotel stars 4 ',
     'hotel internet yes '],
    ['hotel name wartworth ',
     'hotel area east ',
     'hotel parking yes ',
     'hotel stars 4 ',
     'hotel internet yes ',
     'hotel book people 1 ',
     'hotel book day friday ',
     'hotel book stay 1'],
    ['hotel name wartworth ',
     'hotel area east ',
     'hotel parking yes ',
     'hotel stars 4 ',
     'hotel internet yes ',
     'hotel book people 1 ',
     'hotel book day friday ',
     'hotel book stay 1 ',
     'train destination bishops stortford ',
     'train day friday ',
     'train departure cambridge'],
    ['hotel name wartworth ',
     'hotel area east ',
     'hotel parking

In [45]:

test_loader = DataLoader(
            dataset=masked_beliefs_final_test,
            batch_size=4,
            collate_fn=default_data_collator,
            shuffle=False,
        )

model.eval()
test_results = {}
predictions = []
turn_id = 0
for batch in tqdm(test_loader):
    
    with torch.no_grad():
        output = model(**batch)
    generated_ids = output.logits.argmax(-1)
    prediction_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
    
    for prediction_text in prediction_texts:
        gold_text = dataset["test"]["turn"][turn_id]
        
        dialogue_id = dataset["test"]["conversation_id"][turn_id]
        if dialogue_id not in test_results.keys():
            test_results[dialogue_id] = {
                'generated_turn_belief': [],
                'target_turn_belief': [],
            }
        
        test_results[dialogue_id]["generated_turn_belief"] += [postprocessing(prediction_text)]
        test_results[dialogue_id]["target_turn_belief"] += [postprocessing(gold_text)]

        turn_id += 1

  0%|          | 0/3 [00:00<?, ?it/s]

In [47]:
evaluate_dst(test_results), test_results

  0%|          | 0/2 [00:00<?, ?it/s]

joint accuracy: 0.2
slot accuracy: 0.9354838709677419
relative slot accuracy: 0.0


(None,
 {1: {'generated_turn_belief': [[], [], [], []],
   'target_turn_belief': [['taxi destination pizza hut fenditton ',
     'taxi departure saint johns college '],
    ['taxi leaveat 17:15 ',
     'taxi destination pizza hut fenditton ',
     'taxi departure saint johns college '],
    ['taxi leaveat 17:15 ',
     'taxi destination pizza hut fenditton ',
     'taxi departure saint johns college '],
    ['taxi leaveat 17:15 ',
     'taxi destination pizza hut fenditton ',
     'taxi departure saint johns college ']]},
  2: {'generated_turn_belief': [[], [], [], [], [], []],
   'target_turn_belief': [[],
    [],
    ['attraction name nusha '],
    ['attraction name nusha '],
    ['restaurant food indian ',
     'restaurant area centre ',
     'attraction name nusha '],
    ['restaurant food indian ',
     'restaurant pricerange expensive ',
     'restaurant area centre ',
     'attraction name nusha ']]}})