# Start Spark Application

In [1]:
import org.apache.spark.ml._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._

val s3Prefix = "s3://"
val s3BucketName = "dalin-ml-pipeline"
val dataSourcePath = "/transformed-csv/*.csv"

val sageMakerInputPrefix = "sagemaker/trainingInput"
val sageMakerOutputPrefix = "sagemaker/trainingOutput/XGBoost"
val sageMakerRoleArn = "arn:aws:iam::263690384742:role/SparkSageMakerRole"

val spark = SparkSession.builder.getOrCreate

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
0,application_1574882374057_0001,spark,idle,Link,Link,✔


import org.apache.spark.ml._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
s3Prefix: String = s3://
s3BucketName: String = dalin-ml-pipeline
dataSourcePath: String = /transformed-csv/*.csv
sageMakerInputPrefix: String = sagemaker/trainingInput
sageMakerOutputPrefix: String = sagemaker/trainingOutput/XGBoost
sageMakerRoleArn: String = arn:aws:iam::263690384742:role/SparkSageMakerRole
spark: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@7e62ffb1


***
***
# Load Processed Data from S3 

In [2]:
import spark.implicits._
val originalDF = spark.read.format("csv").
    option("header", "true").
    load(s"$s3Prefix$s3BucketName$dataSourcePath").
    withColumnRenamed("TOTCHG", "label")

import spark.implicits._
originalDF: org.apache.spark.sql.DataFrame = [AGE: string, AGE_NEONATE: string ... 96 more fields]


In [3]:
originalDF.printSchema

root
 |-- AGE: string (nullable = true)
 |-- AGE_NEONATE: string (nullable = true)
 |-- AMONTH: string (nullable = true)
 |-- AWEEKEND: string (nullable = true)
 |-- DIED: string (nullable = true)
 |-- DISCWT: string (nullable = true)
 |-- DISPUNIFORM: string (nullable = true)
 |-- DQTR: string (nullable = true)
 |-- DRG: string (nullable = true)
 |-- DRGVER: string (nullable = true)
 |-- DRG_NoPOA: string (nullable = true)
 |-- DXVER: string (nullable = true)
 |-- ELECTIVE: string (nullable = true)
 |-- FEMALE: string (nullable = true)
 |-- HCUP_ED: string (nullable = true)
 |-- HOSP_DIVISION: string (nullable = true)
 |-- HOSP_NIS: string (nullable = true)
 |-- I10_DX1: string (nullable = true)
 |-- I10_DX2: string (nullable = true)
 |-- I10_DX3: string (nullable = true)
 |-- I10_DX4: string (nullable = true)
 |-- I10_DX5: string (nullable = true)
 |-- I10_DX6: string (nullable = true)
 |-- I10_DX7: string (nullable = true)
 |-- I10_DX8: string (nullable = true)
 |-- I10_DX9: string 

***
***
# Filter Data with Single Dianosis Category and Single Procedure
> <b>ICD-10-CM</b> refers to International Classification of Diseases, 10th Revision, Clinical Modification provided by the Centers for Medicare and Medicaid Services and the National Center for Health Statistics, for medical coding and reporting in the United States. (Wikipedia) <br>Similarly, <b>ICD-10-PCS</b> refers to the procedures. 

>The first three characters of an ICD-10 code designate the category of the diagnosis.

Link to [ICD Codes](https://icd.codes)

In [4]:
import spark.implicits._
val oneDiagDF = originalDF.filter($"I10_DX2".isNull && 
                                  $"I10_DX1".isNotNull && 
                                  $"I10_PR2".isNull && 
                                  $"I10_PR1".isNotNull)

import spark.implicits._
oneDiagDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AGE: string, AGE_NEONATE: string ... 96 more fields]


In [5]:
oneDiagDF.count

res5: Long = 68685


In [6]:
oneDiagDF.select("I10_DX1").groupBy("I10_DX1").count.sort(desc("count")).show

+-------+-----+
|I10_DX1|count|
+-------+-----+
|  Z3800|19258|
|  Z3801| 7376|
|  K3580| 4386|
|   P599| 2203|
|   K352| 1430|
|  M1711| 1230|
|  M1611| 1217|
|   K353| 1194|
|  M1712| 1098|
|  M1612| 1027|
|  K8000|  895|
|   Q400|  512|
|   M179|  506|
|S42412A|  344|
|  Z3831|  328|
|   K810|  274|
|   P593|  266|
|  K8012|  222|
|   P819|  220|
|  K3589|  217|
+-------+-----+
only showing top 20 rows



***
***
# Filter Data with Z38 Diagnosis Category
### (ICD-10-CM) Z38: Liveborn infants according to place of birth and type of delivery

In [7]:
val z38DF = oneDiagDF.filter($"I10_DX1".contains("Z38"))
z38DF.count

z38DF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AGE: string, AGE_NEONATE: string ... 96 more fields]
res7: Long = 27118


***
***
# Convert and Clean Data
### Convert all 14 numeric columns to double type and remove invalid/missing values

In [8]:
val numericColumns = Array("label","AGE","AGE_NEONATE","AMONTH","AWEEKEND",
                           "DIED","DQTR","ELECTIVE","FEMALE","HCUP_ED",
                           "I10_NDX","I10_NECAUSE","I10_NPR","LOS")

val toDouble = udf((s: String, lowerLimit: Int) => {
  if (s != null && (s forall Character.isDigit) && s.toDouble >= lowerLimit) s.toDouble else -1
})

var convertedDF = z38DF
for (colName <- numericColumns) {
  convertedDF = convertedDF.withColumn(colName, toDouble(col(colName), lit(0)))
}
var filteredDF = convertedDF
for (colName <- numericColumns) {
  filteredDF = filteredDF.filter(col(colName) >= 0)
}

filteredDF.persist

numericColumns: Array[String] = Array(label, AGE, AGE_NEONATE, AMONTH, AWEEKEND, DIED, DQTR, ELECTIVE, FEMALE, HCUP_ED, I10_NDX, I10_NECAUSE, I10_NPR, LOS)
toDouble: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,DoubleType,Some(List(StringType, IntegerType)))
convertedDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AGE: string, AGE_NEONATE: string ... 96 more fields]
filteredDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AGE: double, AGE_NEONATE: double ... 96 more fields]
res13: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AGE: double, AGE_NEONATE: double ... 96 more fields]


In [9]:
filteredDF.count

res14: Long = 26926


In [10]:
filteredDF.select("label").describe().show

+-------+-----------------+
|summary|            label|
+-------+-----------------+
|  count|            26926|
|   mean|4949.703706454728|
| stddev|15824.61485217541|
|    min|            125.0|
|    max|         984109.0|
+-------+-----------------+



***
***
# Select Features
### Simple Correlation-based Selection

In [11]:
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.stat.ChiSquareTest

for (colName <- numericColumns) {
  println(s"with ${colName}: " + filteredDF.stat.corr("label", colName))
}

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.stat.ChiSquareTest
with label: 1.0
with AGE: NaN
with AGE_NEONATE: -4.992307632583083E-4
with AMONTH: 0.022599905130392093
with AWEEKEND: -0.010414153296519459
with DIED: 0.09640357162165143
with DQTR: 0.016621095690527915
with ELECTIVE: -6.748808916453236E-4
with FEMALE: 0.043313805347155614
with HCUP_ED: 0.010142957421192722
with I10_NDX: NaN
with I10_NECAUSE: 0.1482466540380826
with I10_NPR: NaN
with LOS: 0.8254548216058005


***
***
# Understand Selected Features<br>
## Nominal Feature: ICD-10CM
### Subcategory of Z38

1. Z3800: Single liveborn infant, delivered vaginally
2. Z3801: Single liveborn infant, delivered by cesarean
3. Z3831: Twin liveborn infant, delivered by cesarean
4. Z3830: Twin liveborn infant, delivered vaginally
5. Z381: Single liveborn infant, born outside hospital
6. Z382: Single liveborn infant, unspecified as to place of birth
7. Z3862: Triplet liveborn infant, delivered by cesarean
8. Z384: Twin liveborn infant, born outside hospital

In [12]:
filteredDF.select("I10_DX1").groupBy("I10_DX1").count.sort(desc("count")).show

+-------+-----+
|I10_DX1|count|
+-------+-----+
|  Z3800|19117|
|  Z3801| 7328|
|  Z3831|  325|
|  Z3830|  104|
|   Z381|   40|
|   Z382|    6|
|  Z3862|    5|
|   Z384|    1|
+-------+-----+



***
## Nominal Feature: ICD-10-PCS
### List of Single Procedure Associated with Z38
### Top 5 ICD-10-PCS Procedures
1. 0VTTXZZ: Resection of Prepuce, External Approach
2. F13Z0ZZ: Hearing Screening Assessment
3. 3E0234Z: Introduction of Serum, Toxoid and Vaccine into Muscle, Percutaneous Approach
4. F13ZM6Z: Evoked Otoacoustic Emissions, Screening Assessment using Otoacoustic Emission (OAE) Equipment
5. F13Z01Z: Hearing Screening Assessment using Audiometer

<p><br><br>Some of the other main procedure categories include:</p>

- B24: Ultrasonography
- 0VJ: Inspection
- 4A0: Measurement
- 5A0: Assistance
- 6A6: Phototherapy
- 6A8: Ultraviolet Light Therapy

In [13]:
filteredDF.select("I10_PR1").groupBy("I10_PR1").count.sort(desc("count")).show(81, false)

+-------+-----+
|I10_PR1|count|
+-------+-----+
|0VTTXZZ|21085|
|F13Z0ZZ|2040 |
|3E0234Z|1644 |
|F13ZM6Z|888  |
|F13Z01Z|297  |
|3E0134Z|185  |
|F13ZL7Z|162  |
|3E00X4Z|78   |
|F13ZLZZ|72   |
|F13ZN6Z|48   |
|5A09357|45   |
|F13Z3ZZ|40   |
|6A600ZZ|40   |
|6A601ZZ|38   |
|F13ZMZZ|28   |
|0VBTXZZ|26   |
|3E023GC|17   |
|F13ZQKZ|17   |
|6A801ZZ|14   |
|3E0334Z|13   |
|8E0KXY7|13   |
|0CN7XZZ|12   |
|0VTT0ZZ|11   |
|0BH17EZ|8    |
|F13Z08Z|7    |
|B24DZZZ|7    |
|3E0F7GC|6    |
|6A800ZZ|5    |
|4A03XR1|4    |
|F13ZQZZ|4    |
|5A09457|4    |
|4A05XLZ|4    |
|069Y3ZZ|3    |
|0BH18EZ|3    |
|3E0434Z|2    |
|5A0935Z|2    |
|0CB7XZZ|2    |
|B24BZZZ|2    |
|3E0236Z|2    |
|06HY33Z|2    |
|3E0336Z|2    |
|5A09557|2    |
|0VTT4ZZ|2    |
|0YQA0ZZ|2    |
|5A1935Z|2    |
|0VNTXZZ|1    |
|4A043R1|1    |
|03973ZZ|1    |
|021Q0JB|1    |
|0VJSXZZ|1    |
|059Y3ZZ|1    |
|5A0955Z|1    |
|3E0604Z|1    |
|3E02329|1    |
|0WQF0ZZ|1    |
|5A2204Z|1    |
|F13ZNZZ|1    |
|10D00Z1|1    |
|3E1H38Z|1    |
|BT0BYZZ

In [14]:
filteredDF.select("I10_PR1").distinct.count

res20: Long = 81


***
## Nominal Feature: HOSP_DIVISION
### Census Division of Hospital
- Division 1 (New England): Maine, New Hampshire, Vermont, Massachusetts, Rhode Island, Connecticut
- Division 2 (Mid-Atlantic): New York, Pennsylvania, New Jersey
- Division 3 (East North Central): Wisconsin, Michigan, Illinois, Indiana, Ohio
- Division 4 (West North Central): Missouri, North Dakota, South Dakota, Nebraska, Kansas, Minnesota, Iowa
- Division 5 (South Atlantic): Delaware, Maryland, District of Columbia, Virginia, West Virginia, North Carolina, South Carolina, Georgia, Florida
- Division 6 (East South Central) Kentucky, Tennessee, Mississippi, Alabama
- Division 7 (West South Central) Oklahoma, Texas, Arkansas, Louisiana
- Division 8 (Mountain) Idaho, Montana, Wyoming, Nevada, Utah, Colorado, Arizona, New Mexico
- Division 9 (Pacific) Alaska, Washington, Oregon, California, Hawaii

In [15]:
filteredDF.select("HOSP_DIVISION").groupBy("HOSP_DIVISION").count.sort(desc("count")).show

+-------------+-----+
|HOSP_DIVISION|count|
+-------------+-----+
|            3| 6494|
|            5| 5406|
|            2| 4654|
|            4| 3483|
|            7| 2548|
|            6| 1790|
|            9| 1057|
|            8|  995|
|            1|  499|
+-------------+-----+



***
## Numeric Feature: LOS
### Length of Stay
<p>Calculated by subtracting the admission date from the discharge date</p>

In [16]:
filteredDF.select("LOS").describe().show

+-------+------------------+
|summary|               LOS|
+-------+------------------+
|  count|             26926|
|   mean|2.0655500259971773|
| stddev| 2.062796005750767|
|    min|               0.0|
|    max|             182.0|
+-------+------------------+



In [17]:
filteredDF.filter($"LOS" > 30).select("I10_DX1","I10_PR1","HOSP_DIVISION","LOS").show

+-------+-------+-------------+-----+
|I10_DX1|I10_PR1|HOSP_DIVISION|  LOS|
+-------+-------+-------------+-----+
|  Z3831|5A09557|            5| 38.0|
|   Z384|0BH18EZ|            5| 53.0|
|  Z3862|04HY33Z|            5| 55.0|
|  Z3801|3E0F7GC|            5|112.0|
|  Z3831|0BH17EZ|            5| 78.0|
|  Z3801|3E0F7GC|            5|182.0|
|  Z3801|0VTTXZZ|            5| 46.0|
|  Z3801|6A601ZZ|            5| 32.0|
|  Z3831|6A601ZZ|            5| 65.0|
|  Z3801|06H033T|            2| 90.0|
|  Z3831|0YQA0ZZ|            2| 75.0|
|  Z3862|02HV33Z|            2| 72.0|
|  Z3800|021Q0JB|            2| 38.0|
|  Z3800|6A601ZZ|            2| 62.0|
|  Z3831|0BH17EZ|            2| 38.0|
|  Z3831|6A601ZZ|            2| 59.0|
|  Z3801|5A09557|            2| 52.0|
+-------+-------+-------------+-----+



***
***
# Prepare Data using Spark Pipeline

- indexes nominal features and maps each of those featurse to a binary vector (sparse vector).
- combines the list of feature columns into a single vector column.

In [18]:
import scala.collection.mutable.ArrayBuffer

// array of Spark Pipeline Stages
var stages = ArrayBuffer[PipelineStage]()

// transforms nominal features
val charColumns = Array("I10_DX1","I10_PR1","HOSP_DIVISION")
for(colName <- charColumns) {
  var indexer = new StringIndexer()
    .setInputCol(colName)
    .setOutputCol(colName+"_IND")
  var encoder = new OneHotEncoderEstimator()
    .setInputCols(Array(indexer.getOutputCol))
    .setOutputCols(Array(colName+"_ENC"))
    .setHandleInvalid("keep")
  
  stages += indexer
  stages += encoder
}

// final feature columns
val featureColumns = Array("LOS","I10_DX1_ENC","I10_PR1_ENC","HOSP_DIVISION_ENC")

val assembler = new VectorAssembler().
    setInputCols(featureColumns).
    setOutputCol("features")

stages += assembler

import scala.collection.mutable.ArrayBuffer
stages: scala.collection.mutable.ArrayBuffer[org.apache.spark.ml.PipelineStage] = ArrayBuffer()
charColumns: Array[String] = Array(I10_DX1, I10_PR1, HOSP_DIVISION)
featureColumns: Array[String] = Array(LOS, I10_DX1_ENC, I10_PR1_ENC, HOSP_DIVISION_ENC)
assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_51844443de59
res33: scala.collection.mutable.ArrayBuffer[org.apache.spark.ml.PipelineStage] = ArrayBuffer(strIdx_9f05a57541ae, oneHotEncoder_488bc5503659, strIdx_8c860ec532a1, oneHotEncoder_51926bfde7c9, strIdx_5d6f7f011aa8, oneHotEncoder_940669a4f1d3, vecAssembler_51844443de59)


In [19]:
val pipeline = new Pipeline().setStages(stages.toArray)
val preparedDF = pipeline.fit(filteredDF).transform(filteredDF)

pipeline: org.apache.spark.ml.Pipeline = pipeline_46c3cabd7a3a
preparedDF: org.apache.spark.sql.DataFrame = [AGE: double, AGE_NEONATE: double ... 103 more fields]


In [20]:
preparedDF.select("label","features").show(false)

+------+---------------------------------+
|label |features                         |
+------+---------------------------------+
|5933.0|(99,[0,2,9,90],[3.0,1.0,1.0,1.0])|
|4641.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4839.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4179.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4179.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|2329.0|(99,[0,1,9,90],[1.0,1.0,1.0,1.0])|
|4259.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4203.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4135.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4647.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4179.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|2650.0|(99,[0,1,9,90],[1.0,1.0,1.0,1.0])|
|2444.0|(99,[0,1,9,90],[1.0,1.0,1.0,1.0])|
|4659.0|(99,[0,2,9,90],[2.0,1.0,1.0,1.0])|
|4573.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4683.0|(99,[0,1,9,90],[2.0,1.0,1.0,1.0])|
|4675.0|(99,[0,2,9,90],[2.0,1.0,1.0,1.0])|
|2440.0|(99,[0,1,9,90],[1.0,1.0,1.0,1.0])|
|2222.0|(99,[0,1,9,90],[1.0,1.0,1.0,1.0])|
|3506.0|(99,[0,2,9,90],[2.0,1.0,1.0,1.0])|
+------+---

***
***
### Check String Index and Encoded Vector for Nominal Features

In [21]:
preparedDF.select("I10_DX1","I10_DX1_IND","I10_DX1_ENC").dropDuplicates.show

+-------+-----------+-------------+
|I10_DX1|I10_DX1_IND|  I10_DX1_ENC|
+-------+-----------+-------------+
|   Z382|        5.0|(8,[5],[1.0])|
|  Z3862|        6.0|(8,[6],[1.0])|
|   Z381|        4.0|(8,[4],[1.0])|
|  Z3800|        0.0|(8,[0],[1.0])|
|  Z3801|        1.0|(8,[1],[1.0])|
|   Z384|        7.0|(8,[7],[1.0])|
|  Z3831|        2.0|(8,[2],[1.0])|
|  Z3830|        3.0|(8,[3],[1.0])|
+-------+-----------+-------------+



In [22]:
preparedDF.select("I10_PR1","I10_PR1_IND","I10_PR1_ENC").dropDuplicates.show

+-------+-----------+---------------+
|I10_PR1|I10_PR1_IND|    I10_PR1_ENC|
+-------+-----------+---------------+
|F13ZL7Z|        6.0| (81,[6],[1.0])|
|059Y3ZZ|       67.0|(81,[67],[1.0])|
|5A0955Z|       78.0|(81,[78],[1.0])|
|009U3ZX|       48.0|(81,[48],[1.0])|
|3E023GC|       17.0|(81,[17],[1.0])|
|10D00Z1|       63.0|(81,[63],[1.0])|
|5A09457|       31.0|(81,[31],[1.0])|
|00U107Z|       76.0|(81,[76],[1.0])|
|0VNTXZZ|       69.0|(81,[69],[1.0])|
|3E02329|       61.0|(81,[61],[1.0])|
|3E0504Z|       54.0|(81,[54],[1.0])|
|F13ZNZZ|       53.0|(81,[53],[1.0])|
|4A02XFZ|       73.0|(81,[73],[1.0])|
|5A09358|       72.0|(81,[72],[1.0])|
|8E0KXY7|       19.0|(81,[19],[1.0])|
|5A2204Z|       66.0|(81,[66],[1.0])|
|0H51XZZ|       49.0|(81,[49],[1.0])|
|4A03XR1|       28.0|(81,[28],[1.0])|
|5A12012|       74.0|(81,[74],[1.0])|
|0VBTXZZ|       15.0|(81,[15],[1.0])|
+-------+-----------+---------------+
only showing top 20 rows



In [23]:
preparedDF.select("HOSP_DIVISION","HOSP_DIVISION_IND","HOSP_DIVISION_ENC").dropDuplicates.show

+-------------+-----------------+-----------------+
|HOSP_DIVISION|HOSP_DIVISION_IND|HOSP_DIVISION_ENC|
+-------------+-----------------+-----------------+
|            2|              2.0|    (9,[2],[1.0])|
|            6|              5.0|    (9,[5],[1.0])|
|            7|              4.0|    (9,[4],[1.0])|
|            8|              7.0|    (9,[7],[1.0])|
|            9|              6.0|    (9,[6],[1.0])|
|            3|              0.0|    (9,[0],[1.0])|
|            4|              3.0|    (9,[3],[1.0])|
|            5|              1.0|    (9,[1],[1.0])|
|            1|              8.0|    (9,[8],[1.0])|
+-------------+-----------------+-----------------+



***
***
# Split Training and Testing Data

In [24]:
val randomSplitDSs: Array[Dataset[Row]] = preparedDF.select("label","features").randomSplit(Array(0.7, 0.3), 11)
val trainingDS: Dataset[Row] = randomSplitDSs(0)
val testingDS: Dataset[Row] = randomSplitDSs(1)

randomSplitDSs: Array[org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]] = Array([label: double, features: vector], [label: double, features: vector])
trainingDS: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
testingDS: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]


In [25]:
testingDS.select("label","features").
  limit(100).
  write.
  format("libsvm").
  option("header", "true").
  mode("overwrite").
  save("s3://dalin-ml-pipeline/testing")

***
***
# Train and Build SageMaker Model

In [26]:
import com.amazonaws.services.sagemaker.sparksdk.{CustomNamePolicyFactory, EndpointCreationPolicy, IAMRole, S3DataPath, SageMakerModel}
import com.amazonaws.services.sagemaker.sparksdk.algorithms.XGBoostSageMakerEstimator

import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
val uid = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now)

val xgboostEstimator = new XGBoostSageMakerEstimator(
  sagemakerRole=IAMRole(sageMakerRoleArn),
  trainingInstanceType = "ml.m5.xlarge",
  trainingInstanceCount = 1,
  endpointInstanceType = "ml.m5.xlarge",
  endpointInitialInstanceCount = 1,
  trainingInputS3DataPath = S3DataPath(s3BucketName, sageMakerInputPrefix),
  trainingOutputS3DataPath = S3DataPath(s3BucketName, sageMakerOutputPrefix),
  endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_CONSTRUCT, // DO_NOT_CREATE
  namePolicyFactory = new CustomNamePolicyFactory(s"Z38-training-$uid",
                                                  s"Z38-model-$uid",
                                                  "Z38-endpointConfig",
                                                  "Z38-endpoint")
)
xgboostEstimator.setNumRound(15)
xgboostEstimator.setObjective("reg:linear")

import com.amazonaws.services.sagemaker.sparksdk.{CustomNamePolicyFactory, EndpointCreationPolicy, IAMRole, S3DataPath, SageMakerModel}
import com.amazonaws.services.sagemaker.sparksdk.algorithms.XGBoostSageMakerEstimator
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
uid: String = 201911271926
xgboostEstimator: com.amazonaws.services.sagemaker.sparksdk.algorithms.XGBoostSageMakerEstimator = sagemaker_c7d3a501ffcb
res41: xgboostEstimator.type = sagemaker_c7d3a501ffcb
res42: xgboostEstimator.type = sagemaker_c7d3a501ffcb


In [27]:
val model: SageMakerModel = xgboostEstimator.fit(trainingDS)

model: com.amazonaws.services.sagemaker.sparksdk.SageMakerModel = sagemaker_c7d3a501ffcb


***
***
# Evaluate Model Performance
### Using Spark RegressionMetrics

In [28]:
val predictions: DataFrame = model.transform(testingDS)
// predictions.show(50)
val predictionAndLabels: RDD[(Double, Double)] = predictions.select($"label",$"prediction").as[(Double, Double)].rdd

predictions: org.apache.spark.sql.DataFrame = [label: double, features: vector ... 1 more field]
predictionAndLabels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[282] at rdd at <console>:57


In [29]:
import org.apache.spark.mllib.evaluation.RegressionMetrics

// Instantiate metrics object
val metrics = new RegressionMetrics(predictionAndLabels)

// Squared error
println(s"MSE = ${metrics.meanSquaredError}")
println(s"RMSE = ${metrics.rootMeanSquaredError}")

// R-squared
println(s"R-squared = ${metrics.r2}")

// Mean absolute error
println(s"MAE = ${metrics.meanAbsoluteError}")

import org.apache.spark.mllib.evaluation.RegressionMetrics
metrics: org.apache.spark.mllib.evaluation.RegressionMetrics = org.apache.spark.mllib.evaluation.RegressionMetrics@37c2c797
MSE = 5.003855044569909E7
RMSE = 7073.793214796365
R-squared = 0.7833784302303666
MAE = 1767.2243561865569


# Clean up SageMaker Resources
#### Deletes SageMakerModel, Endpoint Configuration, and Endpoint created by the SageMakerModel

In [30]:
import com.amazonaws.services.sagemaker.sparksdk.SageMakerResourceCleanup

val resource_cleanup = new SageMakerResourceCleanup(model.sagemakerClient)
resource_cleanup.deleteResources(model.getCreatedResources)

import com.amazonaws.services.sagemaker.sparksdk.SageMakerResourceCleanup
resource_cleanup: com.amazonaws.services.sagemaker.sparksdk.SageMakerResourceCleanup = com.amazonaws.services.sagemaker.sparksdk.SageMakerResourceCleanup@2b3c0cb1
