In [30]:
import os
import random
from collections import defaultdict

import numpy as np

import gc

import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, BertTokenizerFast, BertPreTrainedModel, TrainingArguments, Trainer

from seqeval.metrics import precision_score, recall_score, f1_score

In [2]:
TASK = 'atis'

In [3]:
def seed_everything(seed:int = 1004):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

seed_everything(1234)

In [4]:
class IntentClassifier(nn.Module):
    def __init__(self, hidden_size, num_intent_labels, classifier_dropout):
        super(IntentClassifier, self).__init__()
        self.dropout = nn.Dropout(classifier_dropout)
        self.linear = nn.Linear(hidden_size, num_intent_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)


class SlotClassifier(nn.Module):
    def __init__(self, hidden_size, num_slot_labels, classifier_dropout):
        super(SlotClassifier, self).__init__()
        self.dropout = nn.Dropout(classifier_dropout)
        self.linear = nn.Linear(hidden_size, num_slot_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)
    
    
class JointBERT(BertPreTrainedModel):
    def __init__(self, config, intent_labels, slot_labels):
        super().__init__(config)
        self.num_intent_labels = len(intent_labels)
        self.num_slot_labels = len(slot_labels)
        self.config = config

        self.bert = BertModel(config)

        classifier_dropout = config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        
        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, classifier_dropout)
        self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, classifier_dropout)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        token_type_ids = None,
        position_ids = None,
        head_mask = None,
        inputs_embeds = None,
        intent_label_ids = None,
        slot_label_ids = None,
        output_attentions = None,
        output_hidden_states = None,
        # return_dict = None
        ):
        # return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            # return_dict=return_dict,
        )   # sequence_output, pooled_output, (hidden_states), (attentions)

        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        intent_logits = self.intent_classifier(pooled_output)
        slot_logits = self.slot_classifier(sequence_output)

        total_loss = 0
        # 1. Intent Softmax
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.squeeze(), intent_label_ids.squeeze())
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
            total_loss += intent_loss

        # 2. Slot Softmax
        if slot_label_ids is not None:
            loss_fct = nn.CrossEntropyLoss()
            slot_loss = loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_label_ids.view(-1))
            total_loss += slot_loss

        outputs = ((intent_logits, slot_logits),) + outputs[2:]  # add hidden states and attention if they are here

        outputs = (total_loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
    

In [5]:
class LoadDataset:
    def __init__(self, data):
        self.data = data
        
    @classmethod
    def load_dataset(cls, file_name, slot = False):
        data = []
        with open(file_name, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if slot:
                    line = line.split()
                data.append(line)
        
        return cls(data)
            

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [32]:
seq_train = LoadDataset.load_dataset(f'./data/{TASK}/train/seq.in')
seq_dev = LoadDataset.load_dataset(f'./data/{TASK}/dev/seq.in')
seq_test = LoadDataset.load_dataset(f'./data/{TASK}/test/seq.in')

intent_train = LoadDataset.load_dataset(f'./data/{TASK}/train/label')
intent_dev = LoadDataset.load_dataset(f'./data/{TASK}/dev/label')
intent_test = LoadDataset.load_dataset(f'./data/{TASK}/test/label')
intent_labels = LoadDataset.load_dataset(f'./data/{TASK}/intent_label_vocab')

slot_train = LoadDataset.load_dataset(f'./data/{TASK}/train/seq.out', slot = True)
slot_dev = LoadDataset.load_dataset(f'./data/{TASK}/dev/seq.out', slot = True)
slot_test = LoadDataset.load_dataset(f'./data/{TASK}/test/seq.out', slot = True)
slot_labels = LoadDataset.load_dataset(f'./data/{TASK}/slot_label_vocab')

intent_word2idx = defaultdict(int, {k: v for v, k in enumerate(intent_labels)})
intent_idx2word = {v: k for v, k in enumerate(intent_labels)}

slot_word2idx = defaultdict(int, {k: v for v, k in enumerate(slot_labels)})
slot_idx2word = {v: k for v, k in enumerate(slot_labels)}

In [7]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [8]:
model_config = BertConfig.from_pretrained("bert-base-uncased", num_labels = len(intent_idx2word), problem_type = "single_label_classification", id2label = intent_idx2word, label2id = intent_word2idx)
# model_config.classifier_dropout

In [9]:
model = JointBERT.from_pretrained("bert-base-uncased", config = model_config, intent_labels = intent_labels, slot_labels = slot_labels)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device);

Some weights of the model checkpoint at bert-base-uncased were not used when initializing JointBERT: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing JointBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing JointBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of JointBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['slot_classifier.linear.bias', 'i

In [10]:
class TokenizeDataset:
    def __init__(self, seqs, intent_labels, slot_labels, intent_word2idx, slot_word2idx, tokenizer):
        self.seqs = seqs
        self.intent_labels = intent_labels
        self.slot_labels = slot_labels
        
        self.intent_word2idx = intent_word2idx
        self.slot_word2idx = slot_word2idx
        
        self.tokenizer = tokenizer
    
    def align_label(self, seq, intent_label, slot_label):
        tokens = self.tokenizer(seq, padding='max_length', max_length=50, truncation=True)
        token_idxs = tokens.word_ids()
        
        pre_word_idx = None
        slot_label_ids = []
        for word_idx in token_idxs:
            if word_idx != pre_word_idx:
                try:
                    slot_label_ids.append(slot_word2idx[slot_label[word_idx]])
                except:
                    slot_label_ids.append(-100)

            elif word_idx == pre_word_idx or word_idx is None:
                slot_label_ids.append(-100)

            pre_word_idx = word_idx
        
        tokens['intent_label_ids'] = [intent_word2idx[intent_label]]
        tokens['slot_label_ids'] = slot_label_ids
        
        return tokens

    def __getitem__(self, index):
        bert_input = self.align_label(self.seqs[index], self.intent_labels[index], self.slot_labels[index])
        return bert_input
    
    def __len__(self):
        return len(self.seqs)

In [11]:
train_dataset = TokenizeDataset(seq_train, intent_train, slot_train, intent_word2idx, slot_word2idx, tokenizer)
dev_dataset = TokenizeDataset(seq_dev, intent_dev, slot_dev, intent_word2idx, slot_word2idx, tokenizer)
test_dataset = TokenizeDataset(seq_test, intent_test, slot_test, intent_word2idx, slot_word2idx, tokenizer)
print(train_dataset[0])

{'input_ids': [101, 1045, 2215, 2000, 4875, 2013, 6222, 2000, 5759, 2461, 4440, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'intent_label_ids': [13], 'slot_label_ids': [-100, 2, 2, 2, 2, 2, 73, 2, 115, 99, 100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]}


In [12]:
arguments = TrainingArguments(
    output_dir='checkpoints',
    do_train=True,
    do_eval=True,

    num_train_epochs=30,
    learning_rate = 5e-5,

    save_strategy="epoch",
    save_total_limit=2,
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    
    report_to = 'none',

    per_device_train_batch_size=128,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    dataloader_num_workers=0,
    fp16=True,

)

trainer = Trainer(
    model,
    arguments,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset
)

Using cuda_amp half precision backend


In [13]:
gc.collect()
torch.cuda.empty_cache()
trainer.train()
model.save_pretrained(f"checkpoints/first_checkpoint")

***** Running training *****
  Num examples = 4478
  Num Epochs = 30
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 1050


Epoch,Training Loss,Validation Loss
1,No log,1.694382
2,No log,0.85117
3,No log,0.534896
4,No log,0.420869
5,No log,0.353123
6,No log,0.309167
7,No log,0.296346
8,No log,0.29057
9,No log,0.25921
10,No log,0.279477


***** Running Evaluation *****
  Num examples = 500
  Batch size = 32
Saving model checkpoint to checkpoints/checkpoint-35
Configuration saved in checkpoints/checkpoint-35/config.json
Model weights saved in checkpoints/checkpoint-35/pytorch_model.bin
Deleting older checkpoint [checkpoints/checkpoint-515] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 500
  Batch size = 32
Saving model checkpoint to checkpoints/checkpoint-70
Configuration saved in checkpoints/checkpoint-70/config.json
Model weights saved in checkpoints/checkpoint-70/pytorch_model.bin
Deleting older checkpoint [checkpoints/checkpoint-3090] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 500
  Batch size = 32
Saving model checkpoint to checkpoints/checkpoint-105
Configuration saved in checkpoints/checkpoint-105/config.json
Model weights saved in checkpoints/checkpoint-105/pytorch_model.bin
Deleting older checkpoint [checkpoints/checkpoint-35] due to args.save_tota

In [14]:
intent_label_ids = []
slot_label_ids = []

with open(f'./data/{TASK}/test/label', 'r', encoding='utf-8') as intent_f, \
    open(f'./data/{TASK}/test/seq.out', 'r', encoding='utf-8') as slot_f:
    for line in intent_f:
        line = line.strip()
        intent_label_ids.append(line)
    for line in slot_f:
        line = line.strip().split()
        slot_label_ids.append(line)

intent_label_ids = np.array(intent_label_ids)
slot_label_ids = np.array(slot_label_ids)

  slot_label_ids = np.array(slot_label_ids)


In [35]:
def predict(model, seqs):
    model.to('cpu')
    pred_intent_ids = []
    pred_slot_ids = []

    for i in range(len(seqs)):
        input_seq = tokenizer(seq_test[i], return_tensors='pt')
        
        model.eval()
        with torch.no_grad():
            _, (intent_logits, slot_logits) = model(**input_seq)

        # Intent
        pred_intent_ids.append(intent_idx2word[intent_logits[0].argmax().item()])

        # Slot
        slot_logits_size = slot_logits[0].shape[0]
        slot_logits_mask = np.array(test_dataset[i]['slot_label_ids'][:slot_logits_size]) != -100
        slot_logits_clean = slot_logits[0][slot_logits_mask]
        pred_slot_ids.append([slot_idx2word[i.item()] for i in slot_logits_clean.argmax(dim=1)])

    return np.array(pred_intent_ids), np.array(pred_slot_ids)

In [36]:
# last_model = JointBERT.from_pretrained("./checkpoints/checkpoint-3090", config = model_config, intent_labels = intent_labels, slot_labels = slot_labels)

In [37]:
pred_intent_ids, pred_slot_ids = predict(model, seq_test)

  return np.array(pred_intent_ids), np.array(pred_slot_ids)


In [38]:
def get_intent_acc(preds, labels):
    acc = (preds == labels).mean()
    return {
        "intent_acc": acc
    }

def get_slot_metrics(preds, labels):
    assert len(preds) == len(labels)
    return {
        "slot_precision": precision_score(labels, preds),
        "slot_recall": recall_score(labels, preds),
        "slot_f1": f1_score(labels, preds)
    }

def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels):
    """For the cases that intent and all the slots are correct (in one sentence)"""
    # Get the intent comparison result
    intent_result = (intent_preds == intent_labels)

    # Get the slot comparision result
    slot_result = []
    for preds, labels in zip(slot_preds, slot_labels):
        assert len(preds) == len(labels)
        one_sent_result = True
        for p, l in zip(preds, labels):
            if p != l:
                one_sent_result = False
                break
        slot_result.append(one_sent_result)
    slot_result = np.array(slot_result)

    sementic_acc = np.multiply(intent_result, slot_result).mean()
    return {
        "sementic_frame_acc": sementic_acc
    }

def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels):
    assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels)
    
    results = {}
    intent_result = get_intent_acc(intent_preds, intent_labels)
    print(intent_result)
    slot_result = get_slot_metrics(slot_preds, slot_labels)
    print(slot_result)
    sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels)
    print(sementic_result)

    results.update(intent_result)
    results.update(slot_result)
    results.update(sementic_result)

    return results

In [39]:
res = compute_metrics(pred_intent_ids, intent_label_ids, pred_slot_ids, slot_label_ids)

{'intent_acc': 0.9764837625979843}
{'slot_precision': 0.9302977232924694, 'slot_recall': 0.9362002114910116, 'slot_f1': 0.9332396345748418}
{'sementic_frame_acc': 0.8387458006718925}
