# Imports

In [1]:
import sys
sys.path.append('../')

In [2]:
from tqdm import tqdm_notebook as tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForTokenClassification, BertConfig, BertModel

In [3]:
from mlpack.datasets.conll2003 import get_conll2003, get_conll2003_features

# Tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Data

In [5]:
examples, labels = get_conll2003('../datasets/CoNLL2003/')

In [6]:
label_map = {i:l for i, l in enumerate(labels, 0)}
label_map

{0: '[PAD]',
 1: 'O',
 2: 'B-MISC',
 3: 'I-MISC',
 4: 'B-PER',
 5: 'I-PER',
 6: 'B-ORG',
 7: 'I-ORG',
 8: 'B-LOC',
 9: 'I-LOC',
 10: '[CLS]',
 11: '[SEP]',
 12: 'X'}

In [7]:
features = get_conll2003_features(examples, labels, 128, tokenizer, sep_tag='X')

In [8]:
features.keys()

dict_keys(['train', 'valid', 'test'])

# Checking

In [9]:
idx = 4
ex, feat = examples['train'][idx], features['train'][idx]

In [10]:
for token, label_id, mask in zip(tokenizer.convert_ids_to_tokens(feat.input_ids), feat.label_id, feat.label_mask):
    print(f'{token:20} {label_map[label_id]:6} {mask}')

[CLS]                [CLS]  0
germany              B-LOC  1
'                    O      1
s                    X      0
representative       O      1
to                   O      1
the                  O      1
european             B-ORG  1
union                I-ORG  1
'                    O      1
s                    X      0
veterinary           O      1
committee            O      1
werner               B-PER  1
z                    I-PER  1
##wing               X      0
##mann               X      0
said                 O      1
on                   O      1
wednesday            O      1
consumers            O      1
should               O      1
buy                  O      1
sheep                O      1
##me                 X      0
##at                 X      0
from                 O      1
countries            O      1
other                O      1
than                 O      1
britain              B-LOC  1
until                O      1
the                  O      1
scientific

# Dataset

In [11]:
class NERDataset(Dataset):

    def __init__(self, features):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feat = self.features[idx]
        return torch.tensor(feat.input_ids), torch.tensor(feat.input_mask), \
                torch.tensor(feat.label_id), torch.tensor(feat.label_mask)

In [12]:
ds_train = NERDataset(features['train'])
ds_valid = NERDataset(features['valid'])

In [13]:
dl_train = DataLoader(ds_train, batch_size=2, pin_memory=True, shuffle=True,  num_workers=0)
dl_valid = DataLoader(ds_valid, batch_size=2, pin_memory=True, shuffle=False, num_workers=0)

In [14]:
input_ids, input_mask, label_ids, label_mask = next(iter(dl_train))

In [15]:
input_ids.shape, input_mask.shape, label_ids.shape, label_mask.shape

(torch.Size([2, 128]),
 torch.Size([2, 128]),
 torch.Size([2, 128]),
 torch.Size([2, 128]))

# Bert Model

Lets do this with the X tag for training and evaluation

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
LABELS = [
    l for l in labels if l not in ['[PAD]', '[CLS]', '[SEP]', 'X']
]
LABELS

['O', 'B-MISC', 'I-MISC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

In [18]:
class BertForNERClassification(nn.Module):
    
    def __init__(self):
        super().__init__()
        config = BertConfig.from_pretrained('bert-base-uncased', output_hidden_states=True)
        self.bert = BertModel(config)
        
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, len(LABELS))
        
    @staticmethod
    def sum_last_4_layers(sequence_outputs):
        """Sums the last 4 hidden representations of a sequence output of BERT.
        Args:
        -----
        sequence_output: Tuple of tensors of shape (batch, seq_length, hidden_size).
            For BERT base, the Tuple has length 13.

        Returns:
        --------
        summed_layers: Tensor of shape (batch, seq_length, hidden_size)
        """
        last_layers = sequence_outputs[-4:]
        return torch.stack(last_layers, dim=0).sum(dim=0)
        
    def forward(self, input_ids, input_mask, label_ids, label_mask):
        num_examples, lenght = input_ids.shape
        _, _, hidden_states = self.bert(input_ids, attention_mask=input_mask)
        out = self.sum_last_4_layers(hidden_states)
        out = out.view(-1, 768)
        out = self.dropout(out)
        out = self.classifier(out)
        
        # take the active logits
        label_mask = label_mask.view(-1)
        active_logits = out[label_mask == 1]
        
        # take the active labels
        active_labels = label_ids.view(-1)[label_mask == 1] - 1 # remove one because of the [PAD] being the 0
        
        # calc the loss
        loss = nn.CrossEntropyLoss()(active_logits, active_labels)
        
        return loss, active_logits, active_labels

In [19]:
model = BertForNERClassification()
model.to(device)

BertForNERClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

# Evaluating

In [20]:
def to_device(*tensors, device='cpu'):
    return [
        t.to(device) for t in tensors
    ]

In [21]:
def evaluate_fn(model, dataloader):
    model.eval()
    losses, accs = [], []
    for input_ids, input_mask, label_ids, label_mask in tqdm(dataloader, desc='Evaluating', leave=False):
        input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                 label_mask, device=device)
        with torch.no_grad():
            loss, active_logits, active_labels = model(
                input_ids, input_mask, label_ids, label_mask)
            
        losses.append(loss.item())
        accs = (1 * (active_logits.argmax(dim=1) == active_labels)).numpy().tolist()
            
    return np.array(losses).mean(), np.array(accs).mean()

# Training

In [22]:
writer = SummaryWriter('nerbert')

In [27]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
scheduler = None

In [28]:
class Args:
    device = device
    fp16 = True
    num_epochs = 10
    ckp_path = 'bertner.ckp'
    grad_steps = 1
    max_grad_norm = 1.
args = Args()

In [29]:
n_iter = 0
best_acc = None

In [30]:
if args.fp16:
    try:
        from apex import amp
    except ImportError:
        raise ImportError(
            "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
    model, optimizer = amp.initialize(model, optimizer)

for ep in tqdm(range(args.num_epochs), desc='Epochs'):
    model.train()
    for step, (input_ids, input_mask, label_ids, label_mask) in tqdm(enumerate(dl_train), leave=False, total=len(dl_train)):
        input_ids, input_mask, label_ids, label_mask = to_device(input_ids, input_mask, label_ids,
                                                                 label_mask, device=device)

        loss, _, _ = model(input_ids, input_mask, label_ids, label_mask)

        if args.grad_steps > 1:
            loss = loss / args.grad_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                amp.master_params(optimizer), args.max_grad_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), args.max_grad_norm)

        if (step + 1) % args.grad_steps == 0 or (step + 1) == len(dl_train):
            optimizer.step()
            model.zero_grad()
            if scheduler:
                scheduler.step()

        if writer:
            writer.add_scalar('loss/train', loss, n_iter)
        n_iter += 1

    # evaluate
    valid_loss, valid_acc = evaluate_fn(args,
                                        model, dl_valid, loss_fn)

    if writer:
        writer.add_scalar('loss/valid', valid_loss, ep)
        writer.add_scalar('acc/valid', valid_acc, ep)

        for name, param in model.named_parameters():
            writer.add_histogram(name, param, ep)

    print(f'---Valid\nLoss {valid_loss}\nAcc {valid_acc}', flush=True)

    if best_acc is None:
        best_acc = valid_acc
        save_model(model, optimizer, args.ckp_path, scheduler=scheduler)
    else:
        if valid_acc > best_acc:
            best_acc = valid_acc
            save_model(model, optimizer, args.ckp_path,
                       scheduler=scheduler)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(IntProgress(value=0, description='Epochs', max=30, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, max=7021), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0


KeyboardInterrupt: 