In [None]:
import pyspark
sc = pyspark.SparkContext('local[*]')

In [None]:
!rm -rf metastore_db/*.lck

from pyspark.sql import SQLContext
sqlc = SQLContext(sc)

In [None]:
from pyspark.mllib.util import MLUtils

!wget https://raw.githubusercontent.com/apache/spark/master/data/mllib/sample_libsvm_data.txt

data = MLUtils.loadLibSVMFile(sc, "sample_libsvm_data.txt").toDF()
data = MLUtils.convertVectorColumnsToML(data)

In [None]:
data.show(5)

## Decision Trees

### Classification

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, IndexToString, VectorIndexer
from pyspark.ml.classification import DecisionTreeClassifier

labelIndexer = StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)

labelConverter = IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

featureIndexer = VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data)

dtC = DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")

pipelineClass = Pipeline().setStages([labelIndexer, featureIndexer, dtC, labelConverter])

trainingData, testData = data.randomSplit([0.7, 0.3])

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier, DecisionTreeClassificationModel

modelClassifier = pipelineClass.fit(trainingData)

treeModel = modelClassifier.stages[2]

predictionsClass = modelClassifier.transform(testData)

In [None]:
modelClassifier.stages

In [None]:
print(treeModel.toDebugString)

In [None]:
predictionsClass.toPandas()[:5]

### Regression

In [None]:
from pyspark.ml.regression import DecisionTreeRegressor, DecisionTreeRegressionModel

dtR = DecisionTreeRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures")

pipelineReg = Pipeline().setStages([featureIndexer, dtR])

In [None]:
modelRegressor = pipelineReg.fit(trainingData)

treeModel = modelRegressor.stages[1]

print(treeModel.toDebugString)

In [None]:
predictionsReg = modelRegressor.transform(testData)

In [None]:
predictionsReg.toPandas()[:5]

In [None]:
sc.stop()