In [4]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer,OneHotEncoder
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

spark = SparkSession.builder.getOrCreate()
irisDF = spark.read.option('header','true').option('inferSchema','true').csv('datasets/iris-dataset.txt')
indexer = StringIndexer(inputCol='class',outputCol='label')
indexerModel = indexer.fit(irisDF)
irisDF = indexerModel.transform(irisDF)
#print(irisDF.columns[0:4])
vec = VectorAssembler(inputCols=irisDF.columns[0:4],outputCol='features')
irisDF = vec.transform(irisDF)
irisDF =irisDF.select('features','label')

trainDF, testDF = irisDF.randomSplit([0.75,0.25],seed=1234)

classifier = MultilayerPerceptronClassifier(layers=[4,2,3])
model = classifier.fit(trainDF) #Train data should be used to create model

resultDF = model.transform(testDF)
resultDF.show()
eva = MulticlassClassificationEvaluator(metricName='accuracy')
accuracy = eva.evaluate(resultDF)
print("Accuracy : ",accuracy)

+-----------------+-----+--------------------+--------------------+----------+
|         features|label|       rawPrediction|         probability|prediction|
+-----------------+-----+--------------------+--------------------+----------+
|[4.4,2.9,1.4,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[4.5,2.3,1.3,0.3]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[4.9,3.1,1.5,0.1]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.0,1.6,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.2,1.2,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.3,1.4,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.4,1.5,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.5,1.3,0.3]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.0,3.6,1.4,0.2]|  0.0|[161.205433561884...|[1.0,1.8312358597...|       0.0|
|[5.1,2.5,3.0,1.1]|  1.0|[161.205433561884...|[1.0,1