From 39785ebdbb276b365027cb4401e68acf029f6463 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 23 Mar 2020 23:29:01 +0200 Subject: [PATCH] [ML] Data frame analytics data counts (#53998) This commit instruments data frame analytics with stats for the data that are being analyzed. In particular, we count training docs, test docs, and skipped docs. In order to account docs with missing values as skipped docs for analyses that do not support missing values, this commit changes the extractor so that it only ignores docs with missing values when it collects the data summary, which is used to estimate memory usage. --- .../ml/dataframe/DataFrameAnalyticsStats.java | 27 +++- .../ml/dataframe/stats/common/DataCounts.java | 119 +++++++++++++++++ .../DataFrameAnalyticsStatsTests.java | 5 + .../stats/common/DataCountsTests.java | 51 ++++++++ .../GetDataFrameAnalyticsStatsAction.java | 29 ++++- .../ml/dataframe/stats/common/DataCounts.java | 120 ++++++++++++++++++ .../xpack/core/ml/stats_index_mappings.json | 9 ++ ...rameAnalyticsStatsActionResponseTests.java | 5 +- .../stats/common/DataCountsTests.java | 66 ++++++++++ .../ml/integration/ClassificationIT.java | 16 +++ .../OutlierDetectionWithMissingFieldsIT.java | 7 + .../xpack/ml/integration/RegressionIT.java | 17 +++ .../integration/RunDataFrameAnalyticsIT.java | 5 + ...sportGetDataFrameAnalyticsStatsAction.java | 9 ++ .../extractor/DataFrameDataExtractor.java | 23 +++- .../DataFrameDataExtractorContext.java | 6 +- .../DataFrameDataExtractorFactory.java | 30 +---- .../process/AnalyticsProcessManager.java | 30 ++++- .../process/AnalyticsResultProcessor.java | 44 ++----- .../CrossValidationSplitter.java | 2 +- .../CrossValidationSplitterFactory.java | 2 +- .../RandomCrossValidationSplitter.java | 25 ++-- .../ml/dataframe/stats/DataCountsTracker.java | 37 ++++++ .../xpack/ml/dataframe/stats/StatsHolder.java | 6 + .../ml/dataframe/stats/StatsPersister.java | 66 ++++++++++ .../DataFrameDataExtractorTests.java | 57 +++++++-- .../AnalyticsResultProcessorTests.java | 8 +- .../RandomCrossValidationSplitterTests.java | 46 ++++--- 28 files changed, 744 insertions(+), 123 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 70169ce09b278..acdb9cccca1eb 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -21,6 +21,7 @@ import org.elasticsearch.client.ml.NodeAttributes; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -47,6 +48,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws static final ParseField STATE = new ParseField("state"); static final ParseField FAILURE_REASON = new ParseField("failure_reason"); static final ParseField PROGRESS = new ParseField("progress"); + static final ParseField DATA_COUNTS = new ParseField("data_counts"); static final ParseField MEMORY_USAGE = new ParseField("memory_usage"); static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats"); static final ParseField NODE = new ParseField("node"); @@ -60,10 +62,11 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws (DataFrameAnalyticsState) args[1], (String) args[2], (List) args[3], - (MemoryUsage) args[4], - (AnalysisStats) args[5], - (NodeAttributes) args[6], - (String) args[7])); + (DataCounts) args[4], + (MemoryUsage) args[5], + (AnalysisStats) args[6], + (NodeAttributes) args[7], + (String) args[8])); static { PARSER.declareString(constructorArg(), ID); @@ -75,6 +78,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws }, STATE, ObjectParser.ValueType.STRING); PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); + PARSER.declareObject(optionalConstructorArg(), DataCounts.PARSER, DATA_COUNTS); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE); PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); @@ -93,19 +97,21 @@ private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IO private final DataFrameAnalyticsState state; private final String failureReason; private final List progress; + private final DataCounts dataCounts; private final MemoryUsage memoryUsage; private final AnalysisStats analysisStats; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, - @Nullable List progress, @Nullable MemoryUsage memoryUsage, - @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, + @Nullable List progress, @Nullable DataCounts dataCounts, + @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; this.progress = progress; + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -128,6 +134,11 @@ public List getProgress() { return progress; } + @Nullable + public DataCounts getDataCounts() { + return dataCounts; + } + @Nullable public MemoryUsage getMemoryUsage() { return memoryUsage; @@ -156,6 +167,7 @@ public boolean equals(Object o) { && Objects.equals(state, other.state) && Objects.equals(failureReason, other.failureReason) && Objects.equals(progress, other.progress) + && Objects.equals(dataCounts, other.dataCounts) && Objects.equals(memoryUsage, other.memoryUsage) && Objects.equals(analysisStats, other.analysisStats) && Objects.equals(node, other.node) @@ -164,7 +176,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, assignmentExplanation); } @Override @@ -174,6 +186,7 @@ public String toString() { .add("state", state) .add("failureReason", failureReason) .add("progress", progress) + .add("dataCounts", dataCounts) .add("memoryUsage", memoryUsage) .add("analysisStats", analysisStats) .add("node", node) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java new file mode 100644 index 0000000000000..b7a90b1f0b5c6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java @@ -0,0 +1,119 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataCounts implements ToXContentObject { + + public static final String TYPE_VALUE = "analytics_data_counts"; + + public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count"); + public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count"); + public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE_VALUE, true, + a -> { + Long trainingDocsCount = (Long) a[0]; + Long testDocsCount = (Long) a[1]; + Long skippedDocsCount = (Long) a[2]; + return new DataCounts( + getOrDefault(trainingDocsCount, 0L), + getOrDefault(testDocsCount, 0L), + getOrDefault(skippedDocsCount, 0L) + ); + }); + + static { + PARSER.declareLong(optionalConstructorArg(), TRAINING_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), TEST_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), SKIPPED_DOCS_COUNT); + } + + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(trainingDocsCount, testDocsCount, skippedDocsCount); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount) + .add(TEST_DOCS_COUNT.getPreferredName(), testDocsCount) + .add(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount) + .toString(); + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } + + private static T getOrDefault(@Nullable T value, T defaultValue) { + return value != null ? value : defaultValue; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index 25345181982af..d251f568dfa79 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests; import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests; @@ -68,6 +69,7 @@ public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : createRandomProgress(), + randomBoolean() ? null : DataCountsTests.createRandom(), randomBoolean() ? null : MemoryUsageTests.createRandom(), analysisStats, randomBoolean() ? null : NodeAttributesTests.createRandom(), @@ -93,6 +95,9 @@ public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder bui if (stats.getProgress() != null) { builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress()); } + if (stats.getDataCounts() != null) { + builder.field(DataFrameAnalyticsStats.DATA_COUNTS.getPreferredName(), stats.getDataCounts()); + } if (stats.getMemoryUsage() != null) { builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 0000000000000..5e877e2d40f7b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataCountsTests extends AbstractXContentTestCase { + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return DataCounts.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index 209058e00468b..3f65a917a7427 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -165,6 +166,9 @@ public static class Stats implements ToXContentObject, Writeable { */ private final List progress; + @Nullable + private final DataCounts dataCounts; + @Nullable private final MemoryUsage memoryUsage; @@ -177,12 +181,13 @@ public static class Stats implements ToXContentObject, Writeable { private final String assignmentExplanation; public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List progress, - @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node, - @Nullable String assignmentExplanation) { + @Nullable DataCounts dataCounts, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; this.progress = Objects.requireNonNull(progress); + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -198,6 +203,11 @@ public Stats(StreamInput in) throws IOException { } else { progress = in.readList(PhaseProgress::new); } + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + dataCounts = in.readOptionalWriteable(DataCounts::new); + } else { + dataCounts = null; + } if (in.getVersion().onOrAfter(Version.V_7_7_0)) { memoryUsage = in.readOptionalWriteable(MemoryUsage::new); } else { @@ -261,6 +271,11 @@ public List getProgress() { return progress; } + @Nullable + public DataCounts getDataCounts() { + return dataCounts; + } + @Nullable public MemoryUsage getMemoryUsage() { return memoryUsage; @@ -293,6 +308,9 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc if (progress != null) { builder.field("progress", progress); } + if (dataCounts != null) { + builder.field("data_counts", dataCounts); + } if (memoryUsage != null) { builder.field("memory_usage", memoryUsage); } @@ -331,6 +349,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeList(progress); } + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeOptionalWriteable(dataCounts); + } if (out.getVersion().onOrAfter(Version.V_7_7_0)) { out.writeOptionalWriteable(memoryUsage); } @@ -369,7 +390,8 @@ private void writeProgressToLegacy(StreamOutput out) throws IOException { @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, + assignmentExplanation); } @Override @@ -385,6 +407,7 @@ public boolean equals(Object obj) { && Objects.equals(this.state, other.state) && Objects.equals(this.failureReason, other.failureReason) && Objects.equals(this.progress, other.progress) + && Objects.equals(this.dataCounts, other.dataCounts) && Objects.equals(this.memoryUsage, other.memoryUsage) && Objects.equals(this.analysisStats, other.analysisStats) && Objects.equals(this.node, other.node) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java new file mode 100644 index 0000000000000..f77cc781c746a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Objects; + +public class DataCounts implements ToXContentObject, Writeable { + + public static final String TYPE_VALUE = "analytics_data_counts"; + + public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count"); + public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count"); + public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count"); + + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new DataCounts((String) a[0], (long) a[1], (long) a[2], (long) a[3])); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareLong(ConstructingObjectParser.constructorArg(), TRAINING_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), TEST_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), SKIPPED_DOCS_COUNT); + return parser; + } + + private final String jobId; + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(String jobId, long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.jobId = Objects.requireNonNull(jobId); + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + public DataCounts(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.trainingDocsCount = in.readVLong(); + this.testDocsCount = in.readVLong(); + this.skippedDocsCount = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeVLong(trainingDocsCount); + out.writeVLong(testDocsCount); + out.writeVLong(skippedDocsCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return Objects.equals(jobId, that.jobId) + && trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(jobId, trainingDocsCount, testDocsCount, skippedDocsCount); + } + + public static String documentId(String jobId) { + return TYPE_VALUE + "_" + jobId; + } + + public String getJobId() { + return jobId; + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } +} diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json index 5b742c9f6d91e..7a846f5743b91 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json @@ -85,6 +85,9 @@ "peak_usage_bytes" : { "type" : "long" }, + "skipped_docs_count": { + "type": "long" + }, "timestamp" : { "type" : "date" }, @@ -98,6 +101,12 @@ } } }, + "test_docs_count": { + "type": "long" + }, + "training_docs_count": { + "type": "long" + }, "type" : { "type" : "keyword" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 5cb2b3fef5450..3f957f95a902c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; @@ -42,6 +44,7 @@ public static Response randomResponse(int listSize) { List progress = new ArrayList<>(progressSize); IntStream.of(progressSize).forEach(progressIndex -> progress.add( new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); + DataCounts dataCounts = randomBoolean() ? null : DataCountsTests.createRandom(); MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom(); AnalysisStats analysisStats = randomBoolean() ? null : randomFrom( @@ -50,7 +53,7 @@ public static Response randomResponse(int listSize) { RegressionStatsTests.createRandom() ); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null, + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, dataCounts, memoryUsage, analysisStats, null, randomAlphaOfLength(20)); analytics.add(stats); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 0000000000000..84033d49de655 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; + +public class DataCountsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected DataCounts mutateInstanceForVersion(DataCounts instance, Version version) { + return instance; + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return lenient ? DataCounts.LENIENT_PARSER.apply(parser, null) : DataCounts.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataCounts::new; + } + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 9c4511527eeba..a96f6e8b19b9e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -48,6 +49,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.startsWith; @@ -157,6 +159,12 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); @@ -224,6 +232,14 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index e32d923852038..ced8c987fd685 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; @@ -79,6 +80,12 @@ public void testMissingFields() throws Exception { startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(2L)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); for (SearchHit hit : sourceData.getHits()) { GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index fc20bbb5afd7b..07752468a743f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -31,6 +32,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -141,6 +143,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); @@ -197,6 +206,14 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 4ffe948c9f5b7..3f7a5a7156637 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -85,6 +85,11 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); double scoreOfOutlier = 0.0; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 9ee28ba29cf3a..cb8d7bcb2be4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; @@ -108,6 +109,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D Stats stats = buildStats( task.getParams().getId(), statsHolder.getProgressTracker().report(), + statsHolder.getDataCountsTracker().report(task.getParams().getId()), statsHolder.getMemoryUsage(), statsHolder.getAnalysisStats() ); @@ -200,6 +202,7 @@ private void searchStats(String configId, ActionListener listener) { MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); multiSearchRequest.add(buildStoredProgressSearch(configId)); + multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE)); @@ -224,6 +227,7 @@ private void searchStats(String configId, ActionListener listener) { } listener.onResponse(buildStats(configId, retrievedStatsHolder.progress.get(), + retrievedStatsHolder.dataCounts, retrievedStatsHolder.memoryUsage, retrievedStatsHolder.analysisStats )); @@ -258,6 +262,8 @@ private static void parseHit(SearchHit hit, String configId, RetrievedStatsHolde String hitId = hit.getId(); if (StoredProgress.documentId(configId).equals(hitId)) { retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER); + } else if (DataCounts.documentId(configId).equals(hitId)) { + retrievedStatsHolder.dataCounts = MlParserUtils.parse(hit, DataCounts.LENIENT_PARSER); } else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) { retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER); } else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) { @@ -273,6 +279,7 @@ private static void parseHit(SearchHit hit, String configId, RetrievedStatsHolde private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress, + DataCounts dataCounts, MemoryUsage memoryUsage, AnalysisStats analysisStats) { ClusterState clusterState = clusterService.state(); @@ -295,6 +302,7 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre analyticsState, failureReason, progress, + dataCounts, memoryUsage, analysisStats, node, @@ -305,6 +313,7 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre private static class RetrievedStatsHolder { private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report()); + private volatile DataCounts dataCounts; private volatile MemoryUsage memoryUsage; private volatile AnalysisStats analysisStats; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index ba1bdf9a6a5d0..aad06b71c075a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -19,6 +19,9 @@ import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; @@ -187,7 +190,7 @@ private Row createRow(SearchHit hit) { if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { extractedValues[i] = Objects.toString(values[0]); } else { - if (values.length == 0 && context.includeRowsWithMissingValues) { + if (values.length == 0 && context.supportsRowsWithMissingValues) { // if values is empty then it means it's a missing value extractedValues[i] = NULL_VALUE; } else { @@ -263,13 +266,29 @@ public void collectDataSummaryAsync(ActionListener dataSummaryActio } private SearchRequestBuilder buildDataSummarySearchRequestBuilder() { + + QueryBuilder summaryQuery = context.query; + if (context.supportsRowsWithMissingValues == false) { + summaryQuery = QueryBuilders.boolQuery() + .filter(summaryQuery) + .filter(allExtractedFieldsExistQuery()); + } + return new SearchRequestBuilder(client, SearchAction.INSTANCE) .setIndices(context.indices) .setSize(0) - .setQuery(context.query) + .setQuery(summaryQuery) .setTrackTotalHits(true); } + private QueryBuilder allExtractedFieldsExistQuery() { + BoolQueryBuilder query = QueryBuilders.boolQuery(); + for (ExtractedField field : context.extractedFields.getAllFields()) { + query.filter(QueryBuilders.existsQuery(field.getName())); + } + return query; + } + public Set getCategoricalFields(DataFrameAnalysis analysis) { return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java index 0cf391bc33b2e..64ad4bed452e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java @@ -21,10 +21,10 @@ public class DataFrameDataExtractorContext { final int scrollSize; final Map headers; final boolean includeSource; - final boolean includeRowsWithMissingValues; + final boolean supportsRowsWithMissingValues; DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, - Map headers, boolean includeSource, boolean includeRowsWithMissingValues) { + Map headers, boolean includeSource, boolean supportsRowsWithMissingValues) { this.jobId = Objects.requireNonNull(jobId); this.extractedFields = Objects.requireNonNull(extractedFields); this.indices = indices.toArray(new String[indices.size()]); @@ -32,6 +32,6 @@ public class DataFrameDataExtractorContext { this.scrollSize = scrollSize; this.headers = headers; this.includeSource = includeSource; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 3243d92bf77b6..a699e16a7d602 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -7,11 +7,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.Arrays; @@ -28,18 +26,18 @@ public class DataFrameDataExtractorFactory { private final QueryBuilder sourceQuery; private final ExtractedFields extractedFields; private final Map headers; - private final boolean includeRowsWithMissingValues; + private final boolean supportsRowsWithMissingValues; private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, ExtractedFields extractedFields, Map headers, - boolean includeRowsWithMissingValues) { + boolean supportsRowsWithMissingValues) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); this.sourceQuery = Objects.requireNonNull(sourceQuery); this.extractedFields = Objects.requireNonNull(extractedFields); this.headers = headers; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } public DataFrameDataExtractor newExtractor(boolean includeSource) { @@ -47,11 +45,11 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { analyticsId, extractedFields, indices, - createQuery(), + QueryBuilders.boolQuery().filter(sourceQuery), 1000, headers, includeSource, - includeRowsWithMissingValues + supportsRowsWithMissingValues ); return new DataFrameDataExtractor(client, context); } @@ -60,23 +58,6 @@ public ExtractedFields getExtractedFields() { return extractedFields; } - private QueryBuilder createQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - query.filter(sourceQuery); - if (includeRowsWithMissingValues == false) { - query.filter(allExtractedFieldsExistQuery()); - } - return query; - } - - private QueryBuilder allExtractedFieldsExistQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - for (ExtractedField field : extractedFields.getAllFields()) { - query.filter(QueryBuilders.existsQuery(field.getName())); - } - return query; - } - /** * Create a new extractor factory * @@ -109,6 +90,7 @@ public static void createForDestinationIndex(Client client, extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap( extractedFieldsDetector -> { ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); + DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(), Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index c7baad202d2f6..c1f09ff3fb34f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; @@ -22,8 +23,10 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,7 +37,9 @@ import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter; import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -156,7 +161,10 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); try { writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker()); + writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(), + task.getStatsHolder().getDataCountsTracker()); + processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), + DataCounts::documentId); process.writeEndOfDataMessage(); process.flushStream(); @@ -205,8 +213,8 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont } } - private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, - DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException { + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, DataFrameAnalysis analysis, + ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException { CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames()) .create(analysis); @@ -223,11 +231,14 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces Optional> rows = dataExtractor.next(); if (rows.isPresent()) { for (DataFrameDataExtractor.Row row : rows.get()) { - if (row.shouldSkip() == false) { + if (row.shouldSkip()) { + dataCountsTracker.incrementSkippedDocsCount(); + } else { String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); record[record.length - 2] = String.valueOf(row.getChecksum()); - crossValidationSplitter.process(record); + crossValidationSplitter.process(record, dataCountsTracker::incrementTrainingDocsCount, + dataCountsTracker::incrementTestDocsCount); process.writeRecord(record); } } @@ -253,6 +264,10 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr process.writeRecord(headerRecord); } + private void indexDataCounts(DataCounts dataCounts) { + IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias()); + } + private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state, AnalyticsProcess process) { if (config.getAnalysis().persistsState() == false) { @@ -353,9 +368,11 @@ class ProcessContext { private final SetOnce dataExtractor = new SetOnce<>(); private final SetOnce resultProcessor = new SetOnce<>(); private final SetOnce failureReason = new SetOnce<>(); + private final StatsPersister statsPersister; ProcessContext(DataFrameAnalyticsConfig config) { this.config = Objects.requireNonNull(config); + this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); } String getFailureReason() { @@ -378,6 +395,7 @@ synchronized void stop() { if (resultProcessor.get() != null) { resultProcessor.get().cancel(); } + statsPersister.cancel(); if (process.get() != null) { try { process.get().kill(); @@ -434,7 +452,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); return new AnalyticsResultProcessor( - config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, resultsPersisterService, + config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, dataExtractor.get().getAllExtractedFields()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index eaa202df5fc4a..d27fa31b41193 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -11,14 +11,10 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.License; -import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; @@ -31,18 +27,16 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; -import java.io.IOException; import java.time.Instant; import java.util.Collections; import java.util.Iterator; @@ -51,7 +45,6 @@ import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.function.Function; import java.util.stream.Collectors; import static java.util.stream.Collectors.toList; @@ -77,7 +70,7 @@ public class AnalyticsResultProcessor { private final StatsHolder statsHolder; private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; - private final ResultsPersisterService resultsPersisterService; + private final StatsPersister statsPersister; private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; @@ -85,14 +78,13 @@ public class AnalyticsResultProcessor { public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, StatsHolder statsHolder, TrainedModelProvider trainedModelProvider, - DataFrameAnalyticsAuditor auditor, ResultsPersisterService resultsPersisterService, - List fieldNames) { + DataFrameAnalyticsAuditor auditor, StatsPersister statsPersister, List fieldNames) { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.statsHolder = Objects.requireNonNull(statsHolder); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); - this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.statsPersister = Objects.requireNonNull(statsPersister); this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); } @@ -112,6 +104,7 @@ public void awaitForCompletion() { public void cancel() { dataFrameRowsJoiner.cancel(); + statsPersister.cancel(); isCancelled = true; } @@ -177,22 +170,22 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { statsHolder.setMemoryUsage(memoryUsage); - indexStatsResult(memoryUsage, memoryUsage::documentId); + statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId); } OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats(); if (outlierDetectionStats != null) { statsHolder.setAnalysisStats(outlierDetectionStats); - indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId); + statsPersister.persistWithRetry(outlierDetectionStats, outlierDetectionStats::documentId); } ClassificationStats classificationStats = result.getClassificationStats(); if (classificationStats != null) { statsHolder.setAnalysisStats(classificationStats); - indexStatsResult(classificationStats, classificationStats::documentId); + statsPersister.persistWithRetry(classificationStats, classificationStats::documentId); } RegressionStats regressionStats = result.getRegressionStats(); if (regressionStats != null) { statsHolder.setAnalysisStats(regressionStats); - indexStatsResult(regressionStats, regressionStats::documentId); + statsPersister.persistWithRetry(regressionStats, regressionStats::documentId); } } @@ -275,23 +268,4 @@ private void setAndReportFailure(Exception e) { failure = "error processing results; " + e.getMessage(); auditor.error(analytics.getId(), "Error processing results; " + e.getMessage()); } - - private void indexStatsResult(ToXContentObject result, Function docIdSupplier) { - try { - resultsPersisterService.indexWithRetry(analytics.getId(), - MlStatsIndex.writeAlias(), - result, - new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), - WriteRequest.RefreshPolicy.IMMEDIATE, - docIdSupplier.apply(analytics.getId()), - () -> isCancelled == false, - errorMsg -> auditor.error(analytics.getId(), - "failed to persist result with id [" + docIdSupplier.apply(analytics.getId()) + "]; " + errorMsg) - ); - } catch (IOException ioe) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", analytics.getId()), ioe); - } catch (Exception e) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", analytics.getId()), e); - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java index 5d12a2a81a607..fce602b28e2c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java @@ -10,5 +10,5 @@ */ public interface CrossValidationSplitter { - void process(String[] row); + void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java index 47c052dd0bf84..986633aaa3705 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java @@ -31,6 +31,6 @@ public CrossValidationSplitter create(DataFrameAnalysis analysis) { return new RandomCrossValidationSplitter( fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } - return row -> {}; + return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java index 0afc59628e7de..e4e343083ee25 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java @@ -40,22 +40,25 @@ private static int findDependentVariableIndex(List fieldNames, String de } @Override - public void process(String[] row) { - if (canBeUsedForTraining(row)) { - if (isFirstRow) { - // Let's make sure we have at least one training row - isFirstRow = false; - } else if (isRandomlyExcludedFromTraining()) { - row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; - } + public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { + if (canBeUsedForTraining(row) && isPickedForTraining()) { + incrementTrainingDocs.run(); + } else { + row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; + incrementTestDocs.run(); } } private boolean canBeUsedForTraining(String[] row) { - return row[dependentVariableIndex].length() > 0; + return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; } - private boolean isRandomlyExcludedFromTraining() { - return random.nextDouble() * 100 > trainingPercent; + private boolean isPickedForTraining() { + if (isFirstRow) { + // Let's make sure we have at least one training row + isFirstRow = false; + return true; + } + return random.nextDouble() * 100 <= trainingPercent; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java new file mode 100644 index 0000000000000..bed9f52b448cf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; + +public class DataCountsTracker { + + private volatile long trainingDocsCount; + private volatile long testDocsCount; + private volatile long skippedDocsCount; + + public void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + public void incrementTestDocsCount() { + testDocsCount++; + } + + public void incrementSkippedDocsCount() { + skippedDocsCount++; + } + + public DataCounts report(String jobId) { + return new DataCounts( + jobId, + trainingDocsCount, + testDocsCount, + skippedDocsCount + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index ff6b9ec7bcfef..d01eee6a3a3d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -19,11 +19,13 @@ public class StatsHolder { private final ProgressTracker progressTracker; private final AtomicReference memoryUsageHolder; private final AtomicReference analysisStatsHolder; + private final DataCountsTracker dataCountsTracker; public StatsHolder() { progressTracker = new ProgressTracker(); memoryUsageHolder = new AtomicReference<>(); analysisStatsHolder = new AtomicReference<>(); + dataCountsTracker = new DataCountsTracker(); } public ProgressTracker getProgressTracker() { @@ -45,4 +47,8 @@ public void setAnalysisStats(AnalysisStats analysisStats) { public AnalysisStats getAnalysisStats() { return analysisStatsHolder.get(); } + + public DataCountsTracker getDataCountsTracker() { + return dataCountsTracker; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java new file mode 100644 index 0000000000000..eeb8924928ce4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; + +import java.io.IOException; +import java.util.Collections; +import java.util.Objects; +import java.util.function.Function; + +public class StatsPersister { + + private static final Logger LOGGER = LogManager.getLogger(StatsPersister.class); + + private final String jobId; + private final ResultsPersisterService resultsPersisterService; + private final DataFrameAnalyticsAuditor auditor; + private volatile boolean isCancelled; + + public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) { + this.jobId = Objects.requireNonNull(jobId); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.auditor = Objects.requireNonNull(auditor); + } + + public void persistWithRetry(ToXContentObject result, Function docIdSupplier) { + if (isCancelled) { + return; + } + + try { + resultsPersisterService.indexWithRetry(jobId, + MlStatsIndex.writeAlias(), + result, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), + WriteRequest.RefreshPolicy.IMMEDIATE, + docIdSupplier.apply(jobId), + () -> isCancelled == false, + errorMsg -> auditor.error(jobId, + "failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg) + ); + } catch (IOException ioe) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", jobId), ioe); + } catch (Exception e) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e); + } + } + + public void cancel() { + isCancelled = true; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index b75392e03c2bb..01661f6ec82a6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -324,9 +324,46 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}")); } - public void testMissingValues_GivenShouldNotInclude() throws IOException { + public void testCollectDataSummary_GivenAnalysisSupportsMissingFields() { + TestExtractor dataExtractor = createExtractor(true, true); + + // First and only batch + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); + + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + } + + public void testCollectDataSummary_GivenAnalysisDoesNotSupportMissingFields() { TestExtractor dataExtractor = createExtractor(true, false); + // First and only batch + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); + + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString( + "\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"bool\":{\"filter\":" + + "[{\"exists\":{\"field\":\"field_1\",\"boost\":1.0}},{\"exists\":{\"field\":\"field_2\",\"boost\":1.0}}]," + + "\"boost\":1.0}}],\"boost\":1.0}")); + } + + public void testMissingValues_GivenSupported() throws IOException { + TestExtractor dataExtractor = createExtractor(true, true); + // First and only batch SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); dataExtractor.setNextResponse(response1); @@ -343,11 +380,12 @@ public void testMissingValues_GivenShouldNotInclude() throws IOException { assertThat(rows.get().size(), equalTo(3)); assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(rows.get().get(1).getValues(), is(nullValue())); + assertThat(rows.get().get(1).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); + assertThat(rows.get().get(1).getValues()[1], equalTo("22")); assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(true)); + assertThat(rows.get().get(1).shouldSkip(), is(false)); assertThat(rows.get().get(2).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -358,8 +396,8 @@ public void testMissingValues_GivenShouldNotInclude() throws IOException { assertThat(dataExtractor.hasNext(), is(false)); } - public void testMissingValues_GivenShouldInclude() throws IOException { - TestExtractor dataExtractor = createExtractor(true, true); + public void testMissingValues_GivenNotSupported() throws IOException { + TestExtractor dataExtractor = createExtractor(true, false); // First and only batch SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); @@ -377,12 +415,11 @@ public void testMissingValues_GivenShouldInclude() throws IOException { assertThat(rows.get().size(), equalTo(3)); assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(rows.get().get(1).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); - assertThat(rows.get().get(1).getValues()[1], equalTo("22")); + assertThat(rows.get().get(1).getValues(), is(nullValue())); assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(false)); + assertThat(rows.get().get(1).shouldSkip(), is(true)); assertThat(rows.get().get(2).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -424,9 +461,9 @@ public void testGetCategoricalFields() { containsInAnyOrder("field_keyword", "field_text", "field_boolean")); } - private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) { + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( - JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues); + JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues); return new TestExtractor(client, context); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index da7bbff71d5a7..ee2b399ef2d48 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -24,13 +24,13 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; @@ -66,7 +66,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private StatsHolder statsHolder = new StatsHolder(); private TrainedModelProvider trainedModelProvider; private DataFrameAnalyticsAuditor auditor; - private ResultsPersisterService resultsPersisterService; + private StatsPersister statsPersister; private DataFrameAnalyticsConfig analyticsConfig; @Before @@ -76,7 +76,7 @@ public void setUpMocks() { dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); trainedModelProvider = mock(TrainedModelProvider.class); auditor = mock(DataFrameAnalyticsAuditor.class); - resultsPersisterService = mock(ResultsPersisterService.class); + statsPersister = mock(StatsPersister.class); analyticsConfig = new DataFrameAnalyticsConfig.Builder() .setId(JOB_ID) .setDescription(JOB_DESCRIPTION) @@ -251,7 +251,7 @@ private AnalyticsResultProcessor createResultProcessor(List fiel statsHolder, trainedModelProvider, auditor, - resultsPersisterService, + statsPersister, fieldNames); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java index eea102e673893..0bbc9d75d8be5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java @@ -26,6 +26,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { private int dependentVariableIndex; private String dependentVariable; private long randomizeSeed; + private long trainingDocsCount; + private long testDocsCount; @Before public void setUpTests() { @@ -40,47 +42,48 @@ public void setUpTests() { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(50.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); + String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10); row[fieldIndex] = value; } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // As all these rows have no dependent variable value, they're not for training and should be unaffected assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(0L)); + assertThat(testDocsCount, equalTo(100L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 100.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(100.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); - row[fieldIndex] = value; + row[fieldIndex] = randomAlphaOfLength(10); } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // We should pick them all as training percent is 100 assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(100L)); + assertThat(testDocsCount, equalTo(0L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, trainingPercent, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); int runCount = 20; int rowsCount = 1000; @@ -94,7 +97,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (fieldIndex != dependentVariableIndex) { @@ -126,8 +129,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 1.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(1.0); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -135,16 +137,30 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (i < 9 && fieldIndex == dependentVariableIndex) { - row[fieldIndex] = ""; + row[fieldIndex] = DataFrameDataExtractor.NULL_VALUE; } else { row[fieldIndex] = randomAlphaOfLength(10); } } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(1L)); + assertThat(testDocsCount, equalTo(9L)); + } + + private RandomCrossValidationSplitter createSplitter(double trainingPercent) { + return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed); + } + + private void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + private void incrementTestDocsCount() { + testDocsCount++; } }