In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline

In [10]:
from pyspark.ml.classification import (RandomForestClassifier,
                                       GBTClassifier,
                                       DecisionTreeClassifier)

In [5]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [7]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [8]:
train_data,test_data = data.randomSplit([0.7,0.3])

In [12]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [13]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [14]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [15]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [16]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [17]:
acc_eval.evaluate(dtc_preds)

0.9666666666666667

In [18]:
acc_eval.evaluate(rfc_preds)

1.0

In [19]:
acc_eval.evaluate(gbt_preds)

0.9666666666666667

In [20]:
rfc_model.featureImportances

SparseVector(692, {99: 0.0011, 100: 0.0011, 101: 0.0008, 102: 0.0004, 149: 0.0006, 156: 0.0004, 159: 0.0004, 177: 0.0006, 181: 0.0017, 182: 0.0046, 186: 0.0007, 187: 0.0008, 209: 0.0003, 217: 0.0018, 218: 0.0005, 234: 0.0008, 235: 0.0065, 239: 0.0005, 240: 0.0006, 243: 0.0013, 244: 0.0106, 245: 0.0085, 260: 0.0011, 262: 0.0076, 263: 0.0086, 271: 0.0006, 272: 0.0155, 273: 0.0141, 290: 0.0206, 291: 0.0027, 296: 0.0013, 298: 0.0003, 299: 0.0043, 300: 0.0279, 301: 0.0007, 315: 0.001, 317: 0.0011, 319: 0.0002, 322: 0.0029, 323: 0.0058, 324: 0.0005, 326: 0.0005, 329: 0.0057, 330: 0.0019, 341: 0.0011, 344: 0.0098, 351: 0.0102, 355: 0.0031, 356: 0.0149, 357: 0.0026, 358: 0.0078, 359: 0.0018, 360: 0.0005, 370: 0.0013, 373: 0.0076, 378: 0.0175, 379: 0.0222, 382: 0.0005, 384: 0.0077, 385: 0.0094, 386: 0.0017, 387: 0.0045, 400: 0.0199, 403: 0.0004, 405: 0.0015, 406: 0.0107, 407: 0.02, 408: 0.0014, 411: 0.0041, 412: 0.0075, 413: 0.0009, 426: 0.0223, 427: 0.0074, 428: 0.0001, 429: 0.0209, 430: 0.002