Skip to content

Commit

Permalink
[7.x] Handle nested and aliased fields correctly when copying mapping. (
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Jan 14, 2020
1 parent f028ab0 commit 9c6ffdc
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final String NESTED_FIELD = "outer-field.inner-field";
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
Expand Down Expand Up @@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);

}

public void testDependentVariableCardinalityTooHighError() throws Exception {
Expand Down Expand Up @@ -342,6 +344,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
assertProgress(jobId, 100, 100, 100, 100);
}

public void testDependentVariableIsNested() throws Exception {
initialize("dependent_variable_is_nested");
String predictedClassField = NESTED_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, NESTED_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(NESTED_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testDependentVariableIsAliasToKeyword() throws Exception {
initialize("dependent_variable_is_alias");
String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testDependentVariableIsAliasToNested() throws Exception {
initialize("dependent_variable_is_alias_to_nested");
String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, NESTED_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_NESTED_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
String dependentVariable = KEYWORD_FIELD;
Expand Down Expand Up @@ -433,7 +492,10 @@ private static void createIndex(String index) {
BOOLEAN_FIELD, "type=boolean",
NUMERICAL_FIELD, "type=double",
DISCRETE_NUMERICAL_FIELD, "type=integer",
KEYWORD_FIELD, "type=keyword")
KEYWORD_FIELD, "type=keyword",
NESTED_FIELD, "type=keyword",
ALIAS_TO_KEYWORD_FIELD, "type=alias,path=" + KEYWORD_FIELD,
ALIAS_TO_NESTED_FIELD, "type=alias,path=" + NESTED_FIELD)
.get();
}

Expand All @@ -445,7 +507,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
Expand All @@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(Arrays.asList(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (NESTED_FIELD.equals(dependentVariable) == false) {
source.addAll(Arrays.asList(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
Expand All @@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
}

/**
* Wrapper around extractValue with implicit casting to the appropriate type.
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(doc, path);
return (T)extractValue(String.join(".", path), doc);
}

private static <T> void assertTopClasses(Map<String, Object> resultsObject,
Expand Down Expand Up @@ -583,8 +651,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
.get(destIndex)
.get("_doc")
.sourceAsMap();
assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSortConfig;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
Expand All @@ -39,6 +41,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;

/**
Expand Down Expand Up @@ -160,23 +163,38 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse,
return maxValue;
}

@SuppressWarnings("unchecked")
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
Map<String, Object> properties = new HashMap<>();
Map<String, String> idCopyMapping = new HashMap<>();
idCopyMapping.put("type", "keyword");
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
properties.put(ID_COPY, idCopyMapping);
for (Map.Entry<String, String> entry
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
String destFieldPath = entry.getKey();
String sourceFieldPath = entry.getValue();
Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath);
if (sourceFieldMapping != null) {
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
if (sourceFieldMapping instanceof Map) {
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
// If the source field is an alias, fetch the concrete field that the alias points to.
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
sourceFieldMapping = extractMapping(path, mappingsProperties);
}
}
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again.
if (sourceFieldMapping instanceof Map) {
properties.put(destFieldPath, sourceFieldMapping);
}
}
return properties;
}

private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
}

private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
Map<String, Object> metadata = new HashMap<>();
metadata.put(CREATION_DATE_MILLIS, clock.millis());
Expand Down Expand Up @@ -239,4 +257,3 @@ private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalytics
}
}
}

0 comments on commit 9c6ffdc

Please sign in to comment.