# Imports

In [1]:
import sys
sys.path.append('../')

In [2]:
import os

In [3]:
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import f1_score, confusion_matrix


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForTokenClassification, BertConfig, BertModel

In [4]:
from mlpack.datasets.conll2003 import get_conll2003, get_conll2003_features, convert_examples_to_features

# Tokenizer

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

# Data

In [6]:
examples, labels = get_conll2003('../datasets/CoNLL2003/')

In [7]:
label_map = {i:l for i, l in enumerate(labels, 0)}
label_map

{0: '[PAD]',
 1: 'O',
 2: 'B-MISC',
 3: 'I-MISC',
 4: 'B-PER',
 5: 'I-PER',
 6: 'B-ORG',
 7: 'I-ORG',
 8: 'B-LOC',
 9: 'I-LOC',
 10: '[CLS]',
 11: '[SEP]',
 12: 'X'}

In [8]:
features_train = convert_examples_to_features(examples['train'], labels, 128, tokenizer, sep_tag='same')

In [9]:
features_valid = convert_examples_to_features(examples['valid'], labels, 128, tokenizer, sep_tag='X')

In [10]:
features_test = convert_examples_to_features(examples['test'], labels, 128, tokenizer, sep_tag='X')

# Checking

In [11]:
idx = 10
ex, feat = examples['train'][idx], features_train[idx]

In [12]:
for token, label_id, mask in zip(tokenizer.convert_ids_to_tokens(feat.input_ids), feat.label_id, feat.label_mask):
    print(f'{token:20} {label_map[label_id]:6} {mask}')

[CLS]                [CLS]  0
Spanish              B-MISC 1
Farm                 O      1
Minister             O      1
Loyola               B-PER  1
de                   I-PER  1
Pa                   I-PER  1
##la                 I-PER  1
##cio                I-PER  1
had                  O      1
earlier              O      1
accused              O      1
Fi                   B-PER  1
##sch                I-PER  1
##ler                I-PER  1
at                   O      1
an                   O      1
EU                   B-ORG  1
farm                 O      1
ministers            O      1
'                    O      1
meeting              O      1
of                   O      1
causing              O      1
un                   O      1
##ju                 O      1
##st                 O      1
##ified              O      1
alarm                O      1
through              O      1
"                    O      1
dangerous            O      1
general              O      1
##isation 

# Dataset

In [13]:
class NERDataset(Dataset):

    def __init__(self, features):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feat = self.features[idx]
        return torch.tensor(feat.input_ids), torch.tensor(feat.input_mask), \
                torch.tensor(feat.label_id), torch.tensor(feat.label_mask)

In [14]:
ds_train = NERDataset(features_train)
ds_valid = NERDataset(features_valid)
ds_test  = NERDataset(features_test)

In [15]:
dl_train = DataLoader(ds_train, batch_size=32, pin_memory=True, shuffle=True,  num_workers=8)
dl_valid = DataLoader(ds_valid, batch_size=32, pin_memory=True, shuffle=False, num_workers=8)
dl_test  = DataLoader(ds_test, batch_size=32, pin_memory=True, shuffle=False, num_workers=8)

In [16]:
# input_ids, input_mask, label_ids, label_mask = next(iter(dl_train))

In [17]:
# input_ids.shape, input_mask.shape, label_ids.shape, label_mask.shape

# Evaluating

In [18]:
def _optimizer_ckp_path(ckp_path):
    fmt = ckp_path.split('/')[-1].split('.')[-1]
    optim_path = ckp_path.replace(f'.{fmt}', f'_optimizer.{fmt}')
    return optim_path


def _scheduler_ckp_path(ckp_path):
    fmt = ckp_path.split('/')[-1].split('.')[-1]
    sched_path = ckp_path.replace(f'.{fmt}', f'_lrscheduler.{fmt}')
    return sched_path


def save_model(model, optimizer, ckp_path, scheduler=None):
    torch.save(model.state_dict(), ckp_path)
    # saving the optimizer
    optim_path = _optimizer_ckp_path(ckp_path)
    torch.save(optimizer.state_dict(), optim_path)
    if scheduler:
        sched_path = _scheduler_ckp_path(ckp_path)
        torch.save(scheduler.state_dict(), optim_path)
    print('Saved new checkpoint', flush=True)

In [19]:
def to_device(*tensors, device='cpu'):
    return [
        t.to(device) for t in tensors
    ]

In [20]:
from seqeval.metrics import classification_report
import numpy as np

In [21]:
def remap(input_ids, input_mask, label_ids, label_mask, active_logits, active_labels):
    start = 0
    preds = []
    trues = []
    for i in range(input_ids.shape[0]):
        in_ids, in_mask, l_ids, l_mask = input_ids[i], input_mask[i], label_ids[i] - 1, label_mask[i]
        in_mask =  in_mask[in_mask == 1][1:-1] # tira cls e sep
        l_mask = l_mask[1:len(in_mask) + 1]

        end = start + len(l_mask[l_mask == 1])
        preds_labels = active_logits[start:end]
        true_labels  = active_labels[start:end]

        pred_ents, true_ents, count, count_true = [], [], 0, 0
        for j, lm in enumerate(l_mask):
            if lm == 1:
                pred_ents.append(LABELS[preds_labels[count].item()])
                count += 1
                #true
                true_ents.append(LABELS[true_labels[count_true].item()])
                count_true += 1
            else:
                pl = LABELS[preds_labels[count-1].item()]
                if pl.startswith('B'):
                    pl = 'I-'+pl.split('-')[-1]
                pred_ents.append(pl)
                #true
                pl = LABELS[true_labels[count_true-1].item()]
                if pl.startswith('B'):
                    pl = 'I-'+pl.split('-')[-1]
                true_ents.append(pl)

        preds.append(pred_ents)
        trues.append(true_ents)

        start = end
    return trues, preds

In [22]:
LABELS = [
    l for l in labels if l not in ['[PAD]', '[CLS]', '[SEP]', 'X']
]
LABELS

['O', 'B-MISC', 'I-MISC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

In [47]:
def evaluate_fn(model, dataloader):
    model.eval()
    losses, accs = [], []
    y_trues, y_preds = [], []
    for input_ids, input_mask, label_ids, label_mask in tqdm(dataloader, desc='Evaluating', leave=False):
        input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                 label_mask, device=device)
        with torch.no_grad():
            loss, active_logits, active_labels = model(
                input_ids, input_mask, label_ids, label_mask)
            
        losses.append(loss.item())
        
        active_logits = active_logits.argmax(dim=1).cpu().numpy()
        active_labels = active_labels.cpu().numpy()
        accs = (1 * (active_logits == active_labels)).tolist()
        
        # transforming
        ts, ps = remap(input_ids, input_mask, label_ids, label_mask, active_logits, active_labels)
        y_preds += ps
        y_trues += ts
        
    f1 = f1_score(sum(y_trues, []), sum(y_preds, []), average='micro')
    print(classification_report(y_trues, y_preds), flush=True)
    print('Confusion Matrix\n', confusion_matrix(sum(y_trues, []), sum(y_preds, []), labels=LABELS), flush=True)
    print('F1 = ', f1, flush=True)
    print('F1 macro = ', f1_score(sum(y_trues, []), sum(y_preds, []), average='macro'))
    
    return np.array(losses).mean(), np.array(accs).mean(), f1

# Bert Model

Lets do this with the X tag for training and evaluation

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [25]:
class BertForNERClassification(nn.Module):
    
    def __init__(self):
        super().__init__()
        config = BertConfig.from_pretrained('bert-base-cased', output_hidden_states=True)
        self.bert = BertModel(config)
        
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, len(LABELS))
        
    @staticmethod
    def sum_last_4_layers(sequence_outputs):
        """Sums the last 4 hidden representations of a sequence output of BERT.
        Args:
        -----
        sequence_output: Tuple of tensors of shape (batch, seq_length, hidden_size).
            For BERT base, the Tuple has length 13.

        Returns:
        --------
        summed_layers: Tensor of shape (batch, seq_length, hidden_size)
        """
        last_layers = sequence_outputs[-4:]
        return torch.stack(last_layers, dim=0).sum(dim=0)
        
    @staticmethod
    def last_layer(sequence_outputs):
        """Simply returns the last tensor of a list of tensors, indexing -1."""
        return sequence_outputs[-1]
        
    def forward(self, input_ids, input_mask, label_ids, label_mask):
        
        _, _, hidden_states = self.bert(input_ids, attention_mask=input_mask)
        
        out = self.last_layer(hidden_states)
        
        out = self.dropout(out)
        out = self.classifier(out)
        # take the active logits
        label_mask = label_mask.view(-1)
        active_logits = out.view(-1, len(LABELS))[label_mask == 1]
        
        # take the active labels
        active_labels = label_ids.view(-1)[label_mask == 1] - 1 # remove one because of the [PAD] being the 0
        
        # calc the loss
        loss = nn.CrossEntropyLoss(torch.tensor([.5] + 8 *[1.0]).to(device))(active_logits, active_labels)
        
        return loss, active_logits, active_labels

In [35]:
model = BertForNERClassification()
model.to(device)

BertForNERClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [36]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5, weight_decay=0)

# Training

In [37]:
writer = SummaryWriter('nerbert')

In [38]:
scheduler = None

In [45]:
class Args:
    device = device
    fp16 = True
    num_epochs = 10
    ckp_path = 'bertner_lastlayer.ckp'
    grad_steps = 1
    max_grad_norm = 1.
    load_state_dict = True
args = Args()

In [46]:
if args.load_state_dict:
    if os.path.exists(args.ckp_path):
        print(model.load_state_dict(torch.load('bertner_lastlayer.ckp')))
#     if os.path.exists(args.ckp_path.replace('.ckp', '_optimizer.ckp')):
#         optimizer.load_state_dict(torch.load(args.ckp_path.replace('.ckp', '_optimizer.ckp'), map_location='cpu'))

<All keys matched successfully>


In [41]:
n_iter = 0
best_acc = None

In [42]:
if args.fp16:
    try:
        from apex import amp
    except ImportError:
        raise ImportError(
            "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
    model, optimizer = amp.initialize(model, optimizer)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [49]:
for ep in tqdm(range(30), desc='Epochs'):
    model.train()
    if ep > 0:
        losses_train = []
        for step, (input_ids, input_mask, label_ids, label_mask) in tqdm(enumerate(dl_train), leave=False, total=len(dl_train)):
            input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                     label_mask, device=device)

            loss, _, _ = model(input_ids, input_mask, label_ids, label_mask)

            if args.grad_steps > 1:
                loss = loss / args.grad_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.max_grad_norm)

            if (step + 1) % args.grad_steps == 0 or (step + 1) == len(dl_train):
                optimizer.step()
                model.zero_grad()
                if scheduler:
                    scheduler.step()

            losses_train.append(loss.item())
                    
            if writer:
                writer.add_scalar('loss/train', loss, n_iter)
            n_iter += 1

        # print training loss
        print(f'---Training\nLoss {np.array(losses_train).mean()}')
    # evaluate
    valid_loss, valid_acc, valid_f1 = evaluate_fn(model, dl_valid)

    if writer:
        writer.add_scalar('loss/valid', valid_loss, ep)
        writer.add_scalar('acc/valid', valid_acc, ep)

        for name, param in model.named_parameters():
            writer.add_histogram(name, param, ep)

    print(f'---Valid\nLoss {valid_loss}\nAcc {valid_acc}', flush=True)

    if best_acc is None:
        best_acc = valid_f1
        save_model(model, optimizer, args.ckp_path, scheduler=scheduler)
    else:
        if valid_f1 > best_acc:
            best_acc = valid_f1
            save_model(model, optimizer, args.ckp_path,
                       scheduler=scheduler)

HBox(children=(IntProgress(value=0, description='Epochs', max=30, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Evaluating', max=102, style=ProgressStyle(description_width='…

           precision    recall  f1-score   support

      ORG       0.78      0.74      0.76      1341
      LOC       0.91      0.87      0.89      1837
      PER       0.77      0.80      0.78      1836
     MISC       0.80      0.80      0.80       922

micro avg       0.82      0.81      0.81      5936
macro avg       0.82      0.81      0.81      5936

Confusion Matrix
 [[51280    35    81    43    79    34   101    22    19]
 [   92   766     7    27     1    11     0    18     0]
 [  219    17   786     3    72     3    21     2    26]
 [  185     7     3  1520    37    44     5    35     0]
 [  253     2    36    67  3952     7   126     0    51]
 [  120    36     0    92     3  1026    18    45     1]
 [  233     1    51     7   203    35  1842    15    53]
 [   78    25     2    71     1    42     3  1613     2]
 [  157     0    43     2   185     0    63     5  1724]]
F1 =  0.950170859600542
F1 macro =  0.8461004137154432
---Valid
Loss 0.4117719871468754
Acc 0.96143250688705

HBox(children=(IntProgress(value=0, max=439), HTML(value='')))

KeyboardInterrupt: 