Skip to content

Commit

Permalink
[ML] Validate that AucRoc has the data necessary to be calculated (#6…
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Oct 8, 2020
1 parent f453058 commit bd761cc
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,23 @@ public static Result fromXContent(XContentParser parser) {
}

private static final ParseField SCORE = new ParseField("score");
private static final ParseField DOC_COUNT = new ParseField("doc_count");
private static final ParseField CURVE = new ParseField("curve");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));

static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareLong(constructorArg(), DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}

private final double score;
private final long docCount;
private final List<AucRocPoint> curve;

public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = curve;
}

Expand All @@ -147,10 +143,6 @@ public double getScore() {
return score;
}

public long getDocCount() {
return docCount;
}

public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}
Expand All @@ -159,7 +151,6 @@ public List<AucRocPoint> getCurve() {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
builder.field(DOC_COUNT.getPreferredName(), docCount);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
Expand All @@ -173,13 +164,12 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& docCount == that.docCount
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
Expand Down Expand Up @@ -1931,18 +1932,17 @@ public void testEvaluateDataFrame_Classification() throws IOException {
createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat", 0.9))
.add(docForClassification(indexName, "cat", "cat", 0.85))
.add(docForClassification(indexName, "cat", "cat", 0.95))
.add(docForClassification(indexName, "cat", "dog", 0.4))
.add(docForClassification(indexName, "cat", "fish", 0.35))
.add(docForClassification(indexName, "dog", "cat", 0.5))
.add(docForClassification(indexName, "dog", "dog", 0.4))
.add(docForClassification(indexName, "dog", "dog", 0.35))
.add(docForClassification(indexName, "dog", "dog", 0.6))
.add(docForClassification(indexName, "ant", "cat", 0.1));
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);

MachineLearningClient machineLearningClient = highLevelClient().machineLearning();

{ // AucRoc
Expand All @@ -1957,8 +1957,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {

AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
assertThat(aucRocResult.getDocCount(), equalTo(5L));
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
assertNotNull(aucRocResult.getCurve());
}
{ // Accuracy
Expand Down Expand Up @@ -2173,21 +2172,22 @@ private static XContentBuilder mappingForClassification() throws IOException {
.endObject();
}

private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
private static IndexRequest docForClassification(String indexName,
String actualClass,
String... topPredictedClasses) {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON,
actualClassField, actualClass,
predictedClassField, predictedClass,
topClassesField, Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
predictedClassField, topPredictedClasses[0],
topClassesField, IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> new HashMap<String, Object>() {{
put("class_name", topPredictedClasses[i]);
put("class_probability", 1.0 / (2 << i));
}})
.collect(Collectors.toList()));
}

private static final String actualRegression = "regression_actual";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@
import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand Down Expand Up @@ -229,8 +228,11 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
Expand Down Expand Up @@ -3463,34 +3465,33 @@ public void testEvaluateDataFrame_Classification() throws Exception {
.endObject()
.endObject()
.endObject());
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.source(XContentType.JSON,
"actual_class", actualClass,
"predicted_class", predictedClass,
"ml.top_classes", Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
"predicted_class", topPredictedClasses[0],
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> new HashMap<String, Object>() {{
put("class_name", topPredictedClasses[i]);
put("class_probability", 1.0 / (2 << i));
}})
.collect(toList()));
};
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
Expand Down Expand Up @@ -3530,7 +3531,6 @@ public void testEvaluateDataFrame_Classification() throws Exception {

AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
double aucRocScore = aucRocResult.getScore(); // <11>
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
// end::evaluate-data-frame-results-classification

assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
Expand Down Expand Up @@ -3565,8 +3565,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
assertThat(otherClassesCount, equalTo(0L));

assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, equalTo(0.2625));
assertThat(aucRocDocCount, equalTo(5L));
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result(
randomDouble(),
randomLong(),
Stream
.generate(AucRocMetricAucRocPointTests::randomPoint)
.limit(randomIntBetween(1, 10))
Expand Down
1 change: 0 additions & 1 deletion docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<9> Fetching the number of classes that were not included in the matrix
<10> Fetching AucRoc metric by name
<11> Fetching the actual AucRoc score
<12> Fetching the number of documents that were used in order to calculate AucRoc score

===== Regression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ belongs.
`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for
evaluation. The number of documents taken into account is returned in the
evaluation result (`auc_roc.doc_count` field).
negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
in the list of their top classes.

`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -231,37 +230,25 @@ private static double interpolate(double x, double x1, double y1, double x2, dou
public static class Result implements EvaluationMetricResult {

private static final String SCORE = "score";
private static final String DOC_COUNT = "doc_count";
private static final String CURVE = "curve";

private final double score;
private final Long docCount;
private final List<AucRocPoint> curve;

public Result(double score, Long docCount, List<AucRocPoint> curve) {
public Result(double score, List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = Objects.requireNonNull(curve);
}

public Result(StreamInput in) throws IOException {
this.score = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.docCount = in.readOptionalLong();
} else {
this.docCount = null;
}
this.curve = in.readList(AucRocPoint::new);
}

public double getScore() {
return score;
}

public Long getDocCount() {
return docCount;
}

public List<AucRocPoint> getCurve() {
return Collections.unmodifiableList(curve);
}
Expand All @@ -279,19 +266,13 @@ public String getMetricName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeOptionalLong(docCount);
}
out.writeList(curve);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SCORE, score);
if (docCount != null) {
builder.field(DOC_COUNT, docCount);
}
if (curve.isEmpty() == false) {
builder.field(CURVE, curve);
}
Expand All @@ -305,13 +286,12 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(docCount, that.docCount)
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}
}
}

0 comments on commit bd761cc

Please sign in to comment.