In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import Window
import plotly.express as px

In [2]:
spark = (
    SparkSession.builder.appName("iot").getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.version


'4.0.0'

Credits to Anton T. Ruberts:
https://www.youtube.com/watch?v=2LG2hUQxLmA

Dataset from
https://www.kaggle.com/datasets/agungpambudi/network-malware-detection-connection-analysis?resource=download

In [3]:
file_path = r"C:\Users\gabyl\OneDrive\Desktop\Proyectos\pyspark\data" + "\IoT-Malware.csv"
df = spark.read.option('delimiter', '|').csv(file_path, inferSchema=True, header=True)
df.show(5)

  file_path = r"C:\Users\gabyl\OneDrive\Desktop\Proyectos\pyspark\data" + "\IoT-Malware.csv"


+-------------------+------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|                 ts|               uid|      id.orig_h|id.orig_p|      id.resp_h|id.resp_p|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|    label|      detailed-label|
+-------------------+------------------+---------------+---------+---------------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+---------+--------------------+
|1.525879831015811E9|CUmrqr4svHuSXJy5z7|192.168.100.103|  51524.0| 65.127.233.163|     23.0|  tcp|      -|2.999051|         0|         0|     

In [4]:
df.printSchema()

root
 |-- ts: double (nullable = true)
 |-- uid: string (nullable = true)
 |-- id.orig_h: string (nullable = true)
 |-- id.orig_p: double (nullable = true)
 |-- id.resp_h: string (nullable = true)
 |-- id.resp_p: double (nullable = true)
 |-- proto: string (nullable = true)
 |-- service: string (nullable = true)
 |-- duration: string (nullable = true)
 |-- orig_bytes: string (nullable = true)
 |-- resp_bytes: string (nullable = true)
 |-- conn_state: string (nullable = true)
 |-- local_orig: string (nullable = true)
 |-- local_resp: string (nullable = true)
 |-- missed_bytes: double (nullable = true)
 |-- history: string (nullable = true)
 |-- orig_pkts: double (nullable = true)
 |-- orig_ip_bytes: double (nullable = true)
 |-- resp_pkts: double (nullable = true)
 |-- resp_ip_bytes: double (nullable = true)
 |-- tunnel_parents: string (nullable = true)
 |-- label: string (nullable = true)
 |-- detailed-label: string (nullable = true)



### Pre-processing

In [5]:
df = df.withColumn("dt", F.from_unixtime("ts")).withColumn("dt", F.to_timestamp("dt"))

In [6]:
df.select("dt").show(3)

+-------------------+
|                 dt|
+-------------------+
|2018-05-09 17:30:31|
|2018-05-09 17:30:31|
|2018-05-09 17:30:31|
+-------------------+
only showing top 3 rows


In [7]:
df = df.withColumnsRenamed(
    {
        "id.orig_h": "source_ip",
        "id.orig_p": "source_port",
        "id.resp_h": "dest_ip",
        "id.resp_p": "dest_port",
    }
)

In [8]:
df.columns

['ts',
 'uid',
 'source_ip',
 'source_port',
 'dest_ip',
 'dest_port',
 'proto',
 'service',
 'duration',
 'orig_bytes',
 'resp_bytes',
 'conn_state',
 'local_orig',
 'local_resp',
 'missed_bytes',
 'history',
 'orig_pkts',
 'orig_ip_bytes',
 'resp_pkts',
 'resp_ip_bytes',
 'tunnel_parents',
 'label',
 'detailed-label',
 'dt']

### Dataset quality checks

In [9]:
df.agg(
    F.min("dt").alias("min_date"),
    F.max("dt").alias("max_date")
).show()

+-------------------+-------------------+
|           min_date|           max_date|
+-------------------+-------------------+
|2018-05-09 17:30:31|2018-05-14 09:24:43|
+-------------------+-------------------+



In [10]:
# df shape:
df.count(), len(df.columns)

(1008748, 24)

In [11]:
to_analyse = [
    "source_ip",
    "source_port",
    "dest_ip",
    "dest_port",
    "proto",
    "service",
    "duration",
    "orig_bytes",
    "resp_bytes",
    "conn_state",
    "local_orig",
    "local_resp",
    "missed_bytes",
    "history",
    "orig_pkts",
    "orig_ip_bytes",
    "resp_pkts",
    "resp_ip_bytes",
    "tunnel_parents",
    "label",
    "detailed-label",
]

unique_counts = df.agg(
        *(F.countDistinct(
            F.col(c)
            ).alias(c) for c in to_analyse))
print(unique_counts.show())

+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+-----+--------------+
|source_ip|source_port|dest_ip|dest_port|proto|service|duration|orig_bytes|resp_bytes|conn_state|local_orig|local_resp|missed_bytes|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|tunnel_parents|label|detailed-label|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+----------+----------+------------+-------+---------+-------------+---------+-------------+--------------+-----+--------------+
|    15004|      28243| 597107|    65426|    3|      5|   16650|       171|       479|        11|         1|         1|           1|    126|       54|         1249|       69|         1141|             1|    2|             3|
+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+---

In [12]:
unique_counts = unique_counts.first()
static_cols = [c for c in unique_counts.asDict() if unique_counts[c] == 1]
print(static_cols)
df = df.drop(*static_cols)

['local_orig', 'local_resp', 'missed_bytes', 'tunnel_parents']


### Count Nulls

In [13]:
df = df.replace("-", None)
non_static_cols = [c for c in df.columns if c not in static_cols]
df.select(
    [F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in non_static_cols]
).show()

+---+---+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+---+
| ts|uid|source_ip|source_port|dest_ip|dest_port|proto|service|duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|label|detailed-label| dt|
+---+---+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+---+
|  0|  0|        0|          0|      0|        0|    0|1005507|  796300|    796300|    796300|         0|  17421|        0|            0|        0|            0|    0|        469275|  0|
+---+---+---------+-----------+-------+---------+-----+-------+--------+----------+----------+----------+-------+---------+-------------+---------+-------------+-----+--------------+---+



### Time Series

In [14]:
df = df.withColumns(
    {
        "day": F.date_trunc("day", "dt"),
        "hour": F.date_trunc("hour", "dt"),
        "minute": F.date_trunc("minute", "dt"),
    }
)

In [15]:
df.groupBy(["hour", "label"]).agg(F.count("uid").alias('counts')).orderBy("hour").toPandas()

Unnamed: 0,hour,label,counts
0,2018-05-09 17:00:00,Benign,2197
1,2018-05-09 17:00:00,Malicious,2623
2,2018-05-09 18:00:00,Malicious,5420
3,2018-05-09 18:00:00,Benign,4346
4,2018-05-09 19:00:00,Malicious,5402
...,...,...,...
221,2018-05-14 07:00:00,Malicious,4399
222,2018-05-14 08:00:00,Benign,4036
223,2018-05-14 08:00:00,Malicious,4325
224,2018-05-14 09:00:00,Benign,1608


### Univariate Data Analysis

In [16]:
inter_df = df.groupBy("proto").count().alias("count").orderBy("count", ascending=True)
inter_df = inter_df.withColumn(
    "percent", F.round(F.col("count")/F.sum(F.col("count")).over(Window.partitionBy()), 4)
    )
inter_df.show()

+-----+------+-------+
|proto| count|percent|
+-----+------+-------+
| icmp| 17421| 0.0173|
|  udp|408193| 0.4047|
|  tcp|583134| 0.5781|
+-----+------+-------+



In [17]:
def counts(df, var):
    var_counts = df.groupBy(var).count().orderBy("count", ascending=False)
    var_counts = var_counts.withColumn(
        "percent", F.round(F.col("count")/F.sum(F.col("count")).over(Window.partitionBy()), 4)
    )
    var_counts.show()
    fig = px.bar(var_counts.toPandas(), x=var, y="count")
    fig.show()

categorical_cols = ["proto", "service", "conn_state", "history", "label"]
for col in categorical_cols:
    counts(df,col)

+-----+------+-------+
|proto| count|percent|
+-----+------+-------+
|  tcp|583134| 0.5781|
|  udp|408193| 0.4047|
| icmp| 17421| 0.0173|
+-----+------+-------+



+-------+-------+-------+
|service|  count|percent|
+-------+-------+-------+
|   NULL|1005507| 0.9968|
|   http|   3238| 0.0032|
|   dhcp|      1|    0.0|
|    dns|      1|    0.0|
|    ssh|      1|    0.0|
+-------+-------+-------+



+----------+------+-------+
|conn_state| count|percent|
+----------+------+-------+
|        S0|971229| 0.9628|
|       OTH| 17421| 0.0173|
|        SF| 13881| 0.0138|
|       REJ|  4346| 0.0043|
|      RSTR|  1465| 0.0015|
|    RSTOS0|   197| 2.0E-4|
|     RSTRH|    66| 1.0E-4|
|      RSTO|    47|    0.0|
|        S2|    40|    0.0|
|        SH|    29|    0.0|
|        S1|    27|    0.0|
+----------+------+-------+



+---------+------+-------+
|  history| count|percent|
+---------+------+-------+
|        S|569605| 0.5647|
|        D|401606| 0.3981|
|     NULL| 17421| 0.0173|
|       Dd|  6569| 0.0065|
|       Sr|  4346| 0.0043|
| ShAdDafF|  2744| 0.0027|
| ShADadfF|  1385| 0.0014|
| ShAdDaFf|  1258| 0.0012|
|  ShADafF|   740| 7.0E-4|
|   ShADar|   397| 4.0E-4|
| ShAdDaFr|   277| 3.0E-4|
| ShAdDfFr|   228| 2.0E-4|
|        R|   181| 2.0E-4|
|    ShADr|   164| 2.0E-4|
| ShADdfFa|   151| 1.0E-4|
|ShAdDaftF|   150| 1.0E-4|
|   ShAdDr|   123| 1.0E-4|
|  ShADafr|   120| 1.0E-4|
|ShAdDafrR|   120| 1.0E-4|
|ShAdDarfR|   115| 1.0E-4|
+---------+------+-------+
only showing top 20 rows


+---------+------+-------+
|    label| count|percent|
+---------+------+-------+
|Malicious|539473| 0.5348|
|   Benign|469275| 0.4652|
+---------+------+-------+



## Prepare for Modelling

## Full Pipeline

In [18]:
static_cols = ["local_orig", "local_resp", "missed_bytes", "tunnel_parents"]
recast_cols = {
    "duration": F.col("duration").cast("double"),
    "orig_bytes": F.col("orig_bytes").cast("double"),
    "resp_bytes": F.col("resp_bytes").cast("double"),
    "orig_ip_bytes": F.col("orig_ip_bytes").cast("double"),
    "orig_pkts": F.col("orig_pkts").cast("double"),
    "resp_pkts": F.col("resp_pkts").cast("double"),
    "resp_ip_bytes": F.col("resp_ip_bytes").cast("double"),
}

fill_vals = {
    "duration": -999999,
    "orig_bytes": -999999,
    "resp_bytes": -999999,
    "orig_pkts": -999999,
    "orig_ip_bytes": -999999,
    "resp_pkts": -999999,
    "resp_ip_bytes": -999999,
    "history": "missing",
    "proto": "missing",
    "service": "missing",
    "conn_state": "missing",
}

preprocessed_data = (
    spark.read.option("delimiter", '|')
    .csv(file_path, inferSchema=True, header=True)
    .withColumn("dt", F.to_timestamp(F.from_unixtime("ts")))
    .withColumns(
        {
            "day": F.date_trunc("day","dt"),
            "hour": F.date_trunc("hour","dt"),
            "minute": F.date_trunc("minute","dt"),
            "second": F.date_trunc("second","dt"),
        }
    )
    .withColumnsRenamed(
        {
            "id.orig_h": "source_ip",
            "id.orig_p": "source_port",
            "id.resp_h": "dest_ip",
            "id.resp_p": "dest_port"
        }
    )
    .drop(*static_cols)
    .replace("-", None)
    .withColumns(recast_cols)
    .fillna(fill_vals)
)

preprocessed_data.show()


+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
|                 ts|               uid|      source_ip|source_port|        dest_ip|dest_port|proto|service| duration|orig_bytes|resp_bytes|conn_state|history|orig_pkts|orig_ip_bytes|resp_pkts|resp_ip_bytes|    label|      detailed-label|                 dt|                day|               hour|             minute|             second|
+-------------------+------------------+---------------+-----------+---------------+---------+-----+-------+---------+----------+----------+----------+-------+---------+-------------+---------+-------------+---------+--------------------+-------------------+-------------------+-------------------+-------------------+----

## Write Out

In [19]:
# Note the forward slashes and the file:/// prefix
output_dir = r'file:///C:/Users/gabyl/OneDrive/Desktop/Proyectos/pyspark/output/preprocessing'

# Your write command remains the same
df.write.parquet(output_dir + "/df.parquet", mode="overwrite")

Py4JJavaError: An error occurred while calling o678.parquet.
: java.lang.UnsatisfiedLinkError: 'boolean org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(java.lang.String, int)'
	at org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(Native Method)
	at org.apache.hadoop.io.nativeio.NativeIO$Windows.access(NativeIO.java:817)
	at org.apache.hadoop.fs.FileUtil.canRead(FileUtil.java:1415)
	at org.apache.hadoop.fs.FileUtil.list(FileUtil.java:1620)
	at org.apache.hadoop.fs.RawLocalFileSystem.listStatus(RawLocalFileSystem.java:739)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:2078)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:2122)
	at org.apache.hadoop.fs.ChecksumFileSystem.listStatus(ChecksumFileSystem.java:961)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:2078)
	at org.apache.hadoop.fs.FileSystem.listStatus(FileSystem.java:2122)
	at org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter.getAllCommittedTaskPaths(FileOutputCommitter.java:334)
	at org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter.commitJobInternal(FileOutputCommitter.java:404)
	at org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter.commitJob(FileOutputCommitter.java:377)
	at org.apache.parquet.hadoop.ParquetOutputCommitter.commitJob(ParquetOutputCommitter.java:46)
	at org.apache.spark.internal.io.HadoopMapReduceCommitProtocol.commitJob(HadoopMapReduceCommitProtocol.scala:194)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$writeAndCommit$3(FileFormatWriter.scala:275)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18)
	at org.apache.spark.util.Utils$.timeTakenMs(Utils.scala:481)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.writeAndCommit(FileFormatWriter.scala:275)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeWrite(FileFormatWriter.scala:306)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:189)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:195)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:117)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:115)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:129)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$eagerlyExecuteCommands$2(QueryExecution.scala:155)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$8(SQLExecution.scala:162)
	at org.apache.spark.sql.execution.SQLExecution$.withSessionTagsApplied(SQLExecution.scala:268)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$7(SQLExecution.scala:124)
	at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
	at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:112)
	at org.apache.spark.sql.artifact.ArtifactManager.withClassLoaderIfNeeded(ArtifactManager.scala:106)
	at org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:111)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$6(SQLExecution.scala:124)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:291)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:123)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:77)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:233)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$eagerlyExecuteCommands$1(QueryExecution.scala:155)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:654)
	at org.apache.spark.sql.execution.QueryExecution.org$apache$spark$sql$execution$QueryExecution$$eagerlyExecute$1(QueryExecution.scala:154)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$3.applyOrElse(QueryExecution.scala:169)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$3.applyOrElse(QueryExecution.scala:164)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:470)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:86)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:470)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:37)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:360)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:356)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:446)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:164)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$lazyCommandExecuted$1(QueryExecution.scala:126)
	at scala.util.Try$.apply(Try.scala:217)
	at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1378)
	at org.apache.spark.util.LazyTry.tryT$lzycompute(LazyTry.scala:46)
	at org.apache.spark.util.LazyTry.tryT(LazyTry.scala:46)
	at org.apache.spark.util.LazyTry.get(LazyTry.scala:58)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:131)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:192)
	at org.apache.spark.sql.classic.DataFrameWriter.runCommand(DataFrameWriter.scala:622)
	at org.apache.spark.sql.classic.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:273)
	at org.apache.spark.sql.classic.DataFrameWriter.saveInternal(DataFrameWriter.scala:241)
	at org.apache.spark.sql.classic.DataFrameWriter.save(DataFrameWriter.scala:118)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:369)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:108)
	at java.base/java.lang.Thread.run(Thread.java:842)


In [None]:
read_in = spark.read.parquet("processed.pq")
read_in.show()