In [1]:
from torch.utils.data import DataLoader
from model import SimpleSequenceTagger
from ner_dataset import NERDataset
import numpy as np

In [2]:
dev_file = './data/dev.conll'  # path to training data
test_file = './data/test.conll'  # path to validation data
train_file = './data/train.conll'  # path to test data
num_epochs = 20
train_dataset = NERDataset(file=train_file)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
dev_dataset = NERDataset(file=dev_file)
dev_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=True)
test_dataset = NERDataset(file=test_file)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [3]:
seq_tagger = SimpleSequenceTagger(input_dim=50, hidden_dim=100, num_layers=1, class_size=9)
for epoch in range(num_epochs):
    loss = seq_tagger.train(data=train_dataloader, learning_rate=0.01)
    print("Iteration:", epoch, "Loss:", loss)
    metrics = seq_tagger.evaluate(dev_dataloader)
    print('DEV-Data','macro', metrics['f1_scores']['macro'], 'micro', metrics['f1_scores']['micro'])

Iteration: 0 Loss: 16.822690963745117
DEV-Data macro 0.5287793277732318 micro 0.9188115727580701
Iteration: 1 Loss: 10.39950942993164


  p = tp / sum_of_all_classified_elements_class  # calculate precison per class that already includes the tp


DEV-Data macro nan micro 0.9170982438378568
Iteration: 2 Loss: 0.8340609073638916
DEV-Data macro 0.5753273390846158 micro 0.9188115727580701
Iteration: 3 Loss: 0.3662041127681732
DEV-Data macro 0.5787205257669387 micro 0.9210700517892604
Iteration: 4 Loss: 2.5584282875061035
DEV-Data macro 0.5790607375401877 micro 0.922432927066703
Iteration: 5 Loss: 2.1972246170043945
DEV-Data macro 0.582786017141365 micro 0.9216346715470581
Iteration: 6 Loss: 0.8941321969032288
DEV-Data macro 0.5704440044822259 micro 0.9236011058759395
Iteration: 7 Loss: 8.772149085998535
DEV-Data macro 0.5864310133677663 micro 0.9237373934036837
Iteration: 8 Loss: 1.6950182914733887
DEV-Data macro 0.6051110200919964 micro 0.9265604921926717
Iteration: 9 Loss: 15.739957809448242
DEV-Data macro 0.5938624801420157 micro 0.9258206456134885
Iteration: 10 Loss: 1.9878031015396118
DEV-Data macro 0.605007437231607 micro 0.9272029905377517
Iteration: 11 Loss: 2.4825806617736816
DEV-Data macro 0.6035347729455438 micro 0.92582

In [4]:
metrics = seq_tagger.evaluate(test_dataloader)
print('TEST-Data','macro', metrics['f1_scores']['macro'], 'micro', metrics['f1_scores']['micro'])
wrong_word_predictions = metrics['word_statistics']['fail']
wrong_word_predictions = [v for v in wrong_word_predictions if v[1] >= 8]
print('TEST-Data', 'wrong_predicted_words', wrong_word_predictions)
context_statistics = metrics['context_statistics']
print('TEST-Data', 'context_statistics', context_statistics)

TEST-Data macro 0.5933550105318696 micro 0.9140734359857866
TEST-Data wrong_predicted_words [('world', 68), ('new', 55), ('of', 36), ('national', 36), ('madrid', 31), ('real', 25), ('party', 25), ('open', 24), ('united', 24), ('south', 20), ('fe', 20), ('wto', 19), ('santa', 19), ('league', 18), ('newsroom', 16), ('norilsk', 16), ('state', 15), ('city', 15), ('bre-x', 15), ('chicago', 15), ('west', 14), ('york', 14), ('louis', 14), ('university', 14), ('korea', 14), ('east', 13), ('and', 13), ('ny', 13), ('jersey', 12), ('major', 12), ('lara', 12), ('american', 12), ('zealand', 12), ('super', 11), ('gmt', 11), ('melbourne', 11), ('i', 11), ('internet', 10), ('costa', 10), ('san', 10), ('institute', 10), ('western', 10), ('van', 10), ('kroons', 10), ('series', 9), ('colorado', 9), ('fifa', 9), ('ministry', 9), ('international', 9), ('washington', 9), ('est', 9), ('vancouver', 9), ('buffalo', 9), ('barbarians', 9), ('nymex', 9), ('mills', 9), ('council', 8), ('belo', 8), ('sakakibara', 8

In [5]:
metrics['confusion_matrix']

array([[1.4210e+03, 1.3000e+01, 1.2000e+02, 2.1000e+01, 1.3000e+01,
        1.0000e+00, 1.0000e+01, 2.9000e+01, 4.0000e+01],
       [3.0000e+01, 4.1100e+02, 2.8000e+01, 1.4000e+01, 1.1000e+01,
        2.0000e+00, 8.0000e+00, 1.0000e+01, 1.8800e+02],
       [2.7100e+02, 4.0000e+01, 9.4000e+02, 3.8000e+01, 1.8000e+01,
        1.0000e+00, 7.5000e+01, 4.5000e+01, 2.3300e+02],
       [2.3000e+01, 3.0000e+00, 4.6000e+01, 1.0370e+03, 4.0000e+00,
        0.0000e+00, 9.0000e+00, 3.6600e+02, 1.2900e+02],
       [8.5000e+01, 1.7000e+01, 6.0000e+00, 1.3000e+01, 7.6000e+01,
        2.0000e+00, 9.0000e+00, 5.0000e+00, 4.4000e+01],
       [2.0000e+00, 8.0000e+00, 7.0000e+00, 2.0000e+00, 7.0000e+00,
        8.3000e+01, 2.0000e+00, 2.0000e+00, 1.0300e+02],
       [1.4600e+02, 8.0000e+00, 1.2500e+02, 2.1000e+01, 4.3000e+01,
        0.0000e+00, 1.5800e+02, 2.7000e+01, 3.0700e+02],
       [2.0000e+01, 3.0000e+00, 4.6000e+01, 3.4700e+02, 4.0000e+00,
        9.0000e+00, 6.0000e+00, 5.4500e+02, 1.7600e+02],


In [None]:
# world = ambiguity
# 1996-12-06 date

In [None]:
# 2022 https://arxiv.org/pdf/2204.04391.pdf
# https://arxiv.org/pdf/1910.02403.pdf

# https://kiarashk76.github.io/docs/DL4NLP.pdf