In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
ss = SparkSession.builder.getOrCreate()
sc = ss.sparkContext

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/02/24 05:32:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/02/24 05:32:03 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Create dataframe
https://spark.apache.org/docs/latest/api/python/pyspark.ml.html?highlight=kmeans#pyspark.ml.clustering.KMeans

In [2]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4)\
            .map(lambda x:  x.split(", "))\
            .map(lambda row: [float(x) for x in row])
pen_raw.take(1)

                                                                                

[[47.0,
  100.0,
  27.0,
  81.0,
  57.0,
  37.0,
  26.0,
  0.0,
  0.0,
  23.0,
  56.0,
  53.0,
  100.0,
  90.0,
  40.0,
  98.0,
  8.0]]

In [3]:
#Create a DataFrame
from pyspark.sql.types import *

penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = ss.createDataFrame(pen_raw, penschema)

In [4]:
dfpen.show()

+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| pix1| pix2|pix3| pix4| pix5| pix6| pix7| pix8| pix9|pix10|pix11|pix12|pix13|pix14|pix15|pix16|label|
+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 47.0|100.0|27.0| 81.0| 57.0| 37.0| 26.0|  0.0|  0.0| 23.0| 56.0| 53.0|100.0| 90.0| 40.0| 98.0|  8.0|
|  0.0| 89.0|27.0|100.0| 42.0| 75.0| 29.0| 45.0| 15.0| 15.0| 37.0|  0.0| 69.0|  2.0|100.0|  6.0|  2.0|
|  0.0| 57.0|31.0| 68.0| 72.0| 90.0|100.0|100.0| 76.0| 75.0| 50.0| 51.0| 28.0| 25.0| 16.0|  0.0|  1.0|
|  0.0|100.0| 7.0| 92.0|  5.0| 68.0| 19.0| 45.0| 86.0| 34.0|100.0| 45.0| 74.0| 23.0| 67.0|  0.0|  4.0|
|  0.0| 67.0|49.0| 83.0|100.0|100.0| 81.0| 80.0| 60.0| 60.0| 40.0| 40.0| 33.0| 20.0| 47.0|  0.0|  1.0|
|100.0|100.0|88.0| 99.0| 49.0| 74.0| 17.0| 47.0|  0.0| 16.0| 37.0|  0.0| 73.0| 16.0| 20.0| 20.0|  6.0|
|  0.0|100.0| 3.0| 72.0| 26.0| 35.0| 85.0| 35.0|100.0| 71.0| 73.0| 97.0| 

## Create dataframe with a feature vector (Exclude the label)

In [5]:
# Merging the data with Vector Assembler.
from pyspark.ml.feature import VectorAssembler
va = VectorAssembler(outputCol="features", inputCols=dfpen.columns[0:-1]) #except the last col.
penlpoints = va.transform(dfpen)

## Apply KMeans algorithm to the data frame

In [6]:
from pyspark.ml.clustering import KMeans
kmeans =  KMeans(k = 10, maxIter = 200, tol = 0.1) 
# k = 10 as there are 10 different handwritten numbers.
model = kmeans.fit(penlpoints)
predictions = model.transform(penlpoints)

22/02/24 05:32:16 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


## Evaluation

In [7]:
# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)

Cluster Centers: 
[19.8019954  74.85648503 43.42056792 96.88641596 64.23944743 82.02762855
 48.2133538  48.3200307  22.54643131 19.0590944  10.742901    7.7789716
 52.20337682  5.67536454 97.96162701  5.70452801]
[88.36734694 97.54761905 53.57568027 88.03146259 22.10884354 62.37159864
  7.53401361 31.30952381 34.08333333  9.5        78.6420068  13.70748299
 60.71768707 28.53316327 14.04846939 22.52380952]
[4.80693842e+01 9.64891587e+01 2.33018213e+01 7.99765828e+01
 3.57675629e+00 5.47701648e+01 4.38204683e+01 4.63616652e+01
 8.67450130e+01 5.51257589e+01 8.74249783e+01 5.97146574e+01
 7.18933218e+01 3.03391154e+01 6.03052905e+01 9.45359931e-02]
[ 3.98775056 62.40757238 30.93095768 73.60356347 72.68040089 89.90757238
 89.47216036 93.67371938 77.68262806 73.44432071 70.4688196  49.76614699
 56.48997773 23.72939866 49.79732739  0.92761693]
[11.21958457 87.22156281 53.1859545  98.61424332 77.41839763 80.21463897
 67.12858556 42.78437191 51.6983185   7.59940653 24.57665678 11.86251236
 34.

In [8]:
from pyspark.ml.evaluation import ClusteringEvaluator
evaluator = ClusteringEvaluator()
silhouette = evaluator.evaluate(predictions)
print("Silhouette with squared euclidean distance = " + str(silhouette))

Silhouette with squared euclidean distance = 0.4267285839299759


In [9]:
# prediction is a group, not an actual label.
predictions.select('label', 'prediction')\
           .groupBy('label', 'prediction')\
           .count()\
           .show(100)

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  3.0|         3|   44|
|  4.0|         9|    3|
|  1.0|         0|  267|
|  0.0|         5|  505|
|  0.0|         6|  484|
|  1.0|         3|  585|
|  6.0|         2|    3|
|  8.0|         6|    6|
|  5.0|         9|  304|
|  2.0|         0|  971|
|  5.0|         8|  558|
|  7.0|         4|  812|
|  8.0|         4|  126|
|  9.0|         9|  654|
|  6.0|         1|  959|
|  4.0|         0|    4|
|  4.0|         1|   46|
|  3.0|         9|  895|
|  0.0|         1|   33|
|  7.0|         0|   10|
|  9.0|         2|  181|
|  8.0|         9|   58|
|  9.0|         0|   14|
|  8.0|         3|    7|
|  1.0|         4|   37|
|  8.0|         0|   29|
|  6.0|         6|    4|
|  2.0|         3|   15|
|  8.0|         7|  399|
|  8.0|         1|   17|
|  9.0|         5|   12|
|  1.0|         9|  105|
|  7.0|         9|   47|
|  1.0|         1|   27|
|  4.0|         2|  958|
|  4.0|         3|   14|
|  7.0|         8|    9|


In [10]:
ss.stop()