In [4]:
# Imports
from pyspark.sql import Row
from pyspark.ml.feature import StringIndexer
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [5]:
# Spark Session
sparkSession = SparkSession.builder.master("local").appName("irisClassificationDecisionTree").getOrCreate()

In [14]:
irisRDD = sc.textFile("iris.csv")
irisRDD.cache()
irisRDD.count()

151

In [15]:
irisRDD.take(5)

['Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species',
 '1,5.1,3.5,1.4,0.2,Iris-setosa',
 '2,4.9,3.0,1.4,0.2,Iris-setosa',
 '3,4.7,3.2,1.3,0.2,Iris-setosa',
 '4,4.6,3.1,1.5,0.2,Iris-setosa']

In [16]:
# Removing header
irisRDD = irisRDD.filter(lambda x: "Sepal" not in x)

In [17]:
# Mapping columns
irisRDD2 = irisRDD.map(lambda l: l.split(","))
irisRDD2 = irisRDD2.map(lambda l: Row(ID = int(l[0]), SEPAL_LENGTH = float(l[1]), SEPAL_WIDTH = float(l[2]),
                                     PETAL_LENGTH = float(l[3]), PETAL_WIDTH = float(l[4]), SPECIES = l[5]))

In [21]:
# Creating DF
irisDF = sparkSession.createDataFrame(irisRDD2)
irisDF.cache()

DataFrame[ID: bigint, SEPAL_LENGTH: double, SEPAL_WIDTH: double, PETAL_LENGTH: double, PETAL_WIDTH: double, SPECIES: string]

In [22]:
# Creating a Numerical Indexer to target label
stringIndexer = StringIndexer(inputCol = "SPECIES", outputCol = "IDX_SPECIES")
stringIndexerModel = stringIndexer.fit(irisDF)
irisDF_norm = stringIndexerModel.transform(irisDF)

In [23]:
irisDF_norm.select("SPECIES","IDX_SPECIES").distinct().collect()

[Row(SPECIES='Iris-setosa', IDX_SPECIES=0.0),
 Row(SPECIES='Iris-versicolor', IDX_SPECIES=1.0),
 Row(SPECIES='Iris-virginica', IDX_SPECIES=2.0)]

In [26]:
irisDF_norm.describe().show()

+-------+------------------+------------------+------------------+------------------+------------------+--------------+------------------+
|summary|                ID|      SEPAL_LENGTH|       SEPAL_WIDTH|      PETAL_LENGTH|       PETAL_WIDTH|       SPECIES|       IDX_SPECIES|
+-------+------------------+------------------+------------------+------------------+------------------+--------------+------------------+
|  count|               150|               150|               150|               150|               150|           150|               150|
|   mean|              75.5| 5.843333333333332|3.0540000000000007| 3.758666666666668|1.1986666666666663|          null|               1.0|
| stddev|43.445367992456916|0.8280661279778633|0.4335943113621735|1.7644204199522615|0.7631607417008414|          null|0.8192319205190407|
|    min|                 1|               4.3|               2.0|               1.0|               0.1|   Iris-setosa|               0.0|
|    max|               150

In [32]:
for i in irisDF_norm.columns:
    if not(isinstance(irisDF_norm.select(i).take(1)[0][0], str)):
        print("Corr IDX_SPECIES with ", i, irisDF_norm.stat.corr("IDX_SPECIES", i))

Corr IDX_SPECIES with  ID 0.9428299935925015
Corr IDX_SPECIES with  SEPAL_LENGTH 0.7825612318100816
Corr IDX_SPECIES with  SEPAL_WIDTH -0.419446200260027
Corr IDX_SPECIES with  PETAL_LENGTH 0.9490425448523337
Corr IDX_SPECIES with  PETAL_WIDTH 0.9564638238016175
Corr IDX_SPECIES with  IDX_SPECIES 1.0


In [33]:
def transformation(row):
    obj = (row["SPECIES"], row["IDX_SPECIES"], Vectors.dense([row["SEPAL_LENGTH"], row["SEPAL_WIDTH"], 
                                                              row["PETAL_LENGTH"], row["PETAL_WIDTH"]]))
    return obj

In [35]:
irisRDD3 = irisDF_norm.rdd.map(transformation)

In [38]:
irisDF = sparkSession.createDataFrame(irisRDD3, ["species", "label", "features"])
irisDF.select("species", "label", "features").show(10)
irisDF.cache()

+-----------+-----+-----------------+
|    species|label|         features|
+-----------+-----+-----------------+
|Iris-setosa|  0.0|[5.1,3.5,1.4,0.2]|
|Iris-setosa|  0.0|[4.9,3.0,1.4,0.2]|
|Iris-setosa|  0.0|[4.7,3.2,1.3,0.2]|
|Iris-setosa|  0.0|[4.6,3.1,1.5,0.2]|
|Iris-setosa|  0.0|[5.0,3.6,1.4,0.2]|
|Iris-setosa|  0.0|[5.4,3.9,1.7,0.4]|
|Iris-setosa|  0.0|[4.6,3.4,1.4,0.3]|
|Iris-setosa|  0.0|[5.0,3.4,1.5,0.2]|
|Iris-setosa|  0.0|[4.4,2.9,1.4,0.2]|
|Iris-setosa|  0.0|[4.9,3.1,1.5,0.1]|
+-----------+-----+-----------------+
only showing top 10 rows



DataFrame[species: string, label: double, features: vector]

In [39]:
(trainData, testData) = irisDF.randomSplit([0.7, 0.3])

In [51]:
decisionTreeClassifier = DecisionTreeClassifier(maxDepth = 20, labelCol = "label", featuresCol = "features")
model = decisionTreeClassifier.fit(trainData)

In [52]:
model.numNodes
model.depth

6

In [53]:
predictions = model.transform(testData)
predictions.select("prediction", "species", "label").collect()

[Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', label=0.0),
 Row(prediction=0.0, species='Iris-setosa', labe

In [54]:
evaluator = MulticlassClassificationEvaluator (predictionCol = "prediction", labelCol = "label", metricName = "accuracy")
evaluator.evaluate(predictions)

0.975

In [55]:
predictions.groupBy("label", "prediction").count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|    8|
|  1.0|       2.0|    1|
|  0.0|       0.0|   20|
|  2.0|       2.0|   11|
+-----+----------+-----+

