In [1]:
import numpy as np
import pandas as pd

from src.preprocess.text import SentenceGetter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize

from tqdm.notebook import tqdm

from itertools import chain

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

from src.preprocess.text import sent2features, sent2labels

from sklearn_crfsuite import CRF

In [2]:
ner_dataset = pd.read_csv("/Users/Mikhail_Bulgakov/GitRepo/pos_ner_task/data/ner_dataset.csv", delimiter=',', encoding='unicode_escape')
ner_dataset = ner_dataset.fillna(method="ffill")

In [3]:
sg = SentenceGetter(ner_dataset)

In [4]:
train_data, test_data = train_test_split(sg.get_full_data(), test_size=0.2, random_state=100)

In [5]:
X_train = [sent2features(s) for s in train_data]
y_train = [sent2labels(s) for s in train_data]

X_test = [sent2features(s) for s in test_data]
y_test = [sent2labels(s) for s in test_data]

In [6]:
crf = CRF(algorithm='lbfgs',
          c1=10,
          c2=0.1,
          max_iterations=100,
          all_possible_transitions=False)

In [7]:
crf.fit(X_train,y_train)

In [8]:
y_test_pred = list(chain.from_iterable(crf.predict(X_test)))
y_test = list(chain.from_iterable(y_test))

In [9]:
states = list(set([i[2] for i in chain.from_iterable(train_data)]))

In [18]:
df = pd.DataFrame(confusion_matrix(y_test, y_test_pred, labels=states), index=states, columns=states)
df = df.reindex(sorted(df.columns), axis=1).sort_index()
df

Unnamed: 0,B-art,B-eve,B-geo,B-gpe,B-nat,B-org,B-per,B-tim,I-art,I-eve,I-geo,I-gpe,I-nat,I-org,I-per,I-tim,O
B-art,0,0,21,3,0,24,9,0,0,0,0,0,0,5,0,0,17
B-eve,0,17,4,4,0,11,4,4,0,0,0,0,0,2,0,2,14
B-geo,0,0,6904,26,1,290,122,6,0,0,29,0,0,43,62,9,193
B-gpe,0,0,140,2864,0,23,6,0,0,0,3,0,0,7,5,0,60
B-nat,0,0,6,0,5,7,3,0,0,0,0,0,0,0,0,1,18
B-org,0,0,604,19,0,2681,216,7,0,0,3,1,0,54,70,5,260
B-per,0,0,263,2,0,161,2552,0,0,0,6,0,0,104,141,4,138
B-tim,0,0,63,2,0,11,7,3399,0,0,2,0,0,5,5,59,578
I-art,0,0,0,0,0,3,1,1,0,0,5,0,0,33,8,0,5
I-eve,0,4,1,0,0,1,1,1,0,8,5,0,0,15,5,3,14


In [19]:
df = pd.DataFrame(precision_recall_fscore_support(y_test, y_test_pred, labels=states), index=["precision", "recall", "f1_score", "support"], columns=states).round(2)
df = df.reindex(sorted(df.columns), axis=1).sort_index()
df

  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,B-art,B-eve,B-geo,B-gpe,B-nat,B-org,B-per,B-tim,I-art,I-eve,I-geo,I-gpe,I-nat,I-org,I-per,I-tim,O
f1_score,0.0,0.41,0.86,0.94,0.22,0.72,0.79,0.87,0.0,0.24,0.77,0.42,0.0,0.75,0.85,0.73,0.99
precision,0.0,0.81,0.83,0.95,0.83,0.77,0.82,0.92,0.0,1.0,0.8,0.92,0.0,0.74,0.8,0.82,0.99
recall,0.0,0.27,0.9,0.92,0.12,0.68,0.76,0.82,0.0,0.14,0.73,0.28,0.0,0.76,0.9,0.66,0.99
support,79.0,62.0,7685.0,3108.0,40.0,3920.0,3371.0,4131.0,56.0,58.0,1500.0,40.0,12.0,3370.0,3367.0,1329.0,178824.0


In [14]:
import eli5

In [15]:
eli5.show_weights(crf, top=30)

From \ To,O,B-art,I-art,B-eve,I-eve,B-geo,I-geo,B-gpe,I-gpe,B-nat,I-nat,B-org,I-org,B-per,I-per,B-tim,I-tim
O,3.909,2.15,0.0,1.694,0.0,2.1,0.0,1.681,0.0,1.703,0.0,2.919,0.0,3.881,0.0,2.546,0.0
B-art,0.0,0.0,8.412,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
I-art,-0.383,0.0,8.051,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-eve,0.0,0.0,0.0,0.0,8.234,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
I-eve,-0.0,0.0,0.0,0.0,7.04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-geo,0.93,0.0,0.0,0.0,0.0,0.0,9.26,0.676,0.0,0.0,0.0,0.428,0.0,0.0,0.0,1.877,0.0
I-geo,-0.046,0.0,0.0,0.0,0.0,0.0,7.615,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.137,0.0
B-gpe,0.86,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.864,0.0,0.0,1.547,0.0,0.923,0.0,0.0,0.0
I-gpe,-0.067,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-nat,-0.255,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.292,0.0,0.0,0.0,0.0,0.0,0.0

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,Unnamed: 16_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5,Unnamed: 9_level_5,Unnamed: 10_level_5,Unnamed: 11_level_5,Unnamed: 12_level_5,Unnamed: 13_level_5,Unnamed: 14_level_5,Unnamed: 15_level_5,Unnamed: 16_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6,Unnamed: 9_level_6,Unnamed: 10_level_6,Unnamed: 11_level_6,Unnamed: 12_level_6,Unnamed: 13_level_6,Unnamed: 14_level_6,Unnamed: 15_level_6,Unnamed: 16_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7,Unnamed: 9_level_7,Unnamed: 10_level_7,Unnamed: 11_level_7,Unnamed: 12_level_7,Unnamed: 13_level_7,Unnamed: 14_level_7,Unnamed: 15_level_7,Unnamed: 16_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8,Unnamed: 9_level_8,Unnamed: 10_level_8,Unnamed: 11_level_8,Unnamed: 12_level_8,Unnamed: 13_level_8,Unnamed: 14_level_8,Unnamed: 15_level_8,Unnamed: 16_level_8
Weight?,Feature,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,Unnamed: 8_level_9,Unnamed: 9_level_9,Unnamed: 10_level_9,Unnamed: 11_level_9,Unnamed: 12_level_9,Unnamed: 13_level_9,Unnamed: 14_level_9,Unnamed: 15_level_9,Unnamed: 16_level_9
Weight?,Feature,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,Unnamed: 9_level_10,Unnamed: 10_level_10,Unnamed: 11_level_10,Unnamed: 12_level_10,Unnamed: 13_level_10,Unnamed: 14_level_10,Unnamed: 15_level_10,Unnamed: 16_level_10
Weight?,Feature,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,Unnamed: 10_level_11,Unnamed: 11_level_11,Unnamed: 12_level_11,Unnamed: 13_level_11,Unnamed: 14_level_11,Unnamed: 15_level_11,Unnamed: 16_level_11
Weight?,Feature,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,Unnamed: 11_level_12,Unnamed: 12_level_12,Unnamed: 13_level_12,Unnamed: 14_level_12,Unnamed: 15_level_12,Unnamed: 16_level_12
Weight?,Feature,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,Unnamed: 12_level_13,Unnamed: 13_level_13,Unnamed: 14_level_13,Unnamed: 15_level_13,Unnamed: 16_level_13
Weight?,Feature,Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,Unnamed: 6_level_14,Unnamed: 7_level_14,Unnamed: 8_level_14,Unnamed: 9_level_14,Unnamed: 10_level_14,Unnamed: 11_level_14,Unnamed: 12_level_14,Unnamed: 13_level_14,Unnamed: 14_level_14,Unnamed: 15_level_14,Unnamed: 16_level_14
Weight?,Feature,Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,Unnamed: 8_level_15,Unnamed: 9_level_15,Unnamed: 10_level_15,Unnamed: 11_level_15,Unnamed: 12_level_15,Unnamed: 13_level_15,Unnamed: 14_level_15,Unnamed: 15_level_15,Unnamed: 16_level_15
Weight?,Feature,Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,Unnamed: 9_level_16,Unnamed: 10_level_16,Unnamed: 11_level_16,Unnamed: 12_level_16,Unnamed: 13_level_16,Unnamed: 14_level_16,Unnamed: 15_level_16,Unnamed: 16_level_16
+4.511,word.lower():last,,,,,,,,,,,,,,,
+4.185,word.lower():jewish,,,,,,,,,,,,,,,
+3.885,word.lower():trade,,,,,,,,,,,,,,,
+3.865,bias,,,,,,,,,,,,,,,
+3.695,word.lower():hurricane,,,,,,,,,,,,,,,
+3.620,EOS,,,,,,,,,,,,,,,
+3.607,word.lower():month,,,,,,,,,,,,,,,
+3.511,BOS,,,,,,,,,,,,,,,
+3.402,word.lower():christian,,,,,,,,,,,,,,,
+3.338,word.lower():year,,,,,,,,,,,,,,,

Weight?,Feature
+4.511,word.lower():last
+4.185,word.lower():jewish
+3.885,word.lower():trade
+3.865,bias
+3.695,word.lower():hurricane
+3.620,EOS
+3.607,word.lower():month
+3.511,BOS
+3.402,word.lower():christian
+3.338,word.lower():year

Weight?,Feature
1.493,word.lower():english
0.817,word[-3:]:ish
0.699,postag:NNP
0.591,-1:postag[:2]:``
0.591,-1:postag:``
0.547,-1:postag:NN
0.449,postag[:2]:NN
0.429,"-1:word.lower():"""
0.357,-1:postag[:2]:DT
0.357,-1:postag:DT

Weight?,Feature
0.327,-1:word.istitle()
0.065,+1:word.istitle()
0.043,+1:word.lower():.
0.039,+1:postag:.
0.039,+1:postag[:2]:.
0.017,word.istitle()
0.015,+1:postag[:2]:NN
0.001,postag:NNP
-0.144,bias
-0.361,+1:postag[:2]:VB

Weight?,Feature
3.49,-1:word.lower():war
1.439,word.lower():ii
1.439,word[-3:]:II
1.432,word[-2:]:II
0.824,+1:word.lower():open
0.764,word.isupper()
0.519,+1:word.lower():war
0.503,word.lower():world
0.503,word[-3:]:rld
0.392,word.istitle()

Weight?,Feature
1.217,word.lower():games
1.03,word.lower():open
1.028,word[-3:]:pen
0.953,-1:word.istitle()
0.639,-1:word.lower():world
0.638,-1:word.lower():war
0.551,word[-3:]:War
0.548,word.isupper()
0.543,word.lower():war
0.415,word[-3:]:mes

Weight?,Feature
+3.264,-1:word.lower():mr.
+3.049,word.lower():beijing
+2.565,word.lower():israel
+2.444,word.lower():iran
+2.356,word.lower():britain
+2.150,word.lower():ukraine
+2.095,word.lower():washington
+2.024,word.lower():caribbean
+1.980,word.lower():republic
+1.842,word.lower():u.s.

Weight?,Feature
+3.004,-1:word.lower():san
+2.390,word.lower():airport
+2.211,word.lower():republic
+1.842,-1:word.lower():gulf
+1.761,word[-3:]:ast
+1.669,word.lower():city
+1.518,-1:word.lower():middle
+1.514,-1:word.lower():of
+1.500,-1:word.lower():new
+1.436,word.lower():island

Weight?,Feature
+4.621,word.lower():iraqi
+3.786,word.istitle()
+3.648,word.lower():niger
+3.197,word.lower():nepal
+3.141,word[-3:]:pal
+3.092,word.lower():afghan
+2.713,word.lower():jordan
+2.612,postag:NNS
+2.486,word[-3:]:ger
+2.318,word.lower():poland

Weight?,Feature
2.08,-1:postag:NNP
2.011,-1:word.lower():bosnian
1.85,word.istitle()
1.299,word.lower():cypriots
0.635,postag[:2]:JJ
0.557,postag:JJ
0.357,word.lower():cypriot
0.352,word[-3:]:iot
0.309,word[-2:]:ot
0.254,postag:NNS

Weight?,Feature
+5.364,word.lower():katrina
+1.794,word.lower():marburg
+1.683,word[-2:]:N1
+1.586,word.isupper()
+1.579,word.lower():h5n1
+1.579,word[-3:]:5N1
+1.412,word.lower():rita
+1.371,word[-3:]:ita
+1.286,word[-3:]:urg
+1.249,word[-2:]:rg

Weight?,Feature
0.84,-1:postag[:2]:NN
0.558,-1:word.lower():hurricane
0.557,word.lower():katrina
0.534,word[-2:]:na
0.347,word[-3:]:ina
0.336,-1:word.istitle()
0.003,-1:postag:NNP
-0.043,bias

Weight?,Feature
+4.488,word.lower():al-qaida
+4.212,word.lower():philippine
+4.065,word.lower():hamas
+3.156,-1:word.lower():niger
+3.101,word.lower():congress
+2.982,-1:word.lower():mr.
+2.941,word.lower():xinhua
+2.777,word[-3:]:The
+2.685,word[-3:]:ban
+2.624,-1:word.lower():senator

Weight?,Feature
+2.156,word.lower():ministry
+1.801,-1:word.lower():european
+1.664,word.lower():court
+1.616,+1:word.lower():post
+1.598,-1:word.lower():for
+1.582,word[-3:]:for
+1.527,word.lower():department
+1.487,-1:word.lower():u.s.
+1.471,word.lower():bank
+1.414,-1:word.lower():group

Weight?,Feature
+4.979,word.lower():prime
+4.019,word.lower():president
+3.311,word.lower():western
+2.798,BOS
+2.666,word.lower():senator
+2.437,word.lower():obama
+2.397,word[-2:]:r.
+2.381,word.lower():vice
+2.273,word[-2:]:s.
+2.227,+1:word.lower():administration

Weight?,Feature
+1.441,word[-2:]:ez
+1.404,-1:postag:NNP
+1.331,-1:postag:NN
+1.300,word.lower():rice
+1.187,+1:word.lower():of
+1.088,-1:postag[:2]:NN
+1.061,word.lower():annan
+1.020,-1:word.lower():minister
+0.998,+1:word.lower():reports
+0.940,-1:word.lower():condoleezza

Weight?,Feature
+6.247,word[-3:]:day
+4.420,-1:word.lower():week
+3.479,word[-2:]:0s
+3.310,word.lower():february
+3.291,+1:word.lower():week
+3.125,-1:word.lower():months
+2.819,+1:word.lower():year
+2.790,word[-3:]:Day
+2.779,word.lower():january
+2.760,-1:word.lower():month

Weight?,Feature
+4.638,word[-3:]:day
+2.327,word[-2:]:ay
+2.180,word.lower():decades
+2.169,word[-2:]:m.
+2.169,word[-3:]:.m.
+1.697,word[-3:]:ber
+1.596,+1:word.lower():months
+1.500,word.isdigit()
+1.490,+1:word.lower():years
+1.390,word[-2:]:ry
