<a href="https://colab.research.google.com/github/akaver/NLP2019/blob/master/Lab11_2019.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Partly based on https://github.com/Kyubyong/bert_ner

In [0]:
!pip install pytorch-pretrained-bert

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/5d/3c/d5fa084dd3a82ffc645aba78c417e6072ff48552e3301b1fa3bd711e03d4/pytorch_pretrained_bert-0.6.1-py3-none-any.whl (114kB)
[K    100% |████████████████████████████████| 122kB 3.8MB/s 
Installing collected packages: pytorch-pretrained-bert
Successfully installed pytorch-pretrained-bert-0.6.1


In [0]:
!wget -q https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train

In [0]:
!wget -q https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa

In [0]:
from pytorch_pretrained_bert import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


100%|██████████| 213450/213450 [00:00<00:00, 382620.99B/s]


In [0]:
VOCAB = ('<PAD>', 'O', 'LOC', 'PER', 'ORG', 'MISC')
tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)}

In [0]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import torch 

device = 'cpu'
if torch.cuda.is_available():
  device = torch.device('cuda')

print(device)



cuda


In [0]:
'''
An entry or sent looks like ...
SOCCER NN B-NP O
- : O O
JAPAN NNP B-NP B-LOC
GET VB B-VP O
LUCKY NNP B-NP O
WIN NNP I-NP O
, , O O
CHINA NNP B-NP B-PER
IN IN B-PP O
SURPRISE DT B-NP O
DEFEAT NN I-NP O
. . O O
Each mini-batch returns the followings:
words: list of input sents. ["The 26-year-old ...", ...]
x: encoded input sents. [N, T]. int64.
is_heads: list of head markers. [[1, 1, 0, ...], [...]]
tags: list of tags.['O O B-MISC ...', '...']
y: encoded tags. [N, T]. int64
seqlens: list of seqlens. [45, 49, 10, 50, ...]
'''
class NerDataset(Dataset):
    def __init__(self, fpath):
        """
        fpath: [train|valid|test].txt
        """
        entries = open(fpath, 'r').read().strip().split("\n\n")
        sents, tags_li = [], [] # list of lists
        for entry in entries:
            words = [line.split()[0] for line in entry.splitlines()]
            tags = ([line.split()[-1] for line in entry.splitlines()])
            tags = [l.lstrip("B-").lstrip("I-") for l in tags]
            sents.append(["[CLS]"] + words + ["[SEP]"])
            tags_li.append(["<PAD>"] + tags + ["<PAD>"])
        self.sents, self.tags_li = sents, tags_li

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list

        # We give credits only to the first piece.
        x, y = [], [] # list of ids
        is_heads = [] # list. 1: the token is the first piece of a word
        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0]*(len(tokens) - 1)

            t = [t] + ["<PAD>"] * (len(tokens) - 1)  # <PAD>: no decision
            yy = [tag2idx[each] for each in t]  # (T,)

            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        assert len(x)==len(y)==len(is_heads), f"len(x)={len(x)}, len(y)={len(y)}, len(is_heads)={len(is_heads)}"

        # seqlen
        seqlen = len(y)

        # to string
        words = " ".join(words)
        tags = " ".join(tags)
        return words, x, is_heads, tags, y, seqlen


def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    tags = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)


    f = torch.LongTensor

    return words, f(x), is_heads, tags, f(y), seqlens


In [0]:
import torch.nn as nn
import torch.nn.functional as F
from pytorch_pretrained_bert import BertModel

class Net(nn.Module):
    def __init__(self, top_rnns=False, vocab_size=None, device='cpu', finetuning=False):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')

        self.top_rnns=top_rnns
        if top_rnns:
            self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.fc = nn.Linear(768, vocab_size)

        self.device = device
        self.finetuning = finetuning

    def forward(self, x):
        '''
        x: (N, T). int64
        y: (N, T). int64
        Returns
        enc: (N, T, VOCAB)
        '''
        x = x.to(self.device)


        if self.training and self.finetuning:
            # print("->bert.train()")
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

        if self.top_rnns:
            enc, _ = self.rnn(enc)
        logits = self.fc(enc)
        
        return logits


In [0]:
train_dataset = NerDataset("eng.train")
dev_dataset = NerDataset("eng.testa")

In [0]:
train_dataset[1]

('[CLS] EU rejects German call to boycott British lamb . [SEP]',
 [101, 7270, 22961, 1528, 1840, 1106, 21423, 1418, 2495, 12913, 119, 102],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
 '<PAD> ORG O MISC O O O MISC O O <PAD>',
 [0, 4, 1, 5, 1, 1, 1, 5, 1, 0, 1, 0],
 12)

In [0]:
# for calculating metrics
!pip install sklearn_crfsuite
import sklearn_crfsuite
import sklearn_crfsuite.metrics

Collecting sklearn_crfsuite
  Downloading https://files.pythonhosted.org/packages/25/74/5b7befa513482e6dee1f3dd68171a6c9dfc14c0eaa00f885ffeba54fe9b0/sklearn_crfsuite-0.3.6-py2.py3-none-any.whl
Collecting python-crfsuite>=0.8.3 (from sklearn_crfsuite)
[?25l  Downloading https://files.pythonhosted.org/packages/2f/86/cfcd71edca9d25d3d331209a20f6314b6f3f134c29478f90559cee9ce091/python_crfsuite-0.9.6-cp36-cp36m-manylinux1_x86_64.whl (754kB)
[K    1% |▍                               | 10kB 18.2MB/s eta 0:00:01[K    2% |▉                               | 20kB 1.7MB/s eta 0:00:01[K    4% |█▎                              | 30kB 2.5MB/s eta 0:00:01[K    5% |█▊                              | 40kB 3.3MB/s eta 0:00:01[K    6% |██▏                             | 51kB 4.1MB/s eta 0:00:01[K    8% |██▋                             | 61kB 4.8MB/s eta 0:00:01[K    9% |███                             | 71kB 5.5MB/s eta 0:00:01[K    10% |███▌                            | 81kB 3.4MB/s eta 0:00

In [0]:

def train(model, num_epochs, train_iter, dev_iter):

  criterion = nn.CrossEntropyLoss(ignore_index=0)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

  best_acc = 0
  last_step = 0
  for epoch in range(1, num_epochs+1):
    print("Epoch %d" % epoch)
    
    for i, batch in enumerate(train_iter):
        model.train()
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits = model(x) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.to(device)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i == 0:
            print("=====sanity check======")
            print("words:", words[0])
            print("x:", x.cpu().numpy()[0][:seqlens[0]])
            print("tokens:", tokenizer.convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]])
            print("is_heads:", is_heads[0])
            print("y:", _y.cpu().numpy()[0][:seqlens[0]])
            print("tags:", tags[0])
            print("seqlen:", seqlens[0])
            print("=======================")


        if i%10 == 0: # monitoring
            print(f"step: {i}, loss: {loss.item()}")

        if i%100 == 0: # let's evaluate more frequently than evry epoch
            evaluate("dev", dev_iter, model)
            



def evaluate(dataset_name, data_iter, model, full_report=False):
  
  model.eval()
  y_pred_seq = []
  y_seq = []
  with torch.no_grad():
    for batch in data_iter:
      words, x, is_heads, tags, y, seqlens = batch

      logits = model(x)  # y_hat: (N, T)
      y_pred = logits.argmax(-1)
      
      
      for i in range(len(y)):
        head_i = is_heads[i]
        y_i = y[i].cpu().numpy()
        y_pred_i = y_pred[i].cpu().numpy()
            
        y_pred_i = [VOCAB[y_pred_j] for head, y_pred_j in zip(head_i, y_pred_i) if head == 1][1:-1]
        y_i = [VOCAB[y_j] for head, y_j in zip(head_i, y_i) if head == 1][1:-1]

        y_pred_seq.append(y_pred_i)
        y_seq.append(y_i)
  
  accuracy = sklearn_crfsuite.metrics.flat_accuracy_score(y_seq, y_pred_seq)
  
  print('  Evaluation on {} -  acc: {:.4f}%'.format(dataset_name, accuracy))
  if full_report:
    print(sklearn_crfsuite.metrics.flat_classification_report(y_seq, y_pred_seq, labels=["LOC", "MISC", "ORG", "PER"]))
       

   
  

In [0]:
model = Net(False, len(VOCAB), device, True).to(device)

100%|██████████| 404400730/404400730 [00:34<00:00, 11830956.07B/s]


In [0]:
batch_size = 16

train_iter = DataLoader(dataset=train_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=pad)
dev_iter = DataLoader(dataset=dev_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=pad)

In [0]:
train(model, 1, train_iter, dev_iter)

Epoch 1
words: [CLS] CRICKET - GOOCH TO PLAY ANOTHER SEASON FOR ESSEX . [SEP]
x: [  101 15531  9741 22441  1942   118 27157  9244  3048 16972   153 10783
  3663 23096 14697  3048  9637 12342 10719 11414   143  9565   142 12480
 24654   119   102]
tokens: ['[CLS]', 'CR', '##IC', '##KE', '##T', '-', 'GO', '##OC', '##H', 'TO', 'P', '##LA', '##Y', 'AN', '##OT', '##H', '##ER', 'SE', '##AS', '##ON', 'F', '##OR', 'E', '##SS', '##EX', '.', '[SEP]']
is_heads: [1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1]
y: [0 1 0 0 0 1 3 0 0 1 1 0 0 1 0 0 0 1 0 0 1 0 2 0 0 1 0]
tags: <PAD> O O PER O O O O O LOC O <PAD>
seqlen: 27
step: 0, loss: 1.72129225730896
  Evaluation on dev -  acc: 0.8312%
step: 10, loss: 0.4828910827636719
step: 20, loss: 0.27751442790031433
step: 30, loss: 0.132721409201622
step: 40, loss: 0.09761400520801544
step: 50, loss: 0.17915427684783936
step: 60, loss: 0.18298524618148804
step: 70, loss: 0.13008172810077667
step: 80, loss: 0.0561047494411468

In [0]:
evaluate("dev", dev_iter, model, full_report=True)

  Evaluation on dev -  acc: 0.9885%
              precision    recall  f1-score   support

         LOC       0.97      0.96      0.96      2094
        MISC       0.89      0.88      0.88      1268
         ORG       0.94      0.92      0.93      2092
         PER       0.98      0.98      0.98      3149

   micro avg       0.95      0.94      0.95      8603
   macro avg       0.94      0.93      0.94      8603
weighted avg       0.95      0.94      0.95      8603

