In [1]:
import sys
sys.path.append("src")

In [2]:
import torch

from ner.model import BERT
from ner.infer import NERInfer

from nltk.corpus.reader.conll import ConllCorpusReader

In [3]:
vocab = torch.load("data/conll/vocab.pth")

token_vocab = vocab["token_vocab"]
pos_vocab = vocab["pos_vocab"]
chunk_vocab = vocab["chunk_vocab"]
tags_vocab = vocab["tags_vocab"]

In [4]:
{v:k for k,v in tags_vocab.items()}

{1: 'O',
 2: 'B-LOC',
 3: 'B-PER',
 4: 'B-ORG',
 5: 'I-PER',
 6: 'I-ORG',
 7: 'B-MISC',
 8: 'I-LOC',
 9: 'I-MISC',
 0: '<pad>'}

In [5]:
len(token_vocab)

23625

In [6]:
model = BERT(vocab_size=len(token_vocab))

model.state_dict().keys()

odict_keys(['embedding.token.weight', 'embedding.position.pe', 'embedding.segment.weight', 'transformer_blocks.0.attention.linear_layers.0.weight', 'transformer_blocks.0.attention.linear_layers.0.bias', 'transformer_blocks.0.attention.linear_layers.1.weight', 'transformer_blocks.0.attention.linear_layers.1.bias', 'transformer_blocks.0.attention.linear_layers.2.weight', 'transformer_blocks.0.attention.linear_layers.2.bias', 'transformer_blocks.0.attention.output_linear.weight', 'transformer_blocks.0.attention.output_linear.bias', 'transformer_blocks.0.feed_forward.w_1.weight', 'transformer_blocks.0.feed_forward.w_1.bias', 'transformer_blocks.0.feed_forward.w_2.weight', 'transformer_blocks.0.feed_forward.w_2.bias', 'transformer_blocks.0.input_sublayer.norm.a_2', 'transformer_blocks.0.input_sublayer.norm.b_2', 'transformer_blocks.0.output_sublayer.norm.a_2', 'transformer_blocks.0.output_sublayer.norm.b_2', 'transformer_blocks.1.attention.linear_layers.0.weight', 'transformer_blocks.1.atte

In [7]:
_state_dict = torch.load(
    "logs/ner-epoch=48-val_loss=0.22.ckpt",
    weights_only=False,
    map_location="cpu"
)

state_dict = {k.replace("model.", ""):v for k, v in _state_dict["state_dict"].items()}
model.load_state_dict(state_dict)

<All keys matched successfully>

In [8]:
reader = ConllCorpusReader("data/conll", "test.txt", columntypes=("words", "pos", "chunk", "ne"))

In [9]:
grid = reader._grids()

sentences = list(grid)[1:]
sentences[:32]

[[['SOCCER', 'NN', 'B-NP', 'O'],
  ['-', ':', 'O', 'O'],
  ['JAPAN', 'NNP', 'B-NP', 'B-LOC'],
  ['GET', 'VB', 'B-VP', 'O'],
  ['LUCKY', 'NNP', 'B-NP', 'O'],
  ['WIN', 'NNP', 'I-NP', 'O'],
  [',', ',', 'O', 'O'],
  ['CHINA', 'NNP', 'B-NP', 'B-PER'],
  ['IN', 'IN', 'B-PP', 'O'],
  ['SURPRISE', 'DT', 'B-NP', 'O'],
  ['DEFEAT', 'NN', 'I-NP', 'O'],
  ['.', '.', 'O', 'O']],
 [['Nadim', 'NNP', 'B-NP', 'B-PER'], ['Ladki', 'NNP', 'I-NP', 'I-PER']],
 [['AL-AIN', 'NNP', 'B-NP', 'B-LOC'],
  [',', ',', 'O', 'O'],
  ['United', 'NNP', 'B-NP', 'B-LOC'],
  ['Arab', 'NNP', 'I-NP', 'I-LOC'],
  ['Emirates', 'NNPS', 'I-NP', 'I-LOC'],
  ['1996-12-06', 'CD', 'I-NP', 'O']],
 [['Japan', 'NNP', 'B-NP', 'B-LOC'],
  ['began', 'VBD', 'B-VP', 'O'],
  ['the', 'DT', 'B-NP', 'O'],
  ['defence', 'NN', 'I-NP', 'O'],
  ['of', 'IN', 'B-PP', 'O'],
  ['their', 'PRP$', 'B-NP', 'O'],
  ['Asian', 'JJ', 'I-NP', 'B-MISC'],
  ['Cup', 'NNP', 'I-NP', 'I-MISC'],
  ['title', 'NN', 'I-NP', 'O'],
  ['with', 'IN', 'B-PP', 'O'],
  ['a', 

In [10]:
i = 145

print(sentences[i])

token = [token_vocab.get(w[0], token_vocab["<unk>"]) for w in sentences[i]]
tags = [tags_vocab.get(w[3]) for w in sentences[i]]

token = torch.tensor(token, dtype=torch.long).unsqueeze(0)
tags = torch.tensor(tags, dtype=torch.long).unsqueeze(0)

[['Result', 'NN', 'B-NP', 'O'], ['of', 'IN', 'B-PP', 'O'], ['an', 'DT', 'B-NP', 'O'], ['English', 'NNP', 'I-NP', 'B-MISC'], ['F.A.', 'NNP', 'I-NP', 'I-MISC'], ['Challenge', 'NNP', 'I-NP', 'I-MISC']]


In [11]:
token

tensor([[  969,     5,    41,   539,  9211, 15328]])

In [12]:
model.eval()
with torch.no_grad():
    pred = model(token, torch.zeros_like(token))

print(pred)

pred = pred.view(-1, 10)

pred.argmax(dim=-1)

tags = tags.view(-1)

tensor([[[-5.5110, 21.1741,  9.7446, 10.6033, 11.0002, 10.4307, 12.0456,
           6.6333,  6.4055,  8.0601],
         [-5.4464, 18.6616,  9.6520,  7.1484, 11.0816,  6.8671, 12.1711,
           5.8416,  5.0920,  6.7776],
         [-5.9140, 24.2937, 13.4970, 10.9316, 10.9122, 11.2912, 12.5782,
           8.2423,  8.9436,  8.9204],
         [-2.4249,  5.1367,  9.8078,  7.5192,  8.8406, -3.4010, -1.2867,
          18.6100, -5.6940,  6.2814],
         [-2.2606,  4.2572,  5.5185, -0.8565,  4.4824, -2.0144,  3.8608,
           6.7404,  0.6126,  5.2062],
         [-1.9662,  1.0415,  1.2734,  0.8577,  1.9768,  0.7867,  2.4589,
           2.0460, -0.8256,  1.9975]]])


In [13]:
print(pred.argmax(dim=-1))
print(tags)

tensor([1, 1, 1, 7, 7, 6])
tensor([1, 1, 1, 7, 9, 9])


In [14]:
tags_vocab

{'O': 1,
 'B-LOC': 2,
 'B-PER': 3,
 'B-ORG': 4,
 'I-PER': 5,
 'I-ORG': 6,
 'B-MISC': 7,
 'I-LOC': 8,
 'I-MISC': 9,
 '<pad>': 0}

In [15]:
infer = NERInfer(
    model_path = "logs/ner-epoch=48-val_loss=0.22.ckpt",
    token_file = "data/conll/vocab.pth",
    device = "cpu"
)

In [16]:
infer.predict("Result of an English F.A. Challenge")

[969, 5, 41, 539, 9211, 15328]


[{'entity': 'O', 'score': 0.4055, 'index': 0, 'word': 'Result'},
 {'entity': 'O', 'score': 0.348, 'index': 1, 'word': 'of'},
 {'entity': 'O', 'score': 0.9544, 'index': 2, 'word': 'an'},
 {'entity': 'B-MISC', 'score': 1.0, 'index': 3, 'word': 'English'},
 {'entity': 'B-MISC', 'score': 0.3033, 'index': 4, 'word': 'F.A.'},
 {'entity': 'I-ORG', 'score': 0.4071, 'index': 5, 'word': 'Challenge'}]

In [19]:
infer.predict(["Result of an English F.A. Challenge", "London is a great place"])

[[969, 5, 41, 539, 9211, 15328], [230, 30, 8, 983, 453]]


[[{'entity': 'O', 'score': 0.4055, 'index': 0, 'word': 'Result'},
  {'entity': 'O', 'score': 0.348, 'index': 1, 'word': 'of'},
  {'entity': 'O', 'score': 0.9544, 'index': 2, 'word': 'an'},
  {'entity': 'B-MISC', 'score': 1.0, 'index': 3, 'word': 'English'},
  {'entity': 'B-MISC', 'score': 0.3033, 'index': 4, 'word': 'F.A.'},
  {'entity': 'I-ORG', 'score': 0.4071, 'index': 5, 'word': 'Challenge'}],
 [{'entity': 'B-LOC', 'score': 0.2687, 'index': 0, 'word': 'London'},
  {'entity': 'O', 'score': 0.3755, 'index': 1, 'word': 'is'},
  {'entity': 'O', 'score': 0.9557, 'index': 2, 'word': 'a'},
  {'entity': 'O', 'score': 0.7027, 'index': 3, 'word': 'great'},
  {'entity': 'O', 'score': 0.8196, 'index': 4, 'word': 'place'}]]