In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Trees_intro').getOrCreate()

In [3]:
from pyspark.ml import Pipeline

In [13]:
from pyspark.ml.classification import DecisionTreeClassifier, GBTClassifier, RandomForestClassifier

In [6]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [7]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [8]:
train_data,test_data= data.randomSplit([0.7,0.3])

In [9]:
dtc = DecisionTreeClassifier()

In [14]:
rfc= RandomForestClassifier(numTrees=100)

In [15]:
gbt = GBTClassifier()

In [17]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [18]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_predts = gbt_model.transform(test_data)

In [20]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[123,124,125...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[126,127,128...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[129,130,131...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[150,151,152...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[153,154,155...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  1.0|(692,[97,98,99,12...|  [22.0,78.0]|[0.22,0.78]|       1.0|
|  1.0|(692,[123,124,125...|   [2.0,98.0]|[0.02,0.98]|       1.0|
|  1.0|(692,[123,124,125...|   [2.0,98.0]|[0.02,0.98]|       1.0|
|  1.0|(69

In [22]:
gbt_predts.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[-0.7624430468667...|[0.17874314380715...|       1.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.17677494257156...|[0.91321597802429...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[129,130,131...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[150,151,152...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[153,154,155...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  1.0|(692,[97,

In [23]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [24]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [27]:
print('DTC_Predictions')
acc_eval.evaluate(dtc_preds)

DTC_Predictions


0.9629629629629629

In [28]:
acc_eval.evaluate(rfc_preds)

1.0

In [29]:
acc_eval.evaluate(gbt_predts)

0.9629629629629629

In [30]:
rfc_model.featureImportances

SparseVector(692, {99: 0.0005, 101: 0.0006, 147: 0.0018, 153: 0.0004, 158: 0.0003, 160: 0.0007, 175: 0.0007, 179: 0.0015, 181: 0.0003, 182: 0.0009, 183: 0.0013, 184: 0.0005, 185: 0.0003, 204: 0.0007, 208: 0.0003, 213: 0.0004, 214: 0.0009, 215: 0.0108, 216: 0.0017, 217: 0.0031, 234: 0.0127, 235: 0.0055, 238: 0.0015, 239: 0.0003, 243: 0.0059, 244: 0.0214, 262: 0.007, 263: 0.0133, 267: 0.0008, 270: 0.0006, 271: 0.0003, 274: 0.0038, 288: 0.0006, 290: 0.0092, 291: 0.0012, 297: 0.0003, 300: 0.017, 301: 0.0028, 302: 0.0027, 303: 0.0039, 314: 0.0021, 316: 0.0006, 317: 0.0132, 318: 0.0093, 319: 0.0032, 320: 0.0006, 322: 0.0008, 323: 0.0011, 324: 0.0014, 326: 0.0019, 327: 0.0019, 329: 0.0048, 346: 0.0008, 347: 0.0015, 350: 0.0134, 351: 0.007, 353: 0.0007, 358: 0.0105, 372: 0.0058, 373: 0.0062, 374: 0.001, 378: 0.0106, 379: 0.001, 383: 0.0005, 384: 0.0035, 385: 0.0154, 399: 0.0068, 400: 0.0086, 401: 0.0059, 402: 0.0014, 403: 0.0005, 404: 0.0005, 405: 0.0208, 406: 0.0578, 407: 0.0401, 408: 0.0011,