Skip to content

Commit

Permalink
[7.x] Fix accuracy metric (#50310) (#50433)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 20, 2019
1 parent 14d95aa commit 3e3a930
Show file tree
Hide file tree
Showing 14 changed files with 475 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
Expand All @@ -35,10 +36,25 @@
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;

/**
* {@link AccuracyMetric} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
* {@link AccuracyMetric} is a metric that answers the following two questions:
*
* equation: accuracy = 1/n * Σ(y == y´)
* 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/
public class AccuracyMetric implements EvaluationMetric {

Expand Down Expand Up @@ -78,29 +94,29 @@ public int hashCode() {

public static class Result implements EvaluationMetric.Result {

private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));

static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;

public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
public Result(List<PerClassResult> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy;
}

Expand All @@ -109,8 +125,8 @@ public String getMetricName() {
return NAME;
}

public List<ActualClass> getActualClasses() {
return actualClasses;
public List<PerClassResult> getClasses() {
return classes;
}

public double getOverallAccuracy() {
Expand All @@ -120,7 +136,7 @@ public double getOverallAccuracy() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
Expand All @@ -131,52 +147,42 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy;
}

@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
return Objects.hash(classes, overallAccuracy);
}
}

public static class ActualClass implements ToXContentObject {
public static class PerClassResult implements ToXContentObject {

private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));

static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}

/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;

public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = Objects.requireNonNull(actualClass);
this.actualClassDocCount = actualClassDocCount;
public PerClassResult(String className, double accuracy) {
this.className = Objects.requireNonNull(className);
this.accuracy = accuracy;
}

public String getActualClass() {
return actualClass;
}

public long getActualClassDocCount() {
return actualClassDocCount;
public String getClassName() {
return className;
}

public double getAccuracy() {
Expand All @@ -186,8 +192,7 @@ public double getAccuracy() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
Expand All @@ -197,15 +202,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}

@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
return Objects.hash(className, accuracy);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1849,15 +1849,15 @@ public void testEvaluateDataFrame_Classification() throws IOException {
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(
accuracyResult.getActualClasses(),
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("cat", 5, 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75),
// no examples labeled as "ant" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
// 9 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("ant", 0.9),
// 6 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("cat", 0.6),
// 8 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("dog", 0.8))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // Precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;

import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
Expand All @@ -41,13 +41,13 @@ protected NamedXContentRegistry xContentRegistry() {
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
classes.add(new PerClassResult(classNames.get(i), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
return new Result(classes, overallAccuracy);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/
Optional<EvaluationMetricResult> getResult();
Optional<? extends EvaluationMetricResult> getResult();
}
Loading

0 comments on commit 3e3a930

Please sign in to comment.