Skip to content

Commit

Permalink
[7.11][ML] Ensure feature importance is not empty in RegressionIT (#6…
Browse files Browse the repository at this point in the history
…8074) (#68104)

This commit addresses the possibility that feature importance
is empty for the row with values 2, 20, 20. The reason this may
happen is because those rows' values are the average of the
other two and the model computes feature importance is close to
zero for that reason.

The test adjustes the values in a way that no row has values
that are the average of the other rows.

Fixes #59413
Fixes #59664

Backport of #68074
  • Loading branch information
dimitris-athanasiou committed Jan 28, 2021
1 parent 43beb40 commit 9c2346a
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0, 3.0);
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(10L, 20L, 30L);
private static final List<Double> DEPENDENT_VARIABLE_VALUES = org.elasticsearch.common.collect.List.of(10.0, 20.0, 30.0);

// It's important that the values here do not work in a way where
// one of the feature is the average of the other features as it may
// result in empty feature importance and we want to assert it gets
// written out correctly.
private static final List<Double> NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(5.0, 2.0, 3.0);
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(50L, 20L, 30L);
private static final List<Double> DEPENDENT_VARIABLE_VALUES = org.elasticsearch.common.collect.List.of(500.0, 200.0, 300.0);

private String jobId;
private String sourceIndex;
Expand Down Expand Up @@ -141,6 +146,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");

assertThat(importanceArray, hasSize(greaterThan(0)));
assertThat(
importanceArray.stream().filter(m -> NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))
Expand Down Expand Up @@ -434,7 +440,6 @@ public void testDependentVariableIsLong() throws Exception {
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/59664")
public void testWithDatastream() throws Exception {
initialize("regression_with_datastream");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
Expand Down

0 comments on commit 9c2346a

Please sign in to comment.