In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
ss = SparkSession.builder.getOrCreate()
sc = ss.sparkContext

# Create Data Frame

In [63]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4)\
            .map(lambda x:  x.split(", "))\
            .map(lambda row: [float(x) for x in row])
pen_raw.take(1)

[[47.0,
  100.0,
  27.0,
  81.0,
  57.0,
  37.0,
  26.0,
  0.0,
  0.0,
  23.0,
  56.0,
  53.0,
  100.0,
  90.0,
  40.0,
  98.0,
  8.0]]

In [12]:
from pyspark.sql.types import *

penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = ss.createDataFrame(pen_raw, penschema)

In [14]:
dfpen.show(5)

+----+-----+----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+
|pix1| pix2|pix3| pix4| pix5| pix6| pix7| pix8|pix9|pix10|pix11|pix12|pix13|pix14|pix15|pix16|label|
+----+-----+----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+
|47.0|100.0|27.0| 81.0| 57.0| 37.0| 26.0|  0.0| 0.0| 23.0| 56.0| 53.0|100.0| 90.0| 40.0| 98.0|  8.0|
| 0.0| 89.0|27.0|100.0| 42.0| 75.0| 29.0| 45.0|15.0| 15.0| 37.0|  0.0| 69.0|  2.0|100.0|  6.0|  2.0|
| 0.0| 57.0|31.0| 68.0| 72.0| 90.0|100.0|100.0|76.0| 75.0| 50.0| 51.0| 28.0| 25.0| 16.0|  0.0|  1.0|
| 0.0|100.0| 7.0| 92.0|  5.0| 68.0| 19.0| 45.0|86.0| 34.0|100.0| 45.0| 74.0| 23.0| 67.0|  0.0|  4.0|
| 0.0| 67.0|49.0| 83.0|100.0|100.0| 81.0| 80.0|60.0| 60.0| 40.0| 40.0| 33.0| 20.0| 47.0|  0.0|  1.0|
+----+-----+----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+
only showing top 5 rows



# Vectorize and Create a data frame includes "feature" and "label"

In [32]:
from pyspark.ml.feature import VectorAssembler

va = VectorAssembler(inputCols=dfpen.columns[:-1], outputCol='features')

penlpoints = va.transform(dfpen).select('features', 'label')


In [33]:
penlpoints.show(5)

+--------------------+-----+
|            features|label|
+--------------------+-----+
|[47.0,100.0,27.0,...|  8.0|
|[0.0,89.0,27.0,10...|  2.0|
|[0.0,57.0,31.0,68...|  1.0|
|[0.0,100.0,7.0,92...|  4.0|
|[0.0,67.0,49.0,83...|  1.0|
+--------------------+-----+
only showing top 5 rows



# Create Training and Validation data

In [34]:
pendtset = penlpoints.randomSplit([0.8, 0.2], 1)
pentrain = pendtset[0].cache()
pendvalid = pendtset[1].cache()

# Train the Decision tree model 

In [35]:
from pyspark.ml.classification import DecisionTreeClassifier

dt = DecisionTreeClassifier(maxDepth=20, maxBins=32, minInfoGain=0, minInstancesPerNode=1)
dtmodel = dt.fit(pentrain)

In [38]:
# import dtreeviz
# dtmodel.dtreeviz

In [39]:
dtpredicts = dtmodel.transform(pendvalid)

In [40]:
dtpredicts.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[0.0,0.0,41.0,16....|  9.0|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       9.0|
|[0.0,20.0,47.0,42...|  1.0|[0.0,378.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[0.0,22.0,36.0,47...|  1.0|[0.0,378.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[0.0,23.0,63.0,46...|  8.0|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       8.0|
|[0.0,26.0,57.0,56...|  8.0|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       8.0|
|[0.0,39.0,2.0,62....|  0.0|[1.0,0.0,0.0,0.0,...|[1.0,0.0,0.0,0.0,...|       0.0|
|[0.0,39.0,42.0,52...|  1.0|[0.0,378.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[0.0,40.0,29.0,56...|  1.0|[0.0,378.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[0.0,43.0,26.0,65...|  1.0|[0.0,378.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[0.0,43.0,35.0,

# Evalaute the Model 

In [42]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

metric = 'f1'
metrics = MulticlassClassificationEvaluator().setLabelCol('label').setPredictionCol('prediction').setMetricName('f1')

In [45]:
metrics.evaluate(dtpredicts)

0.9585890253909967

# N-fold validation

In [50]:
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import ParamGridBuilder 
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator().setLabelCol('label').setPredictionCol('prediction')
dt = DecisionTreeClassifier()
paramGrid = ParamGridBuilder().addGrid(dt.maxDepth, [5,10,15,20,25,30]).build()

cv = CrossValidator(estimator=dt,
                   evaluator=evaluator,
                   numFolds=5,
                   estimatorParamMaps=paramGrid)

cvmodel = cv.fit(pentrain)


In [51]:
# then predict 

In [53]:
dtpredict = cvmodel.bestModel.transform(pendvalid)

print("Best Max Depth : %s" % cvmodel.bestModel.getMaxDepth)
print("Accuracy : %s" % evaluator.evaluate(dtpredicts))

Best Max Depth : <bound method _DecisionTreeParams.getMaxDepth of DecisionTreeClassificationModel: uid=DecisionTreeClassifier_d37adc370e0c, depth=15, numNodes=569, numClasses=10, numFeatures=16>
Accuracy : 0.9585890253909967


In [59]:
from pyspark.mllib.evaluation import MulticlassMetrics

prediciton_label = dtpredicts.select('prediction', 'label').rdd

metrics = MulticlassMetrics(prediciton_label)

confusionMetrics = metrics.confusionMatrix()

In [60]:
print(confusionMetrics)

DenseMatrix([[223.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   4.,   0.],
             [  0., 197.,  12.,   3.,   0.,   0.,   2.,   0.,   0.,   0.],
             [  0.,   4., 200.,   0.,   1.,   0.,   1.,   1.,   0.,   0.],
             [  1.,   0.,   1., 185.,   0.,   0.,   0.,   1.,   0.,   3.],
             [  0.,   1.,   0.,   0., 216.,   0.,   1.,   0.,   0.,   2.],
             [  0.,   0.,   0.,   3.,   1., 176.,   2.,   2.,   2.,   9.],
             [  0.,   1.,   0.,   0.,   3.,   1., 192.,   1.,   0.,   0.],
             [  0.,   1.,   3.,   1.,   0.,   2.,   0., 189.,   0.,   0.],
             [  4.,   0.,   0.,   0.,   0.,   0.,   0.,   1., 169.,   0.],
             [  0.,   0.,   0.,   3.,   0.,   3.,   0.,   1.,   0., 178.]])
