Skip to content

Commit

Permalink
Added Matthew's correlation coefficient calculation to binary classif…
Browse files Browse the repository at this point in the history
…ication metrics
  • Loading branch information
mandar2812 committed Oct 10, 2016
1 parent 56df6e0 commit 4c95c6f
Showing 1 changed file with 52 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ class BinaryClassificationMetrics(
* will be used to measure the variation
* in precision, recall False Positive
* and False Negative values.
* */

*/
private val thresholds = List.tabulate(100)(i => {
scoresAndLabels.map(_._1).min +
i.toDouble*((scoresAndLabels.map(_._1).max.toInt -
Expand All @@ -68,19 +67,19 @@ class BinaryClassificationMetrics(
/**
* Calculate the area under the Precision-Recall
* curve.
* */
*/
def areaUnderPR(): Double = areaUnderCurve(this.pr())

/**
* Calculate the area under the Receiver
* Operating Characteristic curve.
* */
*/
def areaUnderROC(): Double = areaUnderCurve(this.roc())

/**
* Calculate the F1 metric by threshold, for a
* beta value of 1.0
* */
*/
def fMeasureByThreshold(): List[(Double, Double)] = fMeasureByThreshold(1.0)

/**
Expand All @@ -98,22 +97,22 @@ class BinaryClassificationMetrics(
/**
* Return the Precision-Recall curve, as a [[List]]
* of [[Tuple2]] (Recall, Precision).
* */
*/
def pr(): List[(Double, Double)] =
recallByThreshold().zip(precisionByThreshold()).map((couple) =>
(couple._1._2, couple._2._2)).sorted

/**
* Return the Recall-Threshold curve, as a [[List]]
* of [[Tuple2]] (Threshold, Recall).
* */
*/
def recallByThreshold(): List[(Double, Double)] =
tpfpByThreshold().map((point) => (point._1, point._2._1))

/**
* Return the Precision-Threshold curve, as a [[List]]
* of [[Tuple2]] (Threshold, Precision).
* */
*/
def precisionByThreshold(): List[(Double, Double)] =
tpfpByThreshold().map((point) =>
(point._1, point._2._1/(point._2._1 + point._2._2)))
Expand All @@ -122,15 +121,15 @@ class BinaryClassificationMetrics(
* Return the Receiver Operating Characteristic
* curve, as a [[List]] of [[Tuple2]]
* (False Positive Rate, True Positive Rate).
* */
*/
def roc(): List[(Double, Double)] =
tpfpByThreshold().map((point) => (point._2._2, point._2._1)).sorted

/**
* Return the True Positive and False Positive Rate
* with respect to the threshold, as a [[List]]
* of [[Tuple2]] (Threshold, (True Positive rate, False Positive Rate)).
* */
*/
def tpfpByThreshold(): List[(Double, (Double, Double))] =
this.thresholds.map((th) => {
val true_positive = if(positives.nonEmpty) {
Expand All @@ -148,24 +147,64 @@ class BinaryClassificationMetrics(
(th, (true_positive, false_positive))
})

/**
* Return the True Positive and False Positive Rate
* with respect to the threshold, as a [[List]]
* of [[Tuple2]]
* (Threshold, (True Positive rate, True Negative Rate, False Positive Rate, False Negative Rate)).
*/
def tptn_fpfnByThreshold: List[(Double, (Double, Double, Double, Double))] =
this.thresholds.map((th) => {
val (true_positive, false_negative) = if(positives.nonEmpty) {
val t = positives.partition(p => math.signum(p._1 - th) == 1.0)

(t._1.length.toDouble/positives.length, t._2.length.toDouble/positives.length)
} else {(0.0, 0.0)}

val (false_positive, true_negative) = if(negatives.nonEmpty) {
val f = negatives.partition(p => math.signum(p._1 - th) == 1.0)
(f._1.length.toDouble/positives.length, f._2.length.toDouble/positives.length)
} else {(0.0, 0.0)}

(th, (true_positive, true_negative, false_positive, false_negative))
})

/**
* Returns the Matthew's correlation coefficient
* for every thresholding value.
*
*/
def matthewsCCByThreshold: List[(Double, Double)] = tptn_fpfnByThreshold.map(t => {
val (tp, tn, fp, fn) = t._2
(t._1, (tp*tn-(fp*fn))/((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
})

def accuracyByThreshold(): List[(Double, Double)] = tpfpByThreshold().map((t) => (t._1,
(t._2._1*positives.length + (1.0-t._2._2)*negatives.length)/scoresAndLabels.length))

/**
* Generate the PR, ROC and F1 measure
* plots using Scala-Chart.
* */
*/
override def generatePlots(): Unit = {
val roccurve = this.roc()
val prcurve = this.pr()
val fm = this.fMeasureByThreshold()
logger.info("Generating ROC Plot")
val mtt = matthewsCCByThreshold

logger.info("Generating Matthew's correlation coefficient plot by thresholding value")
spline(mtt.map(_._1), mtt.map(_._2))
title("MCC vs Threshold Cutoff: "+name)
xAxis("Threshold")
yAxis("F Measure")

logger.info("Generating F1-measure plot by thresholding value")
spline(fm.map(_._1), fm.map(_._2))
title("F Measure vs Threshold Cutoff: "+name)
xAxis("Threshold")
yAxis("F Measure")

logger.info("Generating ROC Plot")
areaspline(roccurve.map(_._1), roccurve.map(_._2))
title("Receiver Operating Characteristic: "+name+
", Area under curve: "+areaUnderCurve(roccurve))
Expand All @@ -179,6 +218,7 @@ class BinaryClassificationMetrics(
logger.info("Accuracy: " + accuracyByThreshold().map((c) => c._2).max)
logger.info("Area under ROC: " + areaUnderROC())
logger.info("Maximum F Measure: " + fMeasureByThreshold().map(_._2).max)
logger.info("Maximum Matthew's Correlation Coefficient: " + matthewsCCByThreshold.map(_._2).max)
}

override def kpi() = DenseVector(accuracyByThreshold().map((c) => c._2).max,
Expand Down

0 comments on commit 4c95c6f

Please sign in to comment.