In [22]:
import os
import numpy as np
from collections import Counter
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import LinearSVC

from generator import Generator
from corpus import ConllCorpusReaderX

import warnings
warnings.filterwarnings('ignore')

TRAINSET_PATH = "./conll_trainset.npz"
TESTSETA_PATH = "./conll_testseta.npz"
TESTSETB_PATH = "./conll_testsetb.npz"

In [23]:
conll_trainset = ConllCorpusReaderX('./conll2003_dataset', 
                              fileids='eng.train.txt', 
                              columntypes=('words', 'pos', 'chunk', 'ne'))

conll_testseta = ConllCorpusReaderX('./conll2003_dataset', 
                              fileids='eng.testa.dev.txt', 
                              columntypes=('words', 'pos', 'chunk', 'ne'))

conll_testsetb = ConllCorpusReaderX('./conll2003_dataset', 
                              fileids='eng.testb.test.txt', 
                              columntypes=('words', 'pos', 'chunk', 'ne'))

In [28]:
gen = Generator(column_types=['WORD', 'POS', 'CHUNK'], context_len=2, language='en')

Y_train = [el[1] for el in conll_trainset.get_ne()]
Y_testa = [el[1] for el in conll_testseta.get_ne()] 
Y_testb = [el[1] for el in conll_testsetb.get_ne()] 

X_train = gen.fit_transform(conll_trainset.get_tags(tags=['words', 'pos', 'chunk']), 
                            Y_train, 
                            path=TRAINSET_PATH)
X_testa = gen.transform(conll_testseta.get_tags(tags=['words', 'pos', 'chunk']), 
                        path=TESTSETA_PATH)
X_testb = gen.transform(conll_testsetb.get_tags(tags=['words', 'pos', 'chunk']), 
                        path=TESTSETB_PATH)

In [29]:
# Избавляет данные от случаев O : O #
def clean(Y_pred, Y_test):
    Y_pred = np.array(Y_pred)
    Y_test = np.array(Y_test)

    Y_pred_i = np.array([Y_pred != 'O'])
    Y_test_i = np.array([Y_test != 'O'])

    indexes = (Y_pred_i | Y_test_i).reshape(Y_pred.shape)

    Y_pred_fixed = Y_pred[indexes]
    Y_test_fixed = Y_test[indexes]
    return Y_pred_fixed, Y_test_fixed

In [31]:
def run_baseline(clf=LogisticRegression()):
    clf.fit(X_train, Y_train)
    Y_pred = clf.predict(X_testa)
    Y_pred_c, Y_test_c = clean(Y_pred, Y_testa)

    def get_el(el):
        if el == "O":
            return el
        else:
            return el[2:]

    Y_pred_c_light = [get_el(el) for el in Y_pred_c]
    Y_test_c_light = [get_el(el) for el in Y_test_c]

    # Strict evaluation #

    print("")
    print("# Strict evaluation #")
    counter = Counter(Y_test_c)
    labels = list(counter.keys())
    labels.remove("O")
    results = f1_score(Y_test_c, Y_pred_c, average=None, labels=labels)
    for a, b in zip(labels, results):
        print('F1 for {} == {}, with {} entities'.format(a, b, counter[a]))

    print("Weighted Score:", f1_score(Y_test_c, Y_pred_c, average="weighted", labels=list(counter.keys())))    

    # Not strict evaluation #    

    print("")
    print("# Not strict evaluation #")    
    light_counter = Counter(Y_test_c_light)
    light_labels = list(light_counter.keys())
    light_labels.remove("O")
    print(light_counter)
    light_results = f1_score(Y_test_c_light, Y_pred_c_light, average=None, labels=light_labels)
    for a, b in zip(light_labels, light_results):
        print('F1 for {} == {}, with {} entities'.format(a, b, light_counter[a]))

    print("Weighted Score:", f1_score(Y_test_c_light, Y_pred_c_light, average="weighted", labels=light_labels))

In [32]:
run_baseline()


# Strict evaluation #
F1 for S-ORG == 0.7553699284009546, with 891 entities
F1 for S-LOC == 0.8162111215834119, with 1603 entities
F1 for B-MISC == 0.6466512702078522, with 257 entities
F1 for E-MISC == 0.7616926503340757, with 257 entities
F1 for B-PER == 0.920610089949159, with 1234 entities
F1 for E-PER == 0.9185820023373589, with 1234 entities
F1 for B-LOC == 0.7916666666666666, with 234 entities
F1 for E-LOC == 0.7954545454545455, with 234 entities
F1 for S-PER == 0.6935749588138386, with 608 entities
F1 for S-MISC == 0.8160000000000001, with 665 entities
F1 for I-MISC == 0.421875, with 89 entities
F1 for B-ORG == 0.7111111111111111, with 450 entities
F1 for I-ORG == 0.6833631484794276, with 301 entities
F1 for E-ORG == 0.7298245614035088, with 450 entities
F1 for I-PER == 0.5391304347826087, with 73 entities
F1 for I-LOC == 0.4666666666666667, with 23 entities
Weighted Score: 0.78264408139

# Not strict evaluation #
Counter({'PER': 3149, 'LOC': 2094, 'ORG': 2092, 'MISC': 1268, '

In [33]:
run_baseline(RandomForestClassifier())


# Strict evaluation #
F1 for S-ORG == 0.7048883524441764, with 891 entities
F1 for S-LOC == 0.7398039835599115, with 1603 entities
F1 for B-MISC == 0.6173913043478261, with 257 entities
F1 for E-MISC == 0.6129753914988815, with 257 entities
F1 for B-PER == 0.879905808477237, with 1234 entities
F1 for E-PER == 0.8921259842519684, with 1234 entities
F1 for B-LOC == 0.6375, with 234 entities
F1 for E-LOC == 0.6393088552915767, with 234 entities
F1 for S-PER == 0.5898366606170599, with 608 entities
F1 for S-MISC == 0.7011007620660458, with 665 entities
F1 for I-MISC == 0.42857142857142855, with 89 entities
F1 for B-ORG == 0.6206896551724138, with 450 entities
F1 for I-ORG == 0.6392523364485981, with 301 entities
F1 for E-ORG == 0.6022727272727273, with 450 entities
F1 for I-PER == 0.5043478260869564, with 73 entities
F1 for I-LOC == 0.3684210526315789, with 23 entities
Weighted Score: 0.705466701264

# Not strict evaluation #
Counter({'PER': 3149, 'LOC': 2094, 'ORG': 2092, 'MISC': 1268, '

In [34]:
run_baseline(LinearSVC())


# Strict evaluation #
F1 for S-ORG == 0.7624113475177304, with 891 entities
F1 for S-LOC == 0.8194001276324185, with 1603 entities
F1 for B-MISC == 0.6741071428571429, with 257 entities
F1 for E-MISC == 0.791578947368421, with 257 entities
F1 for B-PER == 0.9275590551181103, with 1234 entities
F1 for E-PER == 0.9221698113207548, with 1234 entities
F1 for B-LOC == 0.8044943820224718, with 234 entities
F1 for E-LOC == 0.8009049773755655, with 234 entities
F1 for S-PER == 0.6816720257234726, with 608 entities
F1 for S-MISC == 0.81629392971246, with 665 entities
F1 for I-MISC == 0.481203007518797, with 89 entities
F1 for B-ORG == 0.735527809307605, with 450 entities
F1 for I-ORG == 0.7063063063063063, with 301 entities
F1 for E-ORG == 0.7417519908987485, with 450 entities
F1 for I-PER == 0.6507936507936508, with 73 entities
F1 for I-LOC == 0.5555555555555555, with 23 entities
Weighted Score: 0.791084163186

# Not strict evaluation #
Counter({'PER': 3149, 'LOC': 2094, 'ORG': 2092, 'MISC': 

In [35]:
run_baseline(GradientBoostingClassifier())


# Strict evaluation #
F1 for S-ORG == 0.6918325326012353, with 891 entities
F1 for S-LOC == 0.7091475502690456, with 1603 entities
F1 for B-MISC == 0.6338028169014085, with 257 entities
F1 for E-MISC == 0.7280898876404495, with 257 entities
F1 for B-PER == 0.870026525198939, with 1234 entities
F1 for E-PER == 0.8664122137404581, with 1234 entities
F1 for B-LOC == 0.7598039215686273, with 234 entities
F1 for E-LOC == 0.7707317073170732, with 234 entities
F1 for S-PER == 0.5297418630751964, with 608 entities
F1 for S-MISC == 0.763877381938691, with 665 entities
F1 for I-MISC == 0.3728813559322034, with 89 entities
F1 for B-ORG == 0.5977859778597786, with 450 entities
F1 for I-ORG == 0.615678776290631, with 301 entities
F1 for E-ORG == 0.612137203166227, with 450 entities
F1 for I-PER == 0.543859649122807, with 73 entities
F1 for I-LOC == 0.5142857142857143, with 23 entities
Weighted Score: 0.706862725912

# Not strict evaluation #
Counter({'PER': 3149, 'LOC': 2094, 'ORG': 2092, 'MISC': 