In [17]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set(font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import classification_report, make_scorer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
import scipy.stats
# import eli5

In [18]:
from itertools import chain

import nltk
import sklearn
import scipy.stats

import sklearn_crfsuite
from sklearn_crfsuite import scorers,CRF
from sklearn_crfsuite.metrics import flat_classification_report, flat_f1_score
from sklearn_crfsuite import metrics
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

In [19]:
data = pd.read_pickle('data_short.pickle')

In [20]:
text, tags = list(data['text']), list(data['tag'])

In [21]:
text_new, tags_new = [], [] # сплит по точкам с запятой и дроп О
for words, tag_list in zip(text, tags):
    last_separator_idx = 0
    for i in range(len(words)):
        if words[i] == ';' and any(['i' in tag for tag in tag_list[last_separator_idx: i]]):
            text_new.append(words[last_separator_idx: i])
            tags_new.append(tag_list[last_separator_idx: i])
            last_separator_idx = i
        if i + 1 == len(words):
            text_new.append(words[last_separator_idx: i])
            tags_new.append(tag_list[last_separator_idx: i]) 
len(text_new)

12334

In [22]:
# Feature set
def word2features(sent,i):
    word = text_new[sent][i]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],

        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),

    }
    if i > 0:
        word1 = text_new[sent][i-1]

        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
        })
    else:
        features['BOS'] = True

    if i < len(text_new[sent])-1:
        word1 = text_new[sent][i+1]

        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
        })
    else:
        features['EOS'] = True

    return features

In [23]:
def sent2features(sent):
    return [word2features(sent, i) for i in range(len(text_new[sent]))]

def sent2labels(sent):
    return [label for label in tags_new[sent]]

In [24]:
#Creating the train and test set
X = [sent2features(s) for s in range(len(text_new))]
y = [sent2labels(s) for s in range(len(tags_new))]

In [25]:
len(X)

12334

In [26]:
crf = CRF(algorithm='lbfgs',
          c1=0.1,
          c2=0.1,
          max_iterations=100,
          all_possible_transitions=False)

In [27]:
X_filtered, y_filtered = [], [] 
for x_, y_ in zip(X, y):
    if len(x_) == len(y_):
        X_filtered.append(x_)
        y_filtered.append(y_)
len(X_filtered), len(y_filtered)

(12334, 12334)

In [28]:
%%time
try:
    crf.fit(X_filtered[:5000], y_filtered[:5000])
except AttributeError:
    pass

CPU times: total: 42.1 s
Wall time: 43 s


In [29]:
def get_score(y_true, y_pred, y_all):
    joined_true = []
    joined_pred = []
    for sent_true, sent_pred in zip(y_true, y_pred):
        joined_true.extend(sent_true)
        joined_pred.extend(sent_pred)
    
    true_transformed, pred_transformed = label_encode(joined_true, joined_pred, y_all)
    print(classification_report(true_transformed, pred_transformed, target_names=dict_tegs))

In [30]:
def label_encode(y_true, y_pred, y_all):
    y_all_flat = []
    for sent in y_all:
        y_all_flat.extend(sent)
    encoder = LabelEncoder().fit(y_all_flat)
    return encoder.transform(y_true), encoder.transform(y_pred)

In [31]:
import pickle 

In [32]:
with open('dict_tegs_short.pkl', 'rb') as dict_tegs:
    dict_tegs = pickle.load(dict_tegs)

In [33]:
dict_tegs

{'o': 1,
 'i-e-treb': 2,
 'i-e-v': 3,
 'i-e-org': 4,
 'i-e-prod': 5,
 'i-e-reg': 6,
 'i-e-obj': 7,
 'i-e-okved': 8,
 'i-e-block': 9}

In [34]:
get_score(y_filtered[5000:], crf.predict(X_filtered[5000:]), y_filtered)

              precision    recall  f1-score   support

           o       0.76      0.83      0.80    254582
    i-e-treb       0.00      0.00      0.00       794
       i-e-v       0.01      0.00      0.00       977
     i-e-org       0.43      0.07      0.12      1247
    i-e-prod       0.41      0.22      0.28     21679
     i-e-reg       0.39      0.11      0.17      4882
     i-e-obj       0.44      0.12      0.19      3513
   i-e-okved       0.59      0.68      0.63     83143
   i-e-block       0.46      0.25      0.32     39320

    accuracy                           0.69    410137
   macro avg       0.39      0.25      0.28    410137
weighted avg       0.67      0.69      0.67    410137



In [33]:
get_score(y_filtered[5000:], crf.predict(X_filtered[5000:]), y_filtered)

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       492
           1       0.00      0.00      0.00       270
           2       0.05      0.27      0.08        11
           3       0.00      0.00      0.00         5
           5       0.25      0.15      0.19        13
           6       0.00      0.00      0.00       538
           8       0.30      0.08      0.12       283
           9       0.64      0.19      0.30       834
          10       0.52      0.11      0.18       589
          11       0.33      0.01      0.02       472
          12       0.00      0.00      0.00        49
          13       0.00      0.00      0.00      1394
          14       0.58      0.21      0.31      6340
          15       0.93      0.23      0.36       363
          16       0.32      0.18      0.23       581
          17       0.02      0.01      0.01       677
          18       0.00      0.00      0.00       515
          19       0.35    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
