# Decision Trees Documentation Example

In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [0]:
from pyspark.ml import Pipeline

In [0]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [0]:
data = spark.read.format('libsvm').load('/FileStore/tables/sample_libsvm_data-1.txt')

In [0]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [0]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [0]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [0]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [0]:
dtc_preds = dtc_model.transform(test_data)
rcf_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

# Evaluate Models

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [0]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [0]:
acc_eval.evaluate(dtc_preds)

Out[18]: 0.9714285714285714

In [0]:
acc_eval.evaluate(rcf_preds)

Out[19]: 0.9714285714285714

In [0]:
acc_eval.evaluate(gbt_preds)

Out[20]: 0.9714285714285714

# Feature Importance

In [0]:
rfc_model.featureImportances

Out[21]: SparseVector(692, {122: 0.0006, 131: 0.0017, 150: 0.0003, 156: 0.0057, 176: 0.0004, 185: 0.0003, 189: 0.0004, 190: 0.0006, 207: 0.001, 235: 0.0012, 238: 0.0, 243: 0.0021, 267: 0.0004, 271: 0.0099, 272: 0.0077, 273: 0.007, 274: 0.0006, 289: 0.0063, 290: 0.008, 293: 0.0015, 298: 0.0011, 299: 0.0195, 300: 0.0128, 301: 0.0028, 317: 0.0246, 318: 0.0025, 323: 0.0078, 324: 0.0027, 326: 0.0003, 327: 0.0014, 328: 0.0079, 329: 0.0085, 342: 0.0013, 343: 0.0004, 344: 0.0034, 347: 0.0005, 350: 0.0007, 351: 0.02, 352: 0.0008, 356: 0.0068, 358: 0.0006, 360: 0.0005, 370: 0.0028, 372: 0.0151, 373: 0.0114, 374: 0.0015, 375: 0.004, 377: 0.0128, 378: 0.0075, 379: 0.0461, 380: 0.0016, 382: 0.0006, 383: 0.0017, 384: 0.0066, 386: 0.0066, 399: 0.0055, 400: 0.0096, 401: 0.0175, 402: 0.0046, 403: 0.0006, 405: 0.0387, 407: 0.01, 408: 0.0005, 409: 0.0017, 411: 0.0006, 414: 0.0007, 425: 0.0006, 427: 0.0073, 428: 0.0088, 429: 0.0173, 432: 0.0072, 433: 0.0143, 434: 0.0636, 435: 0.0244, 438: 0.001, 440: 0.01