# KDD Cup 1999 Data

http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html

In [1]:
import sklearn
import pandas as pd
from sklearn import preprocessing
from sklearn.utils import resample
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
import numpy as np
from sklearn.decomposition import PCA
from sklearn.neural_network import MLPClassifier 
from sklearn.pipeline import Pipeline
import time
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.externals import joblib
from sklearn.utils import resample

In [2]:
print('The scikit-learn version is {}.'.format(sklearn.__version__))

The scikit-learn version is 0.18.1.


In [3]:
col_names = ["duration","protocol_type","service","flag","src_bytes",
             "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
             "logged_in","num_compromised","root_shell","su_attempted","num_root","num_file_creations",
             "num_shells","num_access_files","num_outbound_cmds","is_host_login","is_guest_login","count",
             "srv_count","serror_rate","srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
             "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
             "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
             "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
             "dst_host_rerror_rate","dst_host_srv_rerror_rate","label"]

In [4]:
data = pd.read_csv("data/kddcup.data", header=None, names = col_names)

In [5]:
data.shape

(4898431, 42)

# 前処理
## カテゴリ化

In [6]:
data.label.value_counts()

smurf.              2807886
neptune.            1072017
normal.              972781
satan.                15892
ipsweep.              12481
portsweep.            10413
nmap.                  2316
back.                  2203
warezclient.           1020
teardrop.               979
pod.                    264
guess_passwd.            53
buffer_overflow.         30
land.                    21
warezmaster.             20
imap.                    12
rootkit.                 10
loadmodule.               9
ftp_write.                8
multihop.                 7
phf.                      4
perl.                     3
spy.                      2
Name: label, dtype: int64

In [7]:
data['label2'] = data.label.where(data.label.str.contains('normal'),'atack')

In [8]:
data.label2.value_counts()

atack      3925650
normal.     972781
Name: label2, dtype: int64

In [9]:
data['label3'] = data.label.copy()

In [10]:
data.loc[data.label.str.contains('back|land|neptune|pod|smurf|teardrop|mailbomb|apache2|processtable|udpstorm'),'label3'] = 'DoS'

In [11]:
data.loc[data.label.str.contains('buffer_overflow|loadmodule|perl|rootkit|ps|xterm|sqlattack'),'label3'] = 'U2R'

In [12]:
data.loc[data.label.str.contains('ftp_write|guess_passwd|imap|multihop|phf|spy|warezclient|warezmaster|snmpgetattack|snmpguess|httptunnel|sendmail|named|xlock|xsnoop|worm'),'label3'] = 'R2L'

In [13]:
data.loc[data.label.str.contains('ipsweep|nmap|portsweep|satan|mscan|saint'),'label3'] = 'Probe'

In [14]:
data.label3.value_counts()

DoS        3883370
normal.     972781
Probe        41102
R2L           1126
U2R             52
Name: label3, dtype: int64

In [15]:
#joblib.dump(data,'dump/20171118/corrected.pkl')

## サンプリング

In [16]:
#data = resample(data,n_samples=10000,random_state=0)

In [17]:
#data.shape

## 数値化

In [18]:
le_protocol_type = preprocessing.LabelEncoder()

In [19]:
le_protocol_type.fit(data.protocol_type)

LabelEncoder()

In [20]:
data.protocol_type=le_protocol_type.transform(data.protocol_type)

In [21]:
le_service = preprocessing.LabelEncoder()

In [22]:
le_service.fit(data.service)

LabelEncoder()

In [23]:
data.service = le_service.transform(data.service)

In [24]:
le_flag = preprocessing.LabelEncoder()

In [25]:
le_flag.fit(data.flag)

LabelEncoder()

In [26]:
data.flag = le_flag.transform(data.flag)

In [27]:
data.describe()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,...,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0
mean,48.34243,0.4612036,25.31783,7.834247,1834.621,1093.623,5.716116e-06,0.0006487792,7.961733e-06,0.01243766,...,232.9811,189.2142,0.7537132,0.03071111,0.605052,0.006464107,0.1780911,0.1778859,0.0579278,0.05765941
std,723.3298,0.572557,14.83422,2.256784,941431.1,645012.3,0.002390833,0.04285434,0.007215084,0.4689782,...,64.02094,105.9128,0.411186,0.1085432,0.4809877,0.04125978,0.3818382,0.3821774,0.2309428,0.2309777
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,15.0,9.0,45.0,0.0,0.0,0.0,0.0,0.0,...,255.0,49.0,0.41,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,15.0,9.0,520.0,0.0,0.0,0.0,0.0,0.0,...,255.0,255.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,1.0,46.0,9.0,1032.0,0.0,0.0,0.0,0.0,0.0,...,255.0,255.0,1.0,0.04,1.0,0.0,0.0,0.0,0.0,0.0
max,58329.0,2.0,69.0,10.0,1379964000.0,1309937000.0,1.0,3.0,14.0,77.0,...,255.0,255.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [28]:
data.shape

(4898431, 44)

## ラベルの分離

In [29]:
y_test_1 = data.label.copy() 

In [30]:
y_test_2 = data.label2.copy()

In [31]:
y_test_3 = data.label3.copy()

In [32]:
x_test= data.drop(['label','label2','label3'],axis=1)

In [33]:
x_test.shape

(4898431, 41)

In [34]:
y_test_1.shape

(4898431,)

In [35]:
y_test_2.shape

(4898431,)

In [36]:
y_test_3.shape

(4898431,)

## 標準化

In [37]:
ss = preprocessing.StandardScaler()

In [38]:
ss.fit(x_test)

StandardScaler(copy=True, with_mean=True, with_std=True)

In [39]:
x_test = ss.transform(x_test)

In [40]:
col_names2 = ["duration","protocol_type","service","flag","src_bytes",
             "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
             "logged_in","num_compromised","root_shell","su_attempted","num_root","num_file_creations",
             "num_shells","num_access_files","num_outbound_cmds","is_host_login","is_guest_login","count",
             "srv_count","serror_rate","srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
             "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
             "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
             "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
             "dst_host_rerror_rate","dst_host_srv_rerror_rate"]

In [41]:
pd.DataFrame(x_test,columns=col_names2).describe()

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
count,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,...,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0,4898431.0
mean,2.128177e-12,2.658016e-12,-5.140332e-12,6.948923e-12,-2.393464e-14,2.700956e-14,-5.428228e-14,1.927643e-15,1.885683e-14,1.4455e-12,...,2.501275e-12,3.067024e-12,2.324952e-12,-2.217487e-14,1.052332e-11,6.272213e-13,8.636616e-12,-9.429649e-12,-1.176099e-11,7.4392e-12
std,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
min,-0.06683319,-0.8055157,-1.706718,-3.471421,-0.001948758,-0.001695507,-0.002390847,-0.01513917,-0.001103485,-0.02652076,...,-3.639139,-1.78651,-1.833023,-0.282939,-1.257937,-0.1566685,-0.4664048,-0.4654536,-0.2508318,-0.249632
25%,-0.06683319,-0.8055157,-0.6955426,0.5165548,-0.001900958,-0.001695507,-0.002390847,-0.01513917,-0.001103485,-0.02652076,...,0.3439331,-1.323866,-0.835907,-0.282939,-1.257937,-0.1566685,-0.4664048,-0.4654536,-0.2508318,-0.249632
50%,-0.06683319,-0.8055157,-0.6955426,0.5165548,-0.001396407,-0.001695507,-0.002390847,-0.01513917,-0.001103485,-0.02652076,...,0.3439331,0.6211316,0.5989668,-0.282939,0.8211187,-0.1566685,-0.4664048,-0.4654536,-0.2508318,-0.249632
75%,-0.06683319,0.9410355,1.39422,0.5165548,-0.0008525545,-0.001695507,-0.002390847,-0.01513917,-0.001103485,-0.02652076,...,0.3439331,0.6211316,0.5989668,0.08557786,0.8211187,-0.1566685,-0.4664048,-0.4654536,-0.2508318,-0.249632
max,80.57274,2.687587,2.944688,0.9596633,1465.813,2030.87,418.2618,69.98945,1940.379,164.1602,...,0.3439331,0.6211316,0.5989668,8.929982,0.8211187,24.08001,2.152506,2.151132,4.079245,4.079791


## 学習

In [42]:
clf = joblib.load('dump/20171118/MLPClassifier10per.pkl')

In [43]:
t1=time.perf_counter()
pred = clf.predict(x_test)
t2=time.perf_counter()

In [44]:
print(t2-t1,"秒")

31.907945523002127 秒


In [45]:
print(classification_report(y_test_3, pred))
print(confusion_matrix(y_test_3, pred))

             precision    recall  f1-score   support

        DoS       1.00      1.00      1.00   3883370
      Probe       0.99      0.99      0.99     41102
        R2L       0.59      0.64      0.62      1126
        U2R       0.10      0.50      0.17        52
    normal.       1.00      1.00      1.00    972781

avg / total       1.00      1.00      1.00   4898431

[[3883171      23       1       0     175]
 [     12   40809       3       0     278]
 [      1       0     725       6     394]
 [      0       0       4      26      22]
 [     81     249     490     220  971741]]


In [140]:
#joblib.dump(data,'dump/20171118/MLPClassifier10per.pkl')

['dump/20171118/MLPClassifier10per.pkl']