In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier, DecisionTreeClassifier)

In [3]:
data = spark.read.format('libsvm').load('FileStore/tables/sample_libsvm_data.txt')

In [4]:
data.show()

In [5]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [6]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees = 100)
gbt = GBTClassifier()

In [7]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [8]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [9]:
dtc_preds.show()

In [10]:
rfc_preds.show()

In [11]:
gbt_preds.show()

In [12]:
# MulticlassClassificationEvaluator has metrics such as accuracy and recall that BinaryClassificationEvaluator does not have, but this MulticlassClassificationEvaluator works for binary classification tasks, therefore it doesn't hurt to use this MulticlassClassificationEvaluator for all evaluators
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [13]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [14]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

In [15]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

In [16]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

In [17]:
# feature importance
rfc_model.featureImportances

In [18]:
# more realistic example with tree-based methods using Spark
spark = SparkSession.builder.appName('tree').getOrCreate()

In [19]:
data = spark.read.csv('FileStore/tables/College.csv', inferSchema = True, header = True)
data.printSchema()

In [20]:
data.head(1)

In [21]:
from pyspark.ml.feature import VectorAssembler
data.columns

In [22]:
assembler = VectorAssembler(inputCols = ['Apps',
 'Accept',
 'Enroll',
 'Top10perc',
 'Top25perc',
 'F_Undergrad',
 'P_Undergrad',
 'Outstate',
 'Room_Board',
 'Books',
 'Personal',
 'PhD',
 'Terminal',
 'S_F_Ratio',
 'perc_alumni',
 'Expend',
 'Grad_Rate'], outputCol = 'features')

In [23]:
output = assembler.transform(data)

In [24]:
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol = 'Private', outputCol = 'PrivateIndex')

In [25]:
output_fixed = indexer.fit(output).transform(output)

In [26]:
output_fixed.printSchema()

In [27]:
final_data = output_fixed.select('features', 'PrivateIndex')

In [28]:
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [29]:
from pyspark.ml.classification import DecisionTreeClassifier, GBTClassifier, RandomForestClassifier

In [30]:
from pyspark.ml.regression import RandomForestRegressor

In [31]:
from pyspark.ml import Pipeline

In [32]:
dtc = DecisionTreeClassifier(labelCol = 'PrivateIndex', featuresCol = 'features')
rfc = RandomForestClassifier(numTrees= 150, labelCol = 'PrivateIndex', featuresCol = 'features')
gbt = GBTClassifier(labelCol = 'PrivateIndex', featuresCol = 'features')

In [33]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [34]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [35]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [36]:
my_binary_eval = BinaryClassificationEvaluator(labelCol = 'PrivateIndex')

In [37]:
print('DTC')
print(my_binary_eval.evaluate(dtc_preds))

In [38]:
print('RFC')
print(my_binary_eval.evaluate(rfc_preds))

In [39]:
my_binary_eval2 = BinaryClassificationEvaluator(labelCol = 'PrivateIndex', rawPredictionCol='prediction')

In [40]:
print("GBT")
print(my_binary_eval2.evaluate(gbt_preds))
# the result of GBT is worse than both RFC and DTC ---> need to change the default parameters of GBT, e.g. increase the number of trees

In [41]:
acc_eval = MulticlassClassificationEvaluator(labelCol = 'PrivateIndex', metricName = 'accuracy')

In [42]:
rfc_acc = acc_eval.evaluate(rfc_preds)

In [43]:
rfc_acc