Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing calculate FAR for the EvaluationBinary #7320

Merged
merged 8 commits into from Mar 21, 2019
Expand Up @@ -56,6 +56,9 @@
@EqualsAndHashCode(callSuper = true)
@Data
public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {

public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR}

public static final int DEFAULT_PRECISION = 4;
public static final double DEFAULT_EDGE_VALUE = 0.0;

Expand Down Expand Up @@ -218,6 +221,12 @@ public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArra
}
}

/**
* Merge the other evaluation object into this one. The result is that this {@link #EvaluationBinary} instance contains the counts
* etc from both
*
* @param other EvaluationBinary object to merge into this one.
*/
@Override
public void merge(EvaluationBinary other) {
if (other.countTruePositive == null) {
Expand Down Expand Up @@ -507,6 +516,34 @@ private void assertIndex(int outputNum) {
}
}

/**
* Average False Alarm Rate (FAR) (see {@link #falseAlarmRate(int)}) for all labels.
*
* @return The FAR for all labels.
*/
public double averageFalseAlarmRate() {
double ret = 0.0;
for (int i = 0; i < numLabels(); i++) {
ret += falseAlarmRate(i);
}

ret /= (double) numLabels();
return ret;
}

/**
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
* <a href="http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw">http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw</a><br>
*
* @param outputNum Class index to calculate False Alarm Rate (FAR)
* @return The FAR for the outcomes
*/
public double falseAlarmRate(int outputNum) {
assertIndex(outputNum);

return (falsePositiveRate(outputNum) + falseNegativeRate(outputNum)) / 2.0;
}

/**
* Get a String representation of the EvaluationBinary class, using the default precision
*/
Expand Down Expand Up @@ -591,6 +628,35 @@ public String stats(int printPrecision) {
return sb.toString();
}

/**
* Calculate specific metric (see {@link Metric}) for a given label.
*
* @param metric The Metric to calculate.
* @param outputNum Class index to calculate.
*
* @return Calculated metric.
*/
public double scoreForMetric(Metric metric, int outputNum){
switch (metric){
case ACCURACY:
return accuracy(outputNum);
case F1:
return f1(outputNum);
case PRECISION:
return precision(outputNum);
case RECALL:
return recall(outputNum);
case GMEASURE:
return gMeasure(outputNum);
case MCC:
return matthewsCorrelation(outputNum);
case FAR:
return falseAlarmRate(outputNum);
default:
throw new IllegalStateException("Unknown metric: " + metric);
}
}

public static EvaluationBinary fromJson(String json) {
return fromJson(json, EvaluationBinary.class);
}
Expand Down
Expand Up @@ -27,7 +27,7 @@
import org.nd4j.linalg.indexing.NDArrayIndex;

import static org.junit.Assert.assertEquals;

import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
/**
* Created by Alex on 20/03/2017.
*/
Expand Down Expand Up @@ -87,12 +87,14 @@ public void testEvaluationBinary() {
e.eval(lCol, pCol);

assertEquals(acc, eb.accuracy(i), eps);
assertEquals(e.accuracy(), eb.accuracy(i), eps);
assertEquals(e.precision(1), eb.precision(i), eps);
assertEquals(e.recall(1), eb.recall(i), eps);
assertEquals(e.f1(1), eb.f1(i), eps);
assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps);
assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps);
assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps);
assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps);
assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps);
assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps);


assertEquals(tpCount, eb.truePositives(i));
assertEquals(tnCount, eb.trueNegatives(i));

Expand Down