Skip to content

Commit

Permalink
[ML] Fix bug with data frame analytics classification test data sampl…
Browse files Browse the repository at this point in the history
…ing when using custom feature processors (#64727) (#64864)

When using custom processors, the field names extracted from the documents are not the
same as the feature names used for training.

Consequently, it is possible for the stratified sampler to have an incorrect view of the feature rows.
This can lead to the wrong column being read for the class label, and thus throw errors on training
row extraction.

This commit changes the training row feature names used by the stratified sampler so that it matches
the names (and their order) that are sent to the analytics process.
  • Loading branch information
benwtrent committed Nov 10, 2020
1 parent dafafd7 commit f0ff673
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand Down Expand Up @@ -293,20 +294,31 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
public void testWithCustomFeatureProcessors() throws Exception {
initialize("classification_with_custom_feature_processors");
String predictedClassField = KEYWORD_FIELD + "_prediction";
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);

DataFrameAnalyticsConfig config =
buildAnalytics(jobId, sourceIndex, destIndex, null,
new Classification(
KEYWORD_FIELD,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
null,
null,
null,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(0).build(),
null,
null,
2,
10.0,
42L,
Arrays.asList(
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
new OneHotEncoding(ALIAS_TO_KEYWORD_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom").map(), true),
new OneHotEncoding(ALIAS_TO_NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_1")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_1").map(), true),
new OneHotEncoding(NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_2")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_2").map(), true),
new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
)));
putAnalytics(config);

Expand All @@ -322,11 +334,7 @@ public void testWithCustomFeatureProcessors() throws Exception {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray, hasSize(greaterThan(0)));
}

assertProgressComplete(jobId);
Expand Down Expand Up @@ -354,9 +362,13 @@ public void testWithCustomFeatureProcessors() throws Exception {
TrainedModelConfig modelConfig = response.getResources().results().get(0);
modelConfig.ensureParsedDefinition(xContentRegistry());
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
for (int i = 0; i < 4; i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), is(true));
}
for (int i = 4; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), equalTo(i == 0));
assertThat(preProcessor.isCustom(), is(false));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,8 @@ public class DataFrameDataExtractor {
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
this.client = Objects.requireNonNull(client);
this.context = Objects.requireNonNull(context);
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
this.organicFeatures = context.extractedFields.getAllFields()
.stream()
.map(ExtractedField::getName)
.filter(f -> processedFieldInputs.contains(f) == false)
.toArray(String[]::new);
this.processedFeatures = context.extractedFields.getProcessedFields()
.stream()
.map(ProcessedField::getOutputFieldNames)
.flatMap(List::stream)
.toArray(String[]::new);
this.organicFeatures = context.extractedFields.extractOrganicFeatureNames();
this.processedFeatures = context.extractedFields.extractProcessedFeatureNames();
this.extractedFieldsByName = new LinkedHashMap<>();
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
hasNext = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;

import java.util.Arrays;
Expand All @@ -22,6 +21,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class DataFrameDataExtractorFactory {

Expand Down Expand Up @@ -94,8 +94,13 @@ public static DataFrameDataExtractorFactory createForSourceIndices(Client client

private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new TrainTestSplitterFactory(client, config,
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
return new TrainTestSplitterFactory(
client,
config,
Stream.concat(
Arrays.stream(extractedFields.extractOrganicFeatureNames()),
Arrays.stream(extractedFields.extractProcessedFeatureNames())
).collect(Collectors.toList()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
return cardinalitiesForFieldsWithConstraints;
}

public String[] extractOrganicFeatureNames() {
Set<String> processedFieldInputs = getProcessedFieldInputs();
return allFields
.stream()
.map(ExtractedField::getName)
.filter(f -> processedFieldInputs.contains(f) == false)
.toArray(String[]::new);
}

public String[] extractProcessedFeatureNames() {
return processedFields
.stream()
.map(ProcessedField::getOutputFieldNames)
.flatMap(List::stream)
.toArray(String[]::new);
}

private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.TreeSet;

import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -133,6 +137,33 @@ public void testBuildGivenFieldWithoutMappings() {
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
}

public void testExtractFeatureOrganicAndProcessedNames() {
ExtractedField docValue1 = new DocValueField("doc1", Collections.singleton("keyword"));
ExtractedField docValue2 = new DocValueField("doc2", Collections.singleton("ip"));
ExtractedField scriptField1 = new ScriptField("scripted1");
ExtractedField scriptField2 = new ScriptField("scripted2");
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));

Map<String, String> hotMap = new LinkedHashMap<>();
hotMap.put("bar", "bar_column");
hotMap.put("foo", "foo_column");

ExtractedFields extractedFields = new ExtractedFields(
Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
Arrays.asList(
new ProcessedField(new NGram("doc1", "f", new int[] {1 , 2}, 0, 2, true)),
new ProcessedField(new OneHotEncoding("src1", hotMap, true))),
Collections.emptyMap());


String[] organic = extractedFields.extractOrganicFeatureNames();
assertThat(organic, arrayContaining("doc2", "scripted1", "scripted2", "src2"));

String[] processed = extractedFields.extractProcessedFeatureNames();
assertThat(processed, arrayContaining("f.10", "f.11", "f.20", "bar_column", "foo_column"));
}

private static FieldCapabilities createFieldCaps(boolean isAggregatable) {
FieldCapabilities fieldCaps = mock(FieldCapabilities.class);
when(fieldCaps.isAggregatable()).thenReturn(isAggregatable);
Expand Down

0 comments on commit f0ff673

Please sign in to comment.