# Training script

### Imports

In [1]:
import pandas as pd
from xgboost import XGBClassifier
from sklearn.multioutput import MultiOutputClassifier
from transformers import AutoTokenizer
from transformers import BertModel
from datasets import load_dataset
from torch import nn
import spacy
import nltk
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sklearn.feature_selection import f_classif, SelectKBest
import string
import fasttext
from sklearn.svm import SVC
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import pickle

BERT_TOKENIZER = True

if not BERT_TOKENIZER:
    embedder = fasttext.load_model('fasttext/cc.en.300.bin')
    nlp = spacy.load("en_core_web_lg")
else:
    model = BertModel.from_pretrained("bert-base-uncased")
    embedding_matrix = model.embeddings.word_embeddings.weight
    transformer_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Models

In [2]:
model = XGBClassifier(n_estimators = 100, max_depth = 39 * 2, learning_rate = 0.01)

### Loading the dataset

In [3]:
def process_intent_list(intent_list):
    intents = set()
    if len(intent_list) == 0:
        intents.add('other')
    for intent in intent_list:
        if intent.startswith('Restaurant'):
            intents.add(intent)
        elif intent.startswith('Hotel'):
            intents.add(intent)
        elif intent.startswith('general'):
            intents.add(intent)
        else:
            intents.add('other')
    # print(f'Original {intent_list}')
    # print(f'Modified {list(intents)}')
    return list(intents)

def process_service_list(service_list):
    services = set()
    if len(service_list) == 0:
        services.add('other')
    for service in service_list:
        if service == 'restaurant':
            services.add('restaurant')
        elif service == 'hotel':
            services.add('hotel')
        else:
            services.add('other')
        if len(services) == 3:
            break
    return list(services)

In [4]:
def preprocess_split(dataset, split):
    df = dataset[split].to_pandas()
    new_df = pd.DataFrame(columns = df.columns)
    for i in range(len(df)):
        # Taken from notebook, to know which lines to skip
        row = df.loc[i]
        if not any(set(row.turns['frames'][turn_id]['service']).intersection(['hotel', 'restaurant']) for turn_id,utt in enumerate(row.turns['utterance'])):
            continue
        
        new_df.loc[len(new_df)] = row
        # new_df.loc[len(new_df) - 1]['services'] = process_service_list(new_df.loc[len(new_df) - 1]['services'])
        # for i, frame_service in [frame['service'] for frame in df.loc[i].turns['frames']]:
            # df.loc[i].turns['frames']
    return new_df

def extract_to_be_retrieved_info(dataset):
    user_act_types_list = []
    user_slots_per_act_type_list = []
    to_be_retrieved_list = []
    to_be_retrieved_all = {
        'hotel': set(),
        'restaurant': set()
    }
    
    for i in tqdm(range(len(dataset))):
        turns = dataset.loc[i].turns
        for j, (utterance, speaker, dialogue_act, frames) in enumerate(zip(turns['utterance'], turns['speaker'], turns['dialogue_acts'], turns['frames'])):
            # if speaker != 1:
                # continue
            # Skip using dialogue act intents
            # print(dialogue_act['dialog_act']['act_type'])
            # if 'other' in process_intent_list(dialogue_act['dialog_act']['act_type']):
            #     continue
            # Skip using frame services
            # if 'other' in process_service_list(frames['service']):
                # continue
                
            act_types = dialogue_act['dialog_act']['act_type']
            act_slots = dialogue_act['dialog_act']['act_slots']
            # print(act_slots)
            
            slots_per_act_type = []
            to_be_retrieved = set()
            for act_type, slots in zip(act_types, act_slots):
                slot_names = slots['slot_name']
                slot_values = slots['slot_value']
                
                domain = act_type.split('-')[0].lower()
                if 'hotel' in domain or 'restaurant' in domain:
                    if speaker == 0: # When it's the user's turn
                        for slot_name in slot_names:
                            slots_per_act_type.append(act_type.lower() + '-' + slot_name)
                    else: # When it's the bot's turn
                        act_type_relevant_slots = [(slot_name, slot_value) for slot_name, slot_value in zip(slot_names, slot_values) if slot_value != '?' and 'choice' not in slot_name and slot_name != 'none']
                        to_be_retrieved = set([domain + '-' + slot_name for slot_name, _ in act_type_relevant_slots])
                        to_be_retrieved_all[domain].update(to_be_retrieved)
                
            if speaker == 0: # When it's the user's turn
                user_act_types_list.append(act_types)
                user_slots_per_act_type_list.append(slots_per_act_type)
            else: # When it's the bot's turn
                to_be_retrieved_list.append(to_be_retrieved)
                
            
            
    return user_act_types_list, user_slots_per_act_type_list, to_be_retrieved_list, to_be_retrieved_all

In [5]:
dataset = load_dataset('multi_woz_v22')

try:
    train
    print("Dataset already loaded, moving on")
except:
    train = preprocess_split(dataset, 'train')
    test = preprocess_split(dataset, 'test')
    val = preprocess_split(dataset, 'validation')
    train_user_act_types_list, train_user_slots_per_act_type_list, train_to_be_retrieved_list, to_be_retrieved_all = extract_to_be_retrieved_info(train)
    test_user_act_types_list, test_user_slots_per_act_type_list, test_to_be_retrieved_list, _ = extract_to_be_retrieved_info(test)
    val_user_act_types_list, val_user_slots_per_act_type_list, val_to_be_retrieved_list, _ = extract_to_be_retrieved_info(val)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/adrian/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 6321/6321 [00:00<00:00, 9698.42it/s] 
100%|██████████| 745/745 [00:00<00:00, 9573.87it/s]
100%|██████████| 762/762 [00:00<00:00, 9941.49it/s]


In [6]:
print(to_be_retrieved_all)
print(train_user_slots_per_act_type_list[:10])
print()
print(len(train_user_act_types_list))
print(len(train_to_be_retrieved_list))

{'hotel': {'hotel-postcode', 'hotel-type', 'hotel-address', 'hotel-area', 'hotel-pricerange', 'hotel-internet', 'hotel-name', 'hotel-stars', 'hotel-phone', 'hotel-ref', 'hotel-parking'}, 'restaurant': {'restaurant-ref', 'restaurant-postcode', 'restaurant-area', 'restaurant-phone', 'restaurant-address', 'restaurant-name', 'restaurant-pricerange', 'restaurant-food'}}
[['restaurant-inform-area', 'restaurant-inform-pricerange'], ['restaurant-request-food'], ['hotel-inform-pricerange', 'hotel-inform-type', 'restaurant-request-phone'], ['hotel-inform-none'], ['hotel-inform-bookday', 'hotel-inform-bookpeople', 'hotel-inform-bookstay'], [], ['hotel-inform-internet', 'hotel-inform-parking'], ['hotel-inform-area'], ['hotel-inform-pricerange', 'restaurant-inform-pricerange'], ['hotel-inform-bookday', 'hotel-inform-bookpeople', 'hotel-inform-bookstay']]

45794
45794


In [7]:
output_mlb = MultiLabelBinarizer().fit(train_to_be_retrieved_list)
input_mlb = MultiLabelBinarizer().fit(train_user_slots_per_act_type_list)

train_input = input_mlb.transform(train_user_slots_per_act_type_list)
train_output = output_mlb.transform(train_to_be_retrieved_list)

test_input = input_mlb.transform(test_user_slots_per_act_type_list)
test_output = output_mlb.transform(test_to_be_retrieved_list)

val_input = input_mlb.transform(val_user_slots_per_act_type_list)
val_output = output_mlb.transform(val_to_be_retrieved_list)

In [8]:
model = MultiOutputClassifier(model)

model.fit(train_input, train_output)

In [9]:
predicted_output = model.predict(test_input)

acc = accuracy_score(test_output, predicted_output)
report = classification_report(test_output, predicted_output, target_names = output_mlb.classes_, digits = 3)
print(report)
print(f'acc = {acc}')

                       precision    recall  f1-score   support

        hotel-address      0.816     0.374     0.513       107
           hotel-area      0.727     0.080     0.143       201
       hotel-internet      0.833     0.188     0.307       133
           hotel-name      0.588     0.092     0.159       434
        hotel-parking      0.806     0.194     0.312       129
          hotel-phone      0.849     0.689     0.761        90
       hotel-postcode      0.818     0.625     0.709        72
     hotel-pricerange      0.667     0.107     0.184       187
            hotel-ref      0.000     0.000     0.000         2
          hotel-stars      1.000     0.041     0.079       171
           hotel-type      0.769     0.049     0.093       203
   restaurant-address      0.768     0.478     0.589       159
      restaurant-area      0.833     0.051     0.096       197
      restaurant-food      0.591     0.062     0.112       210
      restaurant-name      0.625     0.171     0.268  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
