In [None]:
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
val input = sc.textFile("/iris-multiclass.csv")
input.take(5)

In [None]:
// map the inputs correctly to a LabeledPoint
val classMap = Map("Iris-setosa"-> 0.0, "Iris-versicolor"-> 1.0, "Iris-virginica"-> 2.0)
val data = input.map { line =>
    val lineSplit = line.split(',')
    val values = Vectors.dense(lineSplit.take(4).map(_.toDouble))
    LabeledPoint(classMap(lineSplit(4)), values)
}.persist()

In [None]:
// train the data
val allData = data.randomSplit(Array(0.7, 0.3), seed = 11L)
val (training, test) = (allData(0), allData(1))
val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(training)

In [None]:
// make predictions based on the inputs
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
  val prediction = model.predict(features)
  (prediction, label)
}

In [None]:
// count the precision
val mappedPredLabs = predictionAndLabels.map(mapped => (classMap.filter(kv => kv._2 == mapped._1).toList(0)._1, classMap.filter(kv => kv._2 == mapped._2).toList(0)._1))
val countWrong = mappedPredLabs.filter(items => items._1 != items._2).count
val countCorrect = mappedPredLabs.filter(items => items._1 == items._2).count
val precision = countCorrect.toDouble / (countCorrect + countWrong).toDouble

In [None]:
// check the confusion matrix
val metrics = new MulticlassMetrics(predictionAndLabels) 
val metricsRecall = metrics.recall
val cf = metrics.confusionMatrix

In [None]:
val metricsPrecision = metrics.precision 