In [1]:
from pyspark.ml import Pipeline
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, date_format, rand

# http://localhost:4040
spark = SparkSession.builder.getOrCreate()
spark.conf.set('spark.sql.shuffle.partitions', '5')
spark


In [2]:
df = spark.read\
    .format('csv')\
    .option('header', 'true')\
    .option('inferSchema', 'true')\
    .load('../data/retail-data/*.csv')

df.createOrReplaceTempView('retail_data')
df_schema = df.schema

df.show(truncate=False)
df.printSchema()


+---------+---------+-----------------------------------+--------+-------------------+---------+----------+--------------+
|InvoiceNo|StockCode|Description                        |Quantity|InvoiceDate        |UnitPrice|CustomerID|Country       |
+---------+---------+-----------------------------------+--------+-------------------+---------+----------+--------------+
|537226   |22811    |SET OF 6 T-LIGHTS CACTI            |6       |2010-12-06 08:34:00|2.95     |15987.0   |United Kingdom|
|537226   |21713    |CITRONELLA CANDLE FLOWERPOT        |8       |2010-12-06 08:34:00|2.1      |15987.0   |United Kingdom|
|537226   |22927    |GREEN GIANT GARDEN THERMOMETER     |2       |2010-12-06 08:34:00|5.95     |15987.0   |United Kingdom|
|537226   |20802    |SMALL GLASS SUNDAE DISH CLEAR      |6       |2010-12-06 08:34:00|1.65     |15987.0   |United Kingdom|
|537226   |22052    |VINTAGE CARAVAN GIFT WRAP          |25      |2010-12-06 08:34:00|0.42     |15987.0   |United Kingdom|
|537226   |22705

In [3]:
prepped_df = df \
    .na.fill(0) \
    .withColumn('day_of_week', date_format(col('InvoiceDate'), 'EEEE')) \
    .coalesce(5)

prepped_df\
    .show(truncate=False)


+---------+---------+-----------------------------------+--------+-------------------+---------+----------+--------------+-----------+
|InvoiceNo|StockCode|Description                        |Quantity|InvoiceDate        |UnitPrice|CustomerID|Country       |day_of_week|
+---------+---------+-----------------------------------+--------+-------------------+---------+----------+--------------+-----------+
|537226   |22811    |SET OF 6 T-LIGHTS CACTI            |6       |2010-12-06 08:34:00|2.95     |15987.0   |United Kingdom|Monday     |
|537226   |21713    |CITRONELLA CANDLE FLOWERPOT        |8       |2010-12-06 08:34:00|2.1      |15987.0   |United Kingdom|Monday     |
|537226   |22927    |GREEN GIANT GARDEN THERMOMETER     |2       |2010-12-06 08:34:00|5.95     |15987.0   |United Kingdom|Monday     |
|537226   |20802    |SMALL GLASS SUNDAE DISH CLEAR      |6       |2010-12-06 08:34:00|1.65     |15987.0   |United Kingdom|Monday     |
|537226   |22052    |VINTAGE CARAVAN GIFT WRAP         

In [4]:
train_df = prepped_df.where('InvoiceDate < "2010-12-12"')
test_df = prepped_df.where('InvoiceDate >= "2010-12-12"')

print(train_df.count())
print(test_df.count())


25281
5821


In [5]:
indexer = StringIndexer() \
    .setInputCol('day_of_week') \
    .setOutputCol('day_of_week_index')

encoder = OneHotEncoder() \
    .setInputCol('day_of_week_index') \
    .setOutputCol('day_of_week_encoded')

vector_assembler = VectorAssembler() \
    .setInputCols(['UnitPrice', 'Quantity', 'day_of_week_encoded']) \
    .setOutputCol('features')

transformation_pipeline = Pipeline() \
    .setStages([indexer, encoder, vector_assembler])

fitted_pipeline = transformation_pipeline.fit(train_df)
transformed_training = fitted_pipeline.transform(train_df)

transformed_training.cache()

kmeans = KMeans() \
    .setK(20) \
    .setSeed(1)

km_model = kmeans.fit(transformed_training)

transformed_test = fitted_pipeline.transform(test_df)

km_model.computeCost(transformed_test)

689127.9930536919