**From the documentation**

In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('tree').getOrCreate()

In [4]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [6]:
df = spark.read.format('libsvm').load('./datasets/sample_libsvm_data.txt')
df.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [7]:
train, test = df.randomSplit([0.7, 0.3])

In [10]:
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='label')
rfc = RandomForestClassifier(featuresCol='features', labelCol='label', numTrees=100)
gbt = GBTClassifier(featuresCol='features', labelCol='label')

In [11]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbt_model = gbt.fit(train)

In [12]:
dtc_preds = dtc_model.transform(test)
rfc_preds = rfc_model.transform(test)
gbt_preds = gbt_model.transform(test)

In [13]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[155,156,180...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|   [0.0,34.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [14]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|  [68.0,32.0]|[0.68,0.32]|       0.0|
|  0.0|(692,[124,125,126...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[126,127,128...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[126,127,128...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[126,127,128...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[127,128,129...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[129,130,131...|  [87.0,13.0]|[0.87,0.13]|       0.0|
|  0.0|(692,[151,152,153...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[155,156,180...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|  [17.0,83.0]|[0.17,0.83]|       1.0|
|  1.0|(69

In [15]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.43204087879901...|[0.94604203870495...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[129,130,131...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[151

In [16]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [17]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [22]:
acc_eval.evaluate(dtc_preds), acc_eval.evaluate(rfc_preds), acc_eval.evaluate(gbt_preds)

(0.9696969696969697, 1.0, 0.9696969696969697)

In [23]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0021, 155: 0.0006, 179: 0.0028, 181: 0.0075, 182: 0.0015, 183: 0.0013, 208: 0.0005, 211: 0.0008, 215: 0.0029, 232: 0.0021, 235: 0.0074, 236: 0.0022, 237: 0.0013, 239: 0.0006, 240: 0.0005, 243: 0.0019, 244: 0.0102, 245: 0.0028, 261: 0.0074, 262: 0.0099, 263: 0.0073, 267: 0.0003, 271: 0.0042, 272: 0.0098, 273: 0.0072, 287: 0.0017, 288: 0.0021, 290: 0.002, 291: 0.0054, 294: 0.0013, 299: 0.0042, 300: 0.0083, 301: 0.008, 323: 0.0049, 328: 0.0072, 329: 0.0059, 330: 0.0076, 342: 0.0004, 344: 0.0148, 346: 0.0006, 349: 0.0005, 350: 0.011, 351: 0.0279, 352: 0.0015, 356: 0.0075, 357: 0.0073, 359: 0.0016, 369: 0.0023, 371: 0.0004, 372: 0.0069, 373: 0.0116, 374: 0.0036, 376: 0.0005, 378: 0.0713, 379: 0.0205, 384: 0.0004, 385: 0.0181, 397: 0.0065, 398: 0.0069, 401: 0.0131, 404: 0.0008, 405: 0.02, 406: 0.0173, 407: 0.0368, 410: 0.0003, 414: 0.0025, 425: 0.01, 427: 0.0008, 429: 0.0103, 433: 0.0595, 434: 0.0287, 435: 0.0099, 436: 0.0017, 439: 0.0006, 440: 0.0032, 442: 0.0005, 