# KDD Cup 1999 Data

http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html

In [93]:
import sklearn
import pandas as pd
from sklearn import preprocessing
from sklearn.utils import resample
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
import numpy as np
from sklearn.decomposition import PCA
from sklearn.neural_network import MLPClassifier 
from sklearn.pipeline import Pipeline
import time
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.externals import joblib
from sklearn.utils import resample

In [94]:
print('The scikit-learn version is {}.'.format(sklearn.__version__))

The scikit-learn version is 0.18.1.


In [95]:
col_names = ["duration","protocol_type","service","flag","src_bytes",
             "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
             "logged_in","num_compromised","root_shell","su_attempted","num_root","num_file_creations",
             "num_shells","num_access_files","num_outbound_cmds","is_host_login","is_guest_login","count",
             "srv_count","serror_rate","srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
             "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
             "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
             "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
             "dst_host_rerror_rate","dst_host_srv_rerror_rate","label"]

In [96]:
data = pd.read_csv("data/kddcup.data_10_percent", header=None, names = col_names)

In [97]:
data.shape

(494021, 42)

# 前処理
## カテゴリ化

In [98]:
data.label.value_counts()

smurf.              280790
neptune.            107201
normal.              97278
back.                 2203
satan.                1589
ipsweep.              1247
portsweep.            1040
warezclient.          1020
teardrop.              979
pod.                   264
nmap.                  231
guess_passwd.           53
buffer_overflow.        30
land.                   21
warezmaster.            20
imap.                   12
rootkit.                10
loadmodule.              9
ftp_write.               8
multihop.                7
phf.                     4
perl.                    3
spy.                     2
Name: label, dtype: int64

In [99]:
data['label2'] = data.label.where(data.label.str.contains('normal'),'atack')

In [100]:
data.label2.value_counts()

atack      396743
normal.     97278
Name: label2, dtype: int64

In [101]:
data['label3'] = data.label.copy()

In [102]:
data.loc[data.label.str.contains('back|land|neptune|pod|smurf|teardrop|mailbomb|apache2|processtable|udpstorm'),'label3'] = 'DoS'

In [103]:
data.loc[data.label.str.contains('buffer_overflow|loadmodule|perl|rootkit|ps|xterm|sqlattack'),'label3'] = 'U2R'

In [104]:
data.loc[data.label.str.contains('ftp_write|guess_passwd|imap|multihop|phf|spy|warezclient|warezmaster|snmpgetattack|snmpguess|httptunnel|sendmail|named|xlock|xsnoop|worm'),'label3'] = 'R2L'

In [105]:
data.loc[data.label.str.contains('ipsweep|nmap|portsweep|satan|mscan|saint'),'label3'] = 'Probe'

In [106]:
data.label3.value_counts()

DoS        391458
normal.     97278
Probe        4107
R2L          1126
U2R            52
Name: label3, dtype: int64

In [107]:
#joblib.dump(data,'dump/20171118/corrected.pkl')

## サンプリング

In [108]:
#data = resample(data,n_samples=10000,random_state=0)

In [109]:
#data.shape

## 数値化

In [110]:
le_protocol_type = preprocessing.LabelEncoder()

In [111]:
le_protocol_type.fit(data.protocol_type)

LabelEncoder()

In [112]:
data.protocol_type=le_protocol_type.transform(data.protocol_type)

In [113]:
le_service = preprocessing.LabelEncoder()

In [114]:
le_service.fit(data.service)

LabelEncoder()

In [115]:
data.service = le_service.transform(data.service)

In [116]:
le_flag = preprocessing.LabelEncoder()

In [117]:
le_flag.fit(data.flag)

LabelEncoder()

In [118]:
data.flag = le_flag.transform(data.flag)

In [119]:
data.describe()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,...,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0
mean,47.979302,0.467132,23.408894,7.842446,3025.61,868.5324,4.5e-05,0.006433,1.4e-05,0.034519,...,232.470778,188.66567,0.75378,0.030906,0.601935,0.006684,0.176754,0.176443,0.058118,0.057412
std,707.746472,0.575606,13.538332,2.250853,988218.1,33040.0,0.006673,0.134805,0.00551,0.782103,...,64.74538,106.040437,0.410781,0.109259,0.481309,0.042133,0.380593,0.380919,0.23059,0.23014
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,14.0,9.0,45.0,0.0,0.0,0.0,0.0,0.0,...,255.0,46.0,0.41,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,14.0,9.0,520.0,0.0,0.0,0.0,0.0,0.0,...,255.0,255.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,1.0,42.0,9.0,1032.0,0.0,0.0,0.0,0.0,0.0,...,255.0,255.0,1.0,0.04,1.0,0.0,0.0,0.0,0.0,0.0
max,58329.0,2.0,65.0,10.0,693375600.0,5155468.0,1.0,3.0,3.0,30.0,...,255.0,255.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [120]:
data.shape

(494021, 44)

## ラベルの分離

In [121]:
y_test_1 = data.label.copy() 

In [122]:
y_test_2 = data.label2.copy()

In [123]:
y_test_3 = data.label3.copy()

In [124]:
x_test= data.drop(['label','label2','label3'],axis=1)

In [125]:
x_test.shape

(494021, 41)

In [126]:
y_test_1.shape

(494021,)

In [127]:
y_test_2.shape

(494021,)

In [128]:
y_test_3.shape

(494021,)

## 標準化

In [129]:
ss = preprocessing.StandardScaler()

In [130]:
ss.fit(x_test)

StandardScaler(copy=True, with_mean=True, with_std=True)

In [131]:
x_test = ss.transform(x_test)

In [132]:
col_names2 = ["duration","protocol_type","service","flag","src_bytes",
             "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
             "logged_in","num_compromised","root_shell","su_attempted","num_root","num_file_creations",
             "num_shells","num_access_files","num_outbound_cmds","is_host_login","is_guest_login","count",
             "srv_count","serror_rate","srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
             "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
             "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
             "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
             "dst_host_rerror_rate","dst_host_srv_rerror_rate"]

In [133]:
pd.DataFrame(x_test,columns=col_names2).describe()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,...,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0,494021.0
mean,-3.625574e-14,-5.905329e-13,-7.556595e-14,-1.583611e-13,-1.320064e-14,-5.49889e-14,2.919715e-14,-6.073146e-15,-1.529804e-14,-6.957988e-14,...,-7.353653e-13,1.95308e-13,3.799918e-13,-1.392772e-13,7.58115e-13,2.700207e-13,2.263813e-13,3.603748e-13,6.990594e-13,-9.654619e-14
std,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,...,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001,1.000001
min,-0.06779172,-0.8115496,-1.729084,-3.484214,-0.003061686,-0.02628733,-0.006673418,-0.04772019,-0.002571468,-0.04413591,...,-3.590542,-1.779188,-1.834994,-0.2828667,-1.250621,-0.1586293,-0.4644176,-0.4632024,-0.2520395,-0.249464
25%,-0.06779172,-0.8115496,-0.6949824,0.5142739,-0.003016149,-0.02628733,-0.006673418,-0.04772019,-0.002571468,-0.04413591,...,0.3479668,-1.345391,-0.8368938,-0.2828667,-1.250621,-0.1586293,-0.4644176,-0.4632024,-0.2520395,-0.249464
50%,-0.06779172,-0.8115496,-0.6949824,0.5142739,-0.002535486,-0.02628733,-0.006673418,-0.04772019,-0.002571468,-0.04413591,...,0.3479668,0.6255576,0.5993962,-0.2828667,0.8270476,-0.1586293,-0.4644176,-0.4632024,-0.2520395,-0.249464
75%,-0.06779172,0.9257531,1.373221,0.5142739,-0.002017381,-0.02628733,-0.006673418,-0.04772019,-0.002571468,-0.04413591,...,0.3479668,0.6255576,0.5993962,0.08323588,0.8270476,-0.1586293,-0.4644176,-0.4632024,-0.2520395,-0.249464
max,82.3474,2.663056,3.072103,0.9585503,701.64,156.011,149.8483,22.20663,544.4371,38.31404,...,0.3479668,0.6255576,0.5993962,8.869697,0.8270476,23.57583,2.163063,2.162027,4.084676,4.095715


## 学習

In [134]:
clf = MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(100, 100, 100, 10), learning_rate='constant',
       learning_rate_init=0.001, max_iter=200, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=None,
       shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,
       verbose=False, warm_start=False)

In [135]:
t1=time.perf_counter()
clf.fit(x_test, y_test_3)
t2=time.perf_counter()

In [136]:
print(t2-t1,"秒")

67.5275362129978 秒


In [137]:
t1=time.perf_counter()
pred = clf.predict(x_test)
t2=time.perf_counter()

In [138]:
print(t2-t1,"秒")

2.079974817999755 秒


In [139]:
print(classification_report(y_test_3, pred))
print(confusion_matrix(y_test_3, pred))

             precision    recall  f1-score   support

        DoS       1.00      1.00      1.00    391458
      Probe       0.99      0.99      0.99      4107
        R2L       0.97      0.88      0.92      1126
        U2R       0.60      0.50      0.55        52
    normal.       1.00      1.00      1.00     97278

avg / total       1.00      1.00      1.00    494021

[[391413      2      0      0     43]
 [     1   4086      0      0     20]
 [     0      0    995      1    130]
 [     0      0      3     26     23]
 [     0     22     31     16  97209]]


In [141]:
#joblib.dump(clf,'dump/20171118/MLPClassifier10per.pkl')

['dump/20171118/MLPClassifier10per.pkl']