In [17]:
## Import Libraries
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

## Set seed
seed = 42

In [18]:
## Create Spark Session
spark = SparkSession.builder.appName('dtRfExample').getOrCreate()

In [19]:
## Setup Schema
schema = StructType(fields=[StructField('school', StringType(), True),
                            StructField('private', StringType(), True),
                            StructField('apps', IntegerType(), True),
                            StructField('accept', IntegerType(), True),
                            StructField('enroll', IntegerType(), True),
                            StructField('top_10_perc', IntegerType(), True),
                            StructField('top_25_perc', IntegerType(), True),
                            StructField('f_undergrad', IntegerType(), True),
                            StructField('p_undergrad', IntegerType(), True),
                            StructField('outstate', IntegerType(), True),
                            StructField('room_board', IntegerType(), True),
                            StructField('books', IntegerType(), True),
                            StructField('personal', IntegerType(), True),
                            StructField('phd', IntegerType(), True),
                            StructField('terminal', IntegerType(), True),
                            StructField('s_f_ratio', DoubleType(), True),
                            StructField('perc_alumni', IntegerType(), True),
                            StructField('expend', IntegerType(), True),
                            StructField('grad_rate', IntegerType(), True)])

In [20]:
## Load Data
df = spark.read.csv('gs://spark-training-data/datasets/College.csv', header=True, inferSchema=False, schema=schema)
df.show(5)
df.printSchema() ## Confirm proper schema

+--------------------+-------+----+------+------+-----------+-----------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+
|              school|private|apps|accept|enroll|top_10_perc|top_25_perc|f_undergrad|p_undergrad|outstate|room_board|books|personal|phd|terminal|s_f_ratio|perc_alumni|expend|grad_rate|
+--------------------+-------+----+------+------+-----------+-----------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+
|Abilene Christian...|    Yes|1660|  1232|   721|         23|         52|       2885|        537|    7440|      3300|  450|    2200| 70|      78|     18.1|         12|  7041|       60|
|  Adelphi University|    Yes|2186|  1924|   512|         16|         29|       2683|       1227|   12280|      6450|  750|    1500| 29|      30|     12.2|         16| 10527|       56|
|      Adrian College|    Yes|1428|  1097|   336|         22|         50|  

In [21]:
## Convert private column to an index
indexer = StringIndexer(inputCol='private', outputCol='private_index')
df_indexed = indexer.fit(df).transform(df)
df_indexed.show(5)

+--------------------+-------+----+------+------+-----------+-----------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+-------------+
|              school|private|apps|accept|enroll|top_10_perc|top_25_perc|f_undergrad|p_undergrad|outstate|room_board|books|personal|phd|terminal|s_f_ratio|perc_alumni|expend|grad_rate|private_index|
+--------------------+-------+----+------+------+-----------+-----------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+-------------+
|Abilene Christian...|    Yes|1660|  1232|   721|         23|         52|       2885|        537|    7440|      3300|  450|    2200| 70|      78|     18.1|         12|  7041|       60|          0.0|
|  Adelphi University|    Yes|2186|  1924|   512|         16|         29|       2683|       1227|   12280|      6450|  750|    1500| 29|      30|     12.2|         16| 10527|       56|          0.0|
|    

In [22]:
## Assembler & Create modeling df
assembler = VectorAssembler(inputCols=['apps','accept','enroll','top_10_perc','top_25_perc',
                                       'f_undergrad','p_undergrad','outstate','room_board',
                                       'books','personal','phd','terminal','s_f_ratio',
                                       'perc_alumni','expend','grad_rate'],
                           outputCol='features')
output_features = assembler.transform(df_indexed)
output_features.head(1)

[Row(school='Abilene Christian University', private='Yes', apps=1660, accept=1232, enroll=721, top_10_perc=23, top_25_perc=52, f_undergrad=2885, p_undergrad=537, outstate=7440, room_board=3300, books=450, personal=2200, phd=70, terminal=78, s_f_ratio=18.1, perc_alumni=12, expend=7041, grad_rate=60, private_index=0.0, features=DenseVector([1660.0, 1232.0, 721.0, 23.0, 52.0, 2885.0, 537.0, 7440.0, 3300.0, 450.0, 2200.0, 70.0, 78.0, 18.1, 12.0, 7041.0, 60.0]))]

In [23]:
## Setup Final Data
final_data = output_features.select(['features','private_index'])
final_data.show(5)

+--------------------+-------------+
|            features|private_index|
+--------------------+-------------+
|[1660.0,1232.0,72...|          0.0|
|[2186.0,1924.0,51...|          0.0|
|[1428.0,1097.0,33...|          0.0|
|[417.0,349.0,137....|          0.0|
|[193.0,146.0,55.0...|          0.0|
+--------------------+-------------+
only showing top 5 rows



In [24]:
## Split into train, test
train_data, test_data = final_data.randomSplit([0.7,0.3], seed=seed)

In [25]:
## Setup Classification Models & fit training data
dtc = DecisionTreeClassifier(labelCol='private_index', featuresCol='features')
dtc_model = dtc.fit(train_data)

rfc = RandomForestClassifier(labelCol='private_index', featuresCol='features')
rfc_model = rfc.fit(train_data)

gbt = GBTClassifier(labelCol='private_index', featuresCol='features')
gbt_model = gbt.fit(train_data)

In [26]:
## Make Predictions for test data
dtc_preds = dtc_model.transform(test_data)
dtc_preds.show(5)

rfc_preds = rfc_model.transform(test_data)
rfc_preds.show(5)

gbt_preds = gbt_model.transform(test_data)
gbt_preds.show(5)

+--------------------+-------------+-------------+--------------------+----------+
|            features|private_index|rawPrediction|         probability|prediction|
+--------------------+-------------+-------------+--------------------+----------+
|[141.0,118.0,55.0...|          0.0|  [305.0,0.0]|           [1.0,0.0]|       0.0|
|[174.0,146.0,88.0...|          0.0|   [20.0,0.0]|           [1.0,0.0]|       0.0|
|[193.0,146.0,55.0...|          0.0|   [13.0,5.0]|[0.72222222222222...|       0.0|
|[202.0,184.0,122....|          0.0|  [305.0,0.0]|           [1.0,0.0]|       0.0|
|[222.0,185.0,91.0...|          0.0|  [305.0,0.0]|           [1.0,0.0]|       0.0|
+--------------------+-------------+-------------+--------------------+----------+
only showing top 5 rows

+--------------------+-------------+--------------------+--------------------+----------+
|            features|private_index|       rawPrediction|         probability|prediction|
+--------------------+-------------+------------

In [27]:
## Evaluate Models using Binary
my_binary_eval = BinaryClassificationEvaluator(labelCol='private_index', rawPredictionCol='rawPrediction')

print(f'DTC Eval: {my_binary_eval.evaluate(dtc_preds)}')
print(f'RFC Eval: {my_binary_eval.evaluate(rfc_preds)}')
print(f'GBT Eval: {my_binary_eval.evaluate(gbt_preds)}')

DTC Eval: 0.9417249417249418
RFC Eval: 0.9730707888602627
GBT Eval: 0.9623359097043309


In [29]:
## Evaluate Models using Multi
accuracy_eval = MulticlassClassificationEvaluator(labelCol='private_index', predictionCol='prediction',
                                                  metricName='accuracy')

print(f'DTC Accuracy: {accuracy_eval.evaluate(dtc_preds)}')
print(f'RFC Accuracy: {accuracy_eval.evaluate(rfc_preds)}')
print(f'GBT Accuracy: {accuracy_eval.evaluate(gbt_preds)}')

DTC Accuracy: 0.925
RFC Accuracy: 0.95
GBT Accuracy: 0.94
