In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("mytree").getOrCreate()

In [0]:
from pyspark.ml import Pipeline

In [0]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [0]:
# for regression
from pyspark.ml.regression import RandomForestRegressor, GBTRegressor, DecisionTreeRegressor

In [0]:
data = spark.read.format("libsvm").load("/FileStore/tables/sample_libsvm_data.txt")

In [0]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [0]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [0]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [0]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [0]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [0]:
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")

In [0]:
print("DTC ACCURACY:")
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:
Out[16]: 0.9090909090909091

In [0]:
print("RFC ACCURACY:")
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:
Out[17]: 1.0

In [0]:
print("GBT ACCURACY:")
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:
Out[18]: 0.9090909090909091

In [0]:
rfc_model.featureImportances

Out[19]: SparseVector(692, {126: 0.0002, 146: 0.0005, 147: 0.0005, 179: 0.0006, 184: 0.0014, 207: 0.0006, 218: 0.0009, 231: 0.0006, 232: 0.0016, 237: 0.0006, 240: 0.0006, 243: 0.0009, 244: 0.0009, 262: 0.0081, 263: 0.0083, 272: 0.007, 273: 0.0095, 274: 0.0044, 275: 0.0012, 286: 0.0006, 289: 0.023, 290: 0.0064, 292: 0.0005, 293: 0.0006, 299: 0.0054, 317: 0.0145, 318: 0.0025, 319: 0.0009, 322: 0.0048, 323: 0.0241, 324: 0.0031, 327: 0.0005, 328: 0.0075, 329: 0.006, 330: 0.0018, 331: 0.0025, 343: 0.0023, 345: 0.0038, 347: 0.0002, 350: 0.0094, 351: 0.0009, 352: 0.0012, 353: 0.0005, 354: 0.0047, 356: 0.006, 359: 0.0017, 360: 0.0005, 370: 0.0011, 371: 0.0033, 372: 0.0082, 373: 0.0071, 374: 0.0007, 375: 0.0007, 377: 0.0213, 378: 0.0334, 379: 0.0596, 380: 0.003, 384: 0.0076, 385: 0.0006, 397: 0.0018, 398: 0.0017, 400: 0.0076, 401: 0.0095, 402: 0.0051, 403: 0.0018, 404: 0.0063, 405: 0.0627, 407: 0.0235, 408: 0.002, 411: 0.0006, 412: 0.006, 415: 0.0029, 426: 0.0099, 429: 0.016, 432: 0.0012, 433: 