# Imports

In [1]:
import sys
sys.path.append('../')

In [2]:
import os

In [3]:
from tqdm import tqdm_notebook as tqdm
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForTokenClassification, BertConfig, BertModel

In [4]:
from mlpack.datasets.conll2003 import get_conll2003, get_conll2003_features, convert_examples_to_features_masked
from mlpack.datasets.conll2003 import convert_examples_to_features
from mlpack.datasets.conll2003 import CoNLL2003Dataset, InputFeatures, InputExample
from mlpack.bert.ner.model import BertForMaskedNERClassification, BertForNERClassification
from mlpack.bert.ner.train import train
from mlpack.bert.ner.utils import to_fp16, to_device
from mlpack.utils import save_pickle, read_pickle

# Tokenizer

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# Data

In [6]:
examples, labels = get_conll2003('../datasets/CoNLL2003/')

In [7]:
def words_and_labels(examples):
    input_examples = []
    words, labels = [], []
    for ex in examples:
        words.extend(ex.text_a.split(' '))
        labels.extend(ex.label)
    return [
        InputExample(guid=None, text_a=word, label=[l]) for word, l in zip(words, labels)
    ]

In [8]:
examples_train = words_and_labels(examples['train'])

In [9]:
examples_valid = words_and_labels(examples['valid'])

In [10]:
features_train = convert_examples_to_features(examples_train, labels, 15, tokenizer)

In [11]:
features_valid = convert_examples_to_features(examples_valid, labels, 15, tokenizer)

# Checking

In [12]:
idx = 2
ex, feat = examples['valid'][idx], features_valid[idx]

In [13]:
zipped = zip(tokenizer.convert_ids_to_tokens(feat.input_ids), feat.label_mask,
            feat.input_mask)
for tok, lm, im in zipped:
    print(f'{tok:10} {lm} {im}')

[CLS]      0 1
L          1 1
##EI       0 1
##CE       0 1
##ST       0 1
##ER       0 1
##S        0 1
##H        0 1
##IR       0 1
##E        0 1
[SEP]      0 1
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0


# Dataset

In [14]:
ds_train = CoNLL2003Dataset(features_train)
ds_valid = CoNLL2003Dataset(features_valid)

In [15]:
dl_train = DataLoader(ds_train, batch_size=32, pin_memory=True, shuffle=True,  num_workers=4)
dl_valid = DataLoader(ds_valid, batch_size=32, pin_memory=True, shuffle=False, num_workers=4)

# Evaluating

In [16]:
import numpy as np
from sklearn.metrics import confusion_matrix

In [17]:
def evaluate_fn(model, dataloader, return_conf=False):
    model.eval()
    losses, accuracies = [], []
    y_trues, y_preds = [], []
    for input_ids, input_mask, label_ids, label_mask in tqdm(dataloader, desc='Evaluating', leave=False):
        input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                 label_mask, device=device)
        with torch.no_grad():
            loss, active_logits, active_labels = model(
                input_ids, input_mask, label_ids, label_mask)
            
        losses.append(loss.item())
        
        active_logits = active_logits.argmax(dim=1).cpu().numpy()
        active_labels = active_labels.cpu().numpy()
        accs = (1 * (active_logits == active_labels)).tolist()
        
        y_trues += active_labels.tolist()
        y_preds += active_logits.tolist()
        accuracies += accs
        
    conf = confusion_matrix([LABELS[y] for y in y_trues], [LABELS[y] for y in y_preds], labels=LABELS)
    if return_conf:
        return conf
    print(conf)
        
        # transforming
#         ts, ps = remap(input_ids, input_mask, label_ids, label_mask, active_logits, active_labels)
#         y_preds += ps
#         y_trues += ts
        
#     print(y_preds, y_trues)
#     print(classification_report(y_trues, y_preds))
            
    return np.array(losses).mean(), np.array(accuracies).mean()

# Bert Model

Lets do this with the X tag for training and evaluation

In [28]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [19]:
LABELS = [
    l for l in labels if l not in ['[PAD]', '[CLS]', '[SEP]', 'X']
]
LABELS, len(LABELS)

(['O',
  'B-MISC',
  'I-MISC',
  'B-PER',
  'I-PER',
  'B-ORG',
  'I-ORG',
  'B-LOC',
  'I-LOC'],
 9)

In [20]:
config = BertConfig.from_pretrained('bert-base-cased', num_labels=len(LABELS), output_hidden_states=True)

In [24]:
model = BertForNERClassification(config, weight_O=0.1)

In [29]:
model.to(device)

BertForNERClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [33]:
evaluate_fn(model, dl_valid)

HBox(children=(IntProgress(value=0, description='Evaluating', max=1606, style=ProgressStyle(description_width=…

[[42145    26     0    88     0   395     0   105     0]
 [  114   541     0    35     0   147     0    85     0]
 [  108   147     0    29     0    48     0    14     0]
 [  300    13     0  1021     0   357     0   151     0]
 [  198    11     0   731     0   203     0   164     0]
 [  109    21     0   149     0   877     0   185     0]
 [  186    21     0    55     0   411     0    78     0]
 [   89    28     0    70     0   307     0  1343     0]
 [   24     1     0    12     0    90     0   130     0]]


(3.336868798707256, 0.8941824695300027)

In [32]:
model.load_state_dict(torch.load('bertner_lastlayer.ckp'), strict=False)

_IncompatibleKeys(missing_keys=['loss_fct.weight'], unexpected_keys=[])

# Training

In [69]:
scheduler = None

In [47]:
class Args:
    device = device
    fp16 = True
    num_epochs = 10
    ckp_path = 'bertner_masked_2.ckp'
    grad_steps = 1
    max_grad_norm = 1.
    load_state_dict = True
    n_iter = 0
    best_acc = 0.86
    writer = SummaryWriter('nerbert_masked_2')
    epoch = 0
args = Args()

In [62]:
if args.fp16:# and args.n_iter == 0:
    model, optimizer = to_fp16(model, optimizer)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [71]:
if args.load_state_dict:
    if os.path.exists('bertner_lastlayer.ckp'):
        print(model.load_state_dict(torch.load('bertner_lastlayer.ckp'), strict=False))
#     if os.path.exists(args.ckp_path.replace('.ckp', '_optimizer.ckp')):
#         optimizer.load_state_dict(torch.load(args.ckp_path.replace('.ckp', '_optimizer.ckp'), map_location='cpu'))

_IncompatibleKeys(missing_keys=['loss_fct.weight'], unexpected_keys=[])


In [64]:
train(args, model, dl_train, dl_valid, optimizer, evaluate_fn=evaluate_fn)

HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0


KeyboardInterrupt: 