In [1]:
# bert2bert and BART models tested on Kaggle. This is just tests the rule-based approach.
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import torch
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
from slot_filler import map_slot_value

In [2]:
TRANSFORMER_MODEL_NAME = 'facebook/bart-base'
save_model_name = TRANSFORMER_MODEL_NAME.split('/')[-1]
epochs = 10
batch_size = 4
learning_rate = 2e-5
patience = 2
use_history = False

In [3]:
def process_intent_list(intent_list):
    intents = set()
    if len(intent_list) == 0:
        intents.add('other')
    for intent in intent_list:
        if intent.startswith('Restaurant'):
            intents.add(intent)
        elif intent.startswith('Hotel'):
            intents.add(intent)
        elif intent.startswith('general'):
            intents.add(intent)
        else:
            intents.add('other')
    # print(f'Original {intent_list}')
    # print(f'Modified {list(intents)}')
    return list(intents)

def process_service_list(service_list):
    services = set()
    if len(service_list) == 0:
        services.add('other')
    for service in service_list:
        if service == 'restaurant':
            services.add('restaurant')
        elif service == 'hotel':
            services.add('hotel')
        else:
            services.add('other')
        if len(services) == 3:
            break
    return list(services)

In [4]:
def preprocess_split(dataset, split):
    df = dataset[split].to_pandas()
    new_df = pd.DataFrame(columns = df.columns)
    for i in range(len(df)):
        # Taken from notebook, to know which lines to skip
        row = df.loc[i]
        if not any(set(row.turns['frames'][turn_id]['service']).intersection(['hotel', 'restaurant']) for turn_id,utt in enumerate(row.turns['utterance'])):
            continue
        
        new_df.loc[len(new_df)] = row
        # new_df.loc[len(new_df) - 1]['services'] = process_service_list(new_df.loc[len(new_df) - 1]['services'])
        # for i, frame_service in [frame['service'] for frame in df.loc[i].turns['frames']]:
            # df.loc[i].turns['frames']
    return new_df

def extract_token_bio_tags(dataset):
    unmapped_string_list = []
    mapped_string_list = []
    
    for i in tqdm(range(len(dataset))):
        turns = dataset.loc[i].turns
        for j, (utterance, speaker, dialogue_act, frames) in enumerate(zip(turns['utterance'], turns['speaker'], turns['dialogue_acts'], turns['frames'])):

            if speaker != 0:
                continue
            # Skip using dialogue act intents
            # if 'other' in process_intent_list(dialogue_act['dialog_act']['act_type']):
            #     continue
            # Skip using frame services
            if 'other' in process_service_list(frames['service']):
                continue
            
            span_info = dialogue_act['span_info']
            act_slot_names = span_info['act_slot_name']
            act_slot_values = span_info['act_slot_value']
            span_starts = span_info['span_start']
            span_ends = span_info['span_end']
            slots = {slot_name : {'start': start, 'end': end, 'value': value} for slot_name, start, end, value in zip(act_slot_names, span_starts, span_ends, act_slot_values)}


            for slot_name in slots:
                slot_start, slot_end, slot_value = slots[slot_name]['start'], slots[slot_name]['end'], slots[slot_name]['value']
                input_string = utterance[slot_start:slot_end]
                output_string = slot_value
                
                # print(input_string)
                # print(output_string)
                # print()
                
                unmapped_string_list.append(input_string)
                mapped_string_list.append(output_string)
            
    return unmapped_string_list, mapped_string_list, mapped_string_list

In [5]:
dataset = load_dataset('multi_woz_v22')

train = preprocess_split(dataset, 'train')
val = preprocess_split(dataset, 'validation')
test = preprocess_split(dataset, 'test')

train_unmapped, train_mapped, _ = extract_token_bio_tags(train)
val_unmapped, val_mapped, _ = extract_token_bio_tags(val)
test_unmapped, test_mapped, test_mapped_string = extract_token_bio_tags(test)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/adrian/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 6321/6321 [00:00<00:00, 14388.80it/s]
100%|██████████| 762/762 [00:00<00:00, 12724.79it/s]
100%|██████████| 745/745 [00:00<00:00, 12580.75it/s]


In [6]:
# transformer = EncoderDecoderModel.from_pretrained('saved_models/MAP_' + save_model_name).cuda()

predicted_mapped_string = []
for unmapped_string in tqdm(test_unmapped):
    predicted_mapped_string.append(map_slot_value(unmapped_string))
    
test_mapped_string = [s.lower() for s in test_mapped_string]

acc = accuracy_score(predicted_mapped_string, test_mapped_string)
print(f"Accuracy: {acc:.3f}")

for pred, test in zip(predicted_mapped_string, test_mapped_string):
    if pred != test:
        print(f'pred = \"{pred}\", truth = \"{test}\"')

  0%|          | 0/3241 [00:00<?, ?it/s]

100%|██████████| 3241/3241 [00:00<00:00, 130348.07it/s]


Accuracy: 0.984
pred = "same area", truth = "north"
pred = "same area", truth = "east"
pred = "same group of people", truth = "5"
pred = "same day", truth = "thursday"
pred = "same area", truth = "centre"
pred = "same area", truth = "north"
pred = "same area", truth = "centre"
pred = "same number of people", truth = "4"
pred = "same day", truth = "tuesday"
pred = "same day", truth = "monday"
pred = "same day", truth = "saturday"
pred = "same area", truth = "east"
pred = "same day", truth = "friday"
pred = "same amount of people", truth = "7"
pred = "same group of people", truth = "6"
pred = "same area", truth = "south"
pred = "same group of people", truth = "2"
pred = "same day", truth = "monday"
pred = "same day", truth = "sunday"
pred = "same area", truth = "east"
pred = "same group of people", truth = "6"
pred = "same day", truth = "sunday"
pred = "same area", truth = "south"
pred = "same area", truth = "centre"
pred = "same day", truth = "monday"
pred = "same area", truth = "centre