In [1]:
from os import getcwd, path
import sys
import matplotlib.pyplot as plt

BASE_PATH = path.dirname(getcwd())
sys.path.append(BASE_PATH)

from entities_recognition.bilstm.train import trainIters, evaluate_all
from config import START_TAG, STOP_TAG

import torch
print(torch.__version__)

0.4.0a0+5463a4a


In [2]:
TRAIN_PATH = path.join(BASE_PATH, 'data/CoNLL-2003/eng.train')
print(TRAIN_PATH)

/Users/2359media/Documents/botbot-nlp/data/CoNLL-2003/eng.train


In [3]:
import io
import string

def ident(x):
    return x

def read_conll_2003(filename):
    all_data = []

    current_txt = []
    current_tags = []
    tagset = []

    fin = io.open(filename, 'r', encoding='utf-8', newline='\n', errors='ignore')
    for line in fin:
        line = line.strip()
        if len(line) > 0: # skip blank lines
            tmp = line.split(' ')
            if tmp[0] != '-DOCSTART-':
                current_txt.append(tmp[0])
                current_tags.append(tmp[-1])
                tagset.append(tmp[-1])
        else:
            if len(current_txt) > 0:
                all_data.append((current_txt, ' '.join(current_tags)))
                current_txt = []
                current_tags = []
    fin.close()

    tagset = list(set(tagset))
    tag_to_ix = {tag: key for key, tag in enumerate(tagset)}
    tag_to_ix[START_TAG] = len(tagset)
    tag_to_ix[STOP_TAG] = len(tagset) + 1

    print(tag_to_ix)
    print('Loaded %s sentences' % len(all_data))
    
    return tag_to_ix, all_data

In [None]:
from entities_recognition.bilstm.predict import read_tags

tag_to_ix, training_data = read_conll_2003(TRAIN_PATH)
result = []
for sentence, tag_seq in training_data:
    assert len(sentence) == len(tag_seq.split(' '))
#     print(read_tags(tokens_in, tag_seq.split(' ')))

{'I-ORG': 0, 'I-LOC': 1, 'B-ORG': 2, 'B-MISC': 3, 'B-LOC': 4, 'O': 5, 'I-MISC': 6, 'I-PER': 7, '<START>': 8, '<STOP>': 9}
Loaded 14041 sentences


In [None]:
losses, model = trainIters(training_data, 
                           tag_to_ix,
                           learning_rate=1e-2,
                           gradual_unfreeze=True,
                           weight_decay=0,
                           optimizer='sgd',
                           n_iters=50,
                           log_every=1,
                           tokenizer=ident,
                           verbose=1)

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.figure()
fig, ax = plt.subplots()
plt.plot(losses)

Importing /Users/2359media/Documents/botbot-nlp/data/fasttext/crawl-300d-2M.vec...
8m 14s (- 403m 57s) (1 2%) 640192.1250
22m 35s (- 542m 2s) (2 4%) 602723.0625
36m 47s (- 576m 30s) (3 6%) 449270.0312
168m 58s (- 1943m 7s) (4 8%) 202334.1562
187m 16s (- 1685m 27s) (5 10%) 154062.6875
205m 20s (- 1505m 47s) (6 12%) 144066.7656
223m 43s (- 1374m 21s) (7 14%) 135290.1250
241m 59s (- 1270m 29s) (8 16%) 126969.2891


In [None]:
model.eval()
torch.save(model.state_dict(), 'bilstm-rnn-conll2003-vanilla.bin')

Model recall

In [None]:
evaluate_all(model, training_data, tag_to_ix, tokenizer=ident)

In [None]:
TEST_PATH_A = path.join(BASE_PATH, 'data/CoNLL-2003/eng.testa')
TEST_PATH_B = path.join(BASE_PATH, 'data/CoNLL-2003/eng.testb')
_, testing_data_a = read_conll_2003(TEST_PATH_A)
_, testing_data_b = read_conll_2003(TEST_PATH_B)

Accuracy on test sets

In [None]:
evaluate_all(model, testing_data_a, tag_to_ix, tokenizer=ident)

In [None]:
evaluate_all(model, testing_data_b, tag_to_ix, tokenizer=ident)

In [None]:
from entities_recognition.bilstm.predict import predict
from common.utils import wordpunct_tokenize

test_data = [
    'I live in Ho Chi Minh City, nice place, though my hometown is in Hanoi. I do miss it sometimes',
    'Trump’s role in midterm elections roils Republicans',
    'Kenya bans film about 2 girls in love because it’s ‘too hopeful’',
    'G.O.P. leaders and White House aides are trying to prepare President Trump for trouble in House and Senate races.'
]
predict(model, test_data, tag_to_ix, tokenizer=wordpunct_tokenize, delimiter=' ')

In [None]:
import json
with open('tag_to_ix.json', 'w') as tagfile:
    json.dump(tag_to_ix, tagfile)

`./conlleval < testa.out.txt`
```
processed 51578 tokens with 5942 phrases; found: 5958 phrases; correct: 5199.
accuracy:  97.93%; precision:  87.26%; recall:  87.50%; FB1:  87.38
              LOC: precision:  91.93%; recall:  91.78%; FB1:  91.86  1834
             MISC: precision:  87.27%; recall:  83.30%; FB1:  85.24  880
              ORG: precision:  78.03%; recall:  83.15%; FB1:  80.51  1429
              PER: precision:  89.81%; recall:  88.49%; FB1:  89.14  1815
```
`./conlleval < testb.out.txt`
```
processed 46666 tokens with 5879 phrases; found: 5703 phrases; correct: 4591.
accuracy:  95.89%; precision:  80.50%; recall:  78.09%; FB1:  79.28
              LOC: precision:  86.08%; recall:  88.61%; FB1:  87.33  1717
             MISC: precision:  71.13%; recall:  73.36%; FB1:  72.23  724
              ORG: precision:  74.11%; recall:  79.11%; FB1:  76.53  1773
              PER: precision:  86.23%; recall:  79.41%; FB1:  82.68  1489
```