In [2]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer,OneHotEncoder
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator,ParamGridBuilder

spark = SparkSession.builder.getOrCreate()
irisDF = spark.read.option('header','true').option('inferSchema','true').csv('datasets/iris-dataset.txt')
indexer = StringIndexer(inputCol='class',outputCol='label')
indexerModel = indexer.fit(irisDF)
irisDF = indexerModel.transform(irisDF)
vec = VectorAssembler(inputCols=irisDF.columns[0:4],outputCol='features')
irisDF = vec.transform(irisDF)
#irisDF.show()

trainDF, testDF = irisDF.randomSplit([0.75,0.25],seed=123) 

rfClassifier= RandomForestClassifier()
eva = MulticlassClassificationEvaluator(metricName='accuracy')

myParams = ParamGridBuilder().addGrid(rfClassifier.numTrees,[8,10,12])\
                            .addGrid(rfClassifier.maxDepth,[2,4,6])\
                            .addGrid(rfClassifier.impurity,['entropy','gini']).build()

validator = CrossValidator(estimator=rfClassifier,
                                 estimatorParamMaps=myParams,
                                 evaluator=eva,
                                 parallelism=4,
                                 numFolds=5
                                )

model = validator.fit(trainDF)
print('finished')

print("Num Trees : ",model.bestModel.getNumTrees)
print("Max Depth : ",model.bestModel._java_obj.getMaxDepth())
print("Impurtiy : ",model.bestModel._java_obj.getImpurity())

finished
Num Trees :  12
Max Depth :  6
Impurtiy :  entropy


In [3]:
resultDF = model.transform(testDF)
accuracy = eva.evaluate(resultDF)
print("Accuracy : ", accuracy)

Accuracy :  0.9696969696969697
