In [3]:
import findspark
findspark.init()
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName('firsttree').master('local[4]').getOrCreate()

In [6]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (
    RandomForestClassifier,
    GBTClassifier,
    DecisionTreeClassifier)

### Load Data

In [7]:
data = spark.read.format('libsvm').load('../data/sample_libsvm_data.txt')

In [8]:
data.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [16]:
data.select('features').take(1)

[Row(features=SparseVector(692, {127: 51.0, 128: 159.0, 129: 253.0, 130: 159.0, 131: 50.0, 154: 48.0, 155: 238.0, 156: 252.0, 157: 252.0, 158: 252.0, 159: 237.0, 181: 54.0, 182: 227.0, 183: 253.0, 184: 252.0, 185: 239.0, 186: 233.0, 187: 252.0, 188: 57.0, 189: 6.0, 207: 10.0, 208: 60.0, 209: 224.0, 210: 252.0, 211: 253.0, 212: 252.0, 213: 202.0, 214: 84.0, 215: 252.0, 216: 253.0, 217: 122.0, 235: 163.0, 236: 252.0, 237: 252.0, 238: 252.0, 239: 253.0, 240: 252.0, 241: 252.0, 242: 96.0, 243: 189.0, 244: 253.0, 245: 167.0, 262: 51.0, 263: 238.0, 264: 253.0, 265: 253.0, 266: 190.0, 267: 114.0, 268: 253.0, 269: 228.0, 270: 47.0, 271: 79.0, 272: 255.0, 273: 168.0, 289: 48.0, 290: 238.0, 291: 252.0, 292: 252.0, 293: 179.0, 294: 12.0, 295: 75.0, 296: 121.0, 297: 21.0, 300: 253.0, 301: 243.0, 302: 50.0, 316: 38.0, 317: 165.0, 318: 253.0, 319: 233.0, 320: 208.0, 321: 84.0, 328: 253.0, 329: 252.0, 330: 165.0, 343: 7.0, 344: 178.0, 345: 252.0, 346: 240.0, 347: 71.0, 348: 19.0, 349: 28.0, 356: 253.

### Test Train Split

In [17]:
train_data, test_data = data.randomSplit([0.8, 0.2])

### Three Default Trees

In [26]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [27]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [29]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [31]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[121,122,123...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[155,156,180...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[234,235,237...|   [36.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[119,120,121...|   [0.0,47.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,47.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,47.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|   [0.0,47.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[125,126,127...|   [0.0,47.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

### Evaluate

In [32]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [33]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [36]:
print('DT Accuracy', acc_eval.evaluate(dtc_preds))
print('RFC Accuracy', acc_eval.evaluate(rfc_preds))
print('GBT Accuracy', acc_eval.evaluate(gbt_preds))

DT Accuracy 0.9375
RFC Accuracy 1.0
GBT Accuracy 0.9375


### Analyzing Feature Importance

In [37]:
rfc_model.featureImportances

SparseVector(692, {102: 0.001, 130: 0.0004, 183: 0.0001, 185: 0.0008, 186: 0.0001, 203: 0.0005, 207: 0.0006, 212: 0.0006, 213: 0.0008, 216: 0.0015, 232: 0.0018, 233: 0.0006, 234: 0.0017, 236: 0.0014, 244: 0.0162, 262: 0.0078, 263: 0.0158, 264: 0.0004, 271: 0.0054, 272: 0.0012, 273: 0.0063, 274: 0.0024, 287: 0.0013, 289: 0.0057, 290: 0.0081, 291: 0.0064, 294: 0.0004, 295: 0.0015, 296: 0.0006, 298: 0.001, 299: 0.0034, 300: 0.0143, 301: 0.003, 317: 0.0081, 318: 0.0004, 319: 0.0075, 320: 0.0004, 323: 0.0157, 328: 0.0173, 329: 0.0011, 330: 0.0072, 347: 0.0053, 351: 0.0281, 355: 0.0016, 356: 0.0001, 357: 0.0075, 358: 0.0068, 359: 0.0049, 370: 0.0005, 377: 0.0011, 378: 0.0242, 379: 0.0339, 380: 0.0006, 381: 0.0005, 383: 0.0005, 386: 0.0088, 387: 0.0012, 398: 0.0027, 399: 0.0076, 400: 0.0238, 401: 0.0084, 402: 0.0005, 403: 0.0015, 405: 0.0353, 406: 0.0223, 407: 0.0281, 408: 0.0006, 409: 0.0004, 410: 0.0026, 411: 0.0005, 413: 0.0015, 414: 0.0093, 425: 0.0023, 427: 0.0073, 428: 0.0074, 429: 0.01