Skip to content

Commit

Permalink
[ML] better handle empty results when evaluating regression (#45745)
Browse files Browse the repository at this point in the history
* [ML] better handle empty results when evaluating regression

* adding new failure test to ml_security black list

* fixing equality check for regression results
  • Loading branch information
benwtrent committed Aug 20, 2019
1 parent fad06ff commit 2202d00
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public List<AggregationBuilder> aggs(String actualField, String predictedField)
@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
return value == null ? null : new Result(value.value());
return value == null ? new Result(0.0) : new Result(value.value());
}

@Override
Expand Down Expand Up @@ -137,5 +137,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return error == other.error;
}

@Override
public int hashCode() {
return Objects.hashCode(error);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public EvaluationMetricResult evaluate(Aggregations aggs) {
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
null :
new Result(0.0) :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
}

Expand Down Expand Up @@ -148,5 +148,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return value == other.value;
}

@Override
public int hashCode() {
return Objects.hashCode(value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ public SearchSourceBuilder buildSearch() {
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
if (searchResponse.getHits().getTotalHits().value == 0) {
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
actualField,
predictedField));
return;
}
for (RegressionMetric metric : metrics) {
results.add(metric.evaluate(searchResponse.getAggregations()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
for (int i = 0; i < recalls.length; i++) {
double threshold = thresholds[i];
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN));
Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
long tp = tpAgg.getDocCount();
long fn = fnAgg.getDocCount();
recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import java.util.Arrays;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -64,7 +62,7 @@ public void testEvaluate_GivenMissingAggs() {

MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
}

private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import java.util.Arrays;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,33 +68,34 @@ public void testEvaluateWithZeroCount() {

RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(new RSquared.Result(0.0)));
}

public void testEvaluate_GivenMissingAggs() {
EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));

RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));

aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
));

result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));

aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createExtendedStatsAgg("extended_stats_actual",100, 50)
));

result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));
}

private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ integTest.runner {
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
'ml/evaluate_data_frame/Test regression given missing actual_field',
'ml/evaluate_data_frame/Test regression given missing predicted_field',
'ml/delete_job_force/Test cannot force delete a non-existent job',
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
'ml/delete_model_snapshot/Test delete snapshot missing job_id',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,34 @@ setup:
- match: { regression.mean_squared_error.error: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 }
---
"Test regression given missing actual_field":
- do:
catch: /No documents found containing both \[missing, regression_field_pred\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "missing",
"predicted_field": "regression_field_pred"
}
}
}
---
"Test regression given missing predicted_field":
- do:
catch: /No documents found containing both \[regression_field_act, missing\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "missing"
}
}
}

0 comments on commit 2202d00

Please sign in to comment.