In [90]:
import org.apache.spark.sql.types.{StructField, StructType, DoubleType, IntegerType, StringType}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.DataFrameNaFunctions
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import java.io._

val schema = new StructType(Array(
    new StructField("classification", StringType, true),
    new StructField("capshape", StringType, true),
    new StructField("capsurface", StringType, true),
    new StructField("capcolor", StringType, true),
    new StructField("bruises", StringType, true),
    new StructField("odor", StringType, true),
    new StructField("gillattachment", StringType, true),
    new StructField("gillspacing", StringType, true),
    new StructField("gillsize", StringType, true),
    new StructField("gillcolor", StringType, true),
    new StructField("stalkshape", StringType, true),
    new StructField("stalkroot", StringType, true),
    new StructField("stalksurfaceabovering", StringType, true),
    new StructField("stalksurfacebelowring", StringType, true),
    new StructField("stalkcolorabovering", StringType, true),
    new StructField("stalkcolorbelowring", StringType, true),
    new StructField("veiltype", StringType, true),
    new StructField("veilcolor", StringType, true),
    new StructField("ringnumber", StringType, true),
    new StructField("ringtype", StringType, true),
    new StructField("sporeprintcolor", StringType, true),
    new StructField("population", StringType, true),
    new StructField("habitat", StringType, true)
))

val data = spark.read.format("csv").schema(schema).option("header",false).load("mushrooms.csv")

In [77]:
//removed veiltype because all mushrooms have same value for this feature
val dropDF = data.na.drop("any")
val params = "classification ~ capshape + capsurface + capcolor + bruises + odor + gillattachment + gillspacing + gillsize + gillcolor + stalkshape + stalkroot + stalksurfaceabovering + stalksurfacebelowring + stalkcolorabovering + stalkcolorbelowring + veilcolor + ringnumber + ringtype + sporeprintcolor + population + habitat"
val myFormala = new RFormula().setFormula(params) 
val fittedRF = myFormala.fit(dropDF)
val preparedDF = fittedRF.transform(dropDF)

In [78]:
val randSeed = 5043
val Array(train, test) = preparedDF.randomSplit(Array(0.7, 0.3), randSeed)

val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(randSeed)
val model = classifier.fit(train)

In [79]:
val predictions = model.transform(test)

In [80]:
val evaluator = new BinaryClassificationEvaluator().setLabelCol("label")
val accuracy = evaluator.evaluate(predictions)
println(f"accuracy: $accuracy")

accuracy: 0.9997153559413391


In [41]:
println("Random Forest Model:\n" + model.toDebugString)

Random Forest Model:
RandomForestClassificationModel (uid=rfc_78e88275df8e) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 29 in {1.0})
     Predict: 1.0
    Else (feature 29 not in {1.0})
     If (feature 45 in {1.0})
      If (feature 25 in {1.0})
       Predict: 1.0
      Else (feature 25 not in {1.0})
       Predict: 0.0
     Else (feature 45 not in {1.0})
      If (feature 47 in {1.0})
       Predict: 0.0
      Else (feature 47 not in {1.0})
       Predict: 1.0
  Tree 1 (weight 1.0):
    If (feature 84 in {0.0})
     If (feature 46 in {0.0})
      If (feature 90 in {1.0})
       Predict: 0.0
      Else (feature 90 not in {1.0})
       Predict: 0.0
     Else (feature 46 not in {0.0})
      If (feature 60 in {1.0})
       Predict: 1.0
      Else (feature 60 not in {1.0})
       Predict: 1.0
    Else (feature 84 not in {0.0})
     If (feature 72 in {1.0})
      If (feature 6 in {0.0})
       Predict: 0.0
      Else (feature 6 not in {0.0})
       Predict: 1.0
     Else (feature

In [42]:
model.write.save("mushroomModel")

Name: java.io.IOException
Message: Path mushroomModel already exists. To overwrite it, please use write.overwrite().save(path) for Scala and use write().overwrite().save(path) for Java and Python.
StackTrace:   at org.apache.spark.ml.util.MLWriter.save(ReadWrite.scala:109)

In [81]:
val pw = new PrintWriter(new File("mushroomModel.txt" ))
pw.write(model.toDebugString)
pw.close

In [82]:
val wrongPred = predictions.where(expr("label != prediction"))
val countErrors = wrongPred.groupBy("label").agg(count("prediction").alias("Errors"))
countErrors.show

+-----+------+
|label|Errors|
+-----+------+
|  1.0|    72|
+-----+------+



In [85]:
val rightPred = predictions.where(expr("label == prediction"))
val countCorrect = rightPred.groupBy("label").agg(count("prediction").alias("Correct"))
countCorrect.show

+-----+-------+
|label|Correct|
+-----+-------+
|  0.0|   1359|
|  1.0|   1143|
+-----+-------+



In [86]:
val falsevec = countErrors.select("Errors").collect()
val falsepos = falseval(0)
val falseneg = 0.0
val truevec = countCorrect.select("Correct").collect()
val trueneg = truevals(0)
val truepos = truevals(1)
println(f"$trueneg $falsepos\n[$falseneg]  $truepos")

[1359] [96]
[0.0]  [1131]


In [73]:
val cm = new PrintWriter(new File("mushroomConfusion.txt" ))
cm.write(f"$trueneg $falsepos\n[$falseneg]  $truepos")
cm.close

In [53]:
//countErrors.write.save("mushroomConfusion")