In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('TreeMethod').getOrCreate()

In [2]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,DecisionTreeClassifier)

In [3]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [4]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [5]:
train, test = data.randomSplit([0.7,0.3])

In [6]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [7]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbt_model = gbt.fit(train)

In [8]:
dtc_pred = dtc_model.transform(test)
rfc_pred = rfc_model.transform(test)
gbt_pred = gbt_model.transform(test)

In [9]:
dtc_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[98,99,100,1...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[100,101,102...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[150,151,152...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [10]:
rfc_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[98,99,100,1...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[100,101,102...|  [68.0,32.0]|[0.68,0.32]|       0.0|
|  0.0|(692,[125,126,127...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[126,127,128...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[128,129,130...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[150,151,152...|  [85.0,15.0]|[0.85,0.15]|       0.0|
|  0.0|(692,[152,153,154...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[152,153,154...|  [88.0,12.0]|[0.88,0.12]|       0.0|
|  0.0|(692,[153,154,155...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [11]:
gbt_pred.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[95,96,97,12...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[98,99,100,1...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[100,101,102...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[128,129,130...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[150,151,152...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[152

In [12]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [13]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [14]:
print("DTC Accuracy:")
acc_eval.evaluate(dtc_pred)

DTC Accuracy:


0.9428571428571428

In [15]:
print("RFC Accuracy:")
acc_eval.evaluate(rfc_pred)

RFC Accuracy:


0.9714285714285714

In [16]:
print("GBT Accuracy:")
acc_eval.evaluate(gbt_pred)

GBT Accuracy:


0.9428571428571428

In [18]:
#feature importance
#this will show the importance of features higher the number higher the importance
rfc_model.featureImportances

SparseVector(692, {119: 0.0014, 146: 0.0005, 149: 0.0006, 157: 0.0003, 183: 0.0006, 185: 0.0003, 186: 0.001, 216: 0.0085, 234: 0.0048, 235: 0.0073, 242: 0.0007, 243: 0.0027, 244: 0.0113, 272: 0.0073, 273: 0.0027, 288: 0.0015, 289: 0.0007, 290: 0.0088, 299: 0.0057, 300: 0.0189, 301: 0.0013, 317: 0.0061, 322: 0.0014, 323: 0.0073, 328: 0.0164, 329: 0.0212, 342: 0.0005, 344: 0.0175, 349: 0.0021, 350: 0.011, 351: 0.0134, 370: 0.0012, 372: 0.0061, 378: 0.0319, 379: 0.0389, 383: 0.0006, 386: 0.0149, 387: 0.0069, 398: 0.0027, 400: 0.0082, 405: 0.0221, 406: 0.0525, 407: 0.0686, 408: 0.0012, 412: 0.0016, 414: 0.0047, 424: 0.0007, 425: 0.0003, 426: 0.0031, 427: 0.0007, 430: 0.0016, 432: 0.0003, 433: 0.0184, 434: 0.0433, 435: 0.0075, 438: 0.0006, 442: 0.0075, 455: 0.0269, 461: 0.0681, 462: 0.05, 466: 0.0008, 467: 0.001, 468: 0.0075, 470: 0.0027, 481: 0.0025, 483: 0.0482, 484: 0.0086, 485: 0.0049, 487: 0.0008, 489: 0.0252, 490: 0.0392, 496: 0.0134, 511: 0.0319, 512: 0.0314, 517: 0.0189, 521: 0.0018