# Load Dataset

In [1]:
%pip install pyspark
%pip install seaborn
%pip install seaborn matplotlib pandas numpy



Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# imports
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import split, col, count, when, sum, expr, udf
from pyspark.sql.types import FloatType, DoubleType
from pyspark.sql.window import Window
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml import Pipeline

import math
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# from google.colab import files

In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("NetTraffic").getOrCreate()
spark.sparkContext.setLogLevel("INFO")

print("✅ Spark is running!")


In [None]:
# Path to the file (same folder as notebook)
file_path = "dataset_malware.csv" # change name if needed

# Load the CSV into a Spark DataFrame
df = spark.read.csv(file_path, header=True, inferSchema=True, sep="|")

# Show first 5 rows
df.show(5)

+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|                 ts|               uid|    id.orig_h|id.orig_p|     id.resp_h|id.resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|          label|detailed-label|
+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|1.545402842863612E9|CdNmOg26ZIaBRzPvWj|192.168.1.196|  59932.0|104.248.160.24|     80.0|  tcp|      -|3.097754|         0|         0|        S0|      

# Light Data Exploration

## Display first 5 rows in dataset

In [None]:
# print(df_split.head())
# Show first 5 rows
df.show(5)

+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|                 ts|               uid|    id_orig_h|id_orig_p|     id_resp_h|id_resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|          label|detailed-label|
+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|1.545402842863612E9|CdNmOg26ZIaBRzPvWj|192.168.1.196|  59932.0|104.248.160.24|     80.0|  tcp|   NULL|3.097754|       0.0|       0.0|        S0|      

## Show schema to identify column types

In [55]:
df.printSchema()

root
 |-- ts: double (nullable = true)
 |-- uid: string (nullable = true)
 |-- id_orig_h: string (nullable = true)
 |-- id_orig_p: float (nullable = true)
 |-- id_resp_h: string (nullable = true)
 |-- id_resp_p: float (nullable = true)
 |-- proto: string (nullable = true)
 |-- service: string (nullable = true)
 |-- duration: float (nullable = true)
 |-- orig_bytes: float (nullable = true)
 |-- resp_bytes: float (nullable = true)
 |-- conn_state: string (nullable = true)
 |-- local_orig: string (nullable = true)
 |-- local_resp: string (nullable = true)
 |-- missed_bytes: float (nullable = true)
 |-- history: string (nullable = true)
 |-- orig_pkts: float (nullable = true)
 |-- orig_ip_bytes: float (nullable = true)
 |-- resp_pkts: float (nullable = true)
 |-- resp_ip_bytes: float (nullable = true)
 |-- tunnel_parents: string (nullable = true)
 |-- label: string (nullable = true)
 |-- detailed-label: string (nullable = true)



## Check number of rows and columns

In [53]:
num_rows = df.count()
print(f"🧾 Number of rows: {num_rows}")
num_columns = len(df.columns)
print(f"🧾 Number of columns: {num_columns}")

🧾 Number of rows: 10447787
🧾 Number of columns: 23


# Data Cleaning

## Cast '-' to NaN change column names so that Spark will not misinterpret them

In [None]:
from pyspark.sql.functions import col, when

for column in df.columns:
    df = df.withColumn(column, when(col(f"`{column}`") == "-", None).otherwise(col(f"`{column}`")))

# Rename columns to replace dots with underscores
new_column_names = [c.replace(".", "_") for c in df.columns]
df = df.toDF(*new_column_names)

# Check
df.show(5)

+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|                 ts|               uid|    id_orig_h|id_orig_p|     id_resp_h|id_resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|          label|detailed-label|
+-------------------+------------------+-------------+---------+--------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------------+--------------+
|1.545402842863612E9|CdNmOg26ZIaBRzPvWj|192.168.1.196|  59932.0|104.248.160.24|     80.0|  tcp|   NULL|3.097754|         0|         0|        S0|      

## Dealing with missing values

In [None]:
from pyspark.sql.functions import col, sum as spark_sum

null_values = (
    df.select([
        spark_sum(col(c).isNull().cast("int")).alias(c)
        for c in df.columns
    ])
    .toPandas()
    .T.rename(columns={0: "Count"})
)

display(null_values)

Unnamed: 0,Count
ts,0
uid,0
id_orig_h,0
id_orig_p,0
id_resp_h,0
id_resp_p,0
proto,0
service,10446261
duration,4432615
orig_bytes,4432615


Since duration, orig_bytes, resp_bytes all have same number of null values (4432615), very likely they are from the same few rows. All three columns are critical for modelling so it is best to drop the rows containing null values in those fields.

Others like service, local_orig, local_resp, tunnel_parents amd detailed_label all have more than 10446261 out of 10447787 null values. These will not help our model and will just bloat memory, so we can drop these columns.

In [None]:
from pyspark.sql.functions import when, col

# Columns you care about
columns_to_clean = ['duration', 'orig_bytes', 'resp_bytes']

# Replace "NULL" or "null" strings with actual nulls (None)
for c in columns_to_clean:
    df = df.withColumn(c, when((col(c) == "NULL") | (col(c) == "null"), None).otherwise(col(c)))

# Now drop rows where any of these are null
df = df.dropna(subset=columns_to_clean)

from pyspark.sql.functions import col, sum as spark_sum

# now drop columns with largely null values
columns_to_drop = ["service", "local_orig", "local_resp", "tunnel_parents", "missed_bytes"]
df = df.drop(*columns_to_drop)


null_values = (
    df.select([
        spark_sum(col(c).isNull().cast("int")).alias(c)
        for c in df.columns
    ])
    .toPandas()
    .T.rename(columns={0: "Count"})
)

display(null_values)

Unnamed: 0,Count
ts,0
uid,0
id_orig_h,0
id_orig_p,0
id_resp_h,0
id_resp_p,0
proto,0
service,6013804
duration,0
orig_bytes,0


They are indeed from the same few rows. now there are 0 NULL count for duration, orig_bytes and resp_bytes.

## check number of rows and colums after dropping rows with null values

In [11]:
num_rows = df.count()
print(f"🧾 Number of rows: {num_rows}")
num_columns = len(df.columns)
print(f"🧾 Number of columns: {num_columns}")

🧾 Number of rows: 6015172
🧾 Number of columns: 23


# Deeper Data Exploration

## Checking for Distribution of Target Variable


In [12]:
# Register DataFrame as a SQL table
df.createOrReplaceTempView("malware_data")

# Define SQL Query to count the number of rows with 0 and 1 in the target variable
query = """
SELECT
    label,
    COUNT(*) AS count
FROM
    malware_data
GROUP BY
    label
"""

# Run the SQL query
result = spark.sql(query)
result.show()

+------------------+-------+
|             label|  count|
+------------------+-------+
|            Benign|4108340|
|   Malicious   C&C|     33|
|Malicious   Attack|      3|
|  Malicious   DDoS|1906796|
+------------------+-------+



Roughly 2:1 ratio, mildly imbalanced.

## Summary statistics for numeric columns

In [67]:
# convert numerical columns
cols_to_convert = [
    "duration", "orig_bytes", "resp_bytes",
    "missed_bytes", "orig_pkts", "orig_ip_bytes", "resp_pkts", "resp_ip_bytes"
]

# Convert to FloatType
for column in cols_to_convert:
    df = df.withColumn(column, col(column).cast(FloatType()))

# Descriptive stats with Pandas for nicer formatting
desc_stats = df.describe(cols_to_convert)
desc_stats_pd = desc_stats.toPandas().set_index('summary').T
display(desc_stats_pd)

summary,count,mean,stddev,min,max
duration,6015172,4.238091014245806,36.329979153533,2e-06,85755.84
orig_bytes,6015172,356012104.9232792,3467692492.0430346,0.0,66205577000.0
resp_bytes,6015172,5591.520989258495,12956898.601109171,0.0,31720511500.0
missed_bytes,6015172,317.3389582209785,778288.884426687,0.0,1908819460.0
orig_pkts,6015172,7.045628121689621,3458.601411709873,0.0,4216883.0
orig_ip_bytes,6015172,291.06053143617504,96841.1278582448,0.0,118072720.0
resp_pkts,6015172,0.00204549429343,1.917404129819029,0.0,4621.0
resp_ip_bytes,6015172,0.8743738333666934,523.3913984945711,0.0,413488.0


## Boxplots for numerical variables

In [None]:
# Boxplots
# Set the number of columns in the grid
# numeric_cols = df_split.select_dtypes(include=["number"]).columns
numeric_cols = [col_name for col_name, dtype in df.dtypes if dtype in ['int', 'double', 'float']]

df_pandas = df.select(numeric_cols).toPandas()

# Create subplots
num_cols = 5

# Calculate the number of rows needed based on the number of variables
num_vars = len(numeric_cols)
num_rows = (num_vars + num_cols - 1) // num_cols

# Create subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10))

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Iterate through each variable and create a horizontal boxplot
for i, col_name in enumerate(numeric_cols):
    sns.boxplot(data=df_pandas[col_name], ax=axes[i], color='skyblue')
    axes[i].set_title(col_name)

# Remove any empty subplots
for i in range(num_vars, num_rows * num_cols):
    fig.delaxes(axes[i])

# Adjust spacing between subplots
plt.tight_layout(rect=(0, 0, 1.5, 1.5))

# Show the grid of boxplots
plt.show()

## Categorical Value Distibution

In [13]:
categorical_cols = ["proto", "conn_state", "history"]

for col in categorical_cols:
    count_df = df.groupBy(col).count().orderBy("count", ascending=False).toPandas()
    print(f"=== {col} Value Counts ===")
    display(count_df)


=== proto Value Counts ===


Unnamed: 0,proto,count
0,tcp,6013405
1,udp,1745
2,icmp,22


=== conn_state Value Counts ===


Unnamed: 0,conn_state,count
0,S0,4105597
1,RSTOS0,1841173
2,OTH,65558
3,REJ,1949
4,SF,747
5,RSTO,52
6,RSTR,47
7,S2,26
8,S1,19
9,SH,2


=== history Value Counts ===


Unnamed: 0,history,count
0,S,4104511
1,I,1841162
2,DTT,65534
3,Sr,1948
4,D,1086
...,...,...
57,ShAdDaRRR,1
58,ShAfdtF,1
59,ShADdFf,1
60,HaDdAr,1


In [14]:
from pyspark.sql.functions import countDistinct

# Count distinct values for each column
cardinality = df.select(
    countDistinct("id_orig_h").alias("id_orig_h_cardinality"),
    countDistinct("id_orig_p").alias("id_orig_p_cardinality"),
    countDistinct("id_resp_h").alias("id_resp_h_cardinality"),
    countDistinct("id_resp_p").alias("id_resp_p_cardinality")
).collect()[0]

# Print the cardinality of each column
print(f"Cardinality of id_orig_h: {cardinality['id_orig_h_cardinality']}")
print(f"Cardinality of id_orig_p: {cardinality['id_orig_p_cardinality']}")
print(f"Cardinality of id_resp_h: {cardinality['id_resp_h_cardinality']}")
print(f"Cardinality of id_resp_p: {cardinality['id_resp_p_cardinality']}")

Cardinality of id_orig_h: 28
Cardinality of id_orig_p: 65536
Cardinality of id_resp_h: 4098969
Cardinality of id_resp_p: 16


id_orig_h has low cardinality -> 28 so likely represent small number of internal source IPs (devices on your network)

id_resp_p has low cardinality -> destination ports, likely standard ones like 80, 443, 22. tells you what kind of service was accessed. 

## Correlation matrix for numeric columns

In [15]:
from pyspark.sql.types import FloatType, DoubleType

# Define numeric features to keep (excluding ID-like columns)
true_numeric_cols = [
    "duration", "orig_bytes", "resp_bytes",
    "missed_bytes", "orig_pkts", "orig_ip_bytes", "resp_pkts", "resp_ip_bytes"
]

# Convert PySpark DataFrame to pandas DataFrame for correlation calculation
pandas_df = df.select(numeric_cols).toPandas()

# Plot correlation heatmap
plt.figure(figsize=(12, 6))
sns.heatmap(pandas_df.corr(), annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Between Numeric Features")
plt.show()

: 

: 

(need to rewrite as we using different dataset)
From the correlation matrix, it can be seen that:
1. Strong Positive Correlation (Close to 1.0)
  - `orig_pkts` and `orig_ip_bytes` are perfectly correlated (1.00), suggesting that packet count and total IP bytes from the origin are directly proportional. One of the variables could be dropped to reduce redundancy.
  - `resp_bytes` and `resp_ip_bytes` are similarly strongly correlated (0.98). This could also mean one of the columns might be redundant.
  - `duration` and `resp_pkts` are also strongly correlated (0.95). This makes sense as longer durations correlate with more response packets, indicating sustained communication during the duration.
2. Weak/Negligible Correlations (Close to 0):
  - Port numbers (`id.orig_p`, `id.resp_p`) have near-zero correlations with other features suggesting that ports do not directly influence traffic metrics (e.g. duration, bytes). This is expected as ports are identifiers, not quantitative measures.
  - Except for its strong positive correlation to `resp_ip_bytes`, `resp_bytes` and most other features have weak/negligible correlations (0.00-0.21) possibly indicating isolated response behaviors.

## Mutual Information Scoring

In [None]:
# Discretise numerical features
numeric_cols = [col_name for col_name, dtype in df.dtypes if dtype in ['int', 'double', 'float']]

# compute mutual info
def compute_mutual_info(df, feature_col, target_col="label", n_bins=5):
    # 5 quantile based bins per numerical col
    if feature_col in numeric_cols:
        df = df.withColumn(
            f"{feature_col}_bin",
            F.ntile(n_bins).over(Window.orderBy(feature_col)))
        feature_col = f"{feature_col}_bin"

    # Compute joint probability P(X,Y)
    joint_prob = (
        df.groupBy(feature_col, target_col)
        .agg(F.count("*").alias("count"))
        .withColumn("p_xy", F.col("count") / df.count())
    )

    # Compute marginal probabilities P(X) and P(Y)
    p_x = joint_prob.groupBy(feature_col).agg(F.sum("p_xy").alias("p_x"))
    p_y = joint_prob.groupBy(target_col).agg(F.sum("p_xy").alias("p_y"))

    # Calculate MI terms: p_xy * log2(p_xy / (p_x * p_y))
    mi_terms = (
        joint_prob.join(p_x, feature_col)
        .join(p_y, target_col)
        .withColumn(
            "mi_term",
            F.col("p_xy") * F.log2(F.col("p_xy") / (F.col("p_x") * F.col("p_y")))
        )
    )

    # Sum MI terms to get final score
    mi = mi_terms.agg(F.sum("mi_term").alias("mi")).collect()[0]["mi"]
    return float(mi) if mi else 0.0

# Compute MI for all features
df_mis = df.select("*")
mi_results = {}

for feature in df_mis.columns:  # Combine all features
    if feature == "label":
        continue
    mi_score = compute_mutual_info(df_mis, feature)
    mi_results[feature] = mi_score
    print(f"MI for {feature}: {mi_score:.4f}")

# Sort the MI and display results
sorted_mi_results = sorted(mi_results.items(), key=lambda x: x[1], reverse=True)
print("\n=== Features sorted by Mutual Information ===")
for feature, mi_score in sorted_mi_results:
    print(f"{feature}: {mi_score:.4f}")

MI for ts: 0.0007
MI for uid: 0.0282
MI for id_orig_h: 0.0072
MI for id_orig_p: 0.0025
MI for id_resp_h: 0.0278
MI for id_resp_p: 0.0016
MI for proto: 0.0116
MI for duration: 0.0039
MI for orig_bytes: 0.0024
MI for resp_bytes: 0.0004
MI for conn_state: 0.0230
MI for history: 0.0176
MI for orig_pkts: 0.0010
MI for orig_ip_bytes: 0.0017
MI for resp_pkts: 0.0023
MI for resp_ip_bytes: 0.0023
MI for detailed-label: 0.0056

=== Features sorted by Mutual Information ===
uid: 0.0282
id_resp_h: 0.0278
conn_state: 0.0230
history: 0.0176
proto: 0.0116
id_orig_h: 0.0072
detailed-label: 0.0056
duration: 0.0039
id_orig_p: 0.0025
orig_bytes: 0.0024
resp_pkts: 0.0023
resp_ip_bytes: 0.0023
orig_ip_bytes: 0.0017
id_resp_p: 0.0016
orig_pkts: 0.0010
ts: 0.0007
resp_bytes: 0.0004


(need to rewrite as we using different dataset now)
**Interpretation of Mutual Information (MI) Scoring:**

MI scoring is a measure of how much information one variable (feature) provides about another (target variable, in this case, `label`). It helps identify which features are most relevant for predicting the target, where a higher MI score means the feature is more informative for the prediction and should be prioritised for model training.

From the results:
The top 5 informative features are
1. `uid` (0.0282)
2. `id_resp_h` (0.0278)
3. `conn_state` (0.0230)
4. `history` (0.0176)
5. `proto` (0.0116)

On the other hand, the features with low/zero MI score are `resp_bytes` (0.0004) and `ts` (0.0007) which have very low MI and may have a negligible effect on prediction.

**Data Leakage Issue**

While `uid` has a high MI score, it could suggest that the `uid` is leaking information about the `label` as `uid` is supposed to be a unique identifier for the connection. Since it is unique per row, it has zero predictive power on new data and cannot be generalised onto new samples. This could cause the model to perform extremely well on training data but fail on unseen data.

# Feature Engineering

## Engineering new Meaningful Features

In [None]:
from pyspark.sql.functions import col, when, lit, log1p

# Total Bytes
df = df.withColumn("total_bytes", col("orig_bytes") + col("resp_bytes"))

# Byte Ratio (orig_bytes / (resp_bytes + 1))
df = df.withColumn("byte_ratio", col("orig_bytes") / (col("resp_bytes") + lit(1)))

# Packet Ratio (orig_pkts / (resp_pkts + 1))
df = df.withColumn("pkt_ratio", col("orig_pkts") / (col("resp_pkts") + lit(1)))

# Total Packet Count
df = df.withColumn("total_pkts", col("orig_pkts") + col("resp_pkts"))

# Throughput = total_bytes / (duration + 1e-6)
df = df.withColumn("throughput", col("total_bytes") / (col("duration") + lit(1e-6)))

# Data-to-packet efficiency
df = df.withColumn("efficiency", col("total_bytes") / (col("total_pkts") + lit(1)))

df.select("total_bytes", "byte_ratio", "throughput", "is_asymmetric").show(5)


## Encoding of Categorical Variables:

Categorical varaiables included are History, proto, conn_state, and id_resp_p. From data exploration, cardinality of each of them are 61, 3, 11, 16 respectively.

History has high cardinality so need to do frequency encoding.
Others have relative low cardinality so one-hot encoding is feasible.

In [None]:
from pyspark.sql.functions import count

## Frequency Encoding of High Cardinality Features such as History

# Calculate frequency counts for each category in the 'history' column
history_counts = df.groupBy("history").agg(count("*").alias("history_count"))

# Join the frequency counts back to the original DataFrame
df = df.join(history_counts, "history", "left")

# Create a new column 'history_freq_encoded' with the frequency counts
df = df.withColumn("history_freq_encoded", col("history_count"))

# Optionally drop the intermediate 'history_count' column
df = df.drop("history_count")

# Show the DataFrame with the frequency encoded column
df.select("history", "history_freq_encoded").show()

In [None]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

# one hot encoding of categorical features with low cardinality

# Create StringIndexers for "proto" and "conn_state"
proto_indexer = StringIndexer(inputCol="proto", outputCol="proto_index")
conn_state_indexer = StringIndexer(inputCol="conn_state", outputCol="conn_state_index")
port_indexer = StringIndexer(inputCol="id_resp_p", outputCol="id_resp_p_index")

# Create OneHotEncoders for the indexed columns
proto_encoder = OneHotEncoder(inputCol="proto_index", outputCol="proto_encoded")
conn_state_encoder = OneHotEncoder(inputCol="conn_state_index", outputCol="conn_state_encoded")
port_encoder = OneHotEncoder(inputCol="id_resp_p_index", outputCol="id_resp_p_encoded")

# Fit and transform the DataFrame
df = proto_indexer.fit(df).transform(df)
df = conn_state_indexer.fit(df).transform(df)
df = port_indexer.fit(df).transform(df)

df = proto_encoder.fit(df).transform(df)
df = conn_state_encoder.fit(df).transform(df)
df = port_encoder.fit(df).transform(df)

# Show the DataFrame with the one-hot encoded columns
df.select("proto", "proto_index", "proto_encoded", "conn_state", "conn_state_index", "conn_state_encoded", "id_resp_p", "id_resp_p_index", "id_resp_p_encoded").show(5)

## Scaling of Numerical Features

In [None]:
# Assuming 'df' is your DataFrame and the features are already created

# Define the columns to be scaled
cols_to_scale = ["duration", "byte_ratio", "pkt_ratio", "total_bytes", "total_pkts", "efficiency", "throughput"]

# Create a VectorAssembler to combine the features into a single vector column
assembler = VectorAssembler(inputCols=cols_to_scale, outputCol="features_unscaled")
df = assembler.transform(df)

# Create a StandardScaler to scale the features
scaler = StandardScaler(inputCol="features_unscaled", outputCol="features_scaled", withStd=True, withMean=True)
scaler_model = scaler.fit(df)
df = scaler_model.transform(df)

# Show the scaled features
df.select("features_scaled").show(5, truncate=False)


## Assemble Features for Modelling

In [None]:
from pyspark.ml.feature import VectorAssembler

final_features = [
    "features_scaled",           # your numeric features
    "proto_encoded",             # one-hot categorical
    "conn_state_encoded",
    "history_encoded",
    "id_resp_p_encoded"
]

assembler = VectorAssembler(inputCols=final_features, outputCol="features")
df_model = assembler.transform(df)

# Model Training and Evaluation

## Logistic Regression

In [None]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

from pyspark.ml.feature import StringIndexer

label_indexer = StringIndexer(inputCol="label", outputCol="label_index")
df_model = label_indexer.fit(df_model).transform(df_model)


lr = LogisticRegression(
    featuresCol="features",
    labelCol="label_index",
    maxIter=20,
    regParam=0.1,
    elasticNetParam=0.0  # L2 regularization (Ridge)
)

lr_model = lr.fit(df_model)

predictions = lr_model.transform(df_model)
predictions.select("label", "label_index", "probability", "prediction").show(5, truncate=False)


evaluator_auc = BinaryClassificationEvaluator(
    labelCol="label_index",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

auc = evaluator_auc.evaluate(predictions)
print(f"🔥 ROC-AUC: {auc:.4f}")


#confusion matrix:
predictions.groupBy("label_index", "prediction").count().show()

train, test = df_model.randomSplit([0.8, 0.2], seed=42)
lr_model = lr.fit(train)
predictions = lr_model.transform(test)


## Random Forest

In [None]:
# to be filled

## Gradient Boosted Trees

In [None]:
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

# gbt regressor
gbt = GBTClassifier(featuresCol="features",
                    labelCol="label_index", 
                    maxIter=50, 
                    maxDepth=5)
# gbt pipeline
gbt_pipeline = Pipeline(stages=[
    label_indexer,
    encoder, # insert encoders !!
    assembler,
    gbt
])
# train the model
gbt_model = gbt_pipeline.fit(train)

# Make predictions with Gradient Boosted Tree
gbt_predictions = gbt_model.transform(test)

# evaluate AUC using BinaryClassificationEvaluator
evaluator_auc = BinaryClassificationEvaluator(
    labelCol="label_index",
    rawPredictionCol="rawPrediction",  # GBT outputs rawPrediction
    metricName="areaUnderROC"
)

gbt_auc = evaluator_auc.evaluate(gbt_predictions)
print(f"GBT AUC: {gbt_auc:.4f}")

evaluator_acc = MulticlassClassificationEvaluator(
    labelCol="label_index", 
    predictionCol="prediction", 
    metricName="accuracy"
)

gbt_accuracy = evaluator_acc.evaluate(gbt_predictions)
print(f"Accuracy: {gbt_accuracy:.4f}")

## MLP

In [None]:
# to be filled

# Hyperparameter Tuning & Cross-Validation

In [None]:
# to be done

# Model Comparison

write interpretations here