<a href="https://colab.research.google.com/github/deepali17043/NetworkIntrusionDetection/blob/main/project_evaluating_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports/Misc

Note: If running for small dataset on Google Colab, uncomment the following three lines/cells

In [40]:
# !pip install pyspark



In [41]:
# %cd /path-to-project/
# replace path-to-project with your working directory that also has the small dataset.

/content/drive/MyDrive/Summer24/BigData/Project


In [42]:
# !ls
# confirm that the input file is present

 generate_small_data.ipynb	   project_report.gdoc
 NF_UQ_NIDS_v2.csv.bz2		  'Screenshot 2024-08-03 at 4.07.31 PM.png'
 output_dir			   small_dataset
 project_evaluating_models.ipynb   small_NF_UQ_NIDS_v2.csv
 Project.ipynb			   spark-3.1.1-bin-hadoop3.2
 project_output_dir		   spark-3.1.1-bin-hadoop3.2.tgz
 project_presentation.gslides


In [43]:
import sys
import numpy as np

from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml.stat import Correlation
from pyspark.ml import Pipeline
from pyspark.mllib.evaluation import MulticlassMetrics

In [44]:
# Initialize Spark Session
spark = SparkSession.builder.appName('NIDS_eval').getOrCreate()

Uncomment the following two lines if running for small dataset on Google Colab

In [45]:
# sys.argv[1] = 'small_NF_UQ_NIDS_v2.csv'
# sys.argv[2] = 'project_output_dir'

In [46]:
input_file = sys.argv[1]
output_dir = sys.argv[2]

In [47]:
output_log = []

# Read Data and EDA

In [48]:
# Load the data
data = spark.read.csv(input_file, header=True, inferSchema=True)

In [49]:
data_schema = data.schema
log = f'Data Read successful, schema infered:\n{data_schema}'

data.printSchema()
output_log.append(log)

root
 |-- IPV4_SRC_ADDR: string (nullable = true)
 |-- L4_SRC_PORT: integer (nullable = true)
 |-- IPV4_DST_ADDR: string (nullable = true)
 |-- L4_DST_PORT: integer (nullable = true)
 |-- PROTOCOL: integer (nullable = true)
 |-- L7_PROTO: double (nullable = true)
 |-- IN_BYTES: integer (nullable = true)
 |-- IN_PKTS: integer (nullable = true)
 |-- OUT_BYTES: integer (nullable = true)
 |-- OUT_PKTS: integer (nullable = true)
 |-- TCP_FLAGS: integer (nullable = true)
 |-- CLIENT_TCP_FLAGS: integer (nullable = true)
 |-- SERVER_TCP_FLAGS: integer (nullable = true)
 |-- FLOW_DURATION_MILLISECONDS: integer (nullable = true)
 |-- DURATION_IN: integer (nullable = true)
 |-- DURATION_OUT: integer (nullable = true)
 |-- MIN_TTL: integer (nullable = true)
 |-- MAX_TTL: integer (nullable = true)
 |-- LONGEST_FLOW_PKT: integer (nullable = true)
 |-- SHORTEST_FLOW_PKT: integer (nullable = true)
 |-- MIN_IP_PKT_LEN: integer (nullable = true)
 |-- MAX_IP_PKT_LEN: integer (nullable = true)
 |-- SRC_TO

In [50]:
data = data.drop('Dataset')

In [51]:
output_label = 'Label'
output_attack = 'Attack_index'
output_columns = [output_label, output_attack]

In [52]:
categorical_columns = [field for (field, dataType) in data.dtypes if dataType == "string"]

In [53]:
data_desc = data.describe()
log = f'Data Description:\n{data_desc}'
data_desc.show()
output_log.append(log)

+-------+-------------+------------------+-------------+------------------+------------------+-----------------+------------------+-----------------+-----------------+------------------+------------------+------------------+------------------+--------------------------+-----------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-----------------------+-----------------------+----------------------+---------------------+-----------------------+----------------------+-------------------------+-------------------------+------------------------+-------------------------+-------------------------+--------------------------+---------------------------+-----------------+------------------+------------------+------------------+------------------+-----------------+-----------------+--------------------+-------------------+--------+
|summary|IPV4_SRC_ADDR|       L4_SRC_PORT|IPV4_DST_ADDR|       L4_DST_PORT|

In [54]:
# Index Catagorical columns to get correlation matrix
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index") for column in categorical_columns]
encoders = [OneHotEncoder(inputCol=column+"_index", outputCol=column+"_encoded") for column in categorical_columns if column != 'Attack']
pipeline = Pipeline(stages=indexers+encoders)
data = pipeline.fit(data).transform(data)

In [55]:
data = data.drop(*categorical_columns)
indexed_cols = [column+"_index" for column in categorical_columns if column != 'Attack']
data = data.drop(*indexed_cols)

In [56]:
updated_schema = data.schema
log = f'Data Schema after indexing and encoding:\n{updated_schema}'

data.printSchema()
output_log.append(log)

root
 |-- L4_SRC_PORT: integer (nullable = true)
 |-- L4_DST_PORT: integer (nullable = true)
 |-- PROTOCOL: integer (nullable = true)
 |-- L7_PROTO: double (nullable = true)
 |-- IN_BYTES: integer (nullable = true)
 |-- IN_PKTS: integer (nullable = true)
 |-- OUT_BYTES: integer (nullable = true)
 |-- OUT_PKTS: integer (nullable = true)
 |-- TCP_FLAGS: integer (nullable = true)
 |-- CLIENT_TCP_FLAGS: integer (nullable = true)
 |-- SERVER_TCP_FLAGS: integer (nullable = true)
 |-- FLOW_DURATION_MILLISECONDS: integer (nullable = true)
 |-- DURATION_IN: integer (nullable = true)
 |-- DURATION_OUT: integer (nullable = true)
 |-- MIN_TTL: integer (nullable = true)
 |-- MAX_TTL: integer (nullable = true)
 |-- LONGEST_FLOW_PKT: integer (nullable = true)
 |-- SHORTEST_FLOW_PKT: integer (nullable = true)
 |-- MIN_IP_PKT_LEN: integer (nullable = true)
 |-- MAX_IP_PKT_LEN: integer (nullable = true)
 |-- SRC_TO_DST_SECOND_BYTES: double (nullable = true)
 |-- DST_TO_SRC_SECOND_BYTES: double (nullable

In [57]:
log = 'Preprocessing complete'
print(log)
output_log.append(log)

Preprocessing complete


#### Assemble features and create train-test split

In [58]:
feature_columns = [field for (field, dataType) in data.dtypes if (dataType in ['double', 'int']) & (field not in output_columns)]

In [59]:
# Assemble features
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)

# Select label and features columns
data_multi = data.select("features", "Attack_index")
data_binary = data.select("features", "Label")

# Split the data
(Train_multi, Test_multi) = data_multi.randomSplit([0.7, 0.3], seed=42)
(Train_binary, Test_binary) = data_binary.randomSplit([0.7, 0.3], seed=42)

In [60]:
log = 'Data Splits created'
print(log)
output_log.append(log)

Data Splits created


# Random Forest Classifier

### Binary Classification - Attack (1) or not (0)

In [61]:
log = 'Binary Classification - Attack (1) or not (0) - starting\nLoading saved model...'
print(log)
output_log.append(log)

Binary Classification - Attack (1) or not (0) - starting
Loading saved model...


In [62]:
rf_model_binary = None

In [63]:
try:
    rf_model_binary = RandomForestClassificationModel.load(f'{output_dir}/best_model_binary')
    log = f'Best Model for binary classification loaded from {output_dir}/best_model_binary'
    print(log)
    output_log.append(log)
except:
    log = f'No saved model found at {output_dir}/best_model_binary'
    print(log)
    output_log.append(log)

Best Model for binary classification loaded from project_output_dir/best_model_binary


In [64]:
importances = rf_model_binary.featureImportances
feat_imp = [f'{feature_columns[i]}: {importances[i]}' for i in range(len(feature_columns))]
feat_imp = sorted(feat_imp, key=lambda x: float(x.split(':')[1]), reverse=True)
feat_imp = '\n'.join(feat_imp)
log = f'Feature importances:\n{feat_imp}'
print(log)
output_log.append(log)

Feature importances:
DURATION_IN: 0.14657789588095893
L7_PROTO: 0.11765263208289697
FLOW_DURATION_MILLISECONDS: 0.11424629142653457
L4_DST_PORT: 0.07595447211449416
DST_TO_SRC_SECOND_BYTES: 0.06460483827146225
MAX_TTL: 0.05378901108124499
SRC_TO_DST_SECOND_BYTES: 0.043696263836712414
TCP_WIN_MAX_IN: 0.04128242619369633
MIN_TTL: 0.03572678026950683
MIN_IP_PKT_LEN: 0.03292065434668193
OUT_BYTES: 0.02947619935959515
MAX_IP_PKT_LEN: 0.02764452140108603
DNS_QUERY_TYPE: 0.027259140600857667
LONGEST_FLOW_PKT: 0.025279986201862864
DNS_TTL_ANSWER: 0.021351132183223744
OUT_PKTS: 0.018544220458133826
CLIENT_TCP_FLAGS: 0.016777840202903148
TCP_WIN_MAX_OUT: 0.014794074913660036
SHORTEST_FLOW_PKT: 0.013668041705127768
L4_SRC_PORT: 0.013523288741305466
IN_BYTES: 0.012582713386226097
DNS_QUERY_ID: 0.012034251514953447
NUM_PKTS_UP_TO_128_BYTES: 0.009240722399647981
IN_PKTS: 0.00884596582420533
RETRANSMITTED_IN_PKTS: 0.00558232593363075
TCP_FLAGS: 0.0038690689481156934
ICMP_IPV4_TYPE: 0.0032523729860650

In [65]:
if rf_model_binary is not None:
    predictions = rf_model_binary.transform(Test_binary)

    preds_and_labels = predictions.select(['prediction','Label']).rdd.map(lambda row: (float(row['prediction']), float(row['Label'])))
    metrics = MulticlassMetrics(preds_and_labels)

    conf_matrix = metrics.confusionMatrix().toArray()
    log = f'Confusion Matrix:\n{conf_matrix}'
    print(log)
    output_log.append(log)

    FPR = metrics.falsePositiveRate(0.0)
    log = f'False Positive Rate: {FPR}'
    print(log)
    output_log.append(log)

    TPR = metrics.truePositiveRate(0.0)
    log = f'True Positive Rate: {TPR}'
    print(log)
    output_log.append(log)

    precision = metrics.precision(1.0)
    log = f'Precision: {precision}'
    print(log)
    output_log.append(log)

    recall = metrics.recall(1.0)
    log = f'Recall: {recall}'
    print(log)
    output_log.append(log)

    f1 = metrics.fMeasure(1.0)
    log = f'F1 Score: {f1}'
    print(log)
    output_log.append(log)

    accuracy = metrics.accuracy
    log = f'Accuracy: {accuracy}'
    print(log)
    output_log.append(log)

    log = 'Binary Classification - Attack (1) or not (0) - complete'
    print(log)
    output_log.append(log)



Confusion Matrix:
[[37131.  1071.]
 [ 3842. 74315.]]
False Positive Rate: 0.0491574651022941
True Positive Rate: 0.971964818595885
Precision: 0.9857931180855862
Recall: 0.9508425348977059
F1 Score: 0.9680024488254104
Accuracy: 0.9577772239362662
Binary Classification - Attack (1) or not (0) - complete


### Multi-class classification - Attack type

In [66]:
log = 'Multi-class classification - Attack type - starting\nLoading saved model...'
print(log)
output_log.append(log)

Multi-class classification - Attack type - starting
Loading saved model...


In [67]:
rf_model_multi = None

In [68]:
try:
    rf_model_multi = RandomForestClassificationModel.load(f'{output_dir}/best_model_multiclass')
    log = f'Best Model for multi-class classification loaded from {output_dir}/best_model_multiclass'
    print(log)
    output_log
except:
    log = f'No saved model found at {output_dir}/best_model_multiclass'
    print(log)
    output_log.append(log)

Best Model for multi-class classification loaded from project_output_dir/best_model_multiclass


In [69]:
importances = rf_model_multi.featureImportances
feat_imp = [f'{feature_columns[i]}: {importances[i]}' for i in range(len(feature_columns))]
feat_imp = sorted(feat_imp, key=lambda x: float(x.split(':')[1]), reverse=True)
feat_imp = '\n'.join(feat_imp)
log = f'Feature importances:\n{feat_imp}'
print(log)
output_log.append(log)

Feature importances:
L7_PROTO: 0.16957579237392256
DURATION_IN: 0.13985018078641906
NUM_PKTS_128_TO_256_BYTES: 0.12699122867955612
FLOW_DURATION_MILLISECONDS: 0.07665390609761917
MAX_IP_PKT_LEN: 0.06874763338592021
SHORTEST_FLOW_PKT: 0.06113466225992124
DST_TO_SRC_SECOND_BYTES: 0.05525962336182866
NUM_PKTS_UP_TO_128_BYTES: 0.05050782459569403
LONGEST_FLOW_PKT: 0.04362938621426567
TCP_WIN_MAX_IN: 0.03953709911696397
SRC_TO_DST_SECOND_BYTES: 0.03857702593204118
MAX_TTL: 0.019410827544666533
TCP_FLAGS: 0.013987667414481927
IN_PKTS: 0.01372958919337835
OUT_BYTES: 0.013005600761801697
TCP_WIN_MAX_OUT: 0.0102236026342548
IN_BYTES: 0.00956153302687746
CLIENT_TCP_FLAGS: 0.009267323614093399
DNS_TTL_ANSWER: 0.007434776939601664
ICMP_IPV4_TYPE: 0.007222796011137007
L4_DST_PORT: 0.005583901330022035
OUT_PKTS: 0.0048804288667535535
MIN_TTL: 0.0043338422834504255
DURATION_OUT: 0.002274705353804831
RETRANSMITTED_OUT_PKTS: 0.002025689175218738
PROTOCOL: 0.0019837198411906576
SERVER_TCP_FLAGS: 0.00182

In [70]:
if rf_model_multi is not None:
    predictions = rf_model_multi.transform(Test_multi)

    preds_and_labels = predictions.select(['prediction','Attack_index']).rdd
    metrics = MulticlassMetrics(preds_and_labels)

    # conf_matrix = metrics.confusionMatrix().toArray()
    # log = f'Confusion Matrix:\n{conf_matrix}'
    # print(log)
    # output_log.append(log)

    FPR = metrics.falsePositiveRate(1.0)
    log = f'False Positive Rate: {FPR}'
    print(log)
    output_log.append(log)

    TPR = metrics.truePositiveRate(1.0)
    log = f'True Positive Rate: {TPR}'
    print(log)
    output_log.append(log)

    precision = metrics.precision(1.0)
    log = f'Precision: {precision}'
    print(log)
    output_log.append(log)

    recall = metrics.recall(1.0)
    log = f'Recall: {recall}'
    print(log)
    output_log.append(log)

    f1 = metrics.fMeasure(1.0)
    log = f'F1 Score: {f1}'
    print(log)
    output_log.append(log)

    accuracy = metrics.accuracy
    log = f'Accuracy: {accuracy}'
    print(log)
    output_log.append(log)

    log = 'Multi-class classification - Attack type - complete'
    print(log)
    output_log.append(log)

False Positive Rate: 0.012845671357884978
True Positive Rate: 0.9719185487746276
Precision: 0.9680806509512984
Recall: 0.9719185487746276
F1 Score: 0.9699958036088965
Accuracy: 0.894705179659502
Multi-class classification - Attack type - complete


# Write output to output file and conclude

In [71]:
output_rdd = spark.sparkContext.parallelize(output_log)
output_rdd.saveAsTextFile(f'{output_dir}/evaluation_logs')

In [72]:
spark.stop()