Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Changes default destination index field mapping and adds scripted_metric agg #40750

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,59 @@ public void testPivotWithMaxOnDateField() throws Exception {
assertThat(actual, containsString("2017-01-15T"));
}

public void testPivotWithScriptedMetricAgg() throws Exception {
String transformId = "scriptedMetricPivot";
String dataFrameIndex = "scripted_metric_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"squared_sum\": {"
+ " \"scripted_metric\": {"
+ " \"init_script\": \"state.reviews_sqrd = []\","
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
+ " \"combine_script\": \"state.reviews_sqrd\","
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
+ " } }"
+ " } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));

// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
assertEquals(711.0, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,5 @@ private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> l
},
listener::onFailure
));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
Expand Down Expand Up @@ -73,6 +74,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import java.util.stream.Stream;

public final class Aggregations {

// the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data.
private static final String DYNAMIC = "_dynamic";
// the field mapping should be determined explicitly from the source field mapping if possible.
private static final String SOURCE = "_source";
private Aggregations() {}

/**
Expand All @@ -27,9 +32,10 @@ enum AggregationType {
AVG("avg", "double"),
CARDINALITY("cardinality", "long"),
VALUE_COUNT("value_count", "long"),
MAX("max", null),
MIN("min", null),
SUM("sum", null);
MAX("max", SOURCE),
MIN("min", SOURCE),
SUM("sum", SOURCE),
SCRIPTED_METRIC("scripted_metric", DYNAMIC);

private final String aggregationType;
private final String targetMapping;
Expand All @@ -55,8 +61,12 @@ public static boolean isSupportedByDataframe(String aggregationType) {
return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT));
}

public static boolean isDynamicMapping(String targetMapping) {
return DYNAMIC.equals(targetMapping);
}

public static String resolveTargetMapping(String aggregationType, String sourceType) {
AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT));
return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping();
return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
Expand Down Expand Up @@ -75,6 +76,8 @@ public static void deduceMappings(final Client client,
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
aggregationTypes.put(agg.getName(), agg.getType());
} else {
// execution should not reach this point
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
Expand Down Expand Up @@ -127,15 +130,17 @@ private static Map<String, String> resolveMappings(Map<String, String> aggregati

aggregationTypes.forEach((targetFieldName, aggregationName) -> {
String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName));
String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping);

logger.debug(
"Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]");
if (destinationMapping != null) {
if (Aggregations.isDynamicMapping(destinationMapping)) {
logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]");
} else if (destinationMapping != null) {
targetMapping.put(targetFieldName, destinationMapping);
} else {
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
targetMapping.put(targetFieldName, "double");
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
Expand Down Expand Up @@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
Expand Down Expand Up @@ -409,6 +412,92 @@ aggTypedName2, asMap(
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";

GroupConfig groupBy = parseGroupConfig("{"
+ "\"" + targetField + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } },"
+ "\"" + targetField2 + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } }"
+ "}");

String aggName = randomAlphaOfLengthBetween(5, 10);
String aggTypedName = "scripted_metric#" + aggName;

Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));

Map<String, Object> input = asMap(
"buckets",
asList(
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 123.0)),
DOC_COUNT, 1),
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 1.0)),
DOC_COUNT, 2),
asMap(
KEY, asMap(
targetField, "ID2",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 2.13)),
DOC_COUNT, 3),
asMap(
KEY, asMap(
targetField, "ID3",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
));

List<Map<String, Object>> expected = asList(
asMap(
targetField, "ID1",
targetField2, "ID1_2",
aggName, asMap("field", 123.0)
),
asMap(
targetField, "ID1",
targetField2, "ID2_2",
aggName, asMap("field", 1.0)
),
asMap(
targetField, "ID2",
targetField2, "ID1_2",
aggName, asMap("field", 2.13)
),
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,31 @@ public void testResolveTargetMapping() {
assertEquals("double", Aggregations.resolveTargetMapping("avg", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("avg", "double"));

// cardinality
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double"));

// value_count
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double"));

// max
assertEquals("int", Aggregations.resolveTargetMapping("max", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("max", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float"));

// min
assertEquals("int", Aggregations.resolveTargetMapping("min", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("min", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float"));

// sum
assertEquals("int", Aggregations.resolveTargetMapping("sum", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));

// scripted_metric
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -176,14 +174,20 @@ private AggregationConfig getValidAggregationConfig() throws IOException {
}

private AggregationConfig getAggregationConfig(String agg) throws IOException {
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
"\"scripted_metric\": {\n" +
" \"init_script\" : \"state.transactions = []\",\n" +
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
" }\n" +
"}}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}

private Map<String, String> getFieldMappings() {
return Collections.singletonMap("values", "double");
}

private AggregationConfig parseAggregations(String json) throws IOException {
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);
Expand Down