Skip to content

Commit

Permalink
[ML] Rename evaluation metric result fields to value (#63809)
Browse files Browse the repository at this point in the history
Renames data frame analytics _evaluate API results as follows:

  - per class accuracy renamed from `accuracy` to `value`
  - per class precision renamed from `precision` to `value`
  - per class recall renamed from `recall` to `value`
  - auc_roc `score` renamed to `value` for both outlier detection and classification
  • Loading branch information
dimitris-athanasiou committed Oct 20, 2020
1 parent 7f2930e commit 03ed7de
Show file tree
Hide file tree
Showing 38 changed files with 686 additions and 763 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
Expand Down Expand Up @@ -122,10 +123,9 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
// Evaluation metrics results
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
new ParseField(registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
AucRocResult::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(
Expand All @@ -145,7 +145,7 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
AucRocMetric.Result::fromXContent),
AucRocResult::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

Expand Down Expand Up @@ -99,10 +97,10 @@ public static class Result implements EvaluationMetric.Result {

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassSingleValue>) a[0], (double) a[1]));

static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}

Expand All @@ -111,11 +109,11 @@ public static Result fromXContent(XContentParser parser) {
}

/** List of per-class results. */
private final List<PerClassResult> classes;
private final List<PerClassSingleValue> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;

public Result(List<PerClassResult> classes, double overallAccuracy) {
public Result(List<PerClassSingleValue> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy;
}
Expand All @@ -125,7 +123,7 @@ public String getMetricName() {
return NAME;
}

public List<PerClassResult> getClasses() {
public List<PerClassSingleValue> getClasses() {
return classes;
}

Expand Down Expand Up @@ -156,65 +154,4 @@ public int hashCode() {
return Objects.hash(classes, overallAccuracy);
}
}

public static class PerClassResult implements ToXContentObject {

private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));

static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}

/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;

public PerClassResult(String className, double accuracy) {
this.className = Objects.requireNonNull(className);
this.accuracy = accuracy;
}

public String getClassName() {
return className;
}

public double getAccuracy() {
return accuracy;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}

@Override
public int hashCode() {
return Objects.hash(className, accuracy);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
Expand All @@ -43,12 +39,11 @@
*/
public class AucRocMetric implements EvaluationMetric {

public static final String NAME = "auc_roc";
public static final String NAME = AucRocResult.NAME;

public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));

Expand Down Expand Up @@ -106,149 +101,4 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(className, includeCurve);
}

public static class Result implements EvaluationMetric.Result {

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final ParseField SCORE = new ParseField("score");
private static final ParseField CURVE = new ParseField("curve");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));

static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}

private final double score;
private final List<AucRocPoint> curve;

public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.curve = curve;
}

@Override
public String getMetricName() {
return NAME;
}

public double getScore() {
return score;
}

public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, curve);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public static final class AucRocPoint implements ToXContentObject {

public static AucRocPoint fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final ParseField TPR = new ParseField("tpr");
private static final ParseField FPR = new ParseField("fpr");
private static final ParseField THRESHOLD = new ParseField("threshold");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_point",
true,
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));

static {
PARSER.declareDouble(constructorArg(), TPR);
PARSER.declareDouble(constructorArg(), FPR);
PARSER.declareDouble(constructorArg(), THRESHOLD);
}

private final double tpr;
private final double fpr;
private final double threshold;

public AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}

public double getTruePositiveRate() {
return tpr;
}

public double getFalsePositiveRate() {
return fpr;
}

public double getThreshold() {
return threshold;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder
.startObject()
.field(TPR.getPreferredName(), tpr)
.field(FPR.getPreferredName(), fpr)
.field(THRESHOLD.getPreferredName(), threshold)
.endObject();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocPoint that = (AucRocPoint) o;
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
}

@Override
public int hashCode() {
return Objects.hash(tpr, fpr, threshold);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}

0 comments on commit 03ed7de

Please sign in to comment.