In [1]:
from pyspark.sql import SparkSession

# Stop any existing session
try:
    spark.stop()
except:
    pass

# Create NEW session with Kafka package
spark = SparkSession.builder \
    .appName("IoT Malware Detector") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0") \
    .config("spark.sql.streaming.kafka.useDeprecatedOffsetFetching", "false") \
    .getOrCreate()

print("✅ Fresh Spark session with Kafka!")

✅ Fresh Spark session with Kafka!


In [7]:
df_kafka = spark.read \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "big-data-final-project-kafka-1:29092") \
    .option("subscribe", "network-traffic") \
    .option("startingOffsets", "earliest") \
    .load()

print(f"✅ Connected to Kafka via IP!")
print(f"Total messages: {df_kafka.count()}")

✅ Connected to Kafka via IP!
Total messages: 156103


In [8]:
from pyspark.sql.functions import from_json, col
from pyspark.sql.types import *

# Define schema based on our data structure
schema = StructType([
    StructField("ts", DoubleType()),
    StructField("id.orig_h", StringType()),
    StructField("id.orig_p", DoubleType()),
    StructField("id.resp_h", StringType()),
    StructField("id.resp_p", DoubleType()),
    StructField("proto", StringType()),
    StructField("duration", StringType()),
    StructField("orig_bytes", StringType()),
    StructField("resp_bytes", StringType()),
    StructField("conn_state", StringType()),
    StructField("label", StringType()),
    StructField("detailed-label", StringType())
])

# Parse JSON from Kafka value
df = df_kafka.selectExpr("CAST(value AS STRING) as json") \
    .select(from_json(col("json"), schema).alias("data")) \
    .select("data.*")

# Show parsed data
df.show(5)
print(f"Total records: {df.count()}")

+-------------------+-----------+---------+--------------+---------+-----+--------+----------+----------+----------+---------+--------------------+
|                 ts|  id.orig_h|id.orig_p|     id.resp_h|id.resp_p|proto|duration|orig_bytes|resp_bytes|conn_state|    label|      detailed-label|
+-------------------+-----------+---------+--------------+---------+-----+--------+----------+----------+----------+---------+--------------------+
|  1.5267562618665E9|192.168.2.5|  38792.0|200.168.87.203|  59353.0|  tcp|2.998333|         0|         0|        S0|Malicious|PartOfAHorizontal...|
|1.526756268874876E9|192.168.2.5|  38792.0|200.168.87.203|  59353.0|  tcp|       -|         -|         -|        S0|Malicious|PartOfAHorizontal...|
|1.526756272877722E9|192.168.2.5|  38793.0|200.168.87.203|  59353.0|  tcp|2.997182|         0|         0|        S0|Malicious|PartOfAHorizontal...|
|1.526756279884959E9|192.168.2.5|  38793.0|200.168.87.203|  59353.0|  tcp|       -|         -|         -|       

In [9]:
from pyspark.sql.functions import col

# Data Cleaning & Type Casting
# Convert string columns to numeric types for analysis
# We also rename nested columns for easier access
df_cleaned = df.withColumn("duration", col("duration").cast("double")) \
    .withColumn("orig_bytes", col("orig_bytes").cast("long")) \
    .withColumn("resp_bytes", col("resp_bytes").cast("long")) \
    .withColumn("orig_port", col("`id.orig_p`").cast("int")) \
    .withColumn("resp_port", col("`id.resp_p`").cast("int")) \
    .withColumn("orig_ip", col("`id.orig_h`")) \
    .withColumn("resp_ip", col("`id.resp_h`")) \
    .fillna(0, subset=["duration", "orig_bytes", "resp_bytes"])

print("Schema after cleaning:")
df_cleaned.printSchema()

Schema after cleaning:
root
 |-- ts: double (nullable = true)
 |-- id.orig_h: string (nullable = true)
 |-- id.orig_p: double (nullable = true)
 |-- id.resp_h: string (nullable = true)
 |-- id.resp_p: double (nullable = true)
 |-- proto: string (nullable = true)
 |-- duration: double (nullable = false)
 |-- orig_bytes: long (nullable = true)
 |-- resp_bytes: long (nullable = true)
 |-- conn_state: string (nullable = true)
 |-- label: string (nullable = true)
 |-- detailed-label: string (nullable = true)
 |-- orig_port: integer (nullable = true)
 |-- resp_port: integer (nullable = true)
 |-- orig_ip: string (nullable = true)
 |-- resp_ip: string (nullable = true)



In [10]:
# 3. Statistical Analysis: Timing & Bytes
from pyspark.sql.functions import avg, stddev, min, max

print("Statistics for Duration and Bytes by Label:")
df_cleaned.groupBy("label") \
    .agg(
        avg("duration").alias("avg_duration"),
        avg("orig_bytes").alias("avg_orig_bytes"),
        avg("resp_bytes").alias("avg_resp_bytes")
    ) \
    .show(truncate=False)

Statistics for Duration and Bytes by Label:
+---------+------------------+------------------+----------------+
|label    |avg_duration      |avg_orig_bytes    |avg_resp_bytes  |
+---------+------------------+------------------+----------------+
|Malicious|2.4333049935738016|22.744885100318672|72.3492712793682|
|Benign   |4.463892593474415 |39.607804232804234|97.831569664903 |
+---------+------------------+------------------+----------------+



In [11]:
# Feature Extraction
# Calculate derived features
df_features = df_cleaned.withColumn("total_bytes", col("orig_bytes") + col("resp_bytes")) \
    .withColumn("bytes_per_sec", (col("orig_bytes") + col("resp_bytes")) / (col("duration") + 0.000001)) \
    .withColumn("orig_bytes_ratio", col("orig_bytes") / (col("orig_bytes") + col("resp_bytes") + 0.000001))

print("Features extracted:")
df_features.select("label", "duration", "total_bytes", "bytes_per_sec", "orig_bytes_ratio").show(5)

Features extracted:
+---------+--------+-----------+-------------+----------------+
|    label|duration|total_bytes|bytes_per_sec|orig_bytes_ratio|
+---------+--------+-----------+-------------+----------------+
|Malicious|2.998333|          0|          0.0|             0.0|
|Malicious|     0.0|          0|          0.0|             0.0|
|Malicious|2.997182|          0|          0.0|             0.0|
|Malicious|     0.0|          0|          0.0|             0.0|
|Malicious|2.996286|          0|          0.0|             0.0|
+---------+--------+-----------+-------------+----------------+
only showing top 5 rows



In [12]:
# Heuristic Signature Generation
# Identify potential signatures based on high frequency ports in malicious traffic
# that are relatively rare in benign traffic.

# 1. Calculate port frequency for Malicious
malicious_ports = df_cleaned.filter(col("label").contains("Malicious")) \
    .groupBy("resp_port", "proto") \
    .count() \
    .withColumnRenamed("count", "malicious_count")

# 2. Calculate port frequency for Benign
benign_ports = df_cleaned.filter(~col("label").contains("Malicious")) \
    .groupBy("resp_port", "proto") \
    .count() \
    .withColumnRenamed("count", "benign_count")

# 3. Join and calculate ratio
signatures = malicious_ports.join(benign_ports, ["resp_port", "proto"], "left_outer") \
    .fillna(0, subset=["benign_count"]) \
    .withColumn("risk_ratio", col("malicious_count") / (col("benign_count") + 1)) \
    .orderBy(col("risk_ratio").desc())

print("Potential Heuristic Signatures (High Risk Ports):")
signatures.show(10)

# Example of a simple rule generation
print("Generated Rules:")
top_signatures = signatures.limit(5).collect()
for row in top_signatures:
    print(f"IF dest_port == {row['resp_port']} AND proto == '{row['proto']}' THEN POTENTIAL MALWARE (Risk Ratio: {row['risk_ratio']:.2f})")

Potential Heuristic Signatures (High Risk Ports):
+---------+-----+---------------+------------+-----------------+
|resp_port|proto|malicious_count|benign_count|       risk_ratio|
+---------+-----+---------------+------------+-----------------+
|    59353|  tcp|          23295|           0|          23295.0|
|       22|  tcp|         128264|        2591|49.48456790123457|
|     2407|  tcp|              8|           0|              8.0|
+---------+-----+---------------+------------+-----------------+

Generated Rules:
IF dest_port == 59353 AND proto == 'tcp' THEN POTENTIAL MALWARE (Risk Ratio: 23295.00)
IF dest_port == 22 AND proto == 'tcp' THEN POTENTIAL MALWARE (Risk Ratio: 49.48)
IF dest_port == 2407 AND proto == 'tcp' THEN POTENTIAL MALWARE (Risk Ratio: 8.00)


In [13]:
# 2. Statistical Analysis: Port Usage for Malicious Traffic
print("Top 10 Destination Ports for Malicious Traffic:")
df_cleaned.filter(col("label").contains("Malicious")) \
    .groupBy("resp_port", "proto") \
    .count() \
    .orderBy(col("count").desc()) \
    .show(10)

Top 10 Destination Ports for Malicious Traffic:
+---------+-----+------+
|resp_port|proto| count|
+---------+-----+------+
|       22|  tcp|128264|
|    59353|  tcp| 23295|
|     2407|  tcp|     8|
+---------+-----+------+



In [14]:
# 1. Label Distribution
print("Distribution of Labels:")
df_cleaned.groupBy("label").count().orderBy(col("count").desc()).show(truncate=False)

Distribution of Labels:
+---------+------+
|label    |count |
+---------+------+
|Malicious|151567|
|Benign   |4536  |
+---------+------+

