01_data_cleaning_and_eda_pyspark.py

Load large CSV with Spark, perform cleaning, basic EDA stats, and save cleaned parquet for next steps.

In [1]:
#Install the required libraries
!pip install pyspark
!pip install spark
from pyspark.sql import SparkSession
# Initialize a Spark session
spark = SparkSession.builder.appName("01_data_cleaning_and_eda").getOrCreate()
spark.sparkContext.setLogLevel("WARN")



In [2]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, isnan, count, lit
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as F
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input", default="/content/data/cicids2017_combinenew.csv")
parser.add_argument("--out_parquet", default="/content/data/cleaned.parquet")
parser.add_argument("--sample_fraction", type=float, default=0.02,
                    help="Fraction to sample for heavy EDA plots/stats (small fraction recommended)")

args = parser.parse_args(args=[])


In [3]:
print("Reading CSV:", args.input)
cicids_data_galla145 = spark.read.option("header", "true").option("inferSchema", "true").csv(args.input)
print("Initial count:", cicids_data_galla145.count())
print("Columns:", cicids_data_galla145.columns[:40])

Reading CSV: /content/data/cicids2017_combinenew.csv
Initial count: 2830743
Columns: [' Destination Port', ' Flow Duration', ' Total Fwd Packets', ' Total Backward Packets', 'Total Length of Fwd Packets', ' Total Length of Bwd Packets', ' Fwd Packet Length Max', ' Fwd Packet Length Min', ' Fwd Packet Length Mean', ' Fwd Packet Length Std', 'Bwd Packet Length Max', ' Bwd Packet Length Min', ' Bwd Packet Length Mean', ' Bwd Packet Length Std', 'Flow Bytes/s', ' Flow Packets/s', ' Flow IAT Mean', ' Flow IAT Std', ' Flow IAT Max', ' Flow IAT Min', 'Fwd IAT Total', ' Fwd IAT Mean', ' Fwd IAT Std', ' Fwd IAT Max', ' Fwd IAT Min', 'Bwd IAT Total', ' Bwd IAT Mean', ' Bwd IAT Std', ' Bwd IAT Max', ' Bwd IAT Min', 'Fwd PSH Flags', ' Bwd PSH Flags', ' Fwd URG Flags', ' Bwd URG Flags', ' Fwd Header Length34', ' Bwd Header Length', 'Fwd Packets/s', ' Bwd Packets/s', ' Min Packet Length', ' Max Packet Length']


In [4]:
data_types = cicids_data_galla145.dtypes
print("Column data_types (first 40):")
for c,t in data_types[:40]:
    print(f"  {c}: {t}")

Column data_types (first 40):
   Destination Port: int
   Flow Duration: int
   Total Fwd Packets: int
   Total Backward Packets: int
  Total Length of Fwd Packets: int
   Total Length of Bwd Packets: int
   Fwd Packet Length Max: int
   Fwd Packet Length Min: int
   Fwd Packet Length Mean: double
   Fwd Packet Length Std: double
  Bwd Packet Length Max: int
   Bwd Packet Length Min: int
   Bwd Packet Length Mean: double
   Bwd Packet Length Std: double
  Flow Bytes/s: double
   Flow Packets/s: double
   Flow IAT Mean: double
   Flow IAT Std: double
   Flow IAT Max: int
   Flow IAT Min: int
  Fwd IAT Total: int
   Fwd IAT Mean: double
   Fwd IAT Std: double
   Fwd IAT Max: int
   Fwd IAT Min: int
  Bwd IAT Total: int
   Bwd IAT Mean: double
   Bwd IAT Std: double
   Bwd IAT Max: int
   Bwd IAT Min: int
  Fwd PSH Flags: int
   Bwd PSH Flags: int
   Fwd URG Flags: int
   Bwd URG Flags: int
   Fwd Header Length34: bigint
   Bwd Header Length: int
  Fwd Packets/s: double
   Bwd Packets/s: 

In [5]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, NumericType

exprs = []
for c, t in cicids_data_galla145.dtypes:
    col = F.col(c)

    # Base null check
    condition = col.isNull()

    # Add isnan only for numeric types
    if isinstance(cicids_data_galla145.schema[c].dataType, NumericType):
        condition = condition | F.isnan(col)

    # Add empty-string check only for string types
    if isinstance(cicids_data_galla145.schema[c].dataType, StringType):
        condition = condition | (col == '')

    exprs.append(F.count(F.when(condition, c)).alias(c))

null_counts = cicids_data_galla145.select(exprs)

print("Null counts (show):")
null_counts.show(truncate=False)



Null counts (show):
+-----------------+--------------+------------------+-----------------------+---------------------------+----------------------------+----------------------+----------------------+-----------------------+----------------------+---------------------+----------------------+-----------------------+----------------------+------------+---------------+--------------+-------------+-------------+-------------+-------------+-------------+------------+------------+------------+-------------+-------------+------------+------------+------------+-------------+--------------+--------------+--------------+--------------------+------------------+-------------+--------------+------------------+------------------+-------------------+------------------+-----------------------+--------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+--------------+--------------------+---------------------+---------------------+-----

In [6]:
total = cicids_data_galla145.count()
nulls = null_counts.collect()[0].asDict()
cols_to_drop = [c for c,v in nulls.items() if v > 0.5 * total]
if cols_to_drop:
    print("Dropping columns with >50% missing:", cols_to_drop)
    cicids_data_galla145 = cicids_data_galla145.drop(*cols_to_drop)
else:
    print("No columns dropped for missingness.")

No columns dropped for missingness.


In [7]:
before = cicids_data_galla145.count()
cicids_data_galla145 = cicids_data_galla145.dropDuplicates()
after = cicids_data_galla145.count()

In [8]:
print(before)
print(after)

2830743
2499784


In [9]:
possible_labels = [c for c in cicids_data_galla145.columns if 'label' in c.lower() or 'attack' in c.lower() or c.lower()=='class']
label_col = possible_labels[0] if possible_labels else None
print("Detected label column:", label_col)

Detected label column:  Label


In [10]:
if label_col:
    cicids_data_galla145 = cicids_data_galla145.withColumn("label_str", F.trim(F.col(label_col).cast("string")))
    # create binary label: benign -> 0, else -> 1
    cicids_data_galla145 = cicids_data_galla145.withColumn("y_binary", when(F.lower(F.col("label_str")).isin('benign','normal','0','none'), lit(0)).otherwise(lit(1)).cast("integer"))
    print("Label counts:")
    cicids_data_galla145.groupBy("y_binary").count().show()

Label counts:
+--------+-------+
|y_binary|  count|
+--------+-------+
|       1| 425878|
|       0|2073906|
+--------+-------+



In [11]:
label_cols_to_exclude = {label_col, "label_str", "y_binary"} # Define columns that should never be cast to numeric

for c,t in cicids_data_galla145.dtypes:
    if c in label_cols_to_exclude:
        continue # Skip label columns
    if t == 'string':
        # attempt to cast to double with safe check on sample
        try:
            cicids_data_galla145 = cicids_data_galla145.withColumn(c, F.when(F.col(c).rlike('^-?\d+(\.\d+)?$'), F.col(c).cast(DoubleType())).otherwise(F.col(c)))
        except Exception:
            pass

  cicids_data_galla145 = cicids_data_galla145.withColumn(c, F.when(F.col(c).rlike('^-?\d+(\.\d+)?$'), F.col(c).cast(DoubleType())).otherwise(F.col(c)))


In [12]:
from pyspark.sql import functions as F

numeric_cols = []
for c, t in cicids_data_galla145.dtypes:
    if t != "string":
        numeric_cols.append(c)

print("Numeric columns count:", len(numeric_cols))


Numeric columns count: 79


In [13]:
safe_df = cicids_data_galla145.select([
    F.expr(f"try_cast(`{c}` as double)").alias(c) if c in numeric_cols else F.col(c)
    for c in cicids_data_galla145.columns
])


In [14]:
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType

safe_df = cicids_data_galla145.select([
    F.expr(f"try_cast(`{c}` as double)").alias(c)
    if t != "string" else F.col(c)
    for c, t in cicids_data_galla145.dtypes
])


In [15]:
numeric_cols = [
    f.name for f in safe_df.schema.fields
    if isinstance(f.dataType, DoubleType)
]

print("Imputer-safe numeric columns:", len(numeric_cols))


Imputer-safe numeric columns: 79


In [16]:
string_cols = [c for c,t in cicids_data_galla145.dtypes if t == 'string']
for c in string_cols:
    nnull = cicids_data_galla145.filter((F.col(c).isNull()) | (F.col(c) == '')).count()
    if nnull > 0:
        cicids_data_galla145 = cicids_data_galla145.withColumn(c, when((F.col(c).isNull()) | (F.col(c) == ''), lit('missing')).otherwise(F.col(c)))
        print(f"Filled {nnull} nulls in string column {c}")

In [17]:
numeric_cols = []
for c, t in cicids_data_galla145.dtypes:
    if t != "string":
        numeric_cols.append(c)


In [18]:
[c for c in numeric_cols if 'label' in c.lower()]


[]

In [19]:
df = cicids_data_galla145

for c in df.columns:
    df = df.withColumnRenamed(c, c.strip())


In [20]:
label_cols = {"Label", "label_str"}


In [21]:
from pyspark.sql.types import NumericType

numeric_cols = [
    f.name for f in df.schema.fields
    if isinstance(f.dataType, NumericType)
    and f.name not in label_cols
]


In [22]:
for c in numeric_cols:
    df = df.withColumn(
        c,
        F.expr(f"try_cast(`{c}` as double)")
    )


In [23]:
from pyspark.ml.feature import Imputer

imputer = Imputer(
    inputCols=numeric_cols,
    outputCols=numeric_cols,
    strategy="median"
)

df = imputer.fit(df).transform(df)


In [24]:
df.printSchema()


root
 |-- Destination Port: double (nullable = true)
 |-- Flow Duration: double (nullable = true)
 |-- Total Fwd Packets: double (nullable = true)
 |-- Total Backward Packets: double (nullable = true)
 |-- Total Length of Fwd Packets: double (nullable = true)
 |-- Total Length of Bwd Packets: double (nullable = true)
 |-- Fwd Packet Length Max: double (nullable = true)
 |-- Fwd Packet Length Min: double (nullable = true)
 |-- Fwd Packet Length Mean: double (nullable = true)
 |-- Fwd Packet Length Std: double (nullable = true)
 |-- Bwd Packet Length Max: double (nullable = true)
 |-- Bwd Packet Length Min: double (nullable = true)
 |-- Bwd Packet Length Mean: double (nullable = true)
 |-- Bwd Packet Length Std: double (nullable = true)
 |-- Flow Bytes/s: double (nullable = true)
 |-- Flow Packets/s: double (nullable = true)
 |-- Flow IAT Mean: double (nullable = true)
 |-- Flow IAT Std: double (nullable = true)
 |-- Flow IAT Max: double (nullable = true)
 |-- Flow IAT Min: double (nulla

In [25]:
from pyspark.sql import functions as F

df = df.withColumn(
    "Label",
    F.when(F.col("Label") == "BENIGN", F.lit(0.0))
     .when(F.col("Label") == "MALICIOUS", F.lit(1.0))
     .otherwise(F.col("Label").cast("double"))
)


In [26]:
df_to_write = df.drop("Label")

df_to_write.write.mode("overwrite").parquet(args.out_parquet)


In [27]:
sample = df.sample(withReplacement=False, fraction=args.sample_fraction, seed=42)
sample_count = sample.count()
print("Sample size for EDA:", sample_count)

Sample size for EDA: 50015


In [28]:
df_to_write = df.drop("Label", "label_str")

df_to_write.write.mode("overwrite").parquet(args.out_parquet)

In [29]:
sample.printSchema()


root
 |-- Destination Port: double (nullable = true)
 |-- Flow Duration: double (nullable = true)
 |-- Total Fwd Packets: double (nullable = true)
 |-- Total Backward Packets: double (nullable = true)
 |-- Total Length of Fwd Packets: double (nullable = true)
 |-- Total Length of Bwd Packets: double (nullable = true)
 |-- Fwd Packet Length Max: double (nullable = true)
 |-- Fwd Packet Length Min: double (nullable = true)
 |-- Fwd Packet Length Mean: double (nullable = true)
 |-- Fwd Packet Length Std: double (nullable = true)
 |-- Bwd Packet Length Max: double (nullable = true)
 |-- Bwd Packet Length Min: double (nullable = true)
 |-- Bwd Packet Length Mean: double (nullable = true)
 |-- Bwd Packet Length Std: double (nullable = true)
 |-- Flow Bytes/s: double (nullable = true)
 |-- Flow Packets/s: double (nullable = true)
 |-- Flow IAT Mean: double (nullable = true)
 |-- Flow IAT Std: double (nullable = true)
 |-- Flow IAT Max: double (nullable = true)
 |-- Flow IAT Min: double (nulla

In [30]:
sample_clean = sample.drop("Label")


In [31]:
sample_clean.toPandas().to_csv(
    "/content/results/sample_for_eda.csv",
    index=False
)
