# Training script

### Imports

In [9]:
import pandas as pd
from xgboost import XGBClassifier
from sklearn.multioutput import MultiOutputClassifier
from transformers import AutoTokenizer
from datasets import load_dataset
from torch import nn
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
import string
import fasttext
from sklearn.svm import SVC
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
import pickle
from dict_model import DictModel

### Models

In [10]:
class MLP(nn.Module):
    def __init__(self, input, epochs = 100, batch_size = 64, patience = 2, lr = 1e-3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.BatchNorm1d(input),
            nn.Linear(input, input),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(input, 100),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU()
        )
        self.epochs = epochs
        self.batch_size = batch_size
        self.patience = patience
        self.lr = lr
    
    def compute_loss(self, X, y, criterion):
        self.eval()
        batch_size = self.batch_size
        N = X.shape[0]
        batches = [(X[(i - batch_size) : (i if i < N else N - 1), :], y[(i - batch_size) : (i if i < N else N - 1)]) for i in range(batch_size, N + batch_size, batch_size)]
        with torch.no_grad():
            losses = []
            for batch, y_true in batches:
                y_pred = self.forward(batch)
                loss = criterion(y_pred, y_true)
                losses.append(loss.item())
        self.train()
        return np.mean(losses)
    
    def fit(self, X, y, X_val, y_val):
        X = torch.Tensor(X).cuda()
        y = torch.Tensor(y).cuda()
        X_val = torch.Tensor(X_val).cuda()
        y_val = torch.Tensor(y_val).cuda()

        self.head = nn.Linear(50, y.shape[1]).cuda()
        batch_size = self.batch_size
        optim = torch.optim.Adam(self.parameters(), lr = self.lr)
        criterion = nn.BCEWithLogitsLoss()
        N = X.shape[0]
        train_losses = []
        val_losses = []
        waited = 0
        
        for epoch in tqdm(range(self.epochs)):
            batches = [(X[(i - batch_size) : (i if i < N else N - 1), :], y[(i - batch_size) : (i if i < N else N - 1)]) for i in range(batch_size, N + batch_size, batch_size)]
            epoch_train_loss = []
            for batch, y_true in batches:
                y_pred = self.forward(batch)
                loss = criterion(y_pred, y_true)
                optim.zero_grad()
                loss.backward()
                optim.step()
                epoch_train_loss.append(loss.item())
            
            epoch_train_loss = np.mean(epoch_train_loss)
            train_losses.append(epoch_train_loss)
            
            epoch_val_loss = self.compute_loss(X_val, y_val, criterion)
            if len(val_losses) != 0 and val_losses[-1] <= epoch_val_loss:
                waited += 1
                if waited > self.patience:
                    break
            else:
                waited = 0

            val_losses.append(epoch_val_loss)
            
            
        return train_losses, val_losses
    
    def forward(self, X):
        return self.head(self.mlp(X))
    
    def predict(self, X):
        X = torch.Tensor(X).cuda()
        y = self.forward(X)
        return (y > 0.5).float().cpu().detach().numpy()

In [11]:
# model = MLP(37).cuda() # Replace model instantiation with another class here (SVC for example) if wishing to test other models
# model = XGBClassifier(n_estimators = 300, max_depth = 13, learning_rate = 0.01)
# model = XGBClassifier(n_estimators = 100, max_depth = 39 * 2, learning_rate = 0.01)
# model = SVC(C = 1, kernel = 'rbf', gamma = 'scale')
# TODO: somehting is wrong since the dict-based model achieves 0.30 accuracy
# and it only encounters 30/3000 not previously seen examples in the test set
# so it should have a 0.99 accuracy
model = DictModel()
# model = SVC(C = 0.1, kernel = 'rbf', gamma = 'scale') # 50.2 % accuracy

normalize_inputs = False

In [12]:
def process_intent_list(intent_list):
    intents = set()
    if len(intent_list) == 0:
        intents.add('other')
    for intent in intent_list:
        if intent.startswith('Restaurant'):
            intents.add(intent)
        elif intent.startswith('Hotel'):
            intents.add(intent)
        elif intent.startswith('Booking'):
            intents.add(intent)
        elif intent.startswith('general'):
            intents.add(intent)
        else:
            intents.add('other')
    # print(f'Original {intent_list}')
    # print(f'Modified {list(intents)}')
    return list(intents)

def process_service_list(service_list):
    services = set()
    if len(service_list) == 0:
        services.add('other')
    for service in service_list:
        if service == 'restaurant':
            services.add('restaurant')
        elif service == 'hotel':
            services.add('hotel')
        else:
            services.add('other')
        if len(services) == 3:
            break
    return list(services)

### Loading the dataset

In [13]:
def preprocess_split(dataset, split):
    df = dataset[split].to_pandas()
    new_df = pd.DataFrame(columns = df.columns)
    for i in range(len(df)):
        # Taken from notebook, to know which lines to skip
        row = df.loc[i]
        if not any(set(row.turns['frames'][turn_id]['service']).intersection(['hotel', 'restaurant']) for turn_id,utt in enumerate(row.turns['utterance'])):
            continue
        
        new_df.loc[len(new_df)] = row
        # new_df.loc[len(new_df) - 1]['services'] = process_service_list(new_df.loc[len(new_df) - 1]['services'])
        # for i, frame_service in [frame['service'] for frame in df.loc[i].turns['frames']]:
            # df.loc[i].turns['frames']
    return new_df

def extract_to_be_retrieved_info(dataset):
    user_act_types_list = []
    user_slots_per_act_type_list = []
    to_be_retrieved_list = []
    
    nr = 0
    for i in tqdm(range(len(dataset))):
        turns = dataset.loc[i].turns
        for j, (utterance, speaker, dialogue_act, frames) in enumerate(zip(turns['utterance'], turns['speaker'], turns['dialogue_acts'], turns['frames'])):
            # if speaker != 1:
                # continue
            # Skip using dialogue act intents
            # print(dialogue_act['dialog_act']['act_type'])
            # if 'other' in process_intent_list(dialogue_act['dialog_act']['act_type']):
                # continue
            # Skip using frame services
            # if 'other' in process_service_list(frames['service']):
                # continue
            services = frames['service']
            current_booking_service = [service for service in services if service in ["hotel", "restaurant"]]
                
            act_types = dialogue_act['dialog_act']['act_type']
            act_slots = dialogue_act['dialog_act']['act_slots']
            
            print(act_types)
            if speaker == 0:
                if 'other' in process_intent_list(dialogue_act['dialog_act']['act_type']):
                    skip_bot = True
                    continue
                if 'other' in process_intent_list(turns['dialogue_acts'][j + 1]['dialog_act']['act_type']):
                    skip_bot = True
                    continue
                skip_bot = False
            else:
                if skip_bot:
                    continue
            
            # if speaker == 0:
            #     skip_bot = False
            #     if not any(da.startswith("Hotel") or da.startswith("Restaurant") or da.startswith("Booking") for da in act_types):
            #         skip_bot = True
            #         continue
            # elif skip_bot:
            #         continue
            
            # print(act_slots)
            # print(act_types)
            slots_per_act_type = []
            to_be_retrieved = set()
            for act_type, slots in zip(act_types, act_slots):
                slot_names = slots['slot_name']
                slot_values = slots['slot_value']
                
                domain = act_type.split('-')[0].lower()
                if domain == 'booking' and len(current_booking_service)==1:
                    domain = current_booking_service[0]
                
                # if 'hotel' in domain or 'restaurant' in domain:
                if domain in ['hotel', 'restaurant', 'booking', 'general']:
                    if speaker == 0: # When it's the user's turn
                        for slot_name in slot_names:
                            if slot_name != 'none':
                                slots_per_act_type.append(act_type.lower() + '-' + slot_name)
                    else: # When it's the bot's turn
                        act_type_relevant_slots = [(slot_name, slot_value) for slot_name, slot_value in zip(slot_names, slot_values) if slot_value != '?' and 'choice' not in slot_name and slot_name != 'none']
                        to_be_retrieved.update(set([domain + '-' + slot_name for slot_name, _ in act_type_relevant_slots]))
                
            if speaker == 0: # When it's the user's turn
                user_act_types_list.append(act_types)
                user_slots_per_act_type_list.append(slots_per_act_type)
                # nr += 1
                # print(nr)
                # print("Input:", slots_per_act_type)
            else: # When it's the bot's turn
                to_be_retrieved_list.append(list(to_be_retrieved))
                # print("Output:", list(to_be_retrieved))
                
            
            
    return user_act_types_list, user_slots_per_act_type_list, to_be_retrieved_list

In [14]:
dataset = load_dataset('multi_woz_v22')

try:
    train
    print("Dataset already loaded, moving on")
except:
    train = preprocess_split(dataset, 'train')
    test = preprocess_split(dataset, 'test')
    val = preprocess_split(dataset, 'validation')
    train_user_act_types_list, train_user_slots_per_act_type_list, train_to_be_retrieved_list = extract_to_be_retrieved_info(train)
    test_user_act_types_list, test_user_slots_per_act_type_list, test_to_be_retrieved_list = extract_to_be_retrieved_info(test)
    val_user_act_types_list, val_user_slots_per_act_type_list, val_to_be_retrieved_list = extract_to_be_retrieved_info(val)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/adrian/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

Dataset already loaded, moving on


In [15]:
if not isinstance(model, MLP) and not isinstance(model, DictModel):
    model = MultiOutputClassifier(model)

output_mlb = MultiLabelBinarizer().fit(train_to_be_retrieved_list)
input_mlb = MultiLabelBinarizer().fit(train_user_slots_per_act_type_list)
pickle.dump(input_mlb, open('saved_models/MOVE_RETR_input_mlb.pkl', 'wb'))
pickle.dump(output_mlb, open('saved_models/MOVE_RETR_output_mlb.pkl', 'wb'))

print(input_mlb.classes_)
print(output_mlb.classes_)

train_input = input_mlb.transform(train_user_slots_per_act_type_list)
train_output = output_mlb.transform(train_to_be_retrieved_list)

test_input = input_mlb.transform(test_user_slots_per_act_type_list)
test_output = output_mlb.transform(test_to_be_retrieved_list)

val_input = input_mlb.transform(val_user_slots_per_act_type_list)
val_output = output_mlb.transform(val_to_be_retrieved_list)

if normalize_inputs:
    train_input = train_input - 0.5
    test_input = test_input - 0.5
    val_input = val_input - 0.5

if not isinstance(model, MLP):
    model.fit(train_input, train_output)
    pickle.dump(model, open('saved_models/MOVE_RETR_DICT.pkl', 'wb'))
else:
    train_losses, val_losses = model.fit(train_input, train_output, val_input, val_output)
    plt.plot(train_losses)
    plt.plot(val_losses)
    plt.show()
    

['hotel-inform-area' 'hotel-inform-bookday' 'hotel-inform-bookpeople'
 'hotel-inform-bookstay' 'hotel-inform-choice' 'hotel-inform-internet'
 'hotel-inform-name' 'hotel-inform-parking' 'hotel-inform-pricerange'
 'hotel-inform-stars' 'hotel-inform-type' 'hotel-request-address'
 'hotel-request-area' 'hotel-request-internet' 'hotel-request-name'
 'hotel-request-parking' 'hotel-request-phone' 'hotel-request-postcode'
 'hotel-request-pricerange' 'hotel-request-ref' 'hotel-request-stars'
 'hotel-request-type' 'restaurant-inform-area' 'restaurant-inform-bookday'
 'restaurant-inform-bookpeople' 'restaurant-inform-booktime'
 'restaurant-inform-food' 'restaurant-inform-name'
 'restaurant-inform-pricerange' 'restaurant-request-address'
 'restaurant-request-area' 'restaurant-request-food'
 'restaurant-request-name' 'restaurant-request-phone'
 'restaurant-request-postcode' 'restaurant-request-pricerange'
 'restaurant-request-ref']
['booking-bookday' 'booking-bookpeople' 'booking-bookstay'
 'booking

In [16]:
predicted_output = model.predict(test_input)

acc = accuracy_score(test_output, predicted_output)
report = classification_report(test_output, predicted_output, target_names = output_mlb.classes_, digits = 3)
print(report)
print(f'acc = {acc}')

                       precision    recall  f1-score   support

      booking-bookday      0.200     0.018     0.034       109
   booking-bookpeople      0.000     0.000     0.000        94
     booking-bookstay      0.274     0.321     0.296        53
     booking-booktime      0.212     0.579     0.310        76
         booking-name      0.185     0.092     0.123       109
          booking-ref      0.535     0.553     0.544       499
        hotel-address      0.700     0.337     0.455       104
           hotel-area      0.378     0.320     0.346       247
       hotel-internet      0.233     0.194     0.211       155
           hotel-name      0.431     0.332     0.375       434
        hotel-parking      0.452     0.182     0.259       154
          hotel-phone      0.633     0.588     0.610        85
       hotel-postcode      0.972     0.530     0.686        66
     hotel-pricerange      0.265     0.220     0.241       236
            hotel-ref      0.000     0.000     0.000  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
