 

# Introduction to PySpark: Spark SQL<a name="id1"></a>

[Spark SQL](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html) is an Apache Spark module for structured data processing (Dataframe and Dataset). The main advantage of Spark SQL is that uses the data structure information for improving the data processing.

There are two ways of interacting with Spark SQL:
* SQL queries: read data from a Database like Hive (for Big Data). Query result is provided as a Dataset/Dataframe. 
* Dataset API (Application Programming Interface):
    * Dataset is a distributed data collection. This new data paradigm was created in Spark 1.6. Dataset API id available for Scala and Java, but not for python.
    * Dataframe is a column-organized Dataset. It is similar to a relational database table, or an optimized R/Python Dataframe. Dataframes can be created from structured data files, or external databases. Datframe API is available in Scala, Java, Python and R. E
    
SQL PySpark module contains the following classes:
    
* `pyspark.sql.SparkSession`: Main starting point to use Spark and DataFrame API.
* `pyspark.sql.DataFrame`: Distributed data collection grouped by columns.
* `pyspark.sql.Column`: DataFrame column
* `pyspark.sql.functions`: List of functions available for DataFrames (min, max, col, mean...)
* `pyspark.sql.GroupedData`: Agregation methods


* `pyspark.sql.Row`: Row in a Dataframe
* `pyspark.sql.DataFrameNaFunctions`: Methods to handle null or nan data
* `pyspark.sql.DataFrameStatFunctions`: Statistics methods.
* `pyspark.sql.types`: Available data types list


### Starting PySpark<a name="id2"></a>

Before begining to work with Spark SQL we must initiate Spark session. Since Spark version 2.0 **SparkSession** is the starting point for PySpark.

SparkSession is the starting point to create any PySpark functionality like DataFrames. 

To create an Spark session, se must use this command - `SparkSession.builder()`:

* `appName(nombre_de_la_app)`: application name to identify in the Spark User Interface. If the input is blank, a random name is selected.
* `config(opciones de spark)`: configuration options.
* `master(tipo de master[x])`: If the process is being performed within a cluster, the name of the claster must be introduce as input. X indicates the numper of partitiions to divide the processing. Ideally X comes from the numer of processor cores.  Types of master:
    * local: executes locally
    * local[4]: executes loclaly in 4 cpus
    * yarn: executes in hadoop batch system
    * mesos: executes in mesos cluster
    * spark://master:7077: execute in Spark cluster
* `getOrCreate()`: Creates a new SparkSession based on the options provided to the builder. This function firstly checks if there is an already running Spark session and uses this one instead of creating a new one. 

In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[1]").appName("Covid19_Vacunacion").getOrCreate()

### Reading data files<a name="id3"></a>

Pyspark can read data in different formats like *Comma Separated Values (CSV)*, *JavaScript Object Notation (JSON)*, Parquet, etc. To read this data files you must use `spark.read`.

For instance:

* CSV: 

`csv_file = /path/to/data.csv
 data = spark.read.csv(csv_file)`
 
* JSON:

`json_file = /path/to/data.json
 data = spark.read.json(json_file)`

* PARQUET:

`parquet_file = /path/to/data.parquet
 data = spark.read.parquet(parquet_file)`

This dataset to use is available at [Covid 19 data repository in Github](https://github.com/owid/covid-19-data).  But it is also available in [kaggle](https://www.kaggle.com/gpreda/covid-world-vaccination-progress). 

El dataset contains the following attributes:

Stored in [`vaccinations.csv`](vaccinations.csv) and [`vaccinations.json`](vaccinations.json). Country-by-country data on global COVID-19 vaccinations. We only rely on figures that are verifiable based on public official sources.

This dataset includes some subnational locations (England, Northern Ireland, Scotland, Wales, Northern Cyprus…) and international aggregates (World, continents, European Union…). They can be identified by their `iso_code` that starts with `OWID_`.

The population estimates we use to calculate per-capita metrics are all based on the last revision of the [United Nations World Population Prospects](https://population.un.org/wpp/). The exact values can be viewed [here](https://github.com/owid/covid-19-data/blob/master/scripts/input/un/population_2020.csv).

* `location`: name of the country (or region within a country).
* `iso_code`: ISO 3166-1 alpha-3 – three-letter country codes.
* `date`: date of the observation.
* `total_vaccinations`: total number of doses administered. For vaccines that require multiple doses, each individual dose is counted. If a person receives one dose of the vaccine, this metric goes up by 1. If they receive a second dose, it goes up by 1 again. If they receive a third/booster dose, it goes up by 1 again.
* `total_vaccinations_per_hundred`: `total_vaccinations` per 100 people in the total population of the country.
* `daily_vaccinations_raw`: daily change in the total number of doses administered. It is only calculated for consecutive days. This is a raw measure provided for data checks and transparency, but we strongly recommend that any analysis on daily vaccination rates be conducted using `daily_vaccinations` instead.
* `daily_vaccinations`: new doses administered per day (7-day smoothed). For countries that don't report data on a daily basis, we assume that doses changed equally on a daily basis over any periods in which no data was reported. This produces a complete series of daily figures, which is then averaged over a rolling 7-day window. An example of how we perform this calculation can be found [here](https://github.com/owid/covid-19-data/issues/333#issuecomment-763015298).
* `daily_vaccinations_per_million`: `daily_vaccinations` per 1,000,000 people in the total population of the country.
* `people_vaccinated`: total number of people who received at least one vaccine dose. If a person receives the first dose of a 2-dose vaccine, this metric goes up by 1. If they receive the second dose, the metric stays the same.
* `people_vaccinated_per_hundred`: `people_vaccinated` per 100 people in the total population of the country.
* `people_fully_vaccinated`: total number of people who received all doses prescribed by the vaccination protocol. If a person receives the first dose of a 2-dose vaccine, this metric stays the same. If they receive the second dose, the metric goes up by 1.
* `people_fully_vaccinated_per_hundred`: `people_fully_vaccinated` per 100 people in the total population of the country.
* `total_boosters`: Total number of COVID-19 vaccination booster doses administered (doses
  administered beyond the number prescribed by the vaccination protocol)
* `total_boosters_per_hundred`: Total number of COVID-19 vaccination booster doses administered per 100 people in the total population.

Note: for `people_vaccinated` and `people_fully_vaccinated` we are dependent on the necessary data being made available, so we may not be able to make these metrics available for some countries.

In [None]:
#Download data
import urllib.request

url = 'https://github.com/owid/covid-19-data/raw/master/public/data/vaccinations/vaccinations.csv'
urllib.request.urlretrieve(url, 'country_vaccinations.csv')

#Load data

data = spark.read.csv('country_vaccinations.csv', sep = ',', header = True, )
type(data)
data.limit(5).toPandas()

Once the DataFrame has been created it can be manipulated using few fuctions from *domain-specific language*, DSL from the API - Dataframe, Column, groupedData, etc. 

In te next subsections, the most common pyspark functions are shown, but the list of available functions is high. You can access [API Pyspark SQL](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html) to get detailed information as well as the definition of all the functions.

### Spark Schema<a name="id4"></a>

**`printSchema()`** function shows the DataFrame structure. This schema can be defined with **StrucType** that is a collection of **StructField**. Defines the name of the column (String), type of columns (DataType), if a column is *null* or not (*boolean*) and metadata if avaialable (Matadata). 

In [None]:
data.printSchema()

There are some types of data that are correct to us since they appear as String. We are going to structure our data by changing the schema.

You can refer to [this guide](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#module-pyspark.sql.types) for information on the supported data types in PySpark.

In [None]:

from pyspark.sql.types import *

data_schema = [
               StructField('location', StringType(), True),
               StructField('iso_code', StringType(), True),
               StructField('date', DateType(), True),
               StructField('total_vaccinations', FloatType(), True),
               StructField('people_vaccinations', FloatType(), False),
               StructField('people_fully_vaccinated', FloatType(), True),
               StructField('total_boosters', FloatType(), True),
               StructField('daily_vaccinations_raw', FloatType(), True),
               StructField('daily_vaccinations', FloatType(), True),
               StructField('total_vaccinations_per_hundred', FloatType(), True),
               StructField('people_vaccinated_per_hundred', FloatType(), True),
               StructField('people_fully_vaccinated_per_hundred', FloatType(), True),
               StructField('total_boosters_per_hundred', FloatType(), True),
               StructField('daily_vaccinations_per_million', FloatType(), True),
               StructField('vaccines', StringType(), True),
               StructField('source_name', StringType(), True),
               StructField('source_website', StringType(), True), 
            ]

final_struc = StructType(fields = data_schema)

data = spark.read.csv(
    'country_vaccinations.csv',
    sep = ',',
    header = True,
    schema = final_struc 
    )

data.printSchema()

### Inspect data<a name="id5"></a>

Methods to read data values: schema, dtypes, show, head, first, take, describe, columns, count, distinct, printSchema. All these functions are called from an object of typr `pyspark.sql.DataFrame`.

**schema**: This method is used to see the data schema (Dataframe). 
 - Column name
 - Data type in column
 - nullable = true/false

In [None]:
data.schema

**dtypes**: Returns the data type of each column as a list. String, double, int, etc.

In [None]:
data.dtypes

**head(n)**: Shows the first n rows as a *list*

In [None]:
data.head(5)

**show([n])**: Show the first n lines of data in the dataFrame format.  

In [None]:
#data.show()
data.show(5)

**first()**: Shows the first line in a row format. 

In [None]:
data.first()

**take(n)**: Returns the first *n* lines of data in a list format

In [None]:
data.take(2)

**describe()**: Shows an summary of statistics in data columns. 
- count: Number of values in columns
- Mean: Mean of column data values
- Stddev: Standard deviation of data
- Min: Min value
- Max: Max value

In [None]:
data.describe().show()

**columns**: Returns the list with the column names

In [None]:
data.columns

**df.column**: Select dataframe column 

Optionally, you can use this syntax: **df["columna"]**. 

In [None]:
date = data.date

**count()**: Returns the number of rows in data 

Can be called from a `pyspark.sql.GroupedData` object to indicate the numbre of values to group. 

In [None]:
data.count()

**distinct()**: Return the number of different rows 

In [None]:
data.distinct().count()
data.distinct().show(5)

### Column manipulation<a name="id6"></a>

In this section we will discover different methods to add, modify and delete columns in data. 

**Add column**: To add a new column you can call the function `withColumn(nombre_columna, data)`. This need two parameters: name of new column and data to add. 

In this case, we can create a new column duplicating the information of date.

In [None]:
data = data.withColumn("new_date", data["date"])

data.show(1)

Exercise: Add a new column with the different bewtween vaccinations from previous and current date. We will use the following methods:

1. pyspark.sql.window 
2. pyspark.sql.functions.lag
 pyspark.sql.functions.lag(col, offset=1, default=None)[source]

    Window function: returns the value that is offset rows before the current row, and defaultValue if there is less than offset rows before the current row. For example, an offset of one will return the previous row at any given point in the window partition.

In [None]:
from pyspark.sql.functions import lag, col
from pyspark.sql.window import Window

w = Window().partitionBy().orderBy(col("date"))

data = data.withColumn('prev_daily_vac',
                        lag(data['daily_vaccinations'])
                                 .over(w))

data = data.withColumn("vaccination_difference", data["daily_vaccinations"] - data["prev_daily_vac"])

data.show(1)

**Modify column**: To rename a column, we use the method `withColumnRenamed(old_name, new_name)`. This function just take as input parameters the name of the column to rename and the new name for the column. 

Vamos a renombrar múltiples columnas para hacerlas más entendibles:

In [None]:
#1ª opción utilizando el método withColumnRenamed:

data = data.withColumnRenamed("source_website", "source_url") \

data.printSchema()

**Remove column**: With `drop(columnName)` the column `columnName` is deleted 

In [None]:
# Eliminamos aquellas columnas que no nos interesan
data = data.drop("iso_code")

data.show(5)

### Adjust null or empty values

When a dataset from external sources is used, it is common to find empty or null values. One of the most common techniques to manipulate data is to remove these kind of values. This values may be null, NaN, empty, etc.

To check is a column has a null value, pySpark include the following functions: 
* **isnan(column)** is the function from the pyspark.sql.functions package to know if a column included nan values.
* **isNull()** is the function from the pyspark.sql.column to know if a column included null values.

There are different techniques to avoid this problem.

- **Delete** those rows with nulls in any column df.na.drop o data.nadrop()
- **Replace** an empty column in a row with 0.0 or big value or **Replace** data with **mean** o **meadian**.
- Select **most frequent** values from a column. It works fine with well classified data but it may introduce some *bias*. 
- Using **KNN**. *K-Nearest Neighbors* algorithm that uses classification to new data usinf some distance mertrics like Euclídea, Mahalanobis, Manhattan, Minkowski, Hamming, etc. This is the most effective method, but also the mos computionally difficult, and it must be understood before appllying. 

In [None]:
#data.na.drop()

In [None]:
from pyspark.sql.functions import mean

# Replace NA by Mean
#data.na.fill(data.select(mean(data['total_vac'])).collect()[0][0])

# Reemplazar los valores null por un nuevo valor
data = data.na.fill(0.0)

### Data queries

PySpark API and PySpark SQL includes a set of metos to perform queries over data, in a more or less easy way: *select, filter, between, when, like, groupby, aggregations*.  

**select(nombreColumna)**: It is used to show one or few columns receiving the name of the columns as a parameter. `pyspark.sql.DataFrame.select`

In [None]:
data.select("total_vaccinations").show(2)
data.select("location", "date", "total_vaccinations", "daily_vaccinations").show(3)

**filter(condition)**: Filter data based on a given condition. Different conditions can be set ising the following operands: AND(&), OR(|), o NOT(~). `pyspark.sql.DataFrame.filter`

What is selected in the following query?

\* `lit()` creates a new literal vale. 

In [None]:
from pyspark.sql.functions import col, lit

data.filter( (col('date') >= lit('2021-02-01')) & (col('date') <= lit('2021-02-15')) ).show(5)

**between(low_value,high_value)**: Returns *true* or *false* based on the values provided as parameters. `pyspark.sql.Column.between`

In [None]:
# Las filas con los valores de la columna vacunaciones_dia devolverán True como condición de filter.
data.filter( data.daily_vaccinations.between(100.0, 1000.0) ).show(5)

**when (condicion, valor)**: Function that returns `value` or `null` depending if the condition is true. Function of `pyspark.sql.functions.when`

In [None]:
# Muestra los valores 
from pyspark.sql.functions import when
data.select("location","total_vaccinations", "total_boosters_per_hundred", when(data.daily_vaccinations >= 1000.0, 1).otherwise(0).alias("daily_vaccinations")).show(10)
data.select("location","total_vaccinations", "total_boosters_per_hundred", when(data.daily_vaccinations >= 1000.0, 1).alias("daily_vaccinations")).show(10)

**like(expression)**: Returns values of column under a certain expression. `pyspark.sql.Column.like`

\* Use `rlike()` or `like()` for a regular expression.
Location starting by S

In [None]:
data.filter(col("location").like('S%')).select('location').distinct().show()

**groupBy(column_name)**: Group data by *column_name* introduced as input. This returns an object of `pyspark.sql.GroupedData` type that can be analized with mean, min, max, count, etc. 

In [None]:
print("Mean by location")
data.groupBy('location').mean().show(5, truncate=False)
print("MIN by location")
data.select("location","total_vaccinations", "people_vaccinations", "people_fully_vaccinated", "daily_vaccinations", "total_vaccinations_per_hundred").groupBy('location').min().show(5)
print("MAX by location")
data.select("location","total_vaccinations", "people_vaccinations", "people_fully_vaccinated", "daily_vaccinations", "total_vaccinations_per_hundred").groupBy('location').max().show(5)

**Aggregation**: Aggregation functions `agg` are used to join two operations in a column. This functions operates over a group of rows (groupedData) and calculates a unique value over that group.

In [None]:
from pyspark.sql.functions import min, max, mean, col,lit

data.filter( (col('date') >= lit('2021-02-01')) & (col('date') <= lit('2021-02-24')) )\
    .orderBy("location")\
    .groupBy("location")\
    .agg(min("date").alias("Date from"),
         max("date").alias("date to"),
         min("total_vaccinations").alias("Total Vaccinated from"),
         max("total_vaccinations").alias("Total vaccinates to"),
         mean("daily_vaccinations").alias("Mean of daily vaccinations")
        )\
    .show(10, truncate=False)

#### Data handling exercises: 

1. TOP 20 of most vaccinated. 
    1. Number of vaccinations
    2. Ratio of vaccinations 
    3. Data visualization
3. Vaccination data from any location in the last 4 weeks. 
    3. Visualize the vaccination progress regarding total of vaccinations
    4. **How many** locations are better?
4. Which is the country which has the best daily vaccination rate?
    1. Which is the country with the best vaccination rate today (or yesterday)?

In [None]:
#1. TOP 20 de los países con mayor número de vacunas suministradas.
from matplotlib import pyplot as plt
from pyspark.sql.functions import max, mean, desc
#data.select('pais', 'total_vac', 'nombre_vacunas')\
#                   .groupBy('pais')\
#                   .agg(max('total_vac').alias("max_vac"))\
#                   .orderBy(desc('max_vac')).show(20,truncate=False)

sec_df =  data.select('location', 'total_vaccinations')\
                   .groupBy('location')\
                   .agg(max('total_vaccinations').alias("max_vac"))\
                   .orderBy(desc('max_vac'))\
                   .limit(20)\
                   .toPandas()

sec_df.plot(kind = 'bar', x='location', y = sec_df.columns.tolist()[1:], 
                    figsize=(12, 6), ylabel = 'Total Vaccinations', xlabel = 'Locations')
plt.show()

In [None]:
#    1. ratio del total de vacunaciones
from matplotlib import pyplot as plt
from pyspark.sql.functions import max, mean, desc
data.select('location', 'total_vaccinations', "people_vaccinated_per_hundred")\
                   .groupBy('location')\
                   .agg(max('people_vaccinated_per_hundred').alias("max_ratio_vac"))\
                   .orderBy(desc('max_ratio_vac')).show(2,truncate=False)

sec_df =  data.select('location', 'total_vaccinations', "people_vaccinated_per_hundred")\
                   .groupBy('location')\
                   .agg(max('people_vaccinated_per_hundred').alias("max_ratio_vac"))\
                   .orderBy(desc('max_ratio_vac'))\
                   .limit(20)\
                   .toPandas()

sec_df.plot(kind = 'bar', x='location', y = sec_df.columns.tolist()[1:], 
                    figsize=(12, 6), ylabel = 'Ratio Total Vacunaciones', xlabel = 'Paises')
plt.show()

In [None]:
#2. ratio del total de personas vacunadas 

from pyspark.sql.functions import max, mean, desc
data.select('location', 'total_vaccinations', "people_fully_vaccinated_per_hundred")\
                   .groupBy('location')\
                   .agg(max('people_fully_vaccinated_per_hundred').alias("max_ratio_personas"))\
                   .orderBy(desc('max_ratio_personas')).show(2,truncate=False)

sec_df =  data.select('location', 'total_vaccinations', "people_fully_vaccinated_per_hundred")\
                   .groupBy('location')\
                   .agg(max('people_fully_vaccinated_per_hundred').alias("max_ratio_personas"))\
                   .orderBy(desc('max_ratio_personas'))\
                   .limit(20)\
                   .toPandas()

sec_df.plot(kind = 'bar', x='location', y = sec_df.columns.tolist()[1:], 
                    figsize=(12, 6), ylabel = 'Ratio Total Vacunaciones', xlabel = 'Paises')
plt.show()

In [None]:
data.printSchema()

In [None]:
#Spain in the last weeks
data.select('location', 'date', 'total_vaccinations', 'people_vaccinations', 'people_fully_vaccinated', 'daily_vaccinations', 'total_vaccinations_per_hundred', 'people_vaccinated_per_hundred', 'people_fully_vaccinated_per_hundred', 'daily_vaccinations_per_million')\
    .filter((data.location=="Spain") & ((data.date >= lit('2021-02-10')) & ((data.date) <= lit('2021-02-24'))))\
    .show()

spain_data = data.select('location', 'date', 'total_vaccinations', 'people_vaccinations', 'people_fully_vaccinated', 'daily_vaccinations', 'total_vaccinations_per_hundred', 'people_vaccinated_per_hundred', 'people_fully_vaccinated_per_hundred', 'daily_vaccinations_per_million')\
                 .filter((data.location=="Spain") & ((data.date >= lit('2021-02-10')) & ((data.date) <= lit('2021-02-24'))))


    #Visualizar el progreso de vacunación en España en función del total de vacunas suministradas desde el principio de vacunación. 
    # Corregir los valores nulos para que el gráfico no se desvirtue
import matplotlib.pyplot as plt 
from pyspark.sql.functions import when

spain_data = spain_data.withColumn("total_vaccinations", when(((spain_data.date == lit('2021-02-12')) | (spain_data.date == lit('2021-02-13'))),2423045.0).otherwise(spain_data.total_vaccinations))
spain_data = spain_data.withColumn("total_vaccinations", when(((spain_data.date == lit('2021-02-19')) | (spain_data.date == lit('2021-02-20'))),2936011.0).otherwise(spain_data.total_vaccinations))
    
spain_data.show()

spain_data_pandas = spain_data.toPandas()

plt.plot(spain_data_pandas.date, spain_data_pandas.total_vaccinations)

plt.xlabel("Fecha")
plt.ylabel("Total Vacunaciones (España)")
plt.title("Progresión de las vacunaciones en España")
plt.xticks(rotation=30)
    
    #¿Cuántos países están por delante de España en la vacunación?
from datetime import date
today = date.today()

# dd/mm/YY
#today = today.strftime("%Y-%m-%d")
#print("d1 =", today)

spain_vac = spain_data.filter(spain_data.date == lit('2021-02-21')).select("total_vaccinations").head()
total_vac = data.select('location', 'total_vaccinations')\
                  .groupBy('location')\
                  .agg(max('total_vaccinations').alias("max_vac"))\
                  .orderBy(desc('max_vac'))

print("Las vacunaciones en España son %d"% spain_vac)
total_vac.select('location', 'max_vac')\
    .filter(total_vac.max_vac >= spain_vac[0])\
    .orderBy(desc('max_vac')).count()

In [None]:
from pyspark.sql.functions import mean

#4. ¿Cuál es el país que vacuna más gente al día?
data.select('location', 'daily_vaccinations')\
    .groupBy('location')\
    .agg(mean('daily_vaccinations').alias("mean_vac"))\
    .orderBy(desc('mean_vac')).show(20,truncate=False)

#    ¿Cuál es el país que vacunó más gente el última día (del que hay records)?
data.filter(data.date=="2021-02-22").groupBy("location", "total_vaccinations")\
    .agg(max("total_vaccinations").alias("max_vac")).orderBy(desc("max_vac")).show()

### Machine Learning Library (MLlib)

We will apply the K-means Algorithm for our analysis, so we must import that library.

The data must be transform to one single column when every row in our DataFrame contains a vector using the function `VectorAssembler`. To create clusters in K-means we need to select the columns based on the parameters that we need to predict. 

In [None]:
data.printSchema()

In K means, we can select N dimensions to find clusters. This time, we will try to define the space taking into account the % of vaccinations per country and the daily vaccinations per million. This way, we will find groups of countries.

Let's prepare the data

In [None]:
from pyspark.ml.clustering import KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

df = data.select('location', 'people_fully_vaccinated_per_hundred', 'daily_vaccinations_per_million').groupBy("location")\
    .agg(max("people_fully_vaccinated_per_hundred").alias("max_vac"), mean("daily_vaccinations_per_million").alias("mean_daily_vac")).orderBy(desc("max_vac"))
# Tener en cuenta que los datos no sean strings
feat_cols = ["max_vac",  "mean_daily_vac"]

vec_assembler = VectorAssembler(inputCols = feat_cols, outputCol='features')

df = df.withColumnRenamed("location", "id")
df = vec_assembler.transform(df)
df.show()

for_plot = df.toPandas()

plt.plot(for_plot.max_vac,  for_plot.mean_daily_vac, 'o', color='black')

To run k-means we just need to have two columns, id (or location) and features. Let's drop the rest and rename location to id

In [None]:
km_data = df.drop('max_vac', 'mean_daily_vac')


In [None]:
km_data.show()

### Optimize choice of k
One disadvantage of KMeans compared to more advanced clustering algorithms is that the algorithm must be told how many clusters, k, it should try to find. To optimize k we cluster a fraction of the data for different choices of k and look for an "elbow" in the cost function.

In [None]:
import numpy as np
from pyspark import SparkContext
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SQLContext
from pyspark.ml.evaluation import ClusteringEvaluator

cost = np.zeros(20)
for k in range(2,20):
    kmeans = KMeans().setK(k).setSeed(1).setFeaturesCol("features")
    model = kmeans.fit(km_data.sample(False,0.1, seed=42))
    # Make predictions
    predictions = model.transform(km_data)

    # Evaluate clustering by computing Silhouette score
    evaluator = ClusteringEvaluator()

    silhouette = evaluator.evaluate(predictions)
    print("Silhouette with squared euclidean distance = " + str(silhouette))

In [None]:
k = 4
kmeans = KMeans().setK(k).setSeed(1).setFeaturesCol("features")
model = kmeans.fit(km_data)
centers = model.clusterCenters()

print("Cluster Centers: ")
for center in centers:
    print(center)

### Assign clusters to events
Their is one import thing left to do; assigning the individual rows to the nearest cluster centroid. That can be done with the transform method, which adds 'prediction' column to the dataframe. The prediction value is an integer between 0 and k, but it has no correlation to the y value of the input.

In [None]:
transformed = model.transform(km_data).select('id', 'prediction')
rows = transformed.collect()
print(rows[:3])

From the rows returned by the collect method it is trivial to create a new dataframe using our SQL context.

In [None]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(spark)
df_pred = sqlContext.createDataFrame(rows)
df_pred.show()

In [None]:
df_pred = df_pred.join(df, 'id')
df_pred.show()

### Convert to Pandas dataframe
Typically at this point I would need to do something else with the data, which does not require Spark, so let's convert the Spark dataframe to a good old Pandas dataframe for further processing.

In [None]:
pddf_pred = df_pred.toPandas().set_index('id')
pddf_pred.head()

### Visualize the results
The final step is to visually inspect the output to see if the KMeans model did a good job or not. Comparing with the first figure it is clear that most clusters were indeed found, but the left blue cluster should have been split in two and the orange+brown clusters should have been only one cluster.

In [None]:
threedee = plt.figure(figsize=(12,10)).gca(projection='3d')
threedee.scatter(pddf_pred.max_vac, pddf_pred.mean_daily_vac, c=pddf_pred.prediction)
threedee.set_xlabel('max_vac')
threedee.set_ylabel('mean_daily_vac')
plt.show()

In [None]:
colors = pddf_pred.prediction
x = pddf_pred.max_vac
y = pddf_pred.mean_daily_vac

plt.scatter(x, y, c=colors)
plt.show()