In [1]:
from itertools import chain

import numpy as np
import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer,classification_report
from sklearn.model_selection import cross_val_score, RandomizedSearchCV, train_test_split, cross_val_predict

import sklearn_crfsuite
from sklearn_crfsuite import scorers, CRF, metrics

In [2]:
import pandas as pd

data = pd.read_csv('final_dataset.csv', encoding='latin1', low_memory=False, dtype={'Sentence #': str, 'Word': str, 'POS': str, 'Tag': str})
filled_data = data.fillna(method='ffill')

class SentenceGetter(object):
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                           s["POS"].values.tolist(),
                                                            s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]
  
    def get_next(self):     
        try:
             s = self.grouped["Sentence: {}".format(self.n_sent)]
             self.n_sent += 1
             return s
        except:
             return None       

getter = SentenceGetter(filled_data)
sentences = getter.sentences

In [3]:
%%time
# train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
# test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

# X = [sent2features(s) for s in train_sents]
# y = [sent2labels(s) for s in train_sents]

# X_test = [sent2features(s) for s in test_sents]
# y_test = [sent2labels(s) for s in test_sents]

X = [sent2features(s) for s in sentences]
y = [sent2labels(s) for s in sentences]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# print(X_test[7])

Wall time: 385 ms


In [4]:
%%time
crf = CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=False
)
crf.fit(X_train, y_train)

# pred = cross_val_predict(estimator=crf, X=X, y=y, cv=5, n_jobs=-1, verbose=1)
# report = classification_report(y_pred=pred, y_true=y)
# print(report)

labels = list(crf.classes_)
labels.remove('O')

Wall time: 13.5 s


In [5]:
print(labels)

['B-MAL', 'B-OS', 'B-DVEC', 'I-DVEC', 'B-BEH', 'B-CAP', 'I-CAP', 'I-BEH', 'I-MAL', 'I-OS']


In [6]:
%%time
# define fixed parameters and parameters to search
crf = CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=False
)
params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.5),
}

# use the same metric for evaluation
f1_scorer = make_scorer(metrics.flat_f1_score,
                        average='weighted', labels=labels)

# search
rs = RandomizedSearchCV(crf, params_space,
                        cv=3,
                        verbose=4,
                        n_jobs=-1,
                        n_iter=20,
                        scoring=f1_scorer)
rs.fit(X_train, y_train)

Fitting 3 folds for each of 20 candidates, totalling 60 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:  2.4min
[Parallel(n_jobs=-1)]: Done  60 out of  60 | elapsed:  8.9min finished


Wall time: 9min 8s


In [7]:
import pickle
filename = '1_crf_opti_model.sav'
pickle.dump(rs, open(filename, 'wb'))

In [8]:
print('best params:', rs.best_params_)
print('best cv score:', rs.best_score_)
print('model size {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))

best params: {'c1': 0.041906528896993624, 'c2': 0.01433007013268679}
best cv score: 0.8410489533317788
model size 0.25M


In [9]:
crf = rs.best_estimator_
y_pred = crf.predict(X_test)

sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)

print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

       B-BEH      0.922     0.808     0.861       146
       I-BEH      1.000     0.750     0.857        32
       B-CAP      0.982     0.900     0.939        60
       I-CAP      0.979     0.939     0.958        49
      B-DVEC      0.875     0.778     0.824        63
      I-DVEC      0.932     0.788     0.854        52
       B-MAL      0.909     0.964     0.936       249
       I-MAL      1.000     1.000     1.000         1
        B-OS      0.861     0.939     0.899        33
        I-OS      0.818     1.000     0.900         9

   micro avg      0.920     0.883     0.901       694
   macro avg      0.928     0.887     0.903       694
weighted avg      0.922     0.883     0.900       694



In [10]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f"%(label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])


Top likely transitions:
B-DVEC -> I-DVEC  5.639431
I-BEH  -> I-BEH   5.086817
B-CAP  -> I-CAP   4.950763
B-OS   -> I-OS    4.451265
B-MAL  -> I-MAL   3.757583
B-BEH  -> I-BEH   3.743590
O      -> O       3.610088
I-OS   -> I-OS    3.521754
I-CAP  -> I-CAP   3.433305
I-DVEC -> I-DVEC  3.014713
I-MAL  -> I-MAL   2.414155
O      -> B-BEH   1.427862
O      -> B-MAL   1.196830
O      -> B-DVEC  1.113900
B-MAL  -> B-BEH   1.051131
O      -> B-CAP   0.852822
B-MAL  -> O       0.750119
B-CAP  -> B-MAL   0.359316
B-BEH  -> O       0.217070
I-CAP  -> O       0.187649

Top unlikely transitions:
B-BEH  -> O       0.217070
I-CAP  -> O       0.187649
O      -> B-OS    0.093559
B-CAP  -> B-CAP   0.027732
B-CAP  -> B-DVEC  -0.068609
B-OS   -> O       -0.279608
I-DVEC -> O       -0.340921
I-MAL  -> O       -0.684048
I-BEH  -> O       -0.747532
B-CAP  -> O       -0.815387
B-DVEC -> B-DVEC  -0.957932
B-MAL  -> B-OS    -1.123167
B-OS   -> B-DVEC  -1.225512
B-DVEC -> I-BEH   -1.505942
I-OS   -> O       -1.

In [11]:
import eli5

eli5.show_weights(crf, top=30)

From \ To,O,B-BEH,I-BEH,B-CAP,I-CAP,B-DVEC,I-DVEC,B-MAL,I-MAL,B-OS,I-OS
O,3.61,1.428,-6.163,0.853,-3.201,1.114,-2.498,1.197,0.0,0.094,0.0
B-BEH,0.217,0.0,3.744,0.0,0.0,-3.582,0.0,0.0,0.0,0.0,0.0
I-BEH,-0.748,0.0,5.087,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-CAP,-0.815,0.0,0.0,0.028,4.951,-0.069,0.0,0.359,0.0,0.0,0.0
I-CAP,0.188,0.0,0.0,0.0,3.433,0.0,0.0,0.0,0.0,0.0,0.0
B-DVEC,-2.595,0.0,-1.506,0.0,0.0,-0.958,5.639,0.0,0.0,0.0,0.0
I-DVEC,-0.341,0.0,0.0,0.0,0.0,0.0,3.015,0.0,0.0,0.0,0.0
B-MAL,0.75,1.051,0.0,0.0,0.0,0.0,0.0,0.0,3.758,-1.123,0.0
I-MAL,-0.684,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.414,0.0,0.0
B-OS,-0.28,0.0,0.0,0.0,0.0,-1.226,0.0,0.0,0.0,0.0,4.451

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5,Unnamed: 9_level_5,Unnamed: 10_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6,Unnamed: 9_level_6,Unnamed: 10_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7,Unnamed: 9_level_7,Unnamed: 10_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8,Unnamed: 9_level_8,Unnamed: 10_level_8
Weight?,Feature,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,Unnamed: 8_level_9,Unnamed: 9_level_9,Unnamed: 10_level_9
Weight?,Feature,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,Unnamed: 9_level_10,Unnamed: 10_level_10
+7.128,+1:word.lower():continues,,,,,,,,,
+6.065,+1:word.lower():several,,,,,,,,,
+6.020,+1:word.lower():network,,,,,,,,,
+5.920,+1:word.lower():remains,,,,,,,,,
+5.620,-1:word.lower():specific,,,,,,,,,
+5.413,word.lower():malozbot,,,,,,,,,
+5.310,+1:word.lower():trickbot,,,,,,,,,
+5.200,word.lower():and,,,,,,,,,
+5.173,-1:word.lower():investigating,,,,,,,,,
+4.949,word.lower():using,,,,,,,,,

Weight?,Feature
+7.128,+1:word.lower():continues
+6.065,+1:word.lower():several
+6.020,+1:word.lower():network
+5.920,+1:word.lower():remains
+5.620,-1:word.lower():specific
+5.413,word.lower():malozbot
+5.310,+1:word.lower():trickbot
+5.200,word.lower():and
+5.173,-1:word.lower():investigating
+4.949,word.lower():using

Weight?,Feature
+10.821,word.lower():overlaying
+10.714,word.lower():uninstallation
+10.203,word.lower():intercepting
+9.429,word.lower():masquerades
+8.585,word.lower():imitating
+8.051,word.lower():mimicking
+7.810,word.lower():pretends
+7.750,word.lower():concealing
+7.187,+1:word.lower():yet
+6.870,word.lower():mimicked

Weight?,Feature
+6.565,+1:word.lower():money
+5.391,-1:word.lower():displaying
+5.182,-1:word.lower():display
+4.850,+1:word.lower():posing
+4.828,-1:word.lower():install
+4.456,-1:word.lower():harvest
+4.298,-1:word.lower():forward
+4.213,-1:word.lower():acquire
+3.980,+1:word.lower():systems
+3.949,word.lower():uninstallation

Weight?,Feature
+7.893,word.lower():command-and-control
+7.692,+1:word.lower():network
+7.528,word.lower():keylogger
+7.208,word.lower():keylogging
+7.207,word.lower():uninstallation
+6.694,word.lower():uninstalled
+6.173,word.lower():password-stealing
+6.157,word.lower():packed
+6.061,word.lower():keyloggers
+5.928,word.lower():information-stealing

Weight?,Feature
+4.902,-1:word.lower():command
+4.899,word.lower():man-in-the-browser
+4.041,-1:word.lower():data
+3.829,word.lower():privileges
+3.374,+1:word.lower():â
+3.372,-1:word.lower():&
+3.316,-1:word.lower():screen
+3.108,-1:word.lower():called
+3.036,word.lower():fluxing
+2.912,-1:word.lower():remote

Weight?,Feature
+8.566,word.lower():accessibility
+8.501,word.lower():sms-based
+5.988,-1:word.lower():accessibility
+5.627,word.lower():office
+5.462,-1:word.lower():using
+5.459,word.lower():e-mails
+4.980,-1:word.lower():ever-evolving
+4.852,word[-3:]:APK
+4.852,word[-2:]:PK
+4.829,word.lower():adsense

Weight?,Feature
+4.585,+1:word.lower():campaign
+4.384,+1:word.lower():network
+4.321,-1:word.lower():on
+4.064,word.lower():documents
+3.949,+1:word.lower():application
+3.853,-1:word.lower():removable
+3.772,+1:word.lower():messages
+3.719,-1:word.lower():login
+3.618,-1:word.lower():social
+3.265,+1:word.lower():file

Weight?,Feature
+6.177,word.lower():zbot's
+5.910,word[-3:]:cub
+5.542,word.lower():svpeng
+5.130,word.lower():faketoken
+5.040,word.lower():triada
+4.769,word.lower():marcher
+4.732,word.lower():zeus
+4.447,word.lower():tspyozbot
+4.231,word[-2:]:.C
+4.231,word.lower():android.marcher.c

Weight?,Feature
+3.966,-1:word.lower():trojan-spy.win32
+3.966,+1:word.lower():zbot.gen
+2.939,word.lower():zloader
+2.939,-1:word.lower():terdot
+2.545,word.lower():zbot.gen
+2.545,word[-3:]:gen
+1.948,word[-3:]:der
+1.815,word[-2:]:en
+1.377,word[-2:]:er
+1.338,+1:word.lower():/

Weight?,Feature
+11.705,word.lower():android-based
+6.430,-1:word.lower():information
+5.189,+1:word.lower():server
+4.342,+1:word.lower():machines
+4.002,+1:word.lower():â
+3.900,+1:word.lower():computer
+3.894,+1:word.lower():smartphones
+3.546,word.lower():android
+3.302,word.lower():windows
+3.255,word[-3:]:sed

Weight?,Feature
+2.816,-1:word.lower():windows
+2.774,-1:word.lower():server
+2.398,-1:postag:VBZ
+2.366,postag:CD
+2.366,postag[:2]:CD
+2.235,postag:NNP
+1.760,word[-3:]:sta
+1.760,word.lower():vista
+1.628,word[-2:]:ta
+1.064,word[-2:]:XP
