In [2]:
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [5]:
from pyspark.ml import Pipeline

In [6]:
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,
                                       DecisionTreeClassifier)

In [7]:
data = spark.read.format('libsvm').load('Data/Spark_for_Machine_Learning/Tree_Methods/sample_libsvm_data.txt')

In [8]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [9]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [10]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [12]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [14]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [16]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [17]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [18]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:


0.9473684210526315

In [19]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:


1.0

In [20]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


0.9473684210526315

In [21]:
rfc_model.featureImportances

SparseVector(692, {131: 0.0005, 147: 0.0008, 152: 0.0001, 179: 0.0016, 181: 0.0035, 186: 0.0008, 187: 0.0011, 209: 0.0006, 212: 0.0006, 213: 0.0007, 216: 0.0022, 231: 0.0018, 236: 0.0058, 243: 0.0053, 244: 0.0197, 245: 0.0061, 262: 0.0001, 271: 0.0021, 273: 0.0058, 289: 0.0061, 290: 0.0018, 291: 0.0008, 317: 0.0135, 318: 0.0014, 327: 0.0054, 328: 0.0055, 330: 0.007, 343: 0.0012, 350: 0.032, 351: 0.0188, 352: 0.0068, 353: 0.0015, 355: 0.0025, 357: 0.0206, 358: 0.0081, 359: 0.0078, 374: 0.0016, 375: 0.0007, 377: 0.008, 378: 0.0348, 379: 0.0289, 384: 0.0016, 385: 0.024, 388: 0.0033, 398: 0.0046, 399: 0.0015, 401: 0.0022, 405: 0.0188, 406: 0.0346, 407: 0.0186, 410: 0.0011, 412: 0.0006, 414: 0.0031, 426: 0.0071, 427: 0.035, 428: 0.0014, 429: 0.0196, 430: 0.0005, 432: 0.0021, 433: 0.0119, 434: 0.0528, 435: 0.024, 438: 0.0007, 440: 0.0014, 442: 0.0089, 443: 0.0057, 453: 0.0008, 455: 0.0184, 456: 0.0167, 459: 0.0004, 461: 0.019, 462: 0.0411, 463: 0.0157, 468: 0.0074, 482: 0.0077, 483: 0.0289, 