Skip to content

Commit

Permalink
Addressing reviewers comments mengxr
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed Jul 8, 2014
1 parent c3a77ad commit a7e8bf0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,27 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.SparkContext._

import scala.collection.Map

/**
* ::Experimental::
* Evaluator for multiclass classification.
*
* @param predictionsAndLabels an RDD of (prediction, label) pairs.
*/
@Experimental
class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging {

private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
private lazy val labelCount = labelCountByClass.values.sum
private lazy val tpByClass = predictionsAndLabels
.map{ case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
private lazy val labelCountByClass: Map[Double, Long] = predictionsAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionsAndLabels
.map { case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val fpByClass = predictionsAndLabels
.map{ case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
}.reduceByKey(_ + _)
.collectAsMap()

Expand All @@ -63,35 +66,41 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
* Returns f-measure for a given label (category)
* @param label the label.
*/
def fMeasure(label: Double, beta:Double = 1.0): Double = {
def fMeasure(label: Double, beta: Double): Double = {
val p = precision(label)
val r = recall(label)
val betaSqrd = beta * beta
if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
}

/**
* Returns micro-averaged recall
* (equals to microPrecision and microF1measure for multiclass classifier)
* Returns f1-measure for a given label (category)
* @param label the label.
*/
def fMeasure(label: Double): Double = fMeasure(label, 1.0)

/**
* Returns precision
*/
lazy val recall: Double =
tpByClass.values.sum.toDouble / labelCount
lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount

/**
* Returns micro-averaged precision
* (equals to microPrecision and microF1measure for multiclass classifier)
* Returns recall
* (equals to precision for multiclass classifier
* because sum of all false positives is equal to sum
* of all false negatives)
*/
lazy val precision: Double = recall
lazy val recall: Double = precision

/**
* Returns micro-averaged f-measure
* (equals to microPrecision and microRecall for multiclass classifier)
* Returns f-measure
* (equals to precision and recall because precision equals recall)
*/
lazy val fMeasure: Double = recall
lazy val fMeasure: Double = precision

/**
* Returns weighted averaged recall
* (equals to micro-averaged precision, recall and f-measure)
* (equals to precision, recall and f-measure)
*/
lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
recall(category) * count.toDouble / labelCount
Expand All @@ -114,6 +123,5 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
/**
* Returns the sequence of labels in ascending order
*/
lazy val labels = tpByClass.unzip._1.toSeq.sorted

lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
* |0|0|1| true class2 (1 instance)
*
*/
val labels = Seq(0.0, 1.0, 2.0)
val labels = Array(0.0, 1.0, 2.0)
val scoreAndLabels = sc.parallelize(
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
Expand Down Expand Up @@ -65,6 +65,6 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
assert(math.abs(metrics.weightedF1Measure -
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
assert(metrics.labels == labels)
assert(metrics.labels.sameElements(labels))
}
}

0 comments on commit a7e8bf0

Please sign in to comment.