# Decision Tree Example

## 1. Import spark modules

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import Row

In [2]:
spark = SparkSession\
        .builder\
        .appName("PythonPi")\
        .getOrCreate()

**Setup spark configuration and create a spark context**

In [3]:
sc = spark.sparkContext

## 2. Load and Inspect the data

**Load the data**

In [4]:
irisData = sc.textFile("./input/iris.csv")

**Cache the data**

In [5]:
irisData.cache()

./input/iris.csv MapPartitionsRDD[1] at textFile at <unknown>:0

**Count the data**

In [6]:
irisData.count()

151

# 3. Cleanup data

### Remove the header line

In [7]:
dataLines = irisData.filter(lambda x: "Sepal" not in x)

In [8]:
dataLines.count()

150

### Create dataframe from RDD

In [9]:
from pyspark.sql import Row

In [10]:
# split each line by the commas
parts = dataLines.map(lambda l: l.split(","))

In [11]:
# map the RDD into row format
irisMap = parts.map(lambda p: Row(SEPAL_LENGTH = float(p[0]),
                                  SEPAL_WIDTH = float(p[1]),
                                  PETAL_LENGTH = float(p[2]),
                                  PETAL_WIDTH = float(p[3]),
                                  SPECIES = p[4]))

In [12]:
# create a dataframe from the rdd
irisDf = spark.createDataFrame(irisMap)

In [13]:
# cache the datafram
irisDf.cache()

DataFrame[PETAL_LENGTH: double, PETAL_WIDTH: double, SEPAL_LENGTH: double, SEPAL_WIDTH: double, SPECIES: string]

In [14]:
irisDf.show()

+------------+-----------+------------+-----------+-------+
|PETAL_LENGTH|PETAL_WIDTH|SEPAL_LENGTH|SEPAL_WIDTH|SPECIES|
+------------+-----------+------------+-----------+-------+
|         1.4|        0.2|         5.1|        3.5| setosa|
|         1.4|        0.2|         4.9|        3.0| setosa|
|         1.3|        0.2|         4.7|        3.2| setosa|
|         1.5|        0.2|         4.6|        3.1| setosa|
|         1.4|        0.2|         5.0|        3.6| setosa|
|         1.7|        0.4|         5.4|        3.9| setosa|
|         1.4|        0.3|         4.6|        3.4| setosa|
|         1.5|        0.2|         5.0|        3.4| setosa|
|         1.4|        0.2|         4.4|        2.9| setosa|
|         1.5|        0.1|         4.9|        3.1| setosa|
|         1.5|        0.2|         5.4|        3.7| setosa|
|         1.6|        0.2|         4.8|        3.4| setosa|
|         1.4|        0.1|         4.8|        3.0| setosa|
|         1.1|        0.1|         4.3| 

### Transform the labels into numeric values

In [15]:
from pyspark.ml.feature import StringIndexer

In [16]:
stringIndexer = StringIndexer(inputCol = "SPECIES", outputCol="IND_SPECIES")

In [17]:
si_model = stringIndexer.fit(irisDf)
irisNormDf = si_model.transform(irisDf)

In [18]:
irisNormDf.select("Species", "IND_SPECIES").distinct().collect()

[Row(Species='versicolor', IND_SPECIES=0.0),
 Row(Species='setosa', IND_SPECIES=2.0),
 Row(Species='virginica', IND_SPECIES=1.0)]

In [19]:
irisNormDf.cache()

DataFrame[PETAL_LENGTH: double, PETAL_WIDTH: double, SEPAL_LENGTH: double, SEPAL_WIDTH: double, SPECIES: string, IND_SPECIES: double]

In [20]:
irisNormDf.show()

+------------+-----------+------------+-----------+-------+-----------+
|PETAL_LENGTH|PETAL_WIDTH|SEPAL_LENGTH|SEPAL_WIDTH|SPECIES|IND_SPECIES|
+------------+-----------+------------+-----------+-------+-----------+
|         1.4|        0.2|         5.1|        3.5| setosa|        2.0|
|         1.4|        0.2|         4.9|        3.0| setosa|        2.0|
|         1.3|        0.2|         4.7|        3.2| setosa|        2.0|
|         1.5|        0.2|         4.6|        3.1| setosa|        2.0|
|         1.4|        0.2|         5.0|        3.6| setosa|        2.0|
|         1.7|        0.4|         5.4|        3.9| setosa|        2.0|
|         1.4|        0.3|         4.6|        3.4| setosa|        2.0|
|         1.5|        0.2|         5.0|        3.4| setosa|        2.0|
|         1.4|        0.2|         4.4|        2.9| setosa|        2.0|
|         1.5|        0.1|         4.9|        3.1| setosa|        2.0|
|         1.5|        0.2|         5.4|        3.7| setosa|     

## 4. Perform data analytics

**Describe the data**

In [21]:
irisNormDf.describe().show()

+-------+------------------+------------------+------------------+------------------+---------+------------------+
|summary|      PETAL_LENGTH|       PETAL_WIDTH|      SEPAL_LENGTH|       SEPAL_WIDTH|  SPECIES|       IND_SPECIES|
+-------+------------------+------------------+------------------+------------------+---------+------------------+
|  count|               150|               150|               150|               150|      150|               150|
|   mean| 3.758000000000001|1.1993333333333331| 5.843333333333332|3.0573333333333337|     null|               1.0|
| stddev|1.7652982332594662|0.7622376689603467|0.8280661279778634|0.4358662849366978|     null|0.8192319205190404|
|    min|               1.0|               0.1|               4.3|               2.0|   setosa|               0.0|
|    max|               6.9|               2.5|               7.9|               4.4|virginica|               2.0|
+-------+------------------+------------------+------------------+--------------

**Correlation between the target variables and the feature variables**

In [22]:
# iterate through each column in the dataframe
for i in irisNormDf.columns:
    # if data is not an instance of string
    if not(isinstance(irisNormDf.select(i).take(1)[0][0], str)):
        print("Correlation to IND_SPECIES for", i, irisNormDf.stat.corr("IND_SPECIES", i))

Correlation to IND_SPECIES for PETAL_LENGTH -0.649241830764174
Correlation to IND_SPECIES for PETAL_WIDTH -0.5803770334306263
Correlation to IND_SPECIES for SEPAL_LENGTH -0.46003915650023686
Correlation to IND_SPECIES for SEPAL_WIDTH 0.6183715308237434
Correlation to IND_SPECIES for IND_SPECIES 1.0


## 5. Prepare data for machine learning

In [23]:
from pyspark.ml.linalg import Vectors

**A function to transform the RDD into labelled points**

In [24]:
def transformToLabeledPoint(row):
    '''a function to transform data into labeled point'''
    lp = (row["SPECIES"],
          row["IND_SPECIES"],
          Vectors.dense([row['SEPAL_LENGTH'], row["SEPAL_WIDTH"], row["PETAL_LENGTH"], row["PETAL_WIDTH"]]))
    return lp

In [25]:
irisLp = irisNormDf.rdd.map(transformToLabeledPoint)

In [26]:
irisLpDf = spark.createDataFrame(irisLp, ["Species", "label", "features"])

**Display the dataframe containing labelled points**

In [27]:
irisLpDf.select("species", "label", "features").show(10)

+-------+-----+-----------------+
|species|label|         features|
+-------+-----+-----------------+
| setosa|  2.0|[5.1,3.5,1.4,0.2]|
| setosa|  2.0|[4.9,3.0,1.4,0.2]|
| setosa|  2.0|[4.7,3.2,1.3,0.2]|
| setosa|  2.0|[4.6,3.1,1.5,0.2]|
| setosa|  2.0|[5.0,3.6,1.4,0.2]|
| setosa|  2.0|[5.4,3.9,1.7,0.4]|
| setosa|  2.0|[4.6,3.4,1.4,0.3]|
| setosa|  2.0|[5.0,3.4,1.5,0.2]|
| setosa|  2.0|[4.4,2.9,1.4,0.2]|
| setosa|  2.0|[4.9,3.1,1.5,0.1]|
+-------+-----+-----------------+
only showing top 10 rows



In [28]:
irisDf.cache()

DataFrame[PETAL_LENGTH: double, PETAL_WIDTH: double, SEPAL_LENGTH: double, SEPAL_WIDTH: double, SPECIES: string]

## 6. Train & Performance Access the Model

**Split the data train and test set**

In [29]:
(trainingData, testData) = irisLpDf.randomSplit([0.8,0.1])

In [30]:
trainingData.count()

132

In [31]:
testData.count()

18

**Create and fit the model**

In [32]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [33]:
dtClassifier = DecisionTreeClassifier(maxDepth=5, labelCol="label", featuresCol="features")
dtModel = dtClassifier.fit(trainingData)

In [34]:
dtModel.numNodes

17

In [35]:
dtModel.depth

5

**Compute predictions**

In [36]:
predictions = dtModel.transform(testData)

In [37]:
predictions.show()

+----------+-----+-----------------+--------------+-------------+----------+
|   Species|label|         features| rawPrediction|  probability|prediction|
+----------+-----+-----------------+--------------+-------------+----------+
|    setosa|  2.0|[4.3,3.0,1.1,0.1]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[4.4,3.2,1.3,0.2]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[4.7,3.2,1.6,0.2]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[4.8,3.0,1.4,0.3]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[4.9,3.6,1.4,0.1]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[5.0,3.5,1.3,0.3]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[5.1,3.5,1.4,0.3]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[5.1,3.8,1.5,0.3]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|    setosa|  2.0|[5.8,4.0,1.2,0.2]|[0.0,0.0,41.0]|[0.0,0.0,1.0]|       2.0|
|versicolor|  0.0|[5.6,2.9,3.6,1.3]|[43.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|

**Evaluate classifier**

In [38]:
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label", metricName="accuracy")

In [39]:
evaluator.evaluate(predictions)

1.0

**Confusion Matrix**

In [40]:
predictions.groupby("label", "prediction").count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|    5|
|  2.0|       2.0|    9|
|  0.0|       0.0|    4|
+-----+----------+-----+

