In [1]:
# !pip install seqeval

In [2]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_dataset
from torch import nn
from tqdm import tqdm
import numpy as np
import torch
from seqeval.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt

In [3]:
TRANSFORMER_MODEL_NAME = 'roberta-base'
save_model_name = TRANSFORMER_MODEL_NAME.split('/')[-1]
epochs = 10
batch_size = 4
learning_rate = 2e-5
class_weight_beta = 0.99999 # increase number of nines if you want stronger imbalance compensation
patience = 2
ignored_tags = ['I-bookstay', 'I-stars']
use_history = False

In [4]:
tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

In [5]:
def process_intent_list(intent_list):
    intents = set()
    if len(intent_list) == 0:
        intents.add('other')
    for intent in intent_list:
        if intent.startswith('Restaurant'):
            intents.add(intent)
        elif intent.startswith('Hotel'):
            intents.add(intent)
        elif intent.startswith('general'):
            intents.add(intent)
        else:
            intents.add('other')
    # print(f'Original {intent_list}')
    # print(f'Modified {list(intents)}')
    return list(intents)

def process_service_list(service_list):
    services = set()
    if len(service_list) == 0:
        services.add('other')
    for service in service_list:
        if service == 'restaurant':
            services.add('restaurant')
        elif service == 'hotel':
            services.add('hotel')
        else:
            services.add('other')
        if len(services) == 3:
            break
    return list(services)

In [6]:
def preprocess_split(dataset, split):
    df = dataset[split].to_pandas()
    new_df = pd.DataFrame(columns = df.columns)
    for i in range(len(df)):
        # Taken from notebook, to know which lines to skip
        row = df.loc[i]
        if not any(set(row.turns['frames'][turn_id]['service']).intersection(['hotel', 'restaurant']) for turn_id,utt in enumerate(row.turns['utterance'])):
            continue
        
        new_df.loc[len(new_df)] = row
        # new_df.loc[len(new_df) - 1]['services'] = process_service_list(new_df.loc[len(new_df) - 1]['services'])
        # for i, frame_service in [frame['service'] for frame in df.loc[i].turns['frames']]:
            # df.loc[i].turns['frames']
    return new_df

def extract_token_bio_tags(dataset):
    tokens_list = []
    bio_tags_list = []
    useful_pos_list = []
    
    for i in tqdm(range(len(dataset))):
        turns = dataset.loc[i].turns
        for j, (utterance, speaker, dialogue_act, frames) in enumerate(zip(turns['utterance'], turns['speaker'], turns['dialogue_acts'], turns['frames'])):

            if speaker != 0:
                continue
            # Skip using dialogue act intents
            # if 'other' in process_intent_list(dialogue_act['dialog_act']['act_type']):
            #     continue
            # Skip using frame services
            if 'other' in process_service_list(frames['service']):
                continue
            
            if j == 0:
                prev_user_utterance = ''
                prev_user_acts = []
                prev_bot_utterance = ''
                prev_bot_acts = []
            else:
                prev_user_utterance = turns['utterance'][j - 2]
                prev_user_acts = turns['dialogue_acts'][j - 2]['dialog_act']['act_type']
                prev_bot_utterance = turns['utterance'][j - 1]
                prev_bot_acts = turns['dialogue_acts'][j - 1]['dialog_act']['act_type']
            
            composed_prefix = ''
            if use_history:
                composed_prefix = ' | '.join([prev_user_utterance, ', '.join(prev_user_acts), prev_bot_utterance, ', '.join(prev_bot_acts)]) + ' | '
                utterance = composed_prefix + utterance
            
            span_info = dialogue_act['span_info']
            act_slot_names = span_info['act_slot_name']
            act_slot_values = span_info['act_slot_value']
            span_starts = span_info['span_start']
            span_ends = span_info['span_end']
            slots = {slot_name : {'start': start + len(composed_prefix), 'end': end + len(composed_prefix)} for slot_name, start, end in zip(act_slot_names, span_starts, span_ends)}
            
            tokenized = tokenizer(utterance, padding = 'max_length')
            token_tags = [None] * len(tokenized.input_ids)
            
            for c in range(len(composed_prefix), len(utterance)):
                if tokenized.char_to_token(c) is not None:
                    token_tags[tokenized.char_to_token(c)] = 'O'
            
            # for j in range(len(token_tags)):
            #     if tokenized.token_to_word(j) is not None:
            #         token_tags[j] = 'O'
            
            for slot_name in slots:
                slot_start, slot_end = slots[slot_name]['start'], slots[slot_name]['end']
                covered_tokens = list(dict.fromkeys(tokenized.char_to_token(k) for k in range(slot_start, slot_end) if utterance[k] != ' '))
                for j, covered_token in enumerate(covered_tokens):
                    bio_type = 'B-' if j == 0 else 'I-'
                    if bio_type + slot_name not in ignored_tags:
                        token_tags[covered_token] = bio_type + slot_name
            
            # print([*zip(tokenizer.convert_ids_to_tokens(tokenized.input_ids), token_tags)])

            tokens_list.append(tokenized)
            bio_tags_list.append(np.array(token_tags))
            useful_pos_list.append((tokenized.char_to_token(len(composed_prefix)), tokenized.char_to_token(len(utterance) - 1) + 1))
            
    return tokens_list, bio_tags_list, useful_pos_list

In [7]:
dataset = load_dataset('multi_woz_v22')

train = preprocess_split(dataset, 'train')
val = preprocess_split(dataset, 'validation')
test = preprocess_split(dataset, 'test')

train_tokens, train_bio_tags, train_useful_pos = extract_token_bio_tags(train)
possible_bio_tags = sorted(set(filter(lambda tag : tag is not None, np.concatenate(train_bio_tags))))
print(possible_bio_tags)
tag_to_encoding = {tag : encoding for encoding, tag in enumerate(possible_bio_tags)}
# - 100 is default ignore index for the pytorch cross entropy function
tag_to_encoding[None] = -100
encoding_to_tag = {encoding : tag for encoding, tag in enumerate(possible_bio_tags)}
train_encoded_tags = [[tag_to_encoding[tag] for tag in tags] for tags in train_bio_tags]

val_tokens, val_bio_tags, val_useful_pos = extract_token_bio_tags(val)
val_encoded_tags = [[tag_to_encoding[tag] for tag in tags] for tags in val_bio_tags]

test_tokens, test_bio_tags, test_useful_pos = extract_token_bio_tags(test)
test_encoded_tags = [np.array([tag_to_encoding[tag] for tag in tags]) for tags in test_bio_tags]

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/adrian/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 6321/6321 [00:06<00:00, 946.30it/s] 


['B-area', 'B-bookday', 'B-bookpeople', 'B-bookstay', 'B-booktime', 'B-food', 'B-name', 'B-pricerange', 'B-stars', 'B-type', 'I-area', 'I-bookday', 'I-bookpeople', 'I-booktime', 'I-food', 'I-name', 'I-pricerange', 'I-type', 'O']


100%|██████████| 762/762 [00:00<00:00, 1073.59it/s]
100%|██████████| 745/745 [00:00<00:00, 1075.28it/s]


In [8]:
def batchify_tokens_tags(tokens_list, encoded_tags_list, useful_pos_list, batch_size):
    ids_batch = []
    mask_batch = []
    useful_pos_batch = []
    labels_batch = []
    
    if encoded_tags_list is None:
        encoded_tags_list = range(len(tokens_list))
    
    if useful_pos_list is None:
        useful_pos_list = range(len(tokens_list))
    
    for tokens, encoded_tags, useful_pos in zip(tokens_list, encoded_tags_list, useful_pos_list):
        ids_batch.append(tokens.input_ids)
        mask_batch.append(tokens.attention_mask)
        useful_pos_batch.append(useful_pos)
        labels_batch.append(encoded_tags)
        
        if len(ids_batch) == batch_size:
            yield torch.Tensor(ids_batch).long().cuda(), torch.Tensor(mask_batch).cuda(), useful_pos_batch, torch.Tensor(labels_batch).long().cuda()
            ids_batch.clear()
            mask_batch.clear()
            useful_pos_batch.clear()
            labels_batch.clear()
    
    yield torch.Tensor(ids_batch).long().cuda(), torch.Tensor(mask_batch).cuda(), useful_pos_batch, torch.Tensor(labels_batch).long().cuda()
    return None

def outputs_keep_useful_part(logits_batch, labels_batch, useful_pos_batch):
    logits_useful = torch.zeros(logits_batch.shape).cuda()
    labels_useful = torch.zeros(labels_batch.shape).cuda().long()
    for i, useful_pos in enumerate(useful_pos_batch):
        logits_useful[i, useful_pos[0] - 1 : useful_pos[1] + 1, :] = logits_batch[i, useful_pos[0] - 1 : useful_pos[1] + 1, :]
        labels_useful[i, useful_pos[0] - 1 : useful_pos[1] + 1] = labels_batch[i, useful_pos[0] - 1 : useful_pos[1] + 1]
    return logits_useful, labels_useful

In [9]:
def predict(transformer, tokens, batch_size):
    transformer.eval()
    predictions = []
    with torch.no_grad():
        for ids_batch, mask_batch, _, _ in tqdm(batchify_tokens_tags(tokens, None, None, batch_size)):
            
            out = transformer.forward(input_ids = ids_batch, attention_mask = mask_batch)
            res = torch.argmax(out.logits, dim = 2).cpu().detach().numpy()
            predictions.append(res)
    return np.concatenate(predictions)

def useful_flattened_tokens(tokens_list, useful_pos_list):
    return np.concatenate([tokens[useful_pos[0] : useful_pos[1]] for tokens, useful_pos in zip(tokens_list, useful_pos_list)])

In [10]:
transformer = AutoModelForTokenClassification.from_pretrained('saved_models/SF_' + save_model_name, num_labels = len(possible_bio_tags)).cuda()

with torch.no_grad():
    predicted_encoded_tags = predict(transformer, test_tokens, batch_size)
predicted_encoded_tags_flattened = [encoding_to_tag[encoding] for encoding in useful_flattened_tokens(predicted_encoded_tags, test_useful_pos)]
test_encoded_tags_flattened = [encoding_to_tag[encoding] for encoding in useful_flattened_tokens(test_encoded_tags, test_useful_pos)]

acc = accuracy_score(test_encoded_tags_flattened, predicted_encoded_tags_flattened)
report = classification_report([test_encoded_tags_flattened], [predicted_encoded_tags_flattened], digits = 3, zero_division = 0)
print(f"Accuracy: {acc:.3f}")
print(report)

655it [02:09,  5.06it/s]


Accuracy: 0.962
              precision    recall  f1-score   support

        area      0.798     0.955     0.870       447
     bookday      0.889     1.000     0.941       367
  bookpeople      0.842     0.989     0.910       376
    bookstay      0.806     0.981     0.885       259
    booktime      0.900     0.990     0.943       210
        food      0.846     0.971     0.904       378
        name      0.612     0.878     0.721       278
  pricerange      0.792     0.973     0.873       486
       stars      0.922     0.989     0.954       190
        type      0.395     0.992     0.565       243

   micro avg      0.751     0.971     0.847      3234
   macro avg      0.780     0.972     0.857      3234
weighted avg      0.786     0.971     0.862      3234

