# Train a Clustering Model with Amazon SageMaker using pySpark
Amazon SageMaker provides an Apache Spark library (in both Python and Scala) that you can use to integrate your Apache Spark applications with SageMaker. For example, you might use Apache Spark for data preprocessing and SageMaker for model training and hosting. 

1. [Setup](#section1)
2. [Prepare training and test data](#section2)
3. [Train the model](#section3)
4. [Call the Inference Endpoint](#section4)
5. [Analyze results](#section5)

<a id="section1"></a>
## 1. Setup

In [1]:
from pyspark import SparkContext, SparkConf
from sagemaker_pyspark import IAMRole, classpath_jars
from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator
from pyspark.sql.functions import * 

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
27,application_1616621397322_0028,pyspark,idle,,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<p style="color:red">! Insert your AWS Account ID bellow and the bucket name</p>
Example: <br>
AWS_ACCOUNT_ID = '000000000000'<br>
S3_BUCKETNAME = 'samplebucket'

In [78]:
AWS_ACCOUNT_ID = ""
S3_BUCKETNAME = ''
S3_PREFIX = 'pysparklab'
region = "us-east-2"
iam_role = "arn:aws:iam::{}:role/sgmrole-pyspark".format(AWS_ACCOUNT_ID)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<br><br>
<a id="section2"></a>
# 2. Prepare training and test data

## 2.1 Load the data into a DataFrame
Training data

In [3]:
path = "s3://sagemaker-sample-data-{}/spark/mnist/train/".format(region)
training_data = spark.read.format("libsvm").option("numFeatures", "784").load(path)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
training_data.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

60000

<br><br>
Test data

In [5]:
path = "s3://sagemaker-sample-data-{}/spark/mnist/test/".format(region)
test_data = spark.read.format("libsvm").option("numFeatures", "784").load(path)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
test_data.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

10000

<br><br>
## 2.2 Data schema
You can use the same Spark functions you are familiar with to explore and prepare the data

In [7]:
training_data.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)

<br><br>
In each row:
* The label column identifies the image's label. If the image of the handwritten number is the digit 5, the label value is 5.
* The features column stores a vector of Double values. These are the 784 features of the handwritten number. (Each handwritten number is a 28 x 28-pixel image, making 784 features.)

In [8]:
training_data.show(1, False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

<br><br>
### 2.3 Find the label distribution

In [16]:
training_data.groupby("label").count().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----+-----+
|label|count|
+-----+-----+
|  8.0| 5851|
|  0.0| 5923|
|  7.0| 6265|
|  1.0| 6742|
|  4.0| 5842|
|  3.0| 6131|
|  2.0| 5958|
|  6.0| 5918|
|  5.0| 5421|
|  9.0| 5949|
+-----+-----+

<br><br>
<a id="section3"></a>
# 3. Train the Model

### 3.1 Create the SageMaker Estimator
We will use the k-means algorithm provided by SageMaker. The <b>KMeansSageMakerEstimator</b> extend the SageMakerEstimator class, which handles end-to-end Amazon SageMaker training and deployment tasks.

We provide information such as training and endpoint instance type.
<br><br>
You can [find more about the SageMaker K-Means algorythm on the docs](https://docs.aws.amazon.com/sagemaker/latest/dg/algo-kmeans-tech-notes.html)

In [20]:
kmeans_estimator = KMeansSageMakerEstimator(
    trainingInstanceType="ml.m4.xlarge",
    trainingInstanceCount=1,
    endpointInstanceType="ml.m4.xlarge",
    endpointInitialInstanceCount=1,
    sagemakerRole=IAMRole(iam_role))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<br><br>
### 3.2 Set hyperparameters

Especify feature size with <b>setFeatureDim()</b>

In [21]:
kmeans_estimator.setK(10)
kmeans_estimator.setFeatureDim(784)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<br><br>
### 3.3 Call the fit method
The <b>fit method</b> of this estimator train models using an input DataFrame. In response, it returns a SageMakerModel object that you can use to get inferences.

In [22]:
kmeans_model = kmeans_estimator.fit(training_data)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<br><br>
### 3.4 Monitor the training job on the SageMaker Console
Open the <a href="https://us-east-2.console.aws.amazon.com/sagemaker/home?#/jobs">Amazon SageMaker console </a> and observe how an Amazon SageMaker Training job is created and running
<br><br>
<img src="https://github.com/lgbaeza/mycloudstuff/raw/main/aws/by-service/emr/pysparknotebook/img/training-jobs.png">

<br><br>
### 3.5 Find the trained model
When the training job is complete, open the <a href="https://console.aws.amazon.com/sagemaker/home?#/models">Amazon SageMaker Console</a> to find the trained model
<img src="https://github.com/lgbaeza/mycloudstuff/raw/main/aws/by-service/emr/pysparknotebook/img/models.png">

<br><br>
<a id="section4"></a>
# 4. Inference

To get inferences from a model hosted in SageMaker, you call the <b>SageMakerModel.transform</b> method. The transform method simplifies the inference process by doing the following under-the-hood:
* Receives a DataFrame as input
* Serializes the features column to protobuf and sends it to the SageMaker endpoint for inference
* Deserializes the protobuf response into the two additional columns (distance_to_cluster and closest_cluster)
* Transforms the inferences obtained from the model in the transformed DataFrame
* Returns another DataFrame containing inferences obtained from the model

### 4.1 Find the SageMaker endpoint
Open the [Amazon SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/endpoints) to find the ednpoint created by the <b>fit method</b>
<img src="https://github.com/lgbaeza/mycloudstuff/raw/main/aws/by-service/emr/pysparknotebook/img/endpoints.png">

<br><br>
### 4.2 Call the Transform() method

In [24]:
transformed_data = kmeans_model.transform(test_data)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<br><br>
### 4.3 Work with the inference results
Now you can work on the inference results the same way you will with any Spark DataFrame

In [25]:
transformed_data.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----+--------------------+-------------------+---------------+
|label|            features|distance_to_cluster|closest_cluster|
+-----+--------------------+-------------------+---------------+
|  7.0|(784,[202,203,204...|  1396.271240234375|            9.0|
|  2.0|(784,[94,95,96,97...|  2037.256103515625|            8.0|
|  1.0|(784,[128,129,130...| 1037.6932373046875|            2.0|
|  0.0|(784,[124,125,126...| 1624.5045166015625|            5.0|
|  4.0|(784,[150,151,159...| 1375.3685302734375|            1.0|
+-----+--------------------+-------------------+---------------+
only showing top 5 rows

<br><br>
<a id="section5"></a>
# 5 Analyze the results
### 5.1 Compare clusters vs labels
Lets find the top 15 closest_cluster with their corresponding labels. 

In [45]:
topClusters = transformed_data.groupby("closest_cluster","label").count().sort(col("count").desc()).take(15)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

You will find that the label and cluster are not the same, since the clustering is an unsupervised algorythm and doesn't take into account the labels

In [83]:
TopClustersDf = spark.createDataFrame(topClusters)
TopClustersDf.collect()
TopClustersDf.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+-----+-----+
|closest_cluster|label|count|
+---------------+-----+-----+
|            5.0|  0.0|  802|
|            3.0|  6.0|  752|
|            2.0|  1.0|  750|
|            6.0|  2.0|  736|
|            8.0|  3.0|  724|
|            0.0|  8.0|  591|
|            1.0|  4.0|  485|
|            9.0|  7.0|  482|
|            1.0|  9.0|  418|
|            9.0|  9.0|  405|
|            4.0|  1.0|  379|
|            0.0|  5.0|  321|
|            7.0|  7.0|  285|
|            8.0|  5.0|  265|
|            9.0|  4.0|  238|
+---------------+-----+-----+

You can interpret the data, as follows:
* A handwritten number with the label 5 belongs to cluster 4 (closest_cluster).
* A handwritten number with the label 0 belongs to cluster 5.
* A handwritten number with the label 4 belongs to cluster 9.

<br><br>
To find more about this example visit [the AWS SageMaker Github](https://github.com/aws/sagemaker-spark/blob/master/README.md)