### Spark MLLib - Decision Tree

**Description**

- Easy to understand and explain.
- Predictor variables are used to build a tree that progressively predicts target values.
- Training data is used to build the decision tree and predict the target value.
- The decision tree becomes a model that is used to make predictions with new data.

**Pros:** Easy to understand and explain, works with missing values and is fast.

**Cons:** Limited accuracy, Bias can occur frequently and does not work well with many predictor variables.

**Application:** Credit approval, preliminary categorization.

### Classifying Iris Dataset Flower Species

In [1]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.ml.linalg import Vectors
from pyspark.sql import Row

In [2]:
spSession = SparkSession.builder.master('local').appName('IrisPrediction').getOrCreate()

In [3]:
rddIris01 = sc.textFile('aux/datasets/iris.csv')

**We can cache the RDD to optimize performance.**

In [4]:
rddIris01.cache()

aux/datasets/iris.csv MapPartitionsRDD[1] at textFile at NativeMethodAccessorImpl.java:0

In [5]:
rddIris01.count()

151

In [6]:
rddIris01.take(5)

['Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species',
 '5.1,3.5,1.4,0.2,setosa',
 '4.9,3,1.4,0.2,setosa',
 '4.7,3.2,1.3,0.2,setosa',
 '4.6,3.1,1.5,0.2,setosa']

In [7]:
header = rddIris01.first()

In [8]:
rddIris02 = rddIris01.filter(lambda row: row != header)

In [9]:
rddIris02.count()

150

### Data Cleaning

In [10]:
def dataCleaning(strRow):
    listAttr = strRow.split(',')
    
    row = Row(
        SEPAL_LENGTH = float(listAttr[0]),
        SEPAL_WIDTH  = float(listAttr[1]),
        PETAL_LENGTH = float(listAttr[2]),
        PETAL_WIDTH  = float(listAttr[3]),
        SPECIE       = listAttr[4]
    )
    
    return row

In [11]:
rddIris03 = rddIris02.map(dataCleaning)

In [12]:
rddIris03.take(5)

[Row(SEPAL_LENGTH=5.1, SEPAL_WIDTH=3.5, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.9, SEPAL_WIDTH=3.0, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.7, SEPAL_WIDTH=3.2, PETAL_LENGTH=1.3, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.6, SEPAL_WIDTH=3.1, PETAL_LENGTH=1.5, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=5.0, SEPAL_WIDTH=3.6, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa')]

**Converting the RDD to a DataFrame**

In [13]:
dfIris = spSession.createDataFrame(rddIris03)

In [14]:
dfIris.cache()

DataFrame[SEPAL_LENGTH: double, SEPAL_WIDTH: double, PETAL_LENGTH: double, PETAL_WIDTH: double, SPECIE: string]

In [15]:
dfIris.take(5)

[Row(SEPAL_LENGTH=5.1, SEPAL_WIDTH=3.5, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.9, SEPAL_WIDTH=3.0, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.7, SEPAL_WIDTH=3.2, PETAL_LENGTH=1.3, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=4.6, SEPAL_WIDTH=3.1, PETAL_LENGTH=1.5, PETAL_WIDTH=0.2, SPECIE='setosa'),
 Row(SEPAL_LENGTH=5.0, SEPAL_WIDTH=3.6, PETAL_LENGTH=1.4, PETAL_WIDTH=0.2, SPECIE='setosa')]

**Creating a numeric index for the label target column**

In [16]:
stringIndexer = StringIndexer(inputCol = 'SPECIE', outputCol = 'IDX_SPECIE')

In [17]:
stringIndexerModel = stringIndexer.fit(dfIris)

In [18]:
dfIris = stringIndexerModel.transform(dfIris)

In [19]:
dfIris.select('SPECIE', 'IDX_SPECIE').distinct().collect()

[Row(SPECIE='setosa', IDX_SPECIE=0.0),
 Row(SPECIE='virginica', IDX_SPECIE=2.0),
 Row(SPECIE='versicolor', IDX_SPECIE=1.0)]

### Exploratory Data Analysis

In [20]:
dfIris.describe().show()

+-------+------------------+------------------+------------------+------------------+---------+------------------+
|summary|      SEPAL_LENGTH|       SEPAL_WIDTH|      PETAL_LENGTH|       PETAL_WIDTH|   SPECIE|        IDX_SPECIE|
+-------+------------------+------------------+------------------+------------------+---------+------------------+
|  count|               150|               150|               150|               150|      150|               150|
|   mean| 5.843333333333332|3.0573333333333337| 3.758000000000001|1.1993333333333331|     null|               1.0|
| stddev|0.8280661279778634|0.4358662849366978|1.7652982332594662|0.7622376689603467|     null|0.8192319205190406|
|    min|               4.3|               2.0|               1.0|               0.1|   setosa|               0.0|
|    max|               7.9|               4.4|               6.9|               2.5|virginica|               2.0|
+-------+------------------+------------------+------------------+--------------

In [21]:
for column in dfIris.columns:
    if not(isinstance(dfIris.select(column).take(1)[0][0], str)):
        print(f"IDX_SPECIE correlation with {column}: {dfIris.stat.corr('IDX_SPECIE', column)}")

IDX_SPECIE correlation with SEPAL_LENGTH: 0.7825612318100814
IDX_SPECIE correlation with SEPAL_WIDTH: -0.4266575607811232
IDX_SPECIE correlation with PETAL_LENGTH: 0.9490346990083887
IDX_SPECIE correlation with PETAL_WIDTH: 0.9565473328764027
IDX_SPECIE correlation with IDX_SPECIE: 1.0


### Data Pre-Processing

**Creating a LabeledPoint (target, Vector[features])**<br />
It removes not relevant columns to the model (or with low correlation)

In [22]:
def setLabeledPoint(row):
    labeledPoint = (
        row['SPECIE'],
        row['IDX_SPECIE'], 
        Vectors.dense([
            row['SEPAL_LENGTH'], 
            row['SEPAL_WIDTH'], 
            row['PETAL_LENGTH'],
            row['PETAL_WIDTH']
        ])
    )
    
    return labeledPoint

In [23]:
rddIris04 = dfIris.rdd.map(setLabeledPoint)

In [24]:
rddIris04.take(5)

[('setosa', 0.0, DenseVector([5.1, 3.5, 1.4, 0.2])),
 ('setosa', 0.0, DenseVector([4.9, 3.0, 1.4, 0.2])),
 ('setosa', 0.0, DenseVector([4.7, 3.2, 1.3, 0.2])),
 ('setosa', 0.0, DenseVector([4.6, 3.1, 1.5, 0.2])),
 ('setosa', 0.0, DenseVector([5.0, 3.6, 1.4, 0.2]))]

In [25]:
dfIris = spSession.createDataFrame(rddIris04, ['specie', 'label', 'features'])

In [26]:
dfIris.select('specie', 'label', 'features').show(5)

+------+-----+-----------------+
|specie|label|         features|
+------+-----+-----------------+
|setosa|  0.0|[5.1,3.5,1.4,0.2]|
|setosa|  0.0|[4.9,3.0,1.4,0.2]|
|setosa|  0.0|[4.7,3.2,1.3,0.2]|
|setosa|  0.0|[4.6,3.1,1.5,0.2]|
|setosa|  0.0|[5.0,3.6,1.4,0.2]|
+------+-----+-----------------+
only showing top 5 rows



### Machine Learning

In [27]:
(dataTraining, dataTest) = dfIris.randomSplit([.7, .3])

In [28]:
dataTraining.count()

104

In [29]:
dataTest.count()

46

In [30]:
dataTraining.count() + dataTest.count() == dfIris.count()

True

In [31]:
decisionTreeClassifier = DecisionTreeClassifier(maxDepth = 2, labelCol = 'label', featuresCol = 'features')

In [32]:
model = decisionTreeClassifier.fit(dataTraining)

In [33]:
model

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_be652aa9570b, depth=2, numNodes=5, numClasses=3, numFeatures=4

In [34]:
print(f'Nodes number: {str(model.numNodes)}')
print(f'Depth: {str(model.depth)}')

Nodes number: 5
Depth: 2


In [35]:
predictions = model.transform(dataTest)

In [36]:
predictions.select('specie', 'features', 'prediction').show(5)

+------+-----------------+----------+
|specie|         features|prediction|
+------+-----------------+----------+
|setosa|[4.3,3.0,1.1,0.1]|       0.0|
|setosa|[4.6,3.2,1.4,0.2]|       0.0|
|setosa|[4.7,3.2,1.6,0.2]|       0.0|
|setosa|[4.8,3.0,1.4,0.3]|       0.0|
|setosa|[4.8,3.4,1.6,0.2]|       0.0|
+------+-----------------+----------+
only showing top 5 rows



In [37]:
evaluator = MulticlassClassificationEvaluator(
    predictionCol = 'prediction', 
    labelCol = 'label', 
    metricName = 'accuracy')

In [38]:
evaluator.evaluate(predictions)

0.9347826086956522

**Confusion Matrix - Summing Up Predictions**

In [39]:
predictions.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|   11|
|  2.0|       2.0|   14|
|  2.0|       1.0|    2|
|  1.0|       2.0|    1|
|  0.0|       0.0|   18|
+-----+----------+-----+

