In [None]:
# Session Connection Time
# Bytes Transferred
# Kali Trace Used
# Servers Corrupted
# Pages Corrupted
# Location
# WPM Typing Speed

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('kmeans_project').getOrCreate()

In [2]:
dataset = spark.read.csv('data/hack_data.csv', inferSchema=True, header=True)
print(dataset.count(), len(dataset.columns))
dataset.printSchema()
dataset.show(5)
dataset.head(1)

334 7
root
 |-- Session_Connection_Time: double (nullable = true)
 |-- Bytes Transferred: double (nullable = true)
 |-- Kali_Trace_Used: integer (nullable = true)
 |-- Servers_Corrupted: double (nullable = true)
 |-- Pages_Corrupted: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- WPM_Typing_Speed: double (nullable = true)

+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|            Location|WPM_Typing_Speed|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+
|                    8.0|           391.09|              1|             2.96|            7.0|            Slovenia|           72.37|
|                   20.0|           720.99|              0|             3.04|            9.0|British Virgin Is...|          

[Row(Session_Connection_Time=8.0, Bytes Transferred=391.09, Kali_Trace_Used=1, Servers_Corrupted=2.96, Pages_Corrupted=7.0, Location='Slovenia', WPM_Typing_Speed=72.37)]

In [17]:
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator

In [11]:
print(dataset.columns)
feat_cols = dataset.columns
feat_cols.remove('Location')
print(feat_cols)

['Session_Connection_Time', 'Bytes Transferred', 'Kali_Trace_Used', 'Servers_Corrupted', 'Pages_Corrupted', 'Location', 'WPM_Typing_Speed']
['Session_Connection_Time', 'Bytes Transferred', 'Kali_Trace_Used', 'Servers_Corrupted', 'Pages_Corrupted', 'WPM_Typing_Speed']


In [13]:
assembler = VectorAssembler(inputCols=feat_cols, outputCol='features')
final_data = assembler.transform(dataset)
final_data.printSchema()

root
 |-- Session_Connection_Time: double (nullable = true)
 |-- Bytes Transferred: double (nullable = true)
 |-- Kali_Trace_Used: integer (nullable = true)
 |-- Servers_Corrupted: double (nullable = true)
 |-- Pages_Corrupted: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- WPM_Typing_Speed: double (nullable = true)
 |-- features: vector (nullable = true)



In [16]:
scaler = StandardScaler(inputCol='features', outputCol='scaledFeatures')
scaler_model = scaler.fit(final_data)
cluster_final_data = scaler_model.transform(final_data)
cluster_final_data.printSchema()

root
 |-- Session_Connection_Time: double (nullable = true)
 |-- Bytes Transferred: double (nullable = true)
 |-- Kali_Trace_Used: integer (nullable = true)
 |-- Servers_Corrupted: double (nullable = true)
 |-- Pages_Corrupted: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- WPM_Typing_Speed: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- scaledFeatures: vector (nullable = true)



In [20]:
for i in range(2, 8):
    kmeans = KMeans(featuresCol='scaledFeatures', k=i)
    model = kmeans.fit(cluster_final_data)
    centers = model.clusterCenters()
    # print(centers)

    evaluator = ClusteringEvaluator()
    prediction = model.transform(cluster_final_data)
    silhouette = evaluator.evaluate(prediction)
    if i < 5:
        prediction.select('prediction').groupBy('prediction').count().show()
    print(i, silhouette, model.summary.trainingCost)

# print(prediction.select('prediction').groupBy('prediction').count().show(truncate=False))

+----------+-----+
|prediction|count|
+----------+-----+
|         1|  167|
|         0|  167|
+----------+-----+

2 0.6683623593283755 601.7707512676691
+----------+-----+
|prediction|count|
+----------+-----+
|         1|   88|
|         2|   79|
|         0|  167|
+----------+-----+

3 0.30412315937808737 434.75507308487596
+----------+-----+
|prediction|count|
+----------+-----+
|         1|   83|
|         3|   46|
|         2|   38|
|         0|  167|
+----------+-----+

4 0.20713135316189968 412.379893480279
5 -0.09143163023790521 245.42401838138517
6 -0.19010616305778094 232.62054194167905
7 -0.13133734690302626 213.28836940747334
