In [5]:
import os,json,sys,logging
sys.path.append("./share")
sys.path.append("./common")
import pandas as pd
import json
from tqdm import tqdm
from IoTCommon import CIoTCommon
from IoTTotalFeature import CIoTTotalFeature
from IoTSample import CIoTSample
from SHSample import CSHSample
from SHDataProcess import CSHDataProcess
from SHFeatureSelect import CSHFeature
from Config import g_data_root
from SHDataEDA import CSHDataDistribution,CSHDataTest
from SHModelClassify import CSHModelClassify
from SHEvaluation import CSHROC
import warnings
import h2o
warnings.simplefilter("ignore")
g_sample_root = "%ssample"%g_data_root
h2o.init(nthreads = -1, verbose=False)

In [6]:
class CIoTModel:
        
    def __init__(self,attack,protocol):
        self.m_attack = attack
        self.m_protocol = protocol
        self.m_ioTSample = CIoTSample()
        self.m_col_x = []
        self.m_col_y = 'Label'
        self.m_train = pd.DataFrame()
        self.m_test = pd.DataFrame()
        self.m_model = CSHModelClassify()
        
    def get_raw_sample(self,attack,protocol,maxCount = 20000):
        df_attack = self.m_ioTSample.get_attack_sample(attack,protocol)
        if maxCount < df_attack.shape[0]: df_attack = df_attack.sample(maxCount)
        attack_count = df_attack[df_attack['Label'] !=0].shape[0]
        normal_count = int(attack_count/len(self.m_ioTSample.get_sensor_type()))
        df_normal = pd.DataFrame()
        for sensor in self.m_ioTSample.get_sensor_type():
            df_tmp = self.m_ioTSample.get_sensor_sample(sensor,protocol)
            if df_tmp.shape[0] <= 0 :
                continue
            if df_tmp.shape[0] > normal_count:
                df_tmp = df_tmp.sample(n=normal_count)
            df_normal = pd.concat([df_normal,df_tmp],ignore_index=True)
        df_sample = pd.concat([df_normal,df_attack],ignore_index=True)
        return df_sample
        
    def get_select_sample(self,df_sample):
        used_columns = []
        for column_name, dtype in df_sample.dtypes.items():
            if column_name in ['id','frame.time_utc','frame.time_delta']:
                continue
            if dtype == 'object':
                continue
            if not column_name in [self.m_col_y] and df_sample[column_name].nunique() < 2:
                continue
            used_columns.append(column_name)
        
        df_data = df_sample[used_columns]
        return df_data

    def load_sample(self,maxCount = 20000):
        df_sample = self.get_raw_sample(self.m_attack,self.m_protocol,maxCount)
        df_data = self.get_select_sample(df_sample)
        self.m_sample = df_data.sample(frac=1).reset_index(drop=True)
        self.m_col_x = self.m_sample.keys().tolist()

        if self.m_sample[self.m_col_y].nunique() < 2:
            print("Less Label category ( class < 2")
            return False
        
        return True
        
    def preprocess(self):
        df_data = self.m_ioTSample.format(self.m_sample)
        self.m_col_x = df_data.keys().tolist()
        self.m_sample, scale_columns = CSHDataProcess.get_scale(df_data,y_column=self.m_col_y)
        self.m_train,self.m_test = CSHSample.split_dataset(self.m_sample)
        
    def train(self):
        self.m_model.train(self.m_train,x_columns=self.m_col_x,y_column=self.m_col_y,train_ratio = 0)
        return self.m_model.importance()

    def test(self):
        return self.m_model.evaluate(self.m_test,x_columns=self.m_col_x,y_column=self.m_col_y)

In [7]:
attack="Backdoor_attack"
protocol="eth-ethertype-ip".replace("-",":")

In [8]:
model = CIoTModel(attack,protocol)
if model.load_sample():
    model.preprocess()
    df_importance = model.train()
    df_verify = model.test()
else:
    print("No Sample")

KeyboardInterrupt: 