In [35]:
import os
import findspark
findspark.init(os.getenv('SPARK_HOME'))
from pyspark.sql import SparkSession
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassifier, GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator as mce

In [2]:
spark = SparkSession.builder.appName('decision_tree').getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/22 10:18:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')
data.show()

22/08/22 10:18:35 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [4]:
train, test = data.randomSplit([0.8, 0.2])

In [5]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [18]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbt_model = gbt.fit(train)

In [24]:
results_dtc = dtc_model.transform(test)
results_rfc = rfc_model.transform(test)
results_gbt = gbt_model.transform(test)

In [21]:
results_dtc.show()
results_rfc.show()
results_gbt.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|   [0.0,43.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[100,101,102...|   [0.0,43.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,43.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [29]:
acc_eval = mce(metricName='accuracy')

In [33]:
print("DTC ACC", acc_eval.evaluate(results_dtc))
print("RFC ACC", acc_eval.evaluate(results_rfc))
print("GBT ACC", acc_eval.evaluate(results_gbt))

DTC ACC 1.0
RFC ACC 1.0
GBT ACC 1.0


In [34]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0005, 101: 0.0011, 128: 0.0004, 156: 0.0003, 176: 0.0005, 182: 0.0006, 183: 0.0002, 185: 0.0004, 186: 0.0003, 207: 0.0065, 209: 0.0007, 211: 0.002, 214: 0.0007, 216: 0.001, 235: 0.0007, 236: 0.001, 243: 0.0104, 244: 0.0075, 245: 0.0005, 261: 0.0009, 262: 0.0067, 268: 0.0004, 270: 0.0006, 271: 0.0069, 272: 0.0104, 273: 0.003, 289: 0.0127, 291: 0.0005, 293: 0.0012, 295: 0.0009, 296: 0.0007, 298: 0.0008, 299: 0.0065, 300: 0.0149, 301: 0.0079, 317: 0.006, 318: 0.0009, 319: 0.0011, 323: 0.0015, 327: 0.0037, 342: 0.0037, 344: 0.0104, 347: 0.0003, 350: 0.0224, 351: 0.0186, 352: 0.0006, 354: 0.0006, 355: 0.0012, 356: 0.0012, 357: 0.0379, 369: 0.0006, 371: 0.0016, 372: 0.0063, 373: 0.0083, 378: 0.0214, 379: 0.0192, 380: 0.0072, 381: 0.0008, 382: 0.0046, 383: 0.0018, 385: 0.0186, 387: 0.0022, 400: 0.0308, 402: 0.0006, 405: 0.0252, 406: 0.0254, 407: 0.0297, 408: 0.001, 409: 0.0005, 413: 0.0146, 415: 0.0055, 425: 0.0041, 427: 0.0069, 429: 0.0011, 430: 0.0012, 431: 0.0001,