In [21]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer,OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

spark = SparkSession.builder.getOrCreate()
irisDF = spark.read.option('header','true').option('inferSchema','true').csv('datasets/iris-dataset.txt')
indexer = StringIndexer(inputCol='class',outputCol='label')
indexerModel = indexer.fit(irisDF)
irisDF = indexerModel.transform(irisDF)
#print(irisDF.columns[0:4])
vec = VectorAssembler(inputCols=irisDF.columns[0:4],outputCol='features')
irisDF = vec.transform(irisDF)
irisDF =irisDF.select('features','label')

trainDF, testDF = irisDF.randomSplit([0.75,0.25],seed=1234)

classifier = DecisionTreeClassifier()
model = classifier.fit(trainDF)#Training date iss used to create model
print(model.toDebugString)
resultDF = model.transform(testDF)

eva = MulticlassClassificationEvaluator(metricName='f1')

result = eva.evaluate(resultDF)
print("Accuracy :",result)
#resultDF.show(50)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_bc8c8a7afc4c, depth=5, numNodes=17, numClasses=3, numFeatures=4
  If (feature 2 <= 2.5999999999999996)
   Predict: 0.0
  Else (feature 2 > 2.5999999999999996)
   If (feature 3 <= 1.75)
    If (feature 0 <= 4.95)
     If (feature 1 <= 2.45)
      Predict: 1.0
     Else (feature 1 > 2.45)
      Predict: 2.0
    Else (feature 0 > 4.95)
     If (feature 2 <= 5.05)
      Predict: 1.0
     Else (feature 2 > 5.05)
      If (feature 0 <= 6.05)
       Predict: 1.0
      Else (feature 0 > 6.05)
       Predict: 2.0
   Else (feature 3 > 1.75)
    If (feature 2 <= 4.85)
     If (feature 0 <= 5.95)
      Predict: 1.0
     Else (feature 0 > 5.95)
      Predict: 2.0
    Else (feature 2 > 4.85)
     Predict: 2.0

Accuracy : 0.9771847507331378


In [20]:
irisDF.show(200)

+-----------------+-----+
|         features|label|
+-----------------+-----+
|[5.1,3.5,1.4,0.2]|  0.0|
|[4.9,3.0,1.4,0.2]|  0.0|
|[4.7,3.2,1.3,0.2]|  0.0|
|[4.6,3.1,1.5,0.2]|  0.0|
|[5.0,3.6,1.4,0.2]|  0.0|
|[5.4,3.9,1.7,0.4]|  0.0|
|[4.6,3.4,1.4,0.3]|  0.0|
|[5.0,3.4,1.5,0.2]|  0.0|
|[4.4,2.9,1.4,0.2]|  0.0|
|[4.9,3.1,1.5,0.1]|  0.0|
|[5.4,3.7,1.5,0.2]|  0.0|
|[4.8,3.4,1.6,0.2]|  0.0|
|[4.8,3.0,1.4,0.1]|  0.0|
|[4.3,3.0,1.1,0.1]|  0.0|
|[5.8,4.0,1.2,0.2]|  0.0|
|[5.7,4.4,1.5,0.4]|  0.0|
|[5.4,3.9,1.3,0.4]|  0.0|
|[5.1,3.5,1.4,0.3]|  0.0|
|[5.7,3.8,1.7,0.3]|  0.0|
|[5.1,3.8,1.5,0.3]|  0.0|
|[5.4,3.4,1.7,0.2]|  0.0|
|[5.1,3.7,1.5,0.4]|  0.0|
|[4.6,3.6,1.0,0.2]|  0.0|
|[5.1,3.3,1.7,0.5]|  0.0|
|[4.8,3.4,1.9,0.2]|  0.0|
|[5.0,3.0,1.6,0.2]|  0.0|
|[5.0,3.4,1.6,0.4]|  0.0|
|[5.2,3.5,1.5,0.2]|  0.0|
|[5.2,3.4,1.4,0.2]|  0.0|
|[4.7,3.2,1.6,0.2]|  0.0|
|[4.8,3.1,1.6,0.2]|  0.0|
|[5.4,3.4,1.5,0.4]|  0.0|
|[5.2,4.1,1.5,0.1]|  0.0|
|[5.5,4.2,1.4,0.2]|  0.0|
|[4.9,3.1,1.5,0.1]|  0.0|
|[5.0,3.2,1.