Skip to content

Commit

Permalink
[ML] fix custom feature processor extraction bugs around boolean fiel…
Browse files Browse the repository at this point in the history
…ds and custom one_hot feature output order (#64937) (#65009)

This commit fixes two problems:

- When extracting a doc value, we allow boolean scalars to be used as input
- The output order of processed feature names is deterministic. Previous custom one hot fields used to be non-deterministic and thus could cause weird bugs.
  • Loading branch information
benwtrent committed Nov 12, 2020
1 parent e40d7e0 commit b888f36
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ public NGram(String field, String featurePrefix, int[] nGrams, int start, int le
if (length > MAX_LENGTH) {
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
}
if (Arrays.stream(this.nGrams).anyMatch(i -> i > length)) {
throw ExceptionsHelper.badRequestException(
"[{}] and [{}] are invalid; all ngrams must be shorter than or equal to length [{}]",
NGRAMS.getPreferredName(),
LENGTH.getPreferredName(),
length);
}
this.custom = custom;
}

Expand Down Expand Up @@ -293,6 +300,9 @@ private List<String> allPossibleNGramOutputFeatureNames() {
for (int nGram : nGrams) {
totalNgrams += (length - (nGram - 1));
}
if (totalNgrams <= 0) {
return Collections.emptyList();
}
List<String> ngramOutputs = new ArrayList<>(totalNgrams);

for (int nGram : nGrams) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -68,13 +69,13 @@ public static OneHotEncoding fromXContentLenient(XContentParser parser, PreProce

public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
this.custom = custom == null ? false : custom;
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)));
this.custom = custom != null && custom;
}

public OneHotEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readString)));
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ public static NGram createRandom() {
}

public static NGram createRandom(Boolean isCustom) {
int possibleLength = randomIntBetween(1, 10);
return new NGram(
randomAlphaOfLength(10),
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()),
IntStream.generate(() -> randomIntBetween(1, Math.min(possibleLength, 5))).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10),
randomBoolean() ? null : possibleLength,
isCustom,
randomBoolean() ? null : randomAlphaOfLength(10));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import org.hamcrest.Matcher;
import org.junit.Before;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;

import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.hamcrest.Matchers.equalTo;

public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
Expand Down Expand Up @@ -41,6 +43,22 @@ void testProcess(PreProcessor preProcessor, Map<String, Object> fieldValues, Map
);
}

public void testInputOutputFieldOrderConsistency() throws IOException {
xContentTester(this::createParser, this::createXContextTestInstance, getToXContentParams(), this::doParseInstance)
.numberOfTestRuns(NUMBER_OF_TEST_RUNS)
.supportsUnknownFields(supportsUnknownFields())
.shuffleFieldsExceptions(getShuffleFieldsExceptions())
.randomFieldsExcludeFilter(getRandomFieldsExcludeFilter())
.assertEqualsConsumer(this::assertFieldConsistency)
.assertToXContentEquivalence(false)
.test();
}

private void assertFieldConsistency(T lft, T rgt) {
assertThat(lft.inputFields(), equalTo(rgt.inputFields()));
assertThat(lft.outputFields(), equalTo(rgt.outputFields()));
}

public void testWithMissingField() {
Map<String, Object> fields = randomFieldValues();
PreProcessor preProcessor = this.createTestInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
}

private static boolean isValidValue(Object value) {
public static boolean isValidValue(Object value) {
// We should allow a number, string or a boolean.
// It is possible for a field to be categorical and have a `keyword` mapping, but be any of these
// three types, in the same index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.util.Set;
import java.util.function.Function;

import static org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor.isValidValue;

public class ProcessedField {
private final PreProcessor preProcessor;

Expand All @@ -36,8 +38,9 @@ public Set<String> getOutputFieldType(String outputField) {
}

public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
for (String field : preProcessor.inputFields()) {
List<String> inputFields = getInputFieldNames();
Map<String, Object> inputs = new HashMap<>(inputFields.size(), 1.0f);
for (String field : inputFields) {
ExtractedField extractedField = fieldExtractor.apply(field);
if (extractedField == null) {
return new Object[0];
Expand All @@ -47,7 +50,7 @@ public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtra
continue;
}
final Object value = values[0];
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
if (values.length == 1 && (isValidValue(value))) {
inputs.put(field, value);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
*/
package org.elasticsearch.xpack.ml.extractor;

import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;

import java.util.Arrays;
import java.util.Collections;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.LinkedHashMap;
import java.util.Map;

import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.emptyArray;
Expand All @@ -30,7 +34,7 @@ public class ProcessedFieldTests extends ESTestCase {

public void testOneHotGetters() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(inputField, "bar", "baz"));
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
Expand All @@ -39,28 +43,92 @@ public void testOneHotGetters() {
}

public void testMissingExtractor() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
}

public void testMissingInputValues() {
String inputField = "foo";
ExtractedField extractedField = makeExtractedField(new Object[0]);
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
}

public void testProcessedField() {
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
public void testProcessedFieldFrequencyEncoding() {
testProcessedField(
new FrequencyEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
});
}

public void testProcessedFieldTargetMeanEncoding() {
testProcessedField(
new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
0.8,
randomBoolean()),
new Object[]{"bar", 1, false, "unknown"},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
new Object[]{0.8},
});
}

public void testProcessedFieldNGramEncoding() {
testProcessedField(
new NGram(randomAlphaOfLength(10),
randomAlphaOfLength(10),
new int[]{1},
0,
3,
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{"b", "a", "r"},
new Object[]{"1", null, null},
new Object[]{"f", "a", "l"}
});
}

private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
return new OneHotEncoding(inputField,
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
true);
public void testProcessedFieldOneHot() {
testProcessedField(
makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "1", "false"),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{0, 1, 0},
new Object[]{1, 0, 0},
new Object[]{0, 0, 1},
});
}

public void testProcessedField(PreProcessor preProcessor, Object[] inputs, Object[][] expectedOutputs) {
ProcessedField processedField = new ProcessedField(preProcessor);
assert inputs.length == expectedOutputs.length;
for (int i = 0; i < inputs.length; i++) {
Object input = inputs[i];
Object[] result = processedField.value(makeHit(input), (s) -> makeExtractedField(new Object[] { input }));
assertThat(
"Input [" + input + "] Expected " + Arrays.toString(expectedOutputs[i]) + " but received " + Arrays.toString(result),
result,
equalTo(expectedOutputs[i]));
}
}

private static PreProcessor makeOneHotPreProcessor(String inputField, String... expectedExtractedValues) {
Map<String, String> map = new LinkedHashMap<>();
for (String v : expectedExtractedValues) {
map.put(v, v + "_column");
}
return new OneHotEncoding(inputField, map,true);
}

private static ExtractedField makeExtractedField(Object[] value) {
Expand All @@ -70,7 +138,11 @@ private static ExtractedField makeExtractedField(Object[] value) {
}

private static SearchHit makeHit() {
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
return makeHit("bar");
}

private static SearchHit makeHit(Object value) {
return new SearchHitBuilder(42).addField("a_keyword", value).build();
}

}

0 comments on commit b888f36

Please sign in to comment.