# Imports

In [17]:
import findspark

findspark.init('C:/spark')

In [32]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Create a session

In [33]:
spark = SparkSession.builder.master('local').appName('treemethods').getOrCreate()

# Read data

In [34]:
data = spark.read.format('libsvm').load('../../data/sample_libsvm_data.txt')

In [35]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



# Split data

In [36]:
train_data, test_data = data.randomSplit([0.7, 0.3])

# Modeling

In [37]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [38]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [39]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

# Model evaluation

In [45]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [46]:
print('DTC Accuracy:')
acc_eval.evaluate(dtc_preds)

DTC Accuracy:


1.0

In [47]:
print('RFC Accuracy:')
acc_eval.evaluate(rfc_preds)

RFC Accuracy:


1.0

In [48]:
print('GBT Accuracy:')
acc_eval.evaluate(gbt_preds)

GBT Accuracy:


1.0