In [35]:
## Import Libraries
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.sql.types import StructType, StructField, DoubleType
from pyspark.ml.clustering import KMeans

## Set seed
seed = 1

In [36]:
## Create Spark Session
spark = SparkSession.builder.appName('kmCodeAlong').getOrCreate()

In [37]:
## Setup Schema
schema = StructType(fields=[StructField('area', DoubleType(), True),
                            StructField('perimeter', DoubleType(), True),
                            StructField('compactness', DoubleType(), True),
                            StructField('length_of_kernel', DoubleType(), True),
                            StructField('width_of_kernel', DoubleType(), True),
                            StructField('asymmetry_coefficient', DoubleType(), True),
                            StructField('length_of_groove', DoubleType(), True)])

In [38]:
## Load Data
df = spark.read.csv('gs://spark-training-data/datasets/seeds_dataset.csv', header=True,
                    inferSchema=False, schema=schema)
df.show(5)
df.printSchema() ## Confirm proper schema

+-----+---------+-----------+------------------+------------------+---------------------+----------------+
| area|perimeter|compactness|  length_of_kernel|   width_of_kernel|asymmetry_coefficient|length_of_groove|
+-----+---------+-----------+------------------+------------------+---------------------+----------------+
|15.26|    14.84|      0.871|             5.763|             3.312|                2.221|            5.22|
|14.88|    14.57|     0.8811| 5.553999999999999|             3.333|                1.018|           4.956|
|14.29|    14.09|      0.905|             5.291|3.3369999999999997|                2.699|           4.825|
|13.84|    13.94|     0.8955|             5.324|3.3789999999999996|                2.259|           4.805|
|16.14|    14.99|     0.9034|5.6579999999999995|             3.562|                1.355|           5.175|
+-----+---------+-----------+------------------+------------------+---------------------+----------------+
only showing top 5 rows

root
 |-- ar

In [39]:
## Assembler & Create modeling df
assembler = VectorAssembler(inputCols=df.columns,
                           outputCol='features')
output_features = assembler.transform(df)
output_features.head(1)

[Row(area=15.26, perimeter=14.84, compactness=0.871, length_of_kernel=5.763, width_of_kernel=3.312, asymmetry_coefficient=2.221, length_of_groove=5.22, features=DenseVector([15.26, 14.84, 0.871, 5.763, 3.312, 2.221, 5.22]))]

In [40]:
## Setup Scaler & Scale Features
scaler = StandardScaler(inputCol='features', outputCol='scaled_features')
output_features_scaled = scaler.fit(output_features).transform(output_features)

In [46]:
## Setup Final Data
final_data = output_features_scaled.select(['scaled_features'])
final_data.show(5)

+--------------------+
|     scaled_features|
+--------------------+
|[5.24452795332028...|
|[5.11393027165175...|
|[4.91116018695588...|
|[4.75650503761158...|
|[5.54696468981581...|
+--------------------+
only showing top 5 rows



In [52]:
## Setup Model & Fit
kmeans = KMeans(featuresCol='scaled_features', k=3, seed=seed)
kmeans_model = kmeans.fit(final_data)

In [53]:
## Evaluate KMeans Model
wssse = kmeans_model.summary.trainingCost
wssse

428.6082011872446

In [54]:
## Get Centers
centers = kmeans_model.clusterCenters()
centers

[array([ 4.96198582, 10.97871333, 37.30930808, 12.44647267,  8.62880781,
         1.80061978, 10.41913733]),
 array([ 4.07497225, 10.14410142, 35.89816849, 11.80812742,  7.54416916,
         3.15410901, 10.38031464]),
 array([ 6.35645488, 12.40730852, 37.41990178, 13.93860446,  9.7892399 ,
         2.41585013, 12.29286107])]

In [55]:
## Make Predictions & Show Clusters
predictions_df = kmeans_model.transform(final_data)
predictions_df.show()

+--------------------+----------+
|     scaled_features|prediction|
+--------------------+----------+
|[5.24452795332028...|         0|
|[5.11393027165175...|         0|
|[4.91116018695588...|         0|
|[4.75650503761158...|         0|
|[5.54696468981581...|         0|
|[4.94209121682475...|         0|
|[5.04863143081749...|         0|
|[4.84929812721816...|         0|
|[5.71536696354628...|         2|
|[5.65006812271202...|         0|
|[5.24452795332028...|         0|
|[4.82180387844584...|         0|
|[4.77368894309428...|         0|
|[4.73588435103234...|         0|
|[4.72213722664617...|         0|
|[5.01426361985209...|         0|
|[4.80805675405968...|         0|
|[5.39230954047151...|         0|
|[5.05206821191403...|         0|
|[4.37158555479908...|         1|
+--------------------+----------+
only showing top 20 rows

