In [1]:
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa

--2021-05-24 19:36:22--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3281528 (3.1M) [text/plain]
Saving to: ‘eng.train’


2021-05-24 19:36:22 (59.1 MB/s) - ‘eng.train’ saved [3281528/3281528]

--2021-05-24 19:36:22--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 827012 (808K) [text/plain]
Saving to: ‘eng.testa’


2021-05-24 19:36:22 (31.6 MB/s) - ‘eng.testa’ saved [8

In [2]:
from torch.utils.data import Dataset, DataLoader
from typing import List
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import nn
from sklearn.metrics import classification_report

In [3]:
!pip3 install pytorch_pretrained_bert

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |████████████████████████████████| 133kB 26.2MB/s 
[?25hCollecting boto3
[?25l  Downloading https://files.pythonhosted.org/packages/58/f5/06e099d85c66c68f203b356117e9df68e98983f2120b9ae598dc840c20e2/boto3-1.17.79-py2.py3-none-any.whl (131kB)
[K     |████████████████████████████████| 133kB 42.7MB/s 
Collecting s3transfer<0.5.0,>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/63/d0/693477c688348654ddc21dcdce0817653a294aa43f41771084c25e7ff9c7/s3transfer-0.4.2-py2.py3-none-any.whl (79kB)
[K     |████████████████████████████████| 81kB 12.4MB/s 
[?25hCollecting botocore<1.21.0,>=1.20.79
[?25l  Downloading https://files.pythonhosted.org/packages/ad/e5/f49beffe2474490a5e7811d533049a4fb701a6f87c66e55724a6b11c25e2/botocore-1.20.79-py2.py3-none-any.wh

In [4]:
import numpy as np
import torch
from torch.utils import data

from pytorch_pretrained_bert import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
VOCAB = ('<PAD>', 'O', 'I-LOC', 'B-PER', 'I-PER', 'I-ORG', 'I-MISC', 'B-MISC', 'B-LOC', 'B-ORG')
tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)}

class NerDataset(data.Dataset):
    def __init__(self, fpath):

        entries = open(fpath, 'r').read().strip().split("\n\n")
        sents, tags_li = [], []
        for entry in entries:
            words = [line.split()[0] for line in entry.splitlines()]
            tags = ([line.split()[-1] for line in entry.splitlines()])
            sents.append(["[CLS]"] + words + ["[SEP]"])
            tags_li.append(["<PAD>"] + tags + ["<PAD>"])
        self.sents, self.tags_li = sents, tags_li

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx]

        x, y = [], []
        is_heads = []
        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0]*(len(tokens) - 1)

            t = [t] + ["<PAD>"] * (len(tokens) - 1)
            yy = [tag2idx[each] for each in t]

            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        assert len(x)==len(y)==len(is_heads), f"len(x)={len(x)}, len(y)={len(y)}, len(is_heads)={len(is_heads)}"

        seqlen = len(y)

        words = " ".join(words)
        tags = " ".join(tags)
        return words, x, is_heads, tags, y, seqlen


def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    tags = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)


    f = torch.LongTensor

    return words, f(x), is_heads, tags, f(y), seqlens

100%|██████████| 213450/213450 [00:00<00:00, 21560922.59B/s]


In [5]:
import torch
import torch.nn as nn
from pytorch_pretrained_bert import BertModel

class Net(nn.Module):
    def __init__(self, vocab_size=None, device='cpu'):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.fc = nn.Linear(768, vocab_size)

        self.device = device

    def forward(self, x, y, ):
        x = x.to(self.device)
        y = y.to(self.device)

        if self.training:
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat

In [6]:
batch_size = 32
lr = 1e-4
n_epochs = 3

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import os
import numpy as np
import argparse

def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y
        optimizer.zero_grad()
        logits, y, _ = model(x, y)

        logits = logits.view(-1, logits.shape[-1])
        y = y.view(-1)
        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i==0:
            print("=====sanity check======")
            print("words:", words[0])
            print("x:", x.cpu().numpy()[0][:seqlens[0]])
            print("tokens:", tokenizer.convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]])
            print("is_heads:", is_heads[0])
            print("y:", _y.cpu().numpy()[0][:seqlens[0]])
            print("tags:", tags[0])
            print("seqlen:", seqlens[0])
            print("=======================")


        if i%25==0:
            print(f"step: {i}, loss: {loss.item()}")

def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    y_true =  np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")

    precision = num_correct / num_proposed
    recall = num_correct / num_gold
    f1 = 2*precision*recall / (precision + recall)
            
    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1

if __name__=="__main__":

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = Net(len(VOCAB), device).cuda()
    model = nn.DataParallel(model)

    train_dataset = NerDataset('eng.train')
    eval_dataset = NerDataset('eng.testa')

    train_iter = data.DataLoader(dataset=train_dataset,
                                 batch_size=batch_size,
                                 shuffle=True, 
                                 num_workers=4,
                                 collate_fn=pad)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=pad)

    optimizer = optim.Adam(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(1, n_epochs+1):
        train(model, train_iter, optimizer, criterion)

        print(f"=========eval at epoch={epoch}=========")
        precision, recall, f1 = eval(model, eval_iter)


100%|██████████| 404400730/404400730 [00:07<00:00, 51729088.91B/s]
  cpuset_checked))


words: [CLS] At Rio , they joined up with the national team squad for the journey to Moscow , where Brazil will face Russia in a friendly international on Wednesday . [SEP]
x: [ 101 1335 5470  117 1152 1688 1146 1114 1103 1569 1264 4322 1111 1103
 5012 1106 4116  117 1187 3524 1209 1339 2733 1107  170 4931 1835 1113
 9031  119  102]
tokens: ['[CLS]', 'At', 'Rio', ',', 'they', 'joined', 'up', 'with', 'the', 'national', 'team', 'squad', 'for', 'the', 'journey', 'to', 'Moscow', ',', 'where', 'Brazil', 'will', 'face', 'Russia', 'in', 'a', 'friendly', 'international', 'on', 'Wednesday', '.', '[SEP]']
is_heads: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
y: [0 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 2 1 1 2 1 1 1 1 1 1 1 0]
tags: <PAD> O I-LOC O O O O O O O O O O O O O I-LOC O O I-LOC O O I-LOC O O O O O O O <PAD>
seqlen: 31
step: 0, loss: 2.2725512981414795
step: 25, loss: 0.09014217555522919
step: 50, loss: 0.09056064486503601
step: 75, loss: 0