In [1]:
import findspark
findspark.init('/home/fede/spark-2.1.0-bin-hadoop2.7')

In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [4]:
from pyspark.ml import Pipeline

In [5]:
from pyspark.ml.classification import (RandomForestClassifier,
                                       GBTClassifier,
                                       DecisionTreeClassifier)

In [6]:
file_path = '/home/fede/sample_libsvm_data.txt'
data = spark.read.format('libsvm').load(file_path)

In [7]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [8]:
train_data,test_data = data.randomSplit([0.7,0.3])

In [10]:
dtc = DecisionTreeClassifier(featuresCol='features',labelCol='label')
rfc = RandomForestClassifier(featuresCol='features',labelCol='label',numTrees=100)
gbt = GBTClassifier(featuresCol='features',labelCol='label')

In [11]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [12]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [13]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[121,122,123...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [14]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[121,122,123...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[123,124,125...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[124,125,126...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|  [79.0,21.0]|[0.79,0.21]|       0.0|
|  0.0|(692,[124,125,126...|   [92.0,8.0]|[0.92,0.08]|       0.0|
|  0.0|(692,[125,126,127...|   [92.0,8.0]|[0.92,0.08]|       0.0|
|  0.0|(692,[126,127,128...|   [92.0,8.0]|[0.92,0.08]|       0.0|
|  0.0|(692,[126,127,128...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[126,127,128...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[126,127,128...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(69

In [15]:
gbt_preds.show()

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  0.0|(692,[95,96,97,12...|       0.0|
|  0.0|(692,[121,122,123...|       0.0|
|  0.0|(692,[123,124,125...|       0.0|
|  0.0|(692,[124,125,126...|       0.0|
|  0.0|(692,[124,125,126...|       0.0|
|  0.0|(692,[124,125,126...|       0.0|
|  0.0|(692,[125,126,127...|       0.0|
|  0.0|(692,[126,127,128...|       0.0|
|  0.0|(692,[126,127,128...|       0.0|
|  0.0|(692,[126,127,128...|       0.0|
|  0.0|(692,[126,127,128...|       0.0|
|  0.0|(692,[126,127,128...|       0.0|
|  0.0|(692,[127,128,129...|       0.0|
|  0.0|(692,[129,130,131...|       0.0|
|  0.0|(692,[150,151,152...|       0.0|
|  0.0|(692,[152,153,154...|       0.0|
|  0.0|(692,[153,154,155...|       0.0|
|  0.0|(692,[154,155,156...|       0.0|
|  1.0|(692,[100,101,102...|       1.0|
|  1.0|(692,[119,120,121...|       1.0|
+-----+--------------------+----------+
only showing top 20 rows



In [16]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [17]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [18]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:


1.0

In [19]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:


1.0

In [20]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


1.0

In [21]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0015, 131: 0.0007, 148: 0.0007, 157: 0.0001, 161: 0.0006, 181: 0.0014, 203: 0.0005, 209: 0.0002, 215: 0.0008, 233: 0.0009, 234: 0.0017, 235: 0.0074, 243: 0.0011, 261: 0.0011, 262: 0.0052, 263: 0.0284, 268: 0.0003, 270: 0.0006, 272: 0.0004, 273: 0.0361, 287: 0.0005, 289: 0.0014, 290: 0.0368, 291: 0.0018, 302: 0.0042, 303: 0.0006, 322: 0.0005, 323: 0.003, 324: 0.0026, 328: 0.0018, 329: 0.0156, 330: 0.0091, 331: 0.0016, 344: 0.006, 348: 0.0006, 350: 0.0176, 351: 0.0042, 355: 0.0033, 356: 0.0093, 358: 0.0091, 359: 0.0001, 360: 0.0012, 369: 0.0011, 370: 0.0008, 371: 0.0074, 372: 0.0134, 374: 0.0006, 375: 0.0007, 378: 0.0233, 380: 0.0031, 382: 0.0012, 383: 0.0011, 386: 0.0088, 387: 0.0098, 388: 0.0015, 399: 0.001, 402: 0.006, 403: 0.0034, 405: 0.0283, 406: 0.0657, 407: 0.027, 408: 0.0011, 409: 0.0004, 412: 0.0123, 415: 0.0006, 427: 0.0007, 428: 0.0291, 429: 0.0012, 432: 0.0004, 433: 0.028, 434: 0.0315, 435: 0.0168, 438: 0.0025, 440: 0.0085, 455: 0.0078, 456: 0.0169,