diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 7a70117f36d4c..8f25a3d0020d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -22,7 +22,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.SparkContext._ +import scala.collection.Map + /** + * ::Experimental:: * Evaluator for multiclass classification. * * @param predictionsAndLabels an RDD of (prediction, label) pairs. @@ -30,16 +33,16 @@ import org.apache.spark.SparkContext._ @Experimental class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging { - private lazy val labelCountByClass = predictionsAndLabels.values.countByValue() - private lazy val labelCount = labelCountByClass.values.sum - private lazy val tpByClass = predictionsAndLabels - .map{ case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Long] = predictionsAndLabels.values.countByValue() + private lazy val labelCount: Long = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Int] = predictionsAndLabels + .map { case (prediction, label) => + (label, if (label == prediction) 1 else 0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass = predictionsAndLabels - .map{ case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels + .map { case (prediction, label) => + (prediction, if (prediction != label) 1 else 0) }.reduceByKey(_ + _) .collectAsMap() @@ -63,7 +66,7 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log * Returns f-measure for a given label (category) * @param label the label. */ - def fMeasure(label: Double, beta:Double = 1.0): Double = { + def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) val r = recall(label) val betaSqrd = beta * beta @@ -71,27 +74,33 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log } /** - * Returns micro-averaged recall - * (equals to microPrecision and microF1measure for multiclass classifier) + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def fMeasure(label: Double): Double = fMeasure(label, 1.0) + + /** + * Returns precision */ - lazy val recall: Double = - tpByClass.values.sum.toDouble / labelCount + lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount /** - * Returns micro-averaged precision - * (equals to microPrecision and microF1measure for multiclass classifier) + * Returns recall + * (equals to precision for multiclass classifier + * because sum of all false positives is equal to sum + * of all false negatives) */ - lazy val precision: Double = recall + lazy val recall: Double = precision /** - * Returns micro-averaged f-measure - * (equals to microPrecision and microRecall for multiclass classifier) + * Returns f-measure + * (equals to precision and recall because precision equals recall) */ - lazy val fMeasure: Double = recall + lazy val fMeasure: Double = precision /** * Returns weighted averaged recall - * (equals to micro-averaged precision, recall and f-measure) + * (equals to precision, recall and f-measure) */ lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => recall(category) * count.toDouble / labelCount @@ -114,6 +123,5 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log /** * Returns the sequence of labels in ascending order */ - lazy val labels = tpByClass.unzip._1.toSeq.sorted - + lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 4b959b2d542ac..9bdd5745677aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -30,7 +30,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { * |0|0|1| true class2 (1 instance) * */ - val labels = Seq(0.0, 1.0, 2.0) + val labels = Array(0.0, 1.0, 2.0) val scoreAndLabels = sc.parallelize( Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) @@ -65,6 +65,6 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { ((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta) assert(math.abs(metrics.weightedF1Measure - ((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta) - assert(metrics.labels == labels) + assert(metrics.labels.sameElements(labels)) } }