# Imports

In [1]:
import sys
sys.path.append('../')

In [2]:
import os

In [3]:
from tqdm import tqdm_notebook as tqdm
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForTokenClassification, BertConfig, BertModel
from transformers import AdamW, get_linear_schedule_with_warmup

In [4]:
from mlpack.datasets.conll2003 import get_conll2003, get_conll2003_features, convert_examples_to_features_masked
from mlpack.datasets.conll2003 import CoNLL2003Dataset, InputFeatures
from mlpack.bert.ner.model import BertForMaskedNERClassification, BertForNERClassification
from mlpack.bert.ner.train import train
from mlpack.bert.ner.utils import to_fp16, to_device
from mlpack.utils import save_pickle, read_pickle

# Tokenizer

In [5]:
# tokenizer = BertTokenizer('../bert-base-cased/vocab.txt', do_lower_case=False)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# Data

In [6]:
def get_all_words_set_from_examples(examples):
    words = sum([
        ex.text_a.split(' ') for ex in examples
    ], [])
    return list(set(words))

In [7]:
examples, labels = get_conll2003('../datasets/CoNLL2003/')

In [8]:
train_words = get_all_words_set_from_examples(examples['train'])
valid_words = get_all_words_set_from_examples(examples['valid'])

In [9]:
len(valid_words), len([w for w in valid_words if w in train_words])

(9966, 6706)

In [10]:
label_map = {i:l for i, l in enumerate(labels, 0)}
label_map

{0: '[PAD]',
 1: 'O',
 2: 'B-MISC',
 3: 'I-MISC',
 4: 'B-PER',
 5: 'I-PER',
 6: 'B-ORG',
 7: 'I-ORG',
 8: 'B-LOC',
 9: 'I-LOC',
 10: '[CLS]',
 11: '[SEP]',
 12: 'X'}

In [11]:
feat_attrs = ['input_ids', 'input_mask', 'segment_ids', 'label_id', 'valid_ids', 'label_mask', 'masked_word']
def features_to_pickle(features, filename):
    feats = [
        { attr: feat.__getattribute__(attr) for attr in feat_attrs }
        for feat in features
    ]
    save_pickle(feats, filename)
    
def pickle_to_features(filename):
    feats = read_pickle(filename)
    return [
        InputFeatures(**attrs) for attrs in feats
    ]

In [12]:
train_pickle = 'train_feats.pickle'
valid_pickle = 'valid_feats.pickle'
redo = False
if os.path.exists(train_pickle) and not redo:
    features_train = pickle_to_features(train_pickle)
else:
    features_train = convert_examples_to_features_masked(examples['train'], labels, 128, tokenizer)
    features_to_pickle(features_train, train_pickle)
    
if os.path.exists(valid_pickle) and not redo:
    features_valid = pickle_to_features(valid_pickle)
else:
    features_valid = convert_examples_to_features_masked(examples['valid'], labels, 128, tokenizer)
    features_to_pickle(features_valid, valid_pickle)
    

In [11]:
features_train = convert_examples_to_features_masked(examples['train'], labels, 128, tokenizer, do_mask=True)
features_valid = convert_examples_to_features_masked(examples['valid'], labels, 128, tokenizer, do_mask=True)

In [12]:
# valid features with masked word on train
onmask_feats = [
    feat for feat in features_valid if feat.masked_word in train_words
]

offmask_feats = [
    feat for feat in features_valid if feat.masked_word not in train_words
]

In [14]:
len(onmask_feats), len(offmask_feats)

(47038, 4285)

# Checking

In [15]:
idx = 2
ex, feat = examples['valid'][idx], features_valid[idx]

In [16]:
zipped = zip(tokenizer.convert_ids_to_tokens(feat.input_ids), feat.label_mask,
            feat.input_mask)
for tok, lm, im in zipped:
    print(f'{tok:10} {lm} {im}')

[CLS]      0 1
CR         0 1
##IC       0 1
##KE       0 1
##T        0 1
-          0 1
[MASK]     1 1
T          0 1
##A        0 1
##KE       0 1
O          0 1
##VE       0 1
##R        0 1
AT         0 1
TO         0 1
##P        0 1
A          0 1
##FT       0 1
##ER       0 1
IN         0 1
##NI       0 1
##NG       0 1
##S        0 1
VI         0 1
##CT       0 1
##OR       0 1
##Y        0 1
.          0 1
[SEP]      0 1
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]      0 0
[PAD]     

# Dataset

In [17]:
ds_train = CoNLL2003Dataset(features_train)
ds_valid = CoNLL2003Dataset(features_valid)
ds_offvalid = CoNLL2003Dataset(offmask_feats)

In [18]:
dl_train = DataLoader(ds_train, batch_size=32, pin_memory=True, shuffle=True,  num_workers=8)
dl_valid = DataLoader(ds_valid, batch_size=32, pin_memory=True, shuffle=False, num_workers=8)
dl_offvalid = DataLoader(ds_offvalid, batch_size=32, pin_memory=True, shuffle=False, num_workers=4)

# Evaluating

In [19]:
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score
from seqeval.metrics import classification_report

In [20]:
LABELS

NameError: name 'LABELS' is not defined

In [28]:
def evaluate_fn(model, dataloader, return_conf=False):
    model.eval()
    losses, accuracies = [], []
    y_trues, y_preds = [], []
    for input_ids, input_mask, label_ids, label_mask in tqdm(dataloader, desc='Evaluating', leave=False):
        input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                 label_mask, device=device)
        with torch.no_grad():
            loss, active_logits, active_labels = model(
                input_ids, input_mask, label_ids, label_mask)
            
        losses.append(loss.item())
        
        active_logits = active_logits.argmax(dim=1).cpu().numpy()
        active_labels = active_labels.cpu().numpy()
        accs = (1 * (active_logits == active_labels)).tolist()
        
        y_trues += active_labels.tolist()
        y_preds += active_logits.tolist()
        accuracies += accs
        
    trues = [LABELS[y] for y in y_trues]
    preds = [LABELS[y] for y in y_preds]
    conf = confusion_matrix(trues, preds, labels=LABELS)
    f1 = f1_score(trues, preds, average='micro')
    print(f'F1 = {f1}')
    print(classification_report(trues, preds))
    if return_conf:
        return conf
    print(conf)
        
        # transforming
#         ts, ps = remap(input_ids, input_mask, label_ids, label_mask, active_logits, active_labels)
#         y_preds += ps
#         y_trues += ts
        
#     print(y_preds, y_trues)
#     print(classification_report(y_trues, y_preds))
            
    return np.array(losses).mean(), np.array(accuracies).mean()

# Bert Model

Lets do this with the X tag for training and evaluation

In [21]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device

'cuda'

In [22]:
LABELS = [
    l for l in labels if l not in ['[PAD]', '[CLS]', '[SEP]', 'X']
]
LABELS, len(LABELS)

(['O',
  'B-MISC',
  'I-MISC',
  'B-PER',
  'I-PER',
  'B-ORG',
  'I-ORG',
  'B-LOC',
  'I-LOC'],
 9)

In [23]:
config = BertConfig.from_pretrained('bert-base-cased', num_labels=len(LABELS), output_hidden_states=True)

In [24]:
model = BertForMaskedNERClassification(config, weight_O=0.1)

In [25]:
model.to(device)

BertForMaskedNERClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [26]:
optimizer = AdamW(model.parameters(), lr=2e-6, weight_decay=2e-5)

In [26]:
# model.load_state_dict(torch.load('bertner_lastlayer.ckp'), strict=False)
model.load_state_dict(torch.load('bert_masked_2.ckp'), strict=False)

<All keys matched successfully>

In [29]:
evaluate_fn(model, dl_offvalid)

HBox(children=(IntProgress(value=0, description='Evaluating', max=134, style=ProgressStyle(description_width='…

F1 = 0.6998833138856476
           precision    recall  f1-score   support

      PER       0.59      0.48      0.53       664
      ORG       0.29      0.27      0.28       272
     MISC       0.19      0.16      0.17       139
      LOC       0.39      0.44      0.41       210

micro avg       0.44      0.39      0.42      1285
macro avg       0.45      0.39      0.42      1285

[[2233   51   11   68   22   64    9   80    2]
 [  62   18    0   10    0   16    0    7    0]
 [  25    3    9    5    4    2    5    2    1]
 [ 180    5    2  232    3   21    1   25    0]
 [ 214    0    2    6  306    1   19    4    4]
 [ 121    8    0   18    0   70    0   15    0]
 [  40    1    3    1    9    7   36    1    1]
 [  79    6    0    8    0   17    0   91    0]
 [  14    0    0    0    0    0    0    1    4]]


(3.423512717820148, 0.6998833138856476)

# Training

In [33]:
class Args:
    device = device
    fp16 = True
    num_epochs = 20
    ckp_path = 'bert_masked.ckp'
    grad_steps = 1
    max_grad_norm = 1.
    load_state_dict = False
    n_iter = 0
    best_acc = 0.8
    writer = SummaryWriter('bert_masked')
    epoch = 0
args = Args()

In [34]:
t_total = len(dl_train) // args.grad_steps * args.num_epochs

In [35]:
scheduler = get_linear_schedule_with_warmup(optimizer, 5*len(dl_train), t_total)

In [36]:
if args.fp16:# and args.n_iter == 0:
    model, optimizer = to_fp16(model, optimizer)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [37]:
# if args.load_state_dict:
#     if os.path.exists('bertner_lastlayer.ckp'):
print(model.load_state_dict(torch.load('bertner_lastlayer.ckp'), strict=False))
#     if os.path.exists(args.ckp_path.replace('.ckp', '_optimizer.ckp')):
#         optimizer.load_state_dict(torch.load(args.ckp_path.replace('.ckp', '_optimizer.ckp'), map_location='cpu'))

_IncompatibleKeys(missing_keys=['loss_fct.weight'], unexpected_keys=[])


In [38]:
train(args, model, dl_train, dl_valid, optimizer, evaluate_fn=evaluate_fn, scheduler=scheduler)

HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[33457  1005   210  1585   911  1429  1268  2626   242]
 [  524   223     3    32     1    49     1    89     0]
 [  161    15    96     1    17     1    43     2    10]
 [  525    12     1  1105    36    60     5    91     1]
 [  238     0     1    28   986     0    28    11     8]
 [  403    32     0    77     0   717    14    97     1]
 [  175     4    14     2    15    21   506     7     7]
 [  574    46     1    43     3    75     7  1088     0]
 [   90     1     7     1    10     4    34     5   105]]
---Valid
Loss 1.4378341644667925
Acc 0.745922880579857


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[33778   759   293  1945   750  1535  1199  2230   244]
 [  505   233     3    45     2    67     2    65     0]
 [  135    11   133     2    13     2    39     2     9]
 [  397     6     1  1260    33    73     7    59     0]
 [  206     0     3    33  1023     1    21     6     7]
 [  340    28     0    86     0   790    13    84     0]
 [  167     2    17     2    12    26   509     7     9]
 [  475    48     1    80     5   108     7  1113     0]
 [   75     0    10     0     5     0    35     3   129]]
---Valid
Loss 1.1514429020279662
Acc 0.7592697231260839


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[36318   401   400  1263   610   910  1023  1613   195]
 [  581   208     4    25     3    41     2    58     0]
 [  119     6   154     2    14     2    38     1    10]
 [  466     5     2  1217    28    47     6    64     1]
 [  188     0     5    33  1037     2    22     4     9]
 [  385    23     1    79     0   753     9    91     0]
 [  156     1    22     0    17    22   512     7    14]
 [  528    32     2    54     2    80     4  1135     0]
 [   65     0    13     0     6     0    32     1   140]]
---Valid
Loss 1.0619013235332941
Acc 0.8080977339594334
Saved new checkpoint


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[36592   519   280  1760   719   652   640  1386   185]
 [  526   251     3    39     5    34     1    63     0]
 [  122     8   148     3    18     1    27     4    15]
 [  358     6     1  1356    40    21     2    50     2]
 [  153     0     1    29  1094     1     7     2    13]
 [  353    30     1   118     1   738     7    93     0]
 [  179     0    19     3    36    17   470     8    19]
 [  479    39     2    99     4    57     3  1154     0]
 [   61     0     7     0    11     0    15     3   160]]
---Valid
Loss 1.0169096018496595
Acc 0.8176256259376887
Saved new checkpoint


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[37250   495   452  1244   651   702   760  1018   161]
 [  532   268     5    27     4    35     3    46     2]
 [   94     7   186     1    18     1    28     0    11]
 [  386     6     5  1310    39    37     5    48     0]
 [  135     0     4    26  1101     2    20     2    10]
 [  367    36     2    87     1   760    10    76     2]
 [  134     0    27     3    33    17   514     7    16]
 [  488    40     1    71     5    90     5  1135     2]
 [   47     0    13     2     9     2    25     0   159]]
---Valid
Loss 1.0088685673674305
Acc 0.8316544239424819
Saved new checkpoint


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[36984   701   204  1710   528   703   648  1130   125]
 [  471   319     3    41     2    35     2    48     1]
 [  112    12   146     6    17     1    29     3    20]
 [  336    10     1  1389    27    31     2    39     1]
 [  138     0     1    40  1096     1     9     3    12]
 [  310    46     1   122     1   788     7    65     1]
 [  136     1    16     5    36    21   510     8    18]
 [  439    59     0   101     3    89     1  1144     1]
 [   47     0     6     3    10     0    17     2   172]]
---Valid
Loss 1.0045381780728893
Acc 0.8290240243165832


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


HBox(children=(IntProgress(value=0, description='Evaluating', max=1604, style=ProgressStyle(description_width=…

[[36534   445   246  2013   462   702   524  1710    97]
 [  460   285     4    45     5    36     3    82     2]
 [  107     5   171     5    17     1    28     1    11]
 [  298     4     1  1419    18    32     6    57     1]
 [  149     0     1    33  1087     7    12     2     9]
 [  283    30     1   132     1   783     7   102     2]
 [  146     1    16     5    31    19   510     7    16]
 [  367    47     0   100     3    67     1  1250     2]
 [   47     0     7     4    10     0    20     2   167]]
---Valid
Loss 1.0437698141682459
Acc 0.8223603452643065


HBox(children=(IntProgress(value=0, max=6363), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


KeyboardInterrupt: 