In [1]:
pip install -q findspark

Note: you may need to restart the kernel to use updated packages.


In [2]:
import findspark
findspark.init('/home/bigdata/Documents/spark-3.0.0')

In [3]:
from pyspark.sql.types import StructType
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import ByteType
from pyspark.sql.types import FloatType
from pyspark.sql.types import BinaryType
from pyspark.sql.types import DataType
from pyspark.sql.types import StructField
import pandas as pd
import numpy as np
from pyspark.sql.session import SparkSession

In [4]:
spark = SparkSession.builder.getOrCreate()

In [5]:
schema = StructType([ StructField("srcip", StringType(), True),   
    StructField("sport", IntegerType(), True),  
    StructField("dstip", StringType(), True),   
    StructField("dsport", IntegerType(), True), 
    StructField("proto", StringType(), True),  
    StructField("state", StringType(), True),
    StructField("dur", FloatType(), True),       
    StructField("sbytes", ByteType(), True),    
    StructField("dbytes", ByteType(), True),    
    StructField("sttl", IntegerType(), True), 
    StructField("dttl", IntegerType(), True),   
    StructField("sloss", IntegerType(), True),
    StructField("dloss", IntegerType(), True),
    StructField("service", StringType(), True),    
    StructField("Sload", FloatType(), True),  
    StructField("Dload", FloatType(), True), 
    StructField("Spkts", IntegerType(), True),    
    StructField("Dpkts", IntegerType(), True),    
    StructField("swin", IntegerType(), True),
    StructField("dwin", IntegerType(), True), 
    StructField("stcpb", IntegerType(), True),  
    StructField("dtcpb", IntegerType(), True),    
    StructField("smeansz", IntegerType(), True),   
    StructField("dmeansz", IntegerType(), True),  
    StructField("trans_depth", IntegerType(), True),  
    StructField("res_bdy_len", IntegerType(), True),    
    StructField("Sjit", FloatType(), True),    
    StructField("Djit", FloatType(), True),     
    StructField("Stime", IntegerType(), True), 
    StructField("Ltime", IntegerType(), True),   
    StructField("Sintpkt", FloatType(), True),
    StructField("Dintpkt", FloatType(), True),    
    StructField("tcprtt", FloatType(), True),  
    StructField("synack", FloatType(), True),    
    StructField("ackdat", FloatType(), True),  
    StructField("is_sm_ips_ports", IntegerType(), True),    
    StructField("ct_state_ttl", IntegerType(), True), 
    StructField("ct_flw_http_mthd", IntegerType(), True), 
    StructField("is_ftp_login", IntegerType(), True),  
    StructField("ct_ftp_cmd", IntegerType(), True),
    StructField("ct_srv_src", IntegerType(), True),    
    StructField ("ct_srv_dst", IntegerType(), True),  
    StructField("ct_dst_ltm", IntegerType(), True),   
    StructField("ct_src_ltm", IntegerType(), True),    
    StructField("ct_src_dport_ltm", IntegerType(), True),    
    StructField("ct_dst_sport_ltm", IntegerType(), True),  
    StructField("ct_dst_src_ltm", IntegerType(), True),  
    StructField("attack_cat", StringType(), True),    
    StructField("Label", IntegerType(), True) 
    
])

In [6]:
df_schema = spark.read.format("csv").option("header", "True").schema(schema).load("/home/bigdata/UNSW-NB15.csv")

In [7]:
df_schema.show()

+----------+-----+-------------+------+-----+-----+--------+------+------+----+----+-----+-----+--------+---------+---------+-----+-----+----+----+----------+----------+-------+-------+-----------+-----------+---------+---------+----------+----------+---------+---------+-------+-------+-------+---------------+------------+----------------+------------+----------+----------+----------+----------+----------+----------------+----------------+--------------+----------+-----+
|     srcip|sport|        dstip|dsport|proto|state|     dur|sbytes|dbytes|sttl|dttl|sloss|dloss| service|    Sload|    Dload|Spkts|Dpkts|swin|dwin|     stcpb|     dtcpb|smeansz|dmeansz|trans_depth|res_bdy_len|     Sjit|     Djit|     Stime|     Ltime|  Sintpkt|  Dintpkt| tcprtt| synack| ackdat|is_sm_ips_ports|ct_state_ttl|ct_flw_http_mthd|is_ftp_login|ct_ftp_cmd|ct_srv_src|ct_srv_dst|ct_dst_ltm|ct_src_ltm|ct_src_dport_ltm|ct_dst_sport_ltm|ct_dst_src_ltm|attack_cat|Label|
+----------+-----+-------------+------+-----+---

In [8]:
df_schema = df_schema.drop("srcip", "sport", "dstip", "dsport", "stime", "ltime", "Label")

In [9]:
from pyspark.sql.functions import col, Column, lit,when, regexp_replace

In [10]:

df_schema = df_schema.withColumn('attack_cat', regexp_replace(col('attack_cat'), " ", ""))

In [11]:
df_schema= df_schema.withColumn('attack_cat', regexp_replace(col('attack_cat'), "Backdoors", "Backdoor"))

In [12]:
df_schema.select('attack_cat').distinct().show()

+--------------+
|    attack_cat|
+--------------+
|         Worms|
|     Shellcode|
|          null|
|       Fuzzers|
|      Analysis|
|           DoS|
|Reconnaissance|
|      Backdoor|
|      Exploits|
|       Generic|
+--------------+



In [13]:
df_schemas = df_schema.withColumn('attack_cat', regexp_replace(col('attack_cat'), "Backdoor", "Backdoors"))

In [14]:
df_schemas.select('attack_cat').distinct().show()

+--------------+
|    attack_cat|
+--------------+
|         Worms|
|     Shellcode|
|          null|
|     Backdoors|
|       Fuzzers|
|      Analysis|
|           DoS|
|Reconnaissance|
|      Exploits|
|       Generic|
+--------------+



In [15]:
df_schemas = df_schema.withColumn('attack_cat', regexp_replace(col('attack_cat'), " " , "normal"))

In [16]:
df_schemas.select('attack_cat').distinct().show()

+--------------+
|    attack_cat|
+--------------+
|         Worms|
|     Shellcode|
|          null|
|       Fuzzers|
|      Analysis|
|           DoS|
|Reconnaissance|
|      Backdoor|
|      Exploits|
|       Generic|
+--------------+



In [17]:
df_schem = df_schemas.fillna({"attack_cat": 'Normal'})

In [18]:
df_schem.select('attack_cat').distinct().show()

+--------------+
|    attack_cat|
+--------------+
|         Worms|
|     Shellcode|
|       Fuzzers|
|      Analysis|
|           DoS|
|Reconnaissance|
|      Backdoor|
|      Exploits|
|        Normal|
|       Generic|
+--------------+



In [19]:
final_data=df_schem.na.drop()

In [20]:
final_data.show()

+-----+-----+---------+------+------+----+----+-----+-----+-------+--------+---------+-----+-----+----+----+-----+-----+-------+-------+-----------+-----------+--------+--------+---------+---------+------+------+------+---------------+------------+----------------+------------+----------+----------+----------+----------+----------+----------------+----------------+--------------+----------+
|proto|state|      dur|sbytes|dbytes|sttl|dttl|sloss|dloss|service|   Sload|    Dload|Spkts|Dpkts|swin|dwin|stcpb|dtcpb|smeansz|dmeansz|trans_depth|res_bdy_len|    Sjit|    Djit|  Sintpkt|  Dintpkt|tcprtt|synack|ackdat|is_sm_ips_ports|ct_state_ttl|ct_flw_http_mthd|is_ftp_login|ct_ftp_cmd|ct_srv_src|ct_srv_dst|ct_dst_ltm|ct_src_ltm|ct_src_dport_ltm|ct_dst_sport_ltm|ct_dst_src_ltm|attack_cat|
+-----+-----+---------+------+------+----+----+-----+-----+-------+--------+---------+-----+-----+----+----+-----+-----+-------+-------+-----------+-----------+--------+--------+---------+---------+------+------+

In [21]:
final_data.summary().select("dur", "dbytes", "sttl", "dttl", "ct_dst_src_ltm").show()

+-----------------+-----------------+------------------+--------------------+------------------+
|              dur|           dbytes|              sttl|                dttl|    ct_dst_src_ltm|
+-----------------+-----------------+------------------+--------------------+------------------+
|           225994|           225994|            225994|              225994|            225994|
|0.407434248147434|1.112348115436693|245.32343779038382|0.009230333548678284|26.790277617989858|
|4.150830810875402|  9.9111587016314| 46.05972276367919|   0.952028783949515|12.612631599831502|
|              0.0|                0|                 0|                   0|                 1|
|           4.0E-6|                0|               254|                   0|                18|
|           8.0E-6|                0|               254|                   0|                28|
|           9.0E-6|                0|               254|                   0|                36|
|         59.99999|           

In [22]:
from pyspark.ml.feature import(VectorAssembler, VectorIndexer, OneHotEncoder, StringIndexer)

In [23]:
attack_indexer = StringIndexer(inputCol='attack_cat', outputCol = 'AttackIndex')

# ONE HOT ENCODING
                               
attack_encoder = OneHotEncoder(inputCol='AttackIndex', outputCol = 'attack_catVec')

In [24]:
proto_indexer = StringIndexer(inputCol='proto', outputCol = 'protoIndex')
proto_encoder = OneHotEncoder(inputCol='protoIndex', outputCol = 'protoVec')

In [25]:
state_indexer = StringIndexer(inputCol='state', outputCol = 'stateIndex')
                              
state_encoder = OneHotEncoder(inputCol='stateIndex', outputCol = 'stateVec')

In [26]:
service_indexer = StringIndexer(inputCol='service', outputCol = 'serviceIndex')
                              
service_encoder = OneHotEncoder(inputCol='serviceIndex', outputCol = 'serviceVec')

In [31]:
assembler = VectorAssembler(inputCols=['protoIndex', 'stateIndex', 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'serviceIndex', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm'], outputCol='features')

In [32]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

In [33]:
log_reg_UNSW = LogisticRegression(featuresCol = 'features', labelCol = 'AttackIndex')

In [34]:
pipeline = Pipeline(stages = [attack_indexer, proto_indexer, state_indexer, service_indexer, attack_encoder, proto_encoder, state_encoder, service_encoder, assembler, log_reg_UNSW])

In [37]:
train_data, test_data = final_data.randomSplit([0.7,0.3], seed =2018)

In [38]:
fit_model = pipeline.fit(train_data)

In [39]:
result = fit_model.transform(test_data)

In [40]:
result.show()

+-----+-----+------+------+------+----+----+-----+-----+-------+-----------+-----------+-----+-----+----+----+-----+-----+-------+-------+-----------+-----------+---------+---------+---------+---------+------+------+------+---------------+------------+----------------+------------+----------+----------+----------+----------+----------+----------------+----------------+--------------+----------+-----------+----------+----------+------------+-------------+--------------+--------------+-------------+--------------------+--------------------+--------------------+----------+
|proto|state|   dur|sbytes|dbytes|sttl|dttl|sloss|dloss|service|      Sload|      Dload|Spkts|Dpkts|swin|dwin|stcpb|dtcpb|smeansz|dmeansz|trans_depth|res_bdy_len|     Sjit|     Djit|  Sintpkt|  Dintpkt|tcprtt|synack|ackdat|is_sm_ips_ports|ct_state_ttl|ct_flw_http_mthd|is_ftp_login|ct_ftp_cmd|ct_srv_src|ct_srv_dst|ct_dst_ltm|ct_src_ltm|ct_src_dport_ltm|ct_dst_sport_ltm|ct_dst_src_ltm|attack_cat|AttackIndex|protoIndex|sta

In [41]:
train_data.count()

158462

In [42]:
test_data.count()

67532

In [43]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [44]:
my_eval =MulticlassClassificationEvaluator(predictionCol='prediction', labelCol = 'AttackIndex', metricName="accuracy")

In [54]:
result.select('prediction','features','attack_catVec','attackindex','is_ftp_login').show()

+----------+--------------------+-------------+-----------+------------+
|prediction|            features|attack_catVec|attackindex|is_ftp_login|
+----------+--------------------+-------------+-----------+------------+
|       1.0|(41,[0,1,3,4,9,12...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,3,4,9,10...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,3,4,9,10...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,3,4,9,10...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,3,4,9,10...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,3,4,9,10...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,2,3,4,9,...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,2,3,4,9,...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,2,3,4,9,...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,2,3,4,9,...|(9,[1],[1.0])|        1.0|           0|
|       1.0|(41,[0,1,2,3,4,9,...|(9,[1],[1.0])|    

In [55]:
Accuracy= my_eval.evaluate(result)

In [56]:
Accuracy

0.9858881715334953

In [57]:
result.select('attackindex','prediction').show()

+-----------+----------+
|attackindex|prediction|
+-----------+----------+
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
|        1.0|       1.0|
+-----------+----------+
only showing top 20 rows

