In [381]:
#import the libraries, initialise the spark session and then load the data in the spark dataframe
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer

from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("KMeansExample")\
        .getOrCreate()

dataset = spark.read.load('app-category.csv',format="csv",inferschema=True, header=True)

In [382]:
import pandas as pd
import numpy as np

In [383]:
dataset.printSchema()

root
 |-- phoneNumber: integer (nullable = true)
 |-- dataUsage: integer (nullable = true)
 |-- appName: string (nullable = true)
 |-- appCategory: string (nullable = true)



In [384]:
dataset.show(10)

+-----------+---------+--------+-----------+
|phoneNumber|dataUsage| appName|appCategory|
+-----------+---------+--------+-----------+
|     100000|     2430|   Gaana|      Music|
|     100000|     1755|   Gmail|       Mail|
|     100004|     3939| Youtube|  Streaming|
|     100002|     4836| JioNews|       News|
|     100001|     4791|   Gaana|      Music|
|     100003|     3384|JioGames|      Games|
|     100010|     1219| JioNews|       News|
|     100009|     1904|   Gaana|      Music|
|     100001|     2993|     TOI|       News|
|     100007|     4885|Inshorts|       News|
+-----------+---------+--------+-----------+
only showing top 10 rows



In [385]:
#pre-process the data according to the requirements of the model

stringIndexer = StringIndexer(inputCol="appName", outputCol="appNameIndexed")
model = stringIndexer.fit(dataset)
dataset = model.transform(dataset)

stringIndexer = StringIndexer(inputCol="appCategory", outputCol="appCategoryIndexed")
model = stringIndexer.fit(dataset)
dataset = model.transform(dataset)

dataset.printSchema()

root
 |-- phoneNumber: integer (nullable = true)
 |-- dataUsage: integer (nullable = true)
 |-- appName: string (nullable = true)
 |-- appCategory: string (nullable = true)
 |-- appNameIndexed: double (nullable = false)
 |-- appCategoryIndexed: double (nullable = false)



In [386]:
#assemble the features into one vector
vectorAss = VectorAssembler(inputCols=["dataUsage"],outputCol="features")

dataset = vectorAss.transform(dataset)

In [393]:
# Trains a k-means model.
kmeans = KMeans().setK(3).setSeed(1)
model = kmeans.fit(dataset)

# Make predictions
predictions = model.transform(dataset)

In [394]:
# Evaluate clustering by computing Silhouette score
evaluator = ClusteringEvaluator()

silhouette = evaluator.evaluate(predictions)
print("Silhouette with squared euclidean distance = " + str(silhouette))

clist = []

# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)
predictions.show()

Silhouette with squared euclidean distance = 0.7829932313550811
Cluster Centers: 
[3004.60606061]
[4492.79310345]
[1611.68421053]
+-----------+---------+--------+-----------+--------------+------------------+--------+----------+
|phoneNumber|dataUsage| appName|appCategory|appNameIndexed|appCategoryIndexed|features|prediction|
+-----------+---------+--------+-----------+--------------+------------------+--------+----------+
|     100000|     2430|   Gaana|      Music|           6.0|               1.0|[2430.0]|         0|
|     100000|     1755|   Gmail|       Mail|           7.0|               4.0|[1755.0]|         2|
|     100004|     3939| Youtube|  Streaming|           0.0|               2.0|[3939.0]|         1|
|     100002|     4836| JioNews|       News|           2.0|               0.0|[4836.0]|         1|
|     100001|     4791|   Gaana|      Music|           6.0|               1.0|[4791.0]|         1|
|     100003|     3384|JioGames|      Games|           1.0|               3.0|

In [395]:
numPred = predictions.groupby('prediction').count()

numPred.show()

+----------+-----+
|prediction|count|
+----------+-----+
|         1|   29|
|         2|   38|
|         0|   33|
+----------+-----+



In [368]:
numPredDF = predictions.toPandas()
numPrediction = numPredDF[['prediction']]
numPrediction

numPredNP = numPrediction.values
numPredNP = numPredNP.ravel()

In [408]:
#understand the data in the clusters, like the average dataUsage and count of each appCategory
predictions.createOrReplaceTempView("predFile")

df = spark.sql("select appCategory,avg(dataUsage),count(phoneNumber) from predFile where prediction = 0 group by appCategory")
df.show() 

df = spark.sql("select appCategory,avg(dataUsage),count(phoneNumber) from predFile where prediction = 1 group by appCategory")
df.show()

df = spark.sql("select appCategory,avg(dataUsage),count(phoneNumber) from predFile where prediction = 2 group by appCategory")
df.show()

+-----------+------------------+------------------+
|appCategory|    avg(dataUsage)|count(phoneNumber)|
+-----------+------------------+------------------+
|       Mail|            3279.5|                 4|
|      Games|            3150.0|                 6|
|  Streaming|3085.5714285714284|                 7|
|      Music|2887.6666666666665|                 3|
|       News|2836.3076923076924|                13|
+-----------+------------------+------------------+

+-----------+-----------------+------------------+
|appCategory|   avg(dataUsage)|count(phoneNumber)|
+-----------+-----------------+------------------+
|       Mail|          4247.75|                 4|
|      Games|           4625.5|                 4|
|  Streaming|          4497.75|                 4|
|      Music|4483.142857142857|                 7|
|       News|           4542.5|                10|
+-----------+-----------------+------------------+

+-----------+------------------+------------------+
|appCategory|    av

In [407]:
#analyse the dataUsage for each clusters 
df = spark.sql("select prediction,min(dataUsage),max(dataUsage),avg(dataUsage) from predFile group by prediction")
df.show()

+----------+--------------+--------------+------------------+
|prediction|min(dataUsage)|max(dataUsage)|    avg(dataUsage)|
+----------+--------------+--------------+------------------+
|         1|          3762|          5000| 4492.793103448276|
|         2|          1038|          2260|1611.6842105263158|
|         0|          2363|          3721|3004.6060606060605|
+----------+--------------+--------------+------------------+

