In [1]:
import findspark
findspark.init("/home/gorazda/spark-2.4.7-bin-hadoop2.7")
import pyspark
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("tree_methods_basics").getOrCreate()

In [3]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [4]:
data = spark.read.format("libsvm").load("sample_libsvm_data.txt")

In [5]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [6]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [7]:
dtc = DecisionTreeClassifier(labelCol="label", featuresCol="features")

In [9]:
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [10]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [11]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [13]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,124...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,148...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.25198218210507...|[0.92441926976158...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126

In [14]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [15]:
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")

In [19]:
print("DTC accuracy")
print(acc_eval.evaluate(dtc_preds))
print("\nRFC accuracy")
print(acc_eval.evaluate(rfc_preds))
print("\nGBT accuracy")
print(acc_eval.evaluate(gbt_preds))

DTC accuracy
0.9666666666666667

RFC accuracy
1.0

GBT accuracy
0.9666666666666667


In [20]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0005, 152: 0.0012, 154: 0.0005, 157: 0.0011, 178: 0.0007, 180: 0.0002, 182: 0.0007, 185: 0.0006, 186: 0.0005, 210: 0.0003, 214: 0.0006, 215: 0.0016, 216: 0.0079, 235: 0.0006, 243: 0.0015, 258: 0.0013, 262: 0.0214, 263: 0.0072, 267: 0.0023, 272: 0.0028, 273: 0.0015, 274: 0.0028, 289: 0.0077, 290: 0.0265, 296: 0.0037, 299: 0.0006, 300: 0.0006, 301: 0.0082, 302: 0.0009, 303: 0.0013, 317: 0.0062, 319: 0.0002, 320: 0.0006, 322: 0.0036, 323: 0.0124, 328: 0.0041, 329: 0.0014, 345: 0.0066, 346: 0.0073, 350: 0.0089, 351: 0.0101, 352: 0.0057, 355: 0.0023, 356: 0.0089, 357: 0.0094, 370: 0.0015, 378: 0.0058, 379: 0.0389, 380: 0.0012, 381: 0.003, 384: 0.0086, 386: 0.006, 399: 0.0085, 402: 0.0021, 403: 0.0026, 405: 0.032, 406: 0.056, 407: 0.043, 408: 0.0014, 412: 0.0159, 413: 0.0259, 414: 0.0014, 416: 0.0014, 429: 0.0078, 430: 0.0017, 432: 0.0004, 433: 0.0188, 434: 0.0043, 435: 0.0233, 438: 0.0027, 440: 0.016, 455: 0.0168, 456: 0.0099, 457: 0.0011, 461: 0.0199, 462: 0.0307,