In [10]:
import os
import numpy as np
from collections import Counter
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import LinearSVC

from generator import Generator
from corpus import ConllCorpusReaderX

import warnings
warnings.filterwarnings('ignore')

EC_PATH = "./bsnlp_ec.npz"
TRUMP_PATH = "./bsnlp_trump.npz"
TRAINSET_PATH = "./factrueval_trainset.npz"

In [11]:
trump_dataset = ConllCorpusReaderX('./bsnlp_dataset/', 
                              fileids='trump.txt', 
                              columntypes=('words', 'ne'))

eu_dataset = ConllCorpusReaderX('./bsnlp_dataset/', 
                              fileids='ec.txt', 
                              columntypes=('words', 'ne'))

factrueval_devset = ConllCorpusReaderX('../FactRuEval/factrueval2016_dataset/', 
                                        fileids='devset.txt', 
                                        columntypes=['words', 'offset', 'len', 'ne'])

In [12]:
gen = Generator(column_types=['WORD'], context_len=2)

Y_train = [el[1] for el in factrueval_devset.get_ne()]

Y_test_eu = [el[1] for el in eu_dataset.get_ne()]
Y_test_trump = [el[1] for el in trump_dataset.get_ne()] 

X_train = gen.fit_transform([[el] for el in factrueval_devset.words()], 
                            Y_train, 
                            path=TRAINSET_PATH)
X_test_eu = gen.transform([[el] for el in eu_dataset.words()], 
                       path=EC_PATH)
X_test_trump = gen.transform([[el] for el in trump_dataset.words()], 
                       path=TRUMP_PATH)

In [13]:
# Избавляет данные от случаев O : O #
def clean(Y_pred, Y_test):
    Y_pred = np.array(Y_pred)
    Y_test = np.array(Y_test)

    Y_pred_i = np.array([Y_pred != 'O'])
    Y_test_i = np.array([Y_test != 'O'])

    indexes = (Y_pred_i | Y_test_i).reshape(Y_pred.shape)

    Y_pred_fixed = Y_pred[indexes]
    Y_test_fixed = Y_test[indexes]
    return Y_pred_fixed, Y_test_fixed

def replace_y(word):
    if word[2:] == "Person":
        return word[:2] + "PER"
    elif word[2:] == "Location":
        return word[:2] + "LOC"
    elif word[2:] == "LocOrg":
        return word[:2] + "LOC"
    elif word[2:] == "Org":
        return word[:2] + "ORG"
    elif word != "O":
        return word[:2] + "MISC"
    else:
        return word

In [19]:
def run_baseline(clf, X_test, Y_test):
    clf.fit(X_train, Y_train)
    Y_pred = clf.predict(X_test)
    Y_pred = [replace_y(el) for el in Y_pred]
    Y_pred_c, Y_test_c = clean(Y_pred, Y_test)

    def get_el(el):
        if el == "O":
            return el
        else:
            return el[2:]

    Y_pred_c_light = [get_el(el) for el in Y_pred_c]
    Y_test_c_light = [get_el(el) for el in Y_test_c]

    # Strict evaluation #

    print("")
    print("# Strict evaluation #")
    counter = Counter(Y_test_c)
    labels = list(counter.keys())
    labels.remove("O")
    results = f1_score(Y_test_c, Y_pred_c, average=None, labels=labels)
    for a, b in zip(labels, results):
        print('F1 for {} == {}, with {} entities'.format(a, b, counter[a]))

    print("Weighted Score:", f1_score(Y_test_c, Y_pred_c, average="weighted", labels=list(counter.keys())))    

    # Not strict evaluation #    

    print("")
    print("# Not strict evaluation #")    
    light_counter = Counter(Y_test_c_light)
    light_labels = list(light_counter.keys())
    light_labels.remove("O")
    print(light_counter)
    light_results = f1_score(Y_test_c_light, Y_pred_c_light, average=None, labels=light_labels)
    for a, b in zip(light_labels, light_results):
        print('F1 for {} == {}, with {} entities'.format(a, b, light_counter[a]))

    print("Weighted Score:", f1_score(Y_test_c_light, Y_pred_c_light, average="weighted", labels=light_labels))

In [20]:
run_baseline(LogisticRegression(), X_test_eu, Y_test_eu)


# Strict evaluation #
F1 for S-ORG == 0.5010989010989011, with 296 entities
F1 for B-ORG == 0.6921348314606741, with 215 entities
F1 for E-ORG == 0.056074766355140186, with 196 entities
F1 for B-MISC == 0.0, with 42 entities
F1 for E-MISC == 0.0, with 32 entities
F1 for S-LOC == 0.48066298342541436, with 148 entities
F1 for B-PER == 0.4878048780487805, with 25 entities
F1 for E-PER == 0.4, with 23 entities
F1 for S-PER == 0.033898305084745756, with 20 entities
F1 for I-ORG == 0.12790697674418605, with 93 entities
F1 for I-PER == 0.0, with 2 entities
F1 for S-MISC == 0.0, with 62 entities
F1 for I-MISC == 0.0, with 22 entities
F1 for B-LOC == 0.6666666666666667, with 13 entities
F1 for E-LOC == 0.3157894736842105, with 13 entities
F1 for I-LOC == 0.0, with 1 entities
Weighted Score: 0.297280813647

# Not strict evaluation #
Counter({'ORG': 800, 'O': 230, 'LOC': 175, 'MISC': 158, 'PER': 70})
F1 for ORG == 0.6529492455418381, with 800 entities
F1 for MISC == 0.0, with 158 entities
F1 for

In [21]:
run_baseline(RandomForestClassifier(), X_test_eu, Y_test_eu)


# Strict evaluation #
F1 for S-ORG == 0.37788018433179726, with 296 entities
F1 for B-ORG == 0.2549019607843137, with 215 entities
F1 for E-ORG == 0.08620689655172414, with 196 entities
F1 for B-MISC == 0.0, with 42 entities
F1 for E-MISC == 0.0, with 32 entities
F1 for S-LOC == 0.4143222506393861, with 148 entities
F1 for B-PER == 0.28571428571428575, with 25 entities
F1 for E-PER == 0.3114754098360656, with 23 entities
F1 for S-PER == 0.0, with 20 entities
F1 for I-ORG == 0.1717171717171717, with 93 entities
F1 for I-PER == 0.0, with 2 entities
F1 for S-MISC == 0.0, with 62 entities
F1 for I-MISC == 0.0, with 22 entities
F1 for B-LOC == 0.11764705882352941, with 13 entities
F1 for E-LOC == 0.07999999999999999, with 13 entities
F1 for I-LOC == 0.0, with 1 entities
Weighted Score: 0.195164161105

# Not strict evaluation #
Counter({'ORG': 800, 'O': 220, 'LOC': 175, 'MISC': 158, 'PER': 70})
F1 for ORG == 0.3606837606837607, with 800 entities
F1 for MISC == 0.0, with 158 entities
F1 for 

In [22]:
run_baseline(LinearSVC(), X_test_eu, Y_test_eu)


# Strict evaluation #
F1 for S-ORG == 0.4646017699115045, with 296 entities
F1 for B-ORG == 0.6369168356997972, with 215 entities
F1 for E-ORG == 0.10699588477366256, with 196 entities
F1 for B-MISC == 0.0, with 42 entities
F1 for E-MISC == 0.0, with 32 entities
F1 for S-LOC == 0.5391849529780564, with 148 entities
F1 for B-PER == 0.4938271604938272, with 25 entities
F1 for E-PER == 0.5555555555555556, with 23 entities
F1 for S-PER == 0.05405405405405406, with 20 entities
F1 for I-ORG == 0.11057692307692309, with 93 entities
F1 for I-PER == 0.0, with 2 entities
F1 for S-MISC == 0.0, with 62 entities
F1 for I-MISC == 0.0, with 22 entities
F1 for B-LOC == 0.6111111111111112, with 13 entities
F1 for E-LOC == 0.16666666666666669, with 13 entities
F1 for I-LOC == 0.0, with 1 entities
Weighted Score: 0.279542155636

# Not strict evaluation #
Counter({'ORG': 800, 'O': 306, 'LOC': 175, 'MISC': 158, 'PER': 70})
F1 for ORG == 0.6683291770573566, with 800 entities
F1 for MISC == 0.0, with 158 en

In [23]:
run_baseline(GradientBoostingClassifier(), X_test_eu, Y_test_eu)


# Strict evaluation #
F1 for S-ORG == 0.4622641509433962, with 296 entities
F1 for B-ORG == 0.516728624535316, with 215 entities
F1 for E-ORG == 0.09401709401709402, with 196 entities
F1 for B-MISC == 0.0, with 42 entities
F1 for E-MISC == 0.0, with 32 entities
F1 for S-LOC == 0.5157232704402517, with 148 entities
F1 for B-PER == 0.4878048780487805, with 25 entities
F1 for E-PER == 0.3789473684210527, with 23 entities
F1 for S-PER == 0.0, with 20 entities
F1 for I-ORG == 0.0877742946708464, with 93 entities
F1 for I-PER == 0.0, with 2 entities
F1 for S-MISC == 0.0, with 62 entities
F1 for I-MISC == 0.0, with 22 entities
F1 for B-LOC == 0.08, with 13 entities
F1 for E-LOC == 0.14814814814814817, with 13 entities
F1 for I-LOC == 0.0, with 1 entities
Weighted Score: 0.253189921731

# Not strict evaluation #
Counter({'ORG': 800, 'O': 277, 'LOC': 175, 'MISC': 158, 'PER': 70})
F1 for ORG == 0.6653465346534653, with 800 entities
F1 for MISC == 0.0, with 158 entities
F1 for LOC == 0.490765171

In [24]:
run_baseline(LogisticRegression(), X_test_trump, Y_test_trump)


# Strict evaluation #
F1 for S-PER == 0.5074626865671642, with 144 entities
F1 for B-PER == 0.8940397350993378, with 140 entities
F1 for E-PER == 0.8214285714285714, with 119 entities
F1 for S-LOC == 0.8405797101449275, with 126 entities
F1 for B-LOC == 0.7317073170731707, with 22 entities
F1 for E-LOC == 0.2285714285714286, with 22 entities
F1 for B-ORG == 0.40740740740740744, with 31 entities
F1 for E-ORG == 0.16666666666666666, with 31 entities
F1 for S-MISC == 0.0, with 3 entities
F1 for I-ORG == 0.23809523809523808, with 23 entities
F1 for S-ORG == 0.37500000000000006, with 17 entities
F1 for B-MISC == 0.0, with 13 entities
F1 for I-MISC == 0.0, with 10 entities
F1 for E-MISC == 0.0, with 13 entities
F1 for I-PER == 0.0, with 3 entities
F1 for I-LOC == 0.125, with 8 entities
Weighted Score: 0.589190717447

# Not strict evaluation #
Counter({'PER': 406, 'LOC': 178, 'ORG': 102, 'O': 45, 'MISC': 39})
F1 for PER == 0.8835443037974684, with 406 entities
F1 for LOC == 0.820652173913043

In [25]:
run_baseline(RandomForestClassifier(), X_test_trump, Y_test_trump)


# Strict evaluation #
F1 for S-PER == 0.2558139534883721, with 144 entities
F1 for B-PER == 0.847682119205298, with 140 entities
F1 for E-PER == 0.8115942028985506, with 119 entities
F1 for S-LOC == 0.7236842105263157, with 126 entities
F1 for B-LOC == 0.47368421052631576, with 22 entities
F1 for E-LOC == 0.07692307692307693, with 22 entities
F1 for B-ORG == 0.1702127659574468, with 31 entities
F1 for E-ORG == 0.2127659574468085, with 31 entities
F1 for S-MISC == 0.0, with 3 entities
F1 for I-ORG == 0.09999999999999999, with 23 entities
F1 for S-ORG == 0.07999999999999999, with 17 entities
F1 for B-MISC == 0.0, with 13 entities
F1 for I-MISC == 0.0, with 10 entities
F1 for E-MISC == 0.0, with 13 entities
F1 for I-PER == 0.0, with 3 entities
F1 for I-LOC == 0.0, with 8 entities
Weighted Score: 0.485500428549

# Not strict evaluation #
Counter({'PER': 406, 'LOC': 178, 'ORG': 102, 'O': 39, 'MISC': 39})
F1 for PER == 0.823841059602649, with 406 entities
F1 for LOC == 0.7007672634271099, w

In [26]:
run_baseline(LinearSVC(), X_test_trump, Y_test_trump)


# Strict evaluation #
F1 for S-PER == 0.44776119402985076, with 144 entities
F1 for B-PER == 0.8859060402684563, with 140 entities
F1 for E-PER == 0.8172043010752688, with 119 entities
F1 for S-LOC == 0.8351648351648353, with 126 entities
F1 for B-LOC == 0.7058823529411765, with 22 entities
F1 for E-LOC == 0.3157894736842105, with 22 entities
F1 for B-ORG == 0.33333333333333337, with 31 entities
F1 for E-ORG == 0.19047619047619047, with 31 entities
F1 for S-MISC == 0.0, with 3 entities
F1 for I-ORG == 0.1333333333333333, with 23 entities
F1 for S-ORG == 0.25, with 17 entities
F1 for B-MISC == 0.0, with 13 entities
F1 for I-MISC == 0.0, with 10 entities
F1 for E-MISC == 0.0, with 13 entities
F1 for I-PER == 0.0, with 3 entities
F1 for I-LOC == 0.0, with 8 entities
Weighted Score: 0.55388349902

# Not strict evaluation #
Counter({'PER': 406, 'LOC': 178, 'ORG': 102, 'O': 64, 'MISC': 39})
F1 for PER == 0.8600508905852418, with 406 entities
F1 for LOC == 0.8052631578947369, with 178 entiti

In [27]:
run_baseline(GradientBoostingClassifier(), X_test_trump, Y_test_trump)


# Strict evaluation #
F1 for S-PER == 0.42105263157894735, with 144 entities
F1 for B-PER == 0.8636363636363635, with 140 entities
F1 for E-PER == 0.786206896551724, with 119 entities
F1 for S-LOC == 0.7986111111111112, with 126 entities
F1 for B-LOC == 0.6956521739130435, with 22 entities
F1 for E-LOC == 0.23529411764705885, with 22 entities
F1 for B-ORG == 0.2857142857142857, with 31 entities
F1 for E-ORG == 0.06060606060606061, with 31 entities
F1 for S-MISC == 0.0, with 3 entities
F1 for I-ORG == 0.05405405405405405, with 23 entities
F1 for S-ORG == 0.21428571428571427, with 17 entities
F1 for B-MISC == 0.0, with 13 entities
F1 for I-MISC == 0.0, with 10 entities
F1 for E-MISC == 0.0, with 13 entities
F1 for I-PER == 0.0, with 3 entities
F1 for I-LOC == 0.0, with 8 entities
Weighted Score: 0.530022059308

# Not strict evaluation #
Counter({'PER': 406, 'LOC': 178, 'ORG': 102, 'O': 52, 'MISC': 39})
F1 for PER == 0.8406524466750314, with 406 entities
F1 for LOC == 0.7823834196891191,