# Phase 8: Testing & Optimization

Checklist: larger dataset, Spark tuning, batch interval/latency, throughput, backpressure/failure handling.

## Dataset & Paths
- Default large file: `data/CTU-IoT-Malware-Capture-35-1conn.log.labeled.csv` (~1.3GB).
- Topic: `network-traffic`.
- Models: `models/feature_pipeline`, `models/rf_model` or `gbt_model` (adjust to the optimized model).

In [23]:
# Spark config for benchmarking
from pyspark.sql import SparkSession, functions as F

KAFKA_BOOTSTRAP_SERVERS = "kafka:29092"  
KAFKA_TOPIC = "network-traffic"
MODEL_PATH = "../models/gbt_model"
PIPELINE_PATH = "../models/feature_pipeline"
CHECKPOINT = "checkpoint_bench_test"

# Tunables (edit and rerun):
PROCESSING_TIME = "2 seconds"          # micro-batch interval
MAX_OFFSETS_PER_TRIGGER = 5000   # Kafka rate limit; tune for throughput/backpressure
PARTITIONS = 4                   # desired parallelism for Kafka input

extra_conf = {
    "spark.sql.shuffle.partitions": "8",
    "spark.executor.memory": "4g",
    "spark.driver.memory": "2g",
    # Backpressure / rate
    "spark.streaming.backpressure.enabled": "true",
    # Enable Kafka consumer parallelism
    "spark.streaming.kafka.maxRatePerPartition": str(MAX_OFFSETS_PER_TRIGGER),
}

conf_builder = SparkSession.builder.appName("IoT Malware Benchmark")
conf_builder = conf_builder.config(
    "spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0,org.mongodb.spark:mongo-spark-connector_2.12:10.1.1"
)
for k, v in extra_conf.items():
    conf_builder = conf_builder.config(k, v)

spark = conf_builder.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
print(spark.version)

3.5.0


## Load models
Swap to Random Forest if desired.

In [24]:
from pyspark.ml import PipelineModel
from pyspark.ml.classification import GBTClassificationModel

pipeline_model = PipelineModel.load(PIPELINE_PATH)
gbt_model = GBTClassificationModel.load(MODEL_PATH)

## Streaming read from Kafka with tunables

In [25]:
raw = (spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP_SERVERS)
    .option("subscribe", KAFKA_TOPIC)
    .option("startingOffsets", "earliest")
    .option("maxOffsetsPerTrigger", MAX_OFFSETS_PER_TRIGGER)
    .load())

# Force partitions
raw = raw.repartition(PARTITIONS)

## Parse, feature-engineer, predict (reuse from realtime_predictor)

In [26]:
from pyspark.sql.functions import from_json, col, from_unixtime, hour, dayofweek, when
from pyspark.sql.types import *

schema = StructType([
    StructField("ts", DoubleType()),
    StructField("id.orig_h", StringType()),
    StructField("id.orig_p", DoubleType()),
    StructField("id.resp_h", StringType()),
    StructField("id.resp_p", DoubleType()),
    StructField("proto", StringType()),
    StructField("duration", StringType()),
    StructField("orig_bytes", StringType()),
    StructField("resp_bytes", StringType()),
    StructField("conn_state", StringType()),
    StructField("label", StringType()),
    StructField("detailed-label", StringType())
])

parsed = (raw.selectExpr("CAST(value AS STRING) as json")
    .select(from_json(col("json"), schema).alias("data"))
    .select("data.*"))

clean = (parsed
    .withColumn("duration", col("duration").cast("double"))
    .withColumn("orig_bytes", col("orig_bytes").cast("long"))
    .withColumn("resp_bytes", col("resp_bytes").cast("long"))
    .withColumn("orig_port", col("`id.orig_p`").cast("int"))
    .withColumn("resp_port", col("`id.resp_p`").cast("int"))
    .withColumnRenamed("id.orig_h", "id_orig_h")
    .withColumnRenamed("id.resp_h", "id_resp_h")
    .withColumnRenamed("id.orig_p", "id_orig_p")
    .withColumnRenamed("id.resp_p", "id_resp_p")
    .fillna(0, subset=["duration", "orig_bytes", "resp_bytes"]))

features = (clean
    .withColumn("timestamp", from_unixtime("ts").cast("timestamp"))
    .withColumn("hour_of_day", hour("timestamp"))
    .withColumn("day_of_week", dayofweek("timestamp"))
    .withColumn("total_bytes", col("orig_bytes") + col("resp_bytes"))
    .withColumn("bytes_per_sec", (col("orig_bytes") + col("resp_bytes")) / (col("duration") + 0.001))
)

transformed = pipeline_model.transform(features)
predictions = gbt_model.transform(transformed)
final = predictions.withColumn(
    "predicted_label", when(col("prediction") == 0.0, "Malicious").when(col("prediction") == 1.0, "Benign").otherwise("Unknown")
)

## Throughput & latency logging
Collect per-batch metrics from query progress.

In [27]:
import time, json
progress_log = []

def log_progress(query):
    lp = query.lastProgress
    if lp:
        progress_log.append(lp)
        print(json.dumps({
            "id": lp.get("id"),
            "batchId": lp.get("batchId"),
            "inputRows": lp.get("numInputRows"),
            "procTimeMs": lp.get("durationMs", {}).get("addBatch"),
            "avgInputPerSec": lp.get("inputRowsPerSecond"),
            "avgProcPerSec": lp.get("processedRowsPerSecond"),
        }, indent=2))

query = (final.writeStream
    .format("memory")  # in-memory sink for benchmarking
    .queryName("bench_preds")
    .trigger(processingTime=PROCESSING_TIME)
    .option("checkpointLocation", CHECKPOINT)
    .start())

start = time.time()
while time.time() - start < 60:  # run for 1 minute by default
    log_progress(query)
    time.sleep(5)

query.stop()

{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 2,
  "inputRows": 5000,
  "procTimeMs": 217,
  "avgInputPerSec": 2500.0,
  "avgProcPerSec": 17730.496453900712
}
{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 4,
  "inputRows": 5000,
  "procTimeMs": 179,
  "avgInputPerSec": 2500.0,
  "avgProcPerSec": 20746.88796680498
}
{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 7,
  "inputRows": 5000,
  "procTimeMs": 199,
  "avgInputPerSec": 2498.7506246876565,
  "avgProcPerSec": 18315.018315018315
}
{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 9,
  "inputRows": 5000,
  "procTimeMs": 124,
  "avgInputPerSec": 2500.0,
  "avgProcPerSec": 26595.744680851065
}
{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 12,
  "inputRows": 5000,
  "procTimeMs": 168,
  "avgInputPerSec": 2497.5024975024976,
  "avgProcPerSec": 20833.333333333336
}
{
  "id": "d51c9566-baf9-4dbe-80b3-9d53063a4892",
  "batchId": 14,
  "inputRows": 5000,
  "procT

## Analyze logged metrics
Compute summary throughput and latency after run.

In [28]:
import pandas as pd

if progress_log:
    df = pd.DataFrame([{
        "batchId": p.get("batchId"),
        "inputRows": p.get("numInputRows"),
        "procTimeMs": p.get("durationMs", {}).get("addBatch"),
        "inputRowsPerSecond": p.get("inputRowsPerSecond"),
        "processedRowsPerSecond": p.get("processedRowsPerSecond"),
    } for p in progress_log])
    display(df)
    print("Summary:")
    print(df.describe())
else:
    print("No progress logged. Ensure the stream is running and data is flowing.")

Unnamed: 0,batchId,inputRows,procTimeMs,inputRowsPerSecond,processedRowsPerSecond
0,2,5000,217,2500.0,17730.496454
1,4,5000,179,2500.0,20746.887967
2,7,5000,199,2498.750625,18315.018315
3,9,5000,124,2500.0,26595.744681
4,12,5000,168,2497.502498,20833.333333
5,14,5000,100,2500.0,30674.846626
6,17,5000,101,2498.750625,31446.540881
7,19,5000,100,2498.750625,31847.133758
8,22,5000,213,2501.250625,18050.541516
9,24,5000,111,2500.0,27932.960894


Summary:
         batchId  inputRows  procTimeMs  inputRowsPerSecond  \
count  11.000000       11.0   11.000000           11.000000   
mean   14.272727     5000.0  146.727273         2499.432329   
std     8.295672        0.0   48.793628            1.024724   
min     2.000000     5000.0  100.000000         2497.502498   
25%     8.000000     5000.0  101.500000         2498.750625   
50%    14.000000     5000.0  124.000000         2500.000000   
75%    20.500000     5000.0  189.000000         2500.000000   
max    27.000000     5000.0  217.000000         2501.250625   

       processedRowsPerSecond  
count               11.000000  
mean             25074.461277  
std               5976.420245  
min              17730.496454  
25%              19530.953141  
50%              26595.744681  
75%              31060.693753  
max              31847.133758  


## Backpressure & failure handling notes
- `maxOffsetsPerTrigger` and `spark.streaming.kafka.maxRatePerPartition` limit ingest rate. Lower if Spark lags.
- `spark.streaming.backpressure.enabled=true` lets Spark adapt rate for older APIs; Structured Streaming relies on `maxOffsetsPerTrigger`.
- To test backpressure, start with high rate, watch if `procTimeMs` > batch interval; reduce rate or increase partitions/executors.
- Failure simulation: kill the consumer; ensure checkpoint dir is persisted so offsets resume; verify no data loss/duplication in sink.
- For production sinks (e.g., Mongo), swap `format('memory')` for `format('mongodb')` and ensure idempotent writes or unique keys.

## Static batch benchmark
For off-line timing without Kafka, read the large CSV directly, run the pipeline, and time the batch job.
Only a quick, non-streaming sanity check and baseline.

In [35]:

import time
from pyspark.sql import functions as F

file_path = "../data/CTU-IoT-Malware-Capture-35-1conn.log.labeled.csv"

begin = time.time()
raw_df = spark.read.option("sep", "|").csv(file_path, header=True, inferSchema=True)
print(f"Loaded {raw_df.count()} rows in {time.time()-begin:.2f}s")

# Align schema with training/streaming
static_df = (
    raw_df
    .withColumnRenamed("id.orig_h", "id_orig_h")
    .withColumnRenamed("id.resp_h", "id_resp_h")
    .withColumnRenamed("id.orig_p", "id_orig_p")
    .withColumnRenamed("id.resp_p", "id_resp_p")
    .withColumn("duration", F.col("duration").cast("double"))
    .withColumn("orig_bytes", F.col("orig_bytes").cast("long"))
    .withColumn("resp_bytes", F.col("resp_bytes").cast("long"))
    .withColumn("orig_port", F.col("id_orig_p").cast("int"))
    .withColumn("resp_port", F.col("id_resp_p").cast("int"))
    .fillna(0, subset=["duration", "orig_bytes", "resp_bytes"])
)
# Feature columns to match training
static_df = (
    static_df
    .withColumn("timestamp", F.from_unixtime("ts").cast("timestamp"))
    .withColumn("hour_of_day", F.hour("timestamp"))
    .withColumn("day_of_week", F.dayofweek("timestamp"))
    .withColumn("total_bytes", F.col("orig_bytes") + F.col("resp_bytes"))
    .withColumn("bytes_per_sec", (F.col("orig_bytes") + F.col("resp_bytes")) / (F.col("duration") + F.lit(0.001)))
)
print("Static DF schema before pipeline:")
static_df.printSchema()

# Apply feature pipeline and model
static_features = pipeline_model.transform(static_df)
static_preds = gbt_model.transform(static_features)
static_preds.groupBy("prediction").count().show()
print(f"Total time: {time.time()-begin:.2f}s")

Loaded 10447787 rows in 2.88s
Static DF schema before pipeline:
root
 |-- ts: double (nullable = true)
 |-- uid: string (nullable = true)
 |-- id_orig_h: string (nullable = true)
 |-- id_orig_p: double (nullable = true)
 |-- id_resp_h: string (nullable = true)
 |-- id_resp_p: double (nullable = true)
 |-- proto: string (nullable = true)
 |-- service: string (nullable = true)
 |-- duration: double (nullable = false)
 |-- orig_bytes: long (nullable = true)
 |-- resp_bytes: long (nullable = true)
 |-- conn_state: string (nullable = true)
 |-- local_orig: string (nullable = true)
 |-- local_resp: string (nullable = true)
 |-- missed_bytes: double (nullable = true)
 |-- history: string (nullable = true)
 |-- orig_pkts: double (nullable = true)
 |-- orig_ip_bytes: double (nullable = true)
 |-- resp_pkts: double (nullable = true)
 |-- resp_ip_bytes: double (nullable = true)
 |-- tunnel_parents: string (nullable = true)
 |-- label: string (nullable = true)
 |-- detailed-label: string (nullable