In [1]:
# import SparkSession

from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

spark

22/11/06 12:18:21 WARN Utils: Your hostname, pc resolves to a loopback address: 127.0.1.1; using 192.168.170.52 instead (on interface wlp3s0)
22/11/06 12:18:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/11/06 12:18:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
df = spark.read.text("./data/kddcup.data.gz")
df.printSchema()

root
 |-- value: string (nullable = true)



In [5]:
# Split data 

from pyspark.sql.functions import split

split_col = split(df['value'], ',')

df = df.withColumn('Protocol', split_col.getItem(1)) \
        .withColumn('Service', split_col.getItem(2)) \
        .withColumn('flag', split_col.getItem(3)) \
        .withColumn('src_bytes', split_col.getItem(4)) \
        .withColumn('dst_bytes', split_col.getItem(5)) \
        .withColumn('urgent', split_col.getItem(8)) \
        .withColumn('num_failed_login', split_col.getItem(10)) \
        .withColumn('root_shell', split_col.getItem(13)) \
        .withColumn('guest_login', split_col.getItem(21)) \
        .withColumn('label', split_col.getItem(41)) \
        .drop('value')

df.show()

+--------+-------+----+---------+---------+------+----------------+----------+-----------+-------+
|Protocol|Service|flag|src_bytes|dst_bytes|urgent|num_failed_login|root_shell|guest_login|  label|
+--------+-------+----+---------+---------+------+----------------+----------+-----------+-------+
|     tcp|   http|  SF|      215|    45076|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      162|     4528|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      236|     1228|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      233|     2032|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      239|      486|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      238|     1282|     0|               0|         0|          0|normal.|
|     tcp|   http|  SF|      235|     1337|     0|               0|         0|          0|normal.|
|     tcp|

[Stage 0:>                                                          (0 + 1) / 1]                                                                                

In [8]:
# increase the number of partitions
print("Before repartitions : ",df.rdd.getNumPartitions())
df = df.repartition(10)

print("After repartitions : ",df.rdd.getNumPartitions())

df.createOrReplaceTempView("df_kdd_cup")

Before repartitions :  10


[Stage 2:>                                                          (0 + 1) / 1]

After repartitions :  10


In [9]:
# count the number of connections for each label

df.groupBy("label").count().orderBy('count', ascending=False).show()



+----------------+-------+
|           label|  count|
+----------------+-------+
|          smurf.|2807886|
|        neptune.|1072017|
|         normal.| 972781|
|          satan.|  15892|
|        ipsweep.|  12481|
|      portsweep.|  10413|
|           nmap.|   2316|
|           back.|   2203|
|    warezclient.|   1020|
|       teardrop.|    979|
|            pod.|    264|
|   guess_passwd.|     53|
|buffer_overflow.|     30|
|           land.|     21|
|    warezmaster.|     20|
|           imap.|     12|
|        rootkit.|     10|
|     loadmodule.|      9|
|      ftp_write.|      8|
|       multihop.|      7|
+----------------+-------+
only showing top 20 rows



                                                                                

In [10]:
# get the list of protocols that are normal and vulnerable to attacks,
# where there is NOT guest login to the destination addresses

sql_query = """
                SELECT Protocol,
                CASE label
                    WHEN 'normal.' THEN 'no attack'
                    ELSE 'attack'
                END AS State,
                COUNT(*) as freq
                FROM df_kdd_cup
                WHERE guest_login = '0'
                GROUP BY Protocol, State
                ORDER BY Protocol DESC
"""

spark.sql(sql_query).show()

                                                                                

+--------+---------+-------+
|Protocol|    State|   freq|
+--------+---------+-------+
|     udp|no attack| 191348|
|     udp|   attack|   2940|
|     tcp|no attack| 764894|
|     tcp|   attack|1101613|
|    icmp|   attack|2820782|
|    icmp|no attack|  12763|
+--------+---------+-------+



In [16]:
# A descriptive stats based on Protocols and Labels

sql_query = """
                SELECT Protocol,
                CASE label
                    WHEN 'normal.' THEN 'no attack'
                    ELSE 'attack'
                END AS State,
                COUNT(*) AS Freq,
                round(AVG(src_bytes),2) as mean_src_bytes,
                round(AVG(dst_bytes),2) as mean_dst_bytes,
                SUM(urgent) as sum_uregent,
                SUM(num_failed_login) as sum_num_faked_logins,
                SUM(guest_login) as sum_guest_login
                FROM df_kdd_cup
                GROUP BY Protocol, State  
            """

spark.sql(sql_query).show()



+--------+---------+-------+--------------+--------------+-----------+--------------------+---------------+
|Protocol|    State|   Freq|mean_src_bytes|mean_dst_bytes|sum_uregent|sum_num_faked_logins|sum_guest_login|
+--------+---------+-------+--------------+--------------+-----------+--------------------+---------------+
|     tcp|no attack| 768670|       1844.29|       4071.32|       35.0|                96.0|         3776.0|
|     udp|   attack|   2940|          26.4|          0.82|        0.0|                 0.0|            0.0|
|     tcp|   attack|1101928|       4465.81|       2005.96|        4.0|                61.0|          315.0|
|    icmp|   attack|2820782|        931.68|           0.0|        0.0|                 0.0|            0.0|
|    icmp|no attack|  12763|         90.68|           0.0|        0.0|                 0.0|            0.0|
|     udp|no attack| 191348|         98.32|         89.41|        0.0|                 0.0|            0.0|
+--------+---------+-------+

                                                                                

In [19]:
# Get the frequency of sevices for the original UDP and ICMP based attacks
from pyspark.sql.types import StringType

def attack_category(item):
    if item.replace(".","") in ['back', 'land', 'neptune', 'pod', 'smurf', 'teardrop']:
        return "DoS"
    elif item.replace(".","") in ['buffer_overflow', 'loadmodule', 'perl', 'rootkit']:
        return "U2R"
    elif item.replace(".","") in ['dtp_write', 'guess_password', 'multihop', 'phf', 'spy', 'warezclient', 'warezmaster']:
        return "R2L"
    else:
        return "probe"

def center_justify(item):
    return item.center(10)

spark.udf.register("original_attack", attack_category, StringType())
spark.udf.register("center_justify", center_justify, StringType())

sql_query = """
                SELECT
                center_justify(Service) as service,
                center_justify(Protocol) as protocol,
                original_attack(label) as new_label,
                COUNT(*) as freq
                FROM df_kdd_cup
                WHERE (Protocol = 'udp' or Protocol = 'icmp') and label != 'normal.'
                GROUP BY service, new_label, protocol
"""

spark.sql(sql_query).show()



+----------+----------+---------+-------+
|   service|  protocol|new_label|   freq|
+----------+----------+---------+-------+
|  other   |   udp    |    probe|    261|
|  other   |   udp    |      U2R|      3|
| private  |   udp    |      DoS|    979|
|  ecr_i   |   icmp   |    probe|     59|
| private  |   udp    |    probe|   1688|
|  eco_i   |   icmp   |    probe|  12570|
|  ecr_i   |   icmp   |      DoS|2808145|
|  tim_i   |   icmp   |      DoS|      5|
| domain_u |   udp    |    probe|      9|
|  urp_i   |   icmp   |    probe|      3|
+----------+----------+---------+-------+



                                                                                