In [23]:
import findspark
findspark.init('/home/nick/spark-3.0.1-bin-hadoop2.7')

from pyspark.sql import SparkSession

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

spark = SparkSession.builder.appName('Tree Methods').getOrCreate()

In [14]:
data = spark.read.format('libsvm').load('Tree_Methods/sample_libsvm_data.txt')
data.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 5 rows



In [7]:
train_data, test_data = data.randomSplit([0.7,0.3])

In [13]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [18]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [19]:
dtc_prediction = dtc_model.transform(test_data)
rfc_prediction = rfc_model.transform(test_data)
gbt_prediction = gbt_model.transform(test_data)

In [22]:
dtc_prediction.show(5)
rfc_prediction.show(5)
gbt_prediction.show(5)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[121,122,123...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,124...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,148...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 5 rows

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[121,122,123...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[122,123,124...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[122,123,148...|  [83.0,17.0]|[0.83,0.1

In [33]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')
print(f"Decision Tree Accuracy: {acc_eval.evaluate(dtc_prediction)}")

acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')
print(f"RandomForestClassifier Accuracy: {acc_eval.evaluate(rfc_prediction)}")

acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')
print(f"GBTClassifier Accuracy: {acc_eval.evaluate(gbt_prediction)}")
# Example data is too clean since this is the documentation example

Decision Tree Accuracy: 1.0
RandomForestClassifier Accuracy: 1.0
GBTClassifier Accuracy: 1.0


In [37]:
rfc_model.featureImportances # Keep this method in mind in the future

SparseVector(692, {150: 0.0006, 154: 0.0017, 158: 0.0005, 159: 0.0012, 180: 0.0008, 181: 0.001, 184: 0.0006, 186: 0.0004, 189: 0.0008, 207: 0.002, 208: 0.0059, 213: 0.004, 233: 0.0005, 235: 0.003, 236: 0.0058, 242: 0.0018, 243: 0.0008, 260: 0.0051, 262: 0.0083, 263: 0.0176, 265: 0.001, 266: 0.0013, 267: 0.0015, 268: 0.0006, 273: 0.0053, 288: 0.0041, 289: 0.0017, 290: 0.0182, 295: 0.0075, 299: 0.0008, 300: 0.0072, 301: 0.0085, 317: 0.0007, 319: 0.0006, 320: 0.0006, 321: 0.0005, 322: 0.0012, 323: 0.0012, 326: 0.0007, 328: 0.0075, 330: 0.0088, 342: 0.0012, 344: 0.0008, 345: 0.0082, 346: 0.0005, 349: 0.001, 350: 0.0105, 351: 0.0206, 352: 0.0002, 355: 0.0027, 358: 0.001, 359: 0.0113, 369: 0.0004, 372: 0.0167, 373: 0.0156, 374: 0.0081, 375: 0.0006, 377: 0.0007, 378: 0.0281, 379: 0.0262, 381: 0.0005, 383: 0.0018, 384: 0.0082, 385: 0.0188, 400: 0.0172, 401: 0.007, 403: 0.0007, 405: 0.0125, 406: 0.0316, 407: 0.0269, 408: 0.0007, 409: 0.0006, 411: 0.0022, 412: 0.0006, 414: 0.0072, 427: 0.009, 42