In [None]:
from pyspark.sql import SparkSession
import os

data_dir = '../../data/Spark_ML/Tree_Methods/'

In [None]:
spark = SparkSession.builder.appName('tree_code_along').getOrCreate()

data = spark.read.csv(os.path.join(data_dir, 'College.csv'), inferSchema=True, header=True)

data.printSchema()

## Format the data for MLlib

In [None]:
from pyspark.ml.feature import VectorAssembler

col_names = ['Apps', 'Accept', 'Enroll', 'Top10perc', 'Top25perc', 'F_Undergrad', 'P_Undergrad', 'Outstate',
             'Room_Board', 'Books', 'Personal', 'PhD', 'Terminal', 'S_F_Ratio', 'perc_alumni', 'Expend', 
             'Grad_Rate']

assembler = VectorAssembler(inputCols=col_names, outputCol='features')

output = assembler.transform(data)

In [None]:
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol='Private', outputCol='PrivateIndex')

output_fixed = indexer.fit(output).transform(output)

output_fixed.printSchema()

In [None]:
final_data = output_fixed.select(['features', 'PrivateIndex'])

## Train and Evaluate

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier, GBTClassifier, RandomForestClassifier
from pyspark.ml import Pipeline

tr_data, te_data = final_data.randomSplit([0.7, 0.3])

dtc = DecisionTreeClassifier(featuresCol='features', labelCol='PrivateIndex')
rfc = RandomForestClassifier(featuresCol='features', labelCol='PrivateIndex')
gbt = GBTClassifier(featuresCol='features', labelCol='PrivateIndex')

dtc_model = dtc.fit(tr_data)
rfc_model = rfc.fit(tr_data)
gbt_model = gbt.fit(tr_data)


In [None]:
dtc_preds = dtc_model.transform(te_data)
rfc_preds = rfc_model.transform(te_data)
gbt_preds = gbt_model.transform(te_data)

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

bin_evaluator = BinaryClassificationEvaluator(labelCol='PrivateIndex')
bin_evaluator_gbt = BinaryClassificationEvaluator(labelCol='PrivateIndex', rawPredictionCol='prediction')

print('DTC: {}'.format(bin_evaluator.evaluate(dtc_preds)))
print('RFC: {}'.format(bin_evaluator.evaluate(rfc_preds)))
print('GBT: {}'.format(bin_evaluator_gbt.evaluate(gbt_preds)))

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

multi_evaluator = MulticlassClassificationEvaluator(labelCol='PrivateIndex', metricName='accuracy')
multi_evaluator_gbt = MulticlassClassificationEvaluator(labelCol='PrivateIndex', metricName='accuracy')

print('DTC: {}'.format(multi_evaluator.evaluate(dtc_preds)))
print('RFC: {}'.format(multi_evaluator.evaluate(rfc_preds)))
print('GBT: {}'.format(multi_evaluator_gbt.evaluate(gbt_preds)))