In [1]:
import sys
sys.path.append("src")

In [2]:
import torch

from ner.model import BiLSTMModel

from nltk.corpus.reader.conll import ConllCorpusReader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
vocab = torch.load("data/conll/vocab.pth")

token_vocab = vocab["token_vocab"]
pos_vocab = vocab["pos_vocab"]
chunk_vocab = vocab["chunk_vocab"]
tags_vocab = vocab["tags_vocab"]

In [4]:
model = BiLSTMModel(
    n_words = 23624,
    n_pos = 45,
    n_chunks = 20,
    n_tags = 9
)

model.state_dict().keys()



odict_keys(['word_embedding.weight', 'pos_embedding.weight', 'chunk_embedding.weight', 'lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0', 'lstm.weight_ih_l0_reverse', 'lstm.weight_hh_l0_reverse', 'lstm.bias_ih_l0_reverse', 'lstm.bias_hh_l0_reverse', 'output_layer.weight', 'output_layer.bias'])

In [5]:
_state_dict = torch.load(
    "logs/train/runs/2024-12-25/15-15-32/checkpoint/ner-epoch=22-val_loss=1.50.ckpt",
    weights_only=False,
    map_location="cpu"
)

state_dict = {k.replace("model.", ""):v for k, v in _state_dict["state_dict"].items()}
model.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
reader = ConllCorpusReader("data/conll", "train.txt", columntypes=("words", "pos", "chunk", "ne"))

In [7]:
grid = reader._grids()

sentences = list(grid)[1:]
sentences[:32]

[[['EU', 'NNP', 'B-NP', 'B-ORG'],
  ['rejects', 'VBZ', 'B-VP', 'O'],
  ['German', 'JJ', 'B-NP', 'B-MISC'],
  ['call', 'NN', 'I-NP', 'O'],
  ['to', 'TO', 'B-VP', 'O'],
  ['boycott', 'VB', 'I-VP', 'O'],
  ['British', 'JJ', 'B-NP', 'B-MISC'],
  ['lamb', 'NN', 'I-NP', 'O'],
  ['.', '.', 'O', 'O']],
 [['Peter', 'NNP', 'B-NP', 'B-PER'], ['Blackburn', 'NNP', 'I-NP', 'I-PER']],
 [['BRUSSELS', 'NNP', 'B-NP', 'B-LOC'], ['1996-08-22', 'CD', 'I-NP', 'O']],
 [['The', 'DT', 'B-NP', 'O'],
  ['European', 'NNP', 'I-NP', 'B-ORG'],
  ['Commission', 'NNP', 'I-NP', 'I-ORG'],
  ['said', 'VBD', 'B-VP', 'O'],
  ['on', 'IN', 'B-PP', 'O'],
  ['Thursday', 'NNP', 'B-NP', 'O'],
  ['it', 'PRP', 'B-NP', 'O'],
  ['disagreed', 'VBD', 'B-VP', 'O'],
  ['with', 'IN', 'B-PP', 'O'],
  ['German', 'JJ', 'B-NP', 'B-MISC'],
  ['advice', 'NN', 'I-NP', 'O'],
  ['to', 'TO', 'B-PP', 'O'],
  ['consumers', 'NNS', 'B-NP', 'O'],
  ['to', 'TO', 'B-VP', 'O'],
  ['shun', 'VB', 'I-VP', 'O'],
  ['British', 'JJ', 'B-NP', 'B-MISC'],
  ['lamb

In [8]:
token = [token_vocab.get(w[0], token_vocab["<unk>"]) for w in sentences[0]]
pos = [pos_vocab.get(w[1], pos_vocab["<unk>"]) for w in sentences[0]]
chunk = [chunk_vocab.get(w[2], chunk_vocab["<unk>"]) for w in sentences[0]]
tags = [tags_vocab.get(w[3]) for w in sentences[0]]

token = torch.tensor(token, dtype=torch.long).unsqueeze(0)
pos = torch.tensor(pos, dtype=torch.long).unsqueeze(0)
chunk = torch.tensor(chunk, dtype=torch.long).unsqueeze(0)
tags = torch.tensor(tags, dtype=torch.long).unsqueeze(0)

In [9]:
pred = model(token, pos, chunk)

pred = pred.view(-1, 10)

pred.argmax(dim=-1)

tags = tags.view(-1)

In [10]:
print(pred.argmax(dim=-1))
print(tags)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([2, 7, 8, 7, 7, 7, 8, 7, 7])


In [11]:
pred

tensor([[0.7002, 0.6051, 0.0489, 0.0481, 0.0489, 0.0512, 0.0594, 0.0475, 0.0642,
         0.0552],
        [0.6760, 0.6236, 0.0460, 0.0451, 0.0459, 0.0479, 0.0545, 0.0446, 0.0581,
         0.0511],
        [0.8870, 0.2946, 0.0472, 0.0468, 0.0470, 0.0486, 0.0530, 0.0463, 0.0542,
         0.0497],
        [0.9025, 0.2497, 0.0457, 0.0453, 0.0455, 0.0468, 0.0501, 0.0449, 0.0507,
         0.0475],
        [0.8798, 0.2859, 0.0439, 0.0433, 0.0436, 0.0446, 0.0471, 0.0429, 0.0473,
         0.0451],
        [0.9310, 0.1711, 0.0439, 0.0435, 0.0435, 0.0442, 0.0453, 0.0432, 0.0442,
         0.0436],
        [0.9621, 0.0938, 0.0433, 0.0431, 0.0429, 0.0434, 0.0433, 0.0429, 0.0413,
         0.0418],
        [0.9695, 0.0721, 0.0417, 0.0415, 0.0413, 0.0415, 0.0407, 0.0413, 0.0385,
         0.0398],
        [0.9700, 0.0666, 0.0404, 0.0401, 0.0399, 0.0400, 0.0387, 0.0399, 0.0364,
         0.0382]], grad_fn=<ViewBackward0>)