In [38]:
import os
import numpy as np
from collections import Counter
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import LinearSVC

from generator import Generator
from corpus import ConllCorpusReaderX

import warnings
warnings.filterwarnings('ignore')

TRAINSET_PATH = "./factrueval_trainset.npz"
TESTSET_PATH = "./factrueval_testset.npz"

In [39]:
factrueval_devset = ConllCorpusReaderX('./factrueval2016_dataset/',
                                       fileids='devset.txt', 
                                       columntypes=['words', 'offset', 'len', 'ne'])

factrueval_testset = ConllCorpusReaderX('./factrueval2016_dataset/', 
                                        fileids='testset.txt', 
                                        columntypes=['words', 'offset', 'len', 'ne'])

In [40]:
gen = Generator(column_types=['WORD'], context_len=2)

Y_train = [el[1] for el in factrueval_devset.get_ne()]
Y_test = [el[1] for el in factrueval_testset.get_ne()] 

X_train = gen.fit_transform([[el] for el in factrueval_devset.words()], 
                            Y_train, 
                            path=TRAINSET_PATH)
X_test = gen.transform([[el] for el in factrueval_testset.words()], 
                       path=TESTSET_PATH)

In [41]:
# Избавляет данные от случаев O : O #
def clean(Y_pred, Y_test):
    Y_pred = np.array(Y_pred)
    Y_test = np.array(Y_test)

    Y_pred_i = np.array([Y_pred != 'O'])
    Y_test_i = np.array([Y_test != 'O'])

    indexes = (Y_pred_i | Y_test_i).reshape(Y_pred.shape)

    Y_pred_fixed = Y_pred[indexes]
    Y_test_fixed = Y_test[indexes]
    return Y_pred_fixed, Y_test_fixed

In [42]:
def run_baseline(clf=LogisticRegression()):
    clf.fit(X_train, Y_train)
    Y_pred = clf.predict(X_test)
    Y_pred_c, Y_test_c = clean(Y_pred, Y_test)

    def get_el(el):
        if el == "O":
            return el
        else:
            return el[2:]

    Y_pred_c_light = [get_el(el) for el in Y_pred_c]
    Y_test_c_light = [get_el(el) for el in Y_test_c]

    # Strict evaluation #

    print("")
    print("# Strict evaluation #")
    counter = Counter(Y_test_c)
    labels = list(counter.keys())
    labels.remove("O")
    results = f1_score(Y_test_c, Y_pred_c, average=None, labels=labels)
    for a, b in zip(labels, results):
        print('F1 for {} == {}, with {} entities'.format(a, b, counter[a]))

    print("Weighted Score:", f1_score(Y_test_c, Y_pred_c, average="weighted", labels=list(counter.keys())))    

    # Not strict evaluation #    

    print("")
    print("# Not strict evaluation #")    
    light_counter = Counter(Y_test_c_light)
    light_labels = list(light_counter.keys())
    light_labels.remove("O")
    print(light_counter)
    light_results = f1_score(Y_test_c_light, Y_pred_c_light, average=None, labels=light_labels)
    for a, b in zip(light_labels, light_results):
        print('F1 for {} == {}, with {} entities'.format(a, b, light_counter[a]))

    print("Weighted Score:", f1_score(Y_test_c_light, Y_pred_c_light, average="weighted", labels=light_labels))

In [43]:
run_baseline()


# Strict evaluation #
F1 for B-Person == 0.8540332906530089, with 694 entities
F1 for E-Person == 0.8702983138780804, with 692 entities
F1 for S-Person == 0.5008665511265165, with 697 entities
F1 for S-Location == 0.632888888888889, with 554 entities
F1 for B-Location == 0.24844720496894407, with 114 entities
F1 for I-Location == 0.024390243902439025, with 74 entities
F1 for S-Org == 0.4945490584737363, with 1300 entities
F1 for B-Org == 0.30454545454545456, with 646 entities
F1 for I-Org == 0.29062768701633707, with 903 entities
F1 for E-Org == 0.34394904458598724, with 600 entities
F1 for S-LocOrg == 0.5724907063197027, with 666 entities
F1 for I-Person == 0.1, with 27 entities
F1 for E-Location == 0.17977528089887637, with 70 entities
F1 for B-LocOrg == 0.3733333333333333, with 49 entities
F1 for E-LocOrg == 0.0888888888888889, with 40 entities
F1 for I-LocOrg == 0.0, with 13 entities
F1 for B-Project == 0.0, with 16 entities
F1 for I-Project == 0.0, with 12 entities
F1 for S-Proje

In [44]:
run_baseline(RandomForestClassifier())


# Strict evaluation #
F1 for B-Person == 0.8285895003162556, with 694 entities
F1 for E-Person == 0.8451069345430979, with 692 entities
F1 for S-Person == 0.3195767195767195, with 697 entities
F1 for S-Location == 0.554367201426025, with 554 entities
F1 for B-Location == 0.16867469879518074, with 114 entities
F1 for I-Location == 0.14432989690721648, with 74 entities
F1 for S-Org == 0.3782894736842105, with 1300 entities
F1 for B-Org == 0.229988726042841, with 646 entities
F1 for I-Org == 0.19826086956521738, with 903 entities
F1 for E-Org == 0.3171007927519819, with 600 entities
F1 for S-LocOrg == 0.48030303030303034, with 666 entities
F1 for I-Person == 0.0, with 27 entities
F1 for E-Location == 0.02222222222222222, with 70 entities
F1 for B-LocOrg == 0.282051282051282, with 49 entities
F1 for E-LocOrg == 0.16000000000000003, with 40 entities
F1 for I-LocOrg == 0.0625, with 13 entities
F1 for B-Project == 0.0, with 16 entities
F1 for I-Project == 0.0, with 12 entities
F1 for S-Proje

In [45]:
run_baseline(LinearSVC())


# Strict evaluation #
F1 for B-Person == 0.86034255599473, with 694 entities
F1 for E-Person == 0.8844884488448846, with 692 entities
F1 for S-Person == 0.4868189806678383, with 697 entities
F1 for S-Location == 0.6248927038626609, with 554 entities
F1 for B-Location == 0.32044198895027626, with 114 entities
F1 for I-Location == 0.043010752688172046, with 74 entities
F1 for S-Org == 0.5182341650671786, with 1300 entities
F1 for B-Org == 0.3350895679662803, with 646 entities
F1 for I-Org == 0.33202819107282694, with 903 entities
F1 for E-Org == 0.3883720930232558, with 600 entities
F1 for S-LocOrg == 0.5779122541603631, with 666 entities
F1 for I-Person == 0.1694915254237288, with 27 entities
F1 for E-Location == 0.19130434782608693, with 70 entities
F1 for B-LocOrg == 0.4086021505376344, with 49 entities
F1 for E-LocOrg == 0.0816326530612245, with 40 entities
F1 for I-LocOrg == 0.1, with 13 entities
F1 for B-Project == 0.0, with 16 entities
F1 for I-Project == 0.0, with 12 entities
F1

In [46]:
run_baseline(GradientBoostingClassifier())


# Strict evaluation #
F1 for B-Person == 0.821297429620563, with 694 entities
F1 for E-Person == 0.8152240638428484, with 692 entities
F1 for S-Person == 0.4165029469548134, with 697 entities
F1 for S-Location == 0.6153846153846154, with 554 entities
F1 for B-Location == 0.15, with 114 entities
F1 for I-Location == 0.013513513513513514, with 74 entities
F1 for S-Org == 0.5122910521140609, with 1300 entities
F1 for B-Org == 0.23676012461059187, with 646 entities
F1 for I-Org == 0.22335025380710657, with 903 entities
F1 for E-Org == 0.1918918918918919, with 600 entities
F1 for S-LocOrg == 0.5370496261046908, with 666 entities
F1 for I-Person == 0.03921568627450981, with 27 entities
F1 for E-Location == 0.10937500000000001, with 70 entities
F1 for B-LocOrg == 0.26666666666666666, with 49 entities
F1 for E-LocOrg == 0.07017543859649122, with 40 entities
F1 for I-LocOrg == 0.05405405405405406, with 13 entities
F1 for B-Project == 0.0, with 16 entities
F1 for I-Project == 0.0, with 12 entit