In [1]:
from src.models.HMM import HMMTagger
import numpy as np
import pandas as pd

from src.preprocess.text import SentenceGetter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize

from tqdm.notebook import tqdm

from itertools import chain

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

In [2]:
ner_dataset = pd.read_csv("/Users/Mikhail_Bulgakov/GitRepo/pos_ner_task/data/ner_dataset.csv", delimiter=',', encoding='unicode_escape')
ner_dataset = ner_dataset.fillna(method="ffill")

In [3]:
sg = SentenceGetter(ner_dataset)

In [4]:
train_data, test_data = train_test_split(sg.get_full_data(), test_size=0.2, random_state=100)

In [5]:
states = list(set([i[2] for i in chain.from_iterable(train_data)]))
observations = list(set([i[0] for i in chain.from_iterable(train_data)]))

In [6]:
hmm = HMMTagger(states, observations)
hmm.fit(train_data)

100%|██████████| 38367/38367 [00:01<00:00, 36403.37it/s]


In [7]:
test_x = [[k[0] for k in i] for i in test_data]
test_y = list(chain.from_iterable([[k[2] for k in i] for i in test_data]))

In [8]:
test_y_pred = list(chain.from_iterable(hmm.predict(test_x)))

100%|██████████| 9592/9592 [00:27<00:00, 345.60it/s]


In [11]:
df = pd.DataFrame(confusion_matrix(test_y, test_y_pred, labels=states), index=states, columns=states)
df = df.reindex(sorted(df.columns), axis=1).sort_index()
df

Unnamed: 0,B-art,B-eve,B-geo,B-gpe,B-nat,B-org,B-per,B-tim,I-art,I-eve,I-geo,I-gpe,I-nat,I-org,I-per,I-tim,O
B-art,0,0,6,6,0,14,7,0,0,0,1,0,0,4,3,1,37
B-eve,0,0,10,3,0,14,2,2,0,0,0,0,0,14,0,1,16
B-geo,0,0,6649,15,0,165,88,9,0,0,29,1,0,19,66,3,641
B-gpe,0,0,192,2787,0,21,6,0,0,0,34,0,0,14,7,2,45
B-nat,0,0,0,0,0,2,1,0,0,0,0,0,0,2,0,1,34
B-org,0,0,635,51,0,2350,151,3,0,0,2,0,0,86,78,4,560
B-per,0,0,154,104,0,62,2383,1,0,0,7,0,0,81,288,6,285
B-tim,0,0,47,0,0,6,4,3096,0,0,1,0,0,3,4,13,957
I-art,0,0,2,0,0,3,0,0,0,0,5,0,0,28,7,0,11
I-eve,0,0,0,0,0,2,0,3,0,0,6,0,0,24,1,3,19


In [12]:
df = pd.DataFrame(precision_recall_fscore_support(test_y, test_y_pred, labels=states), index=["precision", "recall", "f1_score", "support"], columns=states).round(2)
df = df.reindex(sorted(df.columns), axis=1).sort_index()
df

  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,B-art,B-eve,B-geo,B-gpe,B-nat,B-org,B-per,B-tim,I-art,I-eve,I-geo,I-gpe,I-nat,I-org,I-per,I-tim,O
f1_score,0.0,0.0,0.85,0.89,0.0,0.68,0.74,0.82,0.0,0.0,0.71,0.45,0.0,0.72,0.82,0.57,0.99
precision,0.0,0.0,0.83,0.88,0.0,0.78,0.79,0.92,0.0,0.0,0.75,0.92,0.0,0.71,0.74,0.84,0.98
recall,0.0,0.0,0.87,0.9,0.0,0.6,0.71,0.75,0.0,0.0,0.67,0.3,0.0,0.74,0.93,0.43,0.99
support,79.0,62.0,7685.0,3108.0,40.0,3920.0,3371.0,4131.0,56.0,58.0,1500.0,40.0,12.0,3370.0,3367.0,1329.0,178824.0
