# Spark Clustering the Walmart data

Let's look at a clustering example in Spark MLLib.

Here, we are going to load the mtcars dataset. This has some stats on different models of cars.  Here, we will load the CSV file as a spark dataframe, and view it.

In [None]:
%matplotlib inline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans



In [None]:
dataset = spark.read.csv("/data/walmart-triptype/train-transformed.csv.gz", header=True, inferSchema=True)


In [None]:
dataset.show()

## Creating Vectors

We'll again use the VectorAssembler class to create features from the data..

In [None]:
columns = dataset.columns
columns.remove('VisitNumber')
columns.remove('TripType')

assembler = VectorAssembler(inputCols=columns, outputCol="features")
featureVector = assembler.transform(dataset)


In [None]:
for row in featureVector.select('features').take(10):
    print("Vector: %s\n" % (str(row)))

Note the output. These are Sparse (not dense) Vectors.  That's because we our data IS sparse, we have relatively few of the variables at any given time.

## Step 3: Running Kmeans

We know there are 39 triptypes.  So that makes a good "natural" value of k.

In [None]:
k = 39  # Number of triptypes is 39.
kmeans = KMeans().setK(k).setSeed(1)
model = kmeans.fit(featureVector)
wssse = model.computeCost(featureVector)

print(wssse)

Let's take a look at the transformed dataset.  let's look at a distribution of our transformed dataset

In [None]:
predictions = model.transform(featureVector)
histogram = predictions.groupBy('prediction').count().orderBy('prediction')
histogram.show(40)

In [None]:
histogram.toPandas().plot.bar(colormap='Greens')

## Step 4: Relate Cluster Numbers to Trip Types

Is there a relationship here? Discuss the results.

Remember, clustering is trying to find "natural" patterns -- it is not a classifier, and if we are trying to classify trip type we should use a classification algorithm and not k-means.

In [None]:

for i in (range(0,38)):
    print('Cluster #' + str(i) + ':')
    predictions.filter('prediction == ' + str(i)).groupBy('TripType').count().show()