In [10]:
## Import Libraries
from pyspark.sql import SparkSession
from pyspark.ml import pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

## Set seed
seed = 42

In [11]:
## Create Spark Session
spark = SparkSession.builder.appName('dtRfExample').getOrCreate()

In [12]:
## Load Data
df = spark.read.format('libsvm').load('gs://spark-training-data/datasets/sample_libsvm_data.txt')
df.show(5)

21/11/30 21:53:56 WARN org.apache.spark.ml.source.libsvm.LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.


+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 5 rows



In [13]:
## Split into train, test
train_data, test_data = df.randomSplit([0.7,0.3], seed=seed)

In [14]:
## Create model(s) instance and fit
dtc = DecisionTreeClassifier(labelCol='label', featuresCol='features') # Check out max_depth, max_trees, and info_gain
dtc_model = dtc.fit(train_data)

rfc = RandomForestClassifier(labelCol='label', featuresCol='features')
rfc_model = rfc.fit(train_data)

gbt = GBTClassifier(labelCol='label', featuresCol='features')
gbt_model = gbt.fit(train_data)

In [15]:
## Make Predictions
dtc_preds = dtc_model.transform(test_data)
dtc_preds.show(5)

rfc_preds = rfc_model.transform(test_data)
rfc_preds.show(5)

gbt_preds = gbt_model.transform(test_data)
gbt_preds.show(5)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [24.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [24.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [24.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [24.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [24.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 5 rows

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [15.0,5.0]|[0.75,0.25]|       0.0|
|  0.0|(692,[123,124,125...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [20.0,0.0]|  [1.0,0.

In [19]:
## Evaluate the model using test data
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

print(f'DTC Accuracy: {acc_eval.evaluate(dtc_preds)}')
print(f'RFC Accuracy: {acc_eval.evaluate(rfc_preds)}')
print(f'GBT Accuracy: {acc_eval.evaluate(gbt_preds)}')

DTC Accuracy: 0.9428571428571428
RFC Accuracy: 1.0
GBT Accuracy: 0.9428571428571428


In [22]:
## Look into RandomForestModel
rfc_model.featureImportances # Higher the number, the more important - See documentation

SparseVector(692, {101: 0.0079, 233: 0.0013, 238: 0.0022, 242: 0.0076, 245: 0.0338, 260: 0.0028, 264: 0.0162, 271: 0.0019, 272: 0.0467, 290: 0.036, 301: 0.0465, 350: 0.05, 356: 0.0059, 378: 0.0963, 379: 0.051, 405: 0.0438, 407: 0.0424, 414: 0.0127, 425: 0.0035, 434: 0.1438, 454: 0.0033, 455: 0.0441, 462: 0.05, 466: 0.01, 489: 0.0928, 497: 0.0067, 517: 0.0544, 541: 0.0037, 549: 0.0063, 551: 0.0344, 569: 0.001, 606: 0.0396, 629: 0.0012})