In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier

In [2]:
spark = (
    SparkSession.builder.appName("iot")
    .master("local[*]")
    .config("spark.driver.host", "localhost")
    .config("spark.driver.bindAddress", "127.0.0.1")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

In [3]:
output_dir = r"C:\Users\gabyl\spark_outputs\preprocessing"

df = spark.read.parquet(output_dir)
df.show()

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+--------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+----------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state| history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|               hour|             minute|             second|       day|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+--------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+----------+
|1.5261

In [4]:
df = df.withColumn("is_bad", F.when(F.col("label") != "Benign", 1).otherwise(0))

## Feature Engineering
Let's add some time-series features

In [5]:
# Example of rolling feature generation
df.withColumn(
    "activity_count_last_5m",
    F.count('source_ip').over(Window().partitionBy('source_ip').orderBy(F.col('dt').cast('long')).rangeBetween(-5*60,-1))
).show()

+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+------+--------------+-------------------+-------------------+-------------------+-------------------+----------+------+----------------------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes| label|detailed-label|                 dt|               hour|             minute|             second|       day|is_bad|activity_count_last_5m|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+------+--------------+-------------------+-------------------+-------------------+---------

In [6]:
# Lets create some custom functions
def create_custom_window(
    partition_by: str,
    timestamp_col: str,
    window_in_minutes: int,
):
    window = (
        Window()
        .partitionBy(partition_by)
        .orderBy(F.col(timestamp_col).cast('long'))
        .rangeBetween(-window_in_minutes*60, -1)
    )
    
    return window

def generate_rolling_aggregate(
    col: str,
    partition_by: str = None,
    operation: str = "count",
    timestamp_col: str = "dt",
    window_in_minutes: int = 1,
):
    if partition_by is None:
        partition_by = col
    
    if operation == "count":
        return F.count(col).over(
            create_custom_window(
                partition_by=partition_by,
                timestamp_col=timestamp_col,
                window_in_minutes=window_in_minutes
            )
        )
    elif operation == "sum":
        return F.sum(col).over(
            create_custom_window(
                partition_by=partition_by,
                timestamp_col=timestamp_col,
                window_in_minutes=window_in_minutes
            )
        )
    elif operation == "avg":
        return F.avg(col).over(
            create_custom_window(
                partition_by=partition_by,
                timestamp_col=timestamp_col,
                window_in_minutes=window_in_minutes
            )
        )
    else:
        raise ValueError(f"Operation '{operation}' is not defined.")

In [7]:
# Now we apply the custom feature engineering to create several new features.
# This cell won't take any time because it doesn't really apply to the dataframe. It just defines the calculations to take place.
# We'de need to apply a df.show() for example for it to really compute
df = df.withColumns({
    "source_ip_count_last_5m" : generate_rolling_aggregate(col='source_ip', partition_by='source_ip', operation='count', window_in_minutes=5),
    "source_ip_count_last_30m" : generate_rolling_aggregate(col='source_ip', partition_by='source_ip', operation='count', window_in_minutes=30),
    "source_port_count_last_5m" : generate_rolling_aggregate(col='source_port', partition_by='source_port', operation='count', window_in_minutes=5),
    "source_port_count_last_30m" : generate_rolling_aggregate(col='source_port', partition_by='source_port', operation='count', window_in_minutes=30),
    "dest_ip_count_last_5m" : generate_rolling_aggregate(col='dest_ip', partition_by='dest_ip', operation='count', window_in_minutes=5),
    "dest_ip_count_last_30m" : generate_rolling_aggregate(col='dest_ip', partition_by='dest_ip', operation='count', window_in_minutes=30),
    "dest_port_count_last_5m" : generate_rolling_aggregate(col='dest_port', partition_by='dest_port', operation='count', window_in_minutes=5),
    "dest_port_count_last_30m" : generate_rolling_aggregate(col='dest_port', partition_by='dest_port', operation='count', window_in_minutes=30),
    "source_ip_avg_pkts_last_5m": generate_rolling_aggregate(col='orig_pkts', partition_by='source_ip', operation='avg', window_in_minutes=5),
    "source_ip_avg_pkts_last_30m": generate_rolling_aggregate(col='orig_pkts', partition_by='source_ip', operation='avg', window_in_minutes=30),
    "source_ip_avg_bytes_last_5m": generate_rolling_aggregate(col='orig_ip_bytes', partition_by='source_ip', operation='avg', window_in_minutes=5),
    "source_ip_avg_pkts_last_30m": generate_rolling_aggregate(col='orig_ip_bytes', partition_by='source_ip', operation='avg', window_in_minutes=30),
})

In [8]:
output_dir_fe = r"C:\Users\gabyl\spark_outputs\feature_engineering"

df.write.mode("overwrite").partitionBy("day").parquet(output_dir)

In [10]:
df_feat_eng = spark.read.parquet(output_dir)
del df
df_feat_eng.show()

+-------------------+------------------+---------------+-----------+--------------+---------+-----+-------+---------+----------+----------+----------+--------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+------+-----------------------+------------------------+-------------------------+--------------------------+---------------------+----------------------+-----------------------+------------------------+--------------------------+---------------------------+---------------------------+----------+
|                 ts|               uid|      source_ip|source_port|       dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state| history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|               hour|             minute|             second|is_bad|source_ip_count_last_5m|source_ip_count_last_30m|source_por