diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 70169ce09b278..acdb9cccca1eb 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -21,6 +21,7 @@ import org.elasticsearch.client.ml.NodeAttributes; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -47,6 +48,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws static final ParseField STATE = new ParseField("state"); static final ParseField FAILURE_REASON = new ParseField("failure_reason"); static final ParseField PROGRESS = new ParseField("progress"); + static final ParseField DATA_COUNTS = new ParseField("data_counts"); static final ParseField MEMORY_USAGE = new ParseField("memory_usage"); static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats"); static final ParseField NODE = new ParseField("node"); @@ -60,10 +62,11 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws (DataFrameAnalyticsState) args[1], (String) args[2], (List) args[3], - (MemoryUsage) args[4], - (AnalysisStats) args[5], - (NodeAttributes) args[6], - (String) args[7])); + (DataCounts) args[4], + (MemoryUsage) args[5], + (AnalysisStats) args[6], + (NodeAttributes) args[7], + (String) args[8])); static { PARSER.declareString(constructorArg(), ID); @@ -75,6 +78,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws }, STATE, ObjectParser.ValueType.STRING); PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); + PARSER.declareObject(optionalConstructorArg(), DataCounts.PARSER, DATA_COUNTS); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE); PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); @@ -93,19 +97,21 @@ private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IO private final DataFrameAnalyticsState state; private final String failureReason; private final List progress; + private final DataCounts dataCounts; private final MemoryUsage memoryUsage; private final AnalysisStats analysisStats; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, - @Nullable List progress, @Nullable MemoryUsage memoryUsage, - @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, + @Nullable List progress, @Nullable DataCounts dataCounts, + @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; this.progress = progress; + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -128,6 +134,11 @@ public List getProgress() { return progress; } + @Nullable + public DataCounts getDataCounts() { + return dataCounts; + } + @Nullable public MemoryUsage getMemoryUsage() { return memoryUsage; @@ -156,6 +167,7 @@ public boolean equals(Object o) { && Objects.equals(state, other.state) && Objects.equals(failureReason, other.failureReason) && Objects.equals(progress, other.progress) + && Objects.equals(dataCounts, other.dataCounts) && Objects.equals(memoryUsage, other.memoryUsage) && Objects.equals(analysisStats, other.analysisStats) && Objects.equals(node, other.node) @@ -164,7 +176,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, assignmentExplanation); } @Override @@ -174,6 +186,7 @@ public String toString() { .add("state", state) .add("failureReason", failureReason) .add("progress", progress) + .add("dataCounts", dataCounts) .add("memoryUsage", memoryUsage) .add("analysisStats", analysisStats) .add("node", node) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java new file mode 100644 index 0000000000000..b7a90b1f0b5c6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCounts.java @@ -0,0 +1,119 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataCounts implements ToXContentObject { + + public static final String TYPE_VALUE = "analytics_data_counts"; + + public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count"); + public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count"); + public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE_VALUE, true, + a -> { + Long trainingDocsCount = (Long) a[0]; + Long testDocsCount = (Long) a[1]; + Long skippedDocsCount = (Long) a[2]; + return new DataCounts( + getOrDefault(trainingDocsCount, 0L), + getOrDefault(testDocsCount, 0L), + getOrDefault(skippedDocsCount, 0L) + ); + }); + + static { + PARSER.declareLong(optionalConstructorArg(), TRAINING_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), TEST_DOCS_COUNT); + PARSER.declareLong(optionalConstructorArg(), SKIPPED_DOCS_COUNT); + } + + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(trainingDocsCount, testDocsCount, skippedDocsCount); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount) + .add(TEST_DOCS_COUNT.getPreferredName(), testDocsCount) + .add(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount) + .toString(); + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } + + private static T getOrDefault(@Nullable T value, T defaultValue) { + return value != null ? value : defaultValue; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index 25345181982af..d251f568dfa79 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.client.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests; import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests; @@ -68,6 +69,7 @@ public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : createRandomProgress(), + randomBoolean() ? null : DataCountsTests.createRandom(), randomBoolean() ? null : MemoryUsageTests.createRandom(), analysisStats, randomBoolean() ? null : NodeAttributesTests.createRandom(), @@ -93,6 +95,9 @@ public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder bui if (stats.getProgress() != null) { builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress()); } + if (stats.getDataCounts() != null) { + builder.field(DataFrameAnalyticsStats.DATA_COUNTS.getPreferredName(), stats.getDataCounts()); + } if (stats.getMemoryUsage() != null) { builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 0000000000000..5e877e2d40f7b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe.stats.common; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataCountsTests extends AbstractXContentTestCase { + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return DataCounts.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index 209058e00468b..3f65a917a7427 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -165,6 +166,9 @@ public static class Stats implements ToXContentObject, Writeable { */ private final List progress; + @Nullable + private final DataCounts dataCounts; + @Nullable private final MemoryUsage memoryUsage; @@ -177,12 +181,13 @@ public static class Stats implements ToXContentObject, Writeable { private final String assignmentExplanation; public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List progress, - @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node, - @Nullable String assignmentExplanation) { + @Nullable DataCounts dataCounts, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; this.progress = Objects.requireNonNull(progress); + this.dataCounts = dataCounts; this.memoryUsage = memoryUsage; this.analysisStats = analysisStats; this.node = node; @@ -198,6 +203,11 @@ public Stats(StreamInput in) throws IOException { } else { progress = in.readList(PhaseProgress::new); } + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + dataCounts = in.readOptionalWriteable(DataCounts::new); + } else { + dataCounts = null; + } if (in.getVersion().onOrAfter(Version.V_7_7_0)) { memoryUsage = in.readOptionalWriteable(MemoryUsage::new); } else { @@ -261,6 +271,11 @@ public List getProgress() { return progress; } + @Nullable + public DataCounts getDataCounts() { + return dataCounts; + } + @Nullable public MemoryUsage getMemoryUsage() { return memoryUsage; @@ -293,6 +308,9 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc if (progress != null) { builder.field("progress", progress); } + if (dataCounts != null) { + builder.field("data_counts", dataCounts); + } if (memoryUsage != null) { builder.field("memory_usage", memoryUsage); } @@ -331,6 +349,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeList(progress); } + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeOptionalWriteable(dataCounts); + } if (out.getVersion().onOrAfter(Version.V_7_7_0)) { out.writeOptionalWriteable(memoryUsage); } @@ -369,7 +390,8 @@ private void writeProgressToLegacy(StreamOutput out) throws IOException { @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, + assignmentExplanation); } @Override @@ -385,6 +407,7 @@ public boolean equals(Object obj) { && Objects.equals(this.state, other.state) && Objects.equals(this.failureReason, other.failureReason) && Objects.equals(this.progress, other.progress) + && Objects.equals(this.dataCounts, other.dataCounts) && Objects.equals(this.memoryUsage, other.memoryUsage) && Objects.equals(this.analysisStats, other.analysisStats) && Objects.equals(this.node, other.node) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java new file mode 100644 index 0000000000000..f77cc781c746a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCounts.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Objects; + +public class DataCounts implements ToXContentObject, Writeable { + + public static final String TYPE_VALUE = "analytics_data_counts"; + + public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count"); + public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count"); + public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count"); + + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(TYPE_VALUE, ignoreUnknownFields, + a -> new DataCounts((String) a[0], (long) a[1], (long) a[2], (long) a[3])); + + parser.declareString((bucket, s) -> {}, Fields.TYPE); + parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID); + parser.declareLong(ConstructingObjectParser.constructorArg(), TRAINING_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), TEST_DOCS_COUNT); + parser.declareLong(ConstructingObjectParser.constructorArg(), SKIPPED_DOCS_COUNT); + return parser; + } + + private final String jobId; + private final long trainingDocsCount; + private final long testDocsCount; + private final long skippedDocsCount; + + public DataCounts(String jobId, long trainingDocsCount, long testDocsCount, long skippedDocsCount) { + this.jobId = Objects.requireNonNull(jobId); + this.trainingDocsCount = trainingDocsCount; + this.testDocsCount = testDocsCount; + this.skippedDocsCount = skippedDocsCount; + } + + public DataCounts(StreamInput in) throws IOException { + this.jobId = in.readString(); + this.trainingDocsCount = in.readVLong(); + this.testDocsCount = in.readVLong(); + this.skippedDocsCount = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(jobId); + out.writeVLong(trainingDocsCount); + out.writeVLong(testDocsCount); + out.writeVLong(skippedDocsCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(Fields.TYPE.getPreferredName(), TYPE_VALUE); + builder.field(Fields.JOB_ID.getPreferredName(), jobId); + } + builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount); + builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount); + builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataCounts that = (DataCounts) o; + return Objects.equals(jobId, that.jobId) + && trainingDocsCount == that.trainingDocsCount + && testDocsCount == that.testDocsCount + && skippedDocsCount == that.skippedDocsCount; + } + + @Override + public int hashCode() { + return Objects.hash(jobId, trainingDocsCount, testDocsCount, skippedDocsCount); + } + + public static String documentId(String jobId) { + return TYPE_VALUE + "_" + jobId; + } + + public String getJobId() { + return jobId; + } + + public long getTrainingDocsCount() { + return trainingDocsCount; + } + + public long getTestDocsCount() { + return testDocsCount; + } + + public long getSkippedDocsCount() { + return skippedDocsCount; + } +} diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json index 5b742c9f6d91e..7a846f5743b91 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/stats_index_mappings.json @@ -85,6 +85,9 @@ "peak_usage_bytes" : { "type" : "long" }, + "skipped_docs_count": { + "type": "long" + }, "timestamp" : { "type" : "date" }, @@ -98,6 +101,12 @@ } } }, + "test_docs_count": { + "type": "long" + }, + "training_docs_count": { + "type": "long" + }, "type" : { "type" : "keyword" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 5cb2b3fef5450..3f957f95a902c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCountsTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsageTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; @@ -42,6 +44,7 @@ public static Response randomResponse(int listSize) { List progress = new ArrayList<>(progressSize); IntStream.of(progressSize).forEach(progressIndex -> progress.add( new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); + DataCounts dataCounts = randomBoolean() ? null : DataCountsTests.createRandom(); MemoryUsage memoryUsage = randomBoolean() ? null : MemoryUsageTests.createRandom(); AnalysisStats analysisStats = randomBoolean() ? null : randomFrom( @@ -50,7 +53,7 @@ public static Response randomResponse(int listSize) { RegressionStatsTests.createRandom() ); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, memoryUsage, analysisStats, null, + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, dataCounts, memoryUsage, analysisStats, null, randomAlphaOfLength(20)); analytics.add(stats); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java new file mode 100644 index 0000000000000..84033d49de655 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/DataCountsTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.stats.common; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; + +public class DataCountsTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + @Before + public void chooseLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected DataCounts mutateInstanceForVersion(DataCounts instance, Version version) { + return instance; + } + + @Override + protected DataCounts doParseInstance(XContentParser parser) throws IOException { + return lenient ? DataCounts.LENIENT_PARSER.apply(parser, null) : DataCounts.STRICT_PARSER.apply(parser, null); + } + + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataCounts::new; + } + + @Override + protected DataCounts createTestInstance() { + return createRandom(); + } + + public static DataCounts createRandom() { + return new DataCounts( + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 9c4511527eeba..a96f6e8b19b9e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -48,6 +49,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.startsWith; @@ -157,6 +159,12 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); @@ -224,6 +232,14 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index e32d923852038..ced8c987fd685 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; @@ -79,6 +80,12 @@ public void testMissingFields() throws Exception { startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(2L)); + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); for (SearchHit hit : sourceData.getHits()) { GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index fc20bbb5afd7b..07752468a743f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -31,6 +32,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -141,6 +143,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); @@ -197,6 +206,14 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId); + assertThat(stats.getDataCounts().getJobId(), equalTo(jobId)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getTestDocsCount(), greaterThan(0L)); + assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); + assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 4ffe948c9f5b7..3f7a5a7156637 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -85,6 +85,11 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { startAnalytics(id); waitUntilAnalyticsIsStopped(id); + GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id); + assertThat(stats.getDataCounts().getJobId(), equalTo(id)); + assertThat(stats.getDataCounts().getTrainingDocsCount(), equalTo(5L)); + assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L)); + assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L)); SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); double scoreOfOutlier = 0.0; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 9ee28ba29cf3a..cb8d7bcb2be4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; @@ -108,6 +109,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D Stats stats = buildStats( task.getParams().getId(), statsHolder.getProgressTracker().report(), + statsHolder.getDataCountsTracker().report(task.getParams().getId()), statsHolder.getMemoryUsage(), statsHolder.getAnalysisStats() ); @@ -200,6 +202,7 @@ private void searchStats(String configId, ActionListener listener) { MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); multiSearchRequest.add(buildStoredProgressSearch(configId)); + multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE)); multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE)); @@ -224,6 +227,7 @@ private void searchStats(String configId, ActionListener listener) { } listener.onResponse(buildStats(configId, retrievedStatsHolder.progress.get(), + retrievedStatsHolder.dataCounts, retrievedStatsHolder.memoryUsage, retrievedStatsHolder.analysisStats )); @@ -258,6 +262,8 @@ private static void parseHit(SearchHit hit, String configId, RetrievedStatsHolde String hitId = hit.getId(); if (StoredProgress.documentId(configId).equals(hitId)) { retrievedStatsHolder.progress = MlParserUtils.parse(hit, StoredProgress.PARSER); + } else if (DataCounts.documentId(configId).equals(hitId)) { + retrievedStatsHolder.dataCounts = MlParserUtils.parse(hit, DataCounts.LENIENT_PARSER); } else if (hitId.startsWith(MemoryUsage.documentIdPrefix(configId))) { retrievedStatsHolder.memoryUsage = MlParserUtils.parse(hit, MemoryUsage.LENIENT_PARSER); } else if (hitId.startsWith(OutlierDetectionStats.documentIdPrefix(configId))) { @@ -273,6 +279,7 @@ private static void parseHit(SearchHit hit, String configId, RetrievedStatsHolde private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress, + DataCounts dataCounts, MemoryUsage memoryUsage, AnalysisStats analysisStats) { ClusterState clusterState = clusterService.state(); @@ -295,6 +302,7 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre analyticsState, failureReason, progress, + dataCounts, memoryUsage, analysisStats, node, @@ -305,6 +313,7 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre private static class RetrievedStatsHolder { private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report()); + private volatile DataCounts dataCounts; private volatile MemoryUsage memoryUsage; private volatile AnalysisStats analysisStats; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index ba1bdf9a6a5d0..aad06b71c075a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -19,6 +19,9 @@ import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; @@ -187,7 +190,7 @@ private Row createRow(SearchHit hit) { if (values.length == 1 && (values[0] instanceof Number || values[0] instanceof String)) { extractedValues[i] = Objects.toString(values[0]); } else { - if (values.length == 0 && context.includeRowsWithMissingValues) { + if (values.length == 0 && context.supportsRowsWithMissingValues) { // if values is empty then it means it's a missing value extractedValues[i] = NULL_VALUE; } else { @@ -263,13 +266,29 @@ public void collectDataSummaryAsync(ActionListener dataSummaryActio } private SearchRequestBuilder buildDataSummarySearchRequestBuilder() { + + QueryBuilder summaryQuery = context.query; + if (context.supportsRowsWithMissingValues == false) { + summaryQuery = QueryBuilders.boolQuery() + .filter(summaryQuery) + .filter(allExtractedFieldsExistQuery()); + } + return new SearchRequestBuilder(client, SearchAction.INSTANCE) .setIndices(context.indices) .setSize(0) - .setQuery(context.query) + .setQuery(summaryQuery) .setTrackTotalHits(true); } + private QueryBuilder allExtractedFieldsExistQuery() { + BoolQueryBuilder query = QueryBuilders.boolQuery(); + for (ExtractedField field : context.extractedFields.getAllFields()) { + query.filter(QueryBuilders.existsQuery(field.getName())); + } + return query; + } + public Set getCategoricalFields(DataFrameAnalysis analysis) { return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java index 0cf391bc33b2e..64ad4bed452e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java @@ -21,10 +21,10 @@ public class DataFrameDataExtractorContext { final int scrollSize; final Map headers; final boolean includeSource; - final boolean includeRowsWithMissingValues; + final boolean supportsRowsWithMissingValues; DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, - Map headers, boolean includeSource, boolean includeRowsWithMissingValues) { + Map headers, boolean includeSource, boolean supportsRowsWithMissingValues) { this.jobId = Objects.requireNonNull(jobId); this.extractedFields = Objects.requireNonNull(extractedFields); this.indices = indices.toArray(new String[indices.size()]); @@ -32,6 +32,6 @@ public class DataFrameDataExtractorContext { this.scrollSize = scrollSize; this.headers = headers; this.includeSource = includeSource; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 3243d92bf77b6..a699e16a7d602 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -7,11 +7,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.Arrays; @@ -28,18 +26,18 @@ public class DataFrameDataExtractorFactory { private final QueryBuilder sourceQuery; private final ExtractedFields extractedFields; private final Map headers; - private final boolean includeRowsWithMissingValues; + private final boolean supportsRowsWithMissingValues; private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, ExtractedFields extractedFields, Map headers, - boolean includeRowsWithMissingValues) { + boolean supportsRowsWithMissingValues) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); this.sourceQuery = Objects.requireNonNull(sourceQuery); this.extractedFields = Objects.requireNonNull(extractedFields); this.headers = headers; - this.includeRowsWithMissingValues = includeRowsWithMissingValues; + this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; } public DataFrameDataExtractor newExtractor(boolean includeSource) { @@ -47,11 +45,11 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { analyticsId, extractedFields, indices, - createQuery(), + QueryBuilders.boolQuery().filter(sourceQuery), 1000, headers, includeSource, - includeRowsWithMissingValues + supportsRowsWithMissingValues ); return new DataFrameDataExtractor(client, context); } @@ -60,23 +58,6 @@ public ExtractedFields getExtractedFields() { return extractedFields; } - private QueryBuilder createQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - query.filter(sourceQuery); - if (includeRowsWithMissingValues == false) { - query.filter(allExtractedFieldsExistQuery()); - } - return query; - } - - private QueryBuilder allExtractedFieldsExistQuery() { - BoolQueryBuilder query = QueryBuilders.boolQuery(); - for (ExtractedField field : extractedFields.getAllFields()) { - query.filter(QueryBuilders.existsQuery(field.getName())); - } - return query; - } - /** * Create a new extractor factory * @@ -109,6 +90,7 @@ public static void createForDestinationIndex(Client client, extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap( extractedFieldsDetector -> { ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); + DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(), Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index c7baad202d2f6..c1f09ff3fb34f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; @@ -22,8 +23,10 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,7 +37,9 @@ import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter; import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -156,7 +161,10 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); try { writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker()); + writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(), + task.getStatsHolder().getDataCountsTracker()); + processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), + DataCounts::documentId); process.writeEndOfDataMessage(); process.flushStream(); @@ -205,8 +213,8 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont } } - private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, - DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException { + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, DataFrameAnalysis analysis, + ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException { CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames()) .create(analysis); @@ -223,11 +231,14 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces Optional> rows = dataExtractor.next(); if (rows.isPresent()) { for (DataFrameDataExtractor.Row row : rows.get()) { - if (row.shouldSkip() == false) { + if (row.shouldSkip()) { + dataCountsTracker.incrementSkippedDocsCount(); + } else { String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); record[record.length - 2] = String.valueOf(row.getChecksum()); - crossValidationSplitter.process(record); + crossValidationSplitter.process(record, dataCountsTracker::incrementTrainingDocsCount, + dataCountsTracker::incrementTestDocsCount); process.writeRecord(record); } } @@ -253,6 +264,10 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr process.writeRecord(headerRecord); } + private void indexDataCounts(DataCounts dataCounts) { + IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias()); + } + private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state, AnalyticsProcess process) { if (config.getAnalysis().persistsState() == false) { @@ -353,9 +368,11 @@ class ProcessContext { private final SetOnce dataExtractor = new SetOnce<>(); private final SetOnce resultProcessor = new SetOnce<>(); private final SetOnce failureReason = new SetOnce<>(); + private final StatsPersister statsPersister; ProcessContext(DataFrameAnalyticsConfig config) { this.config = Objects.requireNonNull(config); + this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor); } String getFailureReason() { @@ -378,6 +395,7 @@ synchronized void stop() { if (resultProcessor.get() != null) { resultProcessor.get().cancel(); } + statsPersister.cancel(); if (process.get() != null) { try { process.get().kill(); @@ -434,7 +452,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); return new AnalyticsResultProcessor( - config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, resultsPersisterService, + config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, dataExtractor.get().getAllExtractedFields()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index eaa202df5fc4a..d27fa31b41193 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -11,14 +11,10 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.License; -import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; @@ -31,18 +27,16 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; -import java.io.IOException; import java.time.Instant; import java.util.Collections; import java.util.Iterator; @@ -51,7 +45,6 @@ import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.function.Function; import java.util.stream.Collectors; import static java.util.stream.Collectors.toList; @@ -77,7 +70,7 @@ public class AnalyticsResultProcessor { private final StatsHolder statsHolder; private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; - private final ResultsPersisterService resultsPersisterService; + private final StatsPersister statsPersister; private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; @@ -85,14 +78,13 @@ public class AnalyticsResultProcessor { public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, StatsHolder statsHolder, TrainedModelProvider trainedModelProvider, - DataFrameAnalyticsAuditor auditor, ResultsPersisterService resultsPersisterService, - List fieldNames) { + DataFrameAnalyticsAuditor auditor, StatsPersister statsPersister, List fieldNames) { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.statsHolder = Objects.requireNonNull(statsHolder); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); - this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.statsPersister = Objects.requireNonNull(statsPersister); this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); } @@ -112,6 +104,7 @@ public void awaitForCompletion() { public void cancel() { dataFrameRowsJoiner.cancel(); + statsPersister.cancel(); isCancelled = true; } @@ -177,22 +170,22 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { statsHolder.setMemoryUsage(memoryUsage); - indexStatsResult(memoryUsage, memoryUsage::documentId); + statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId); } OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats(); if (outlierDetectionStats != null) { statsHolder.setAnalysisStats(outlierDetectionStats); - indexStatsResult(outlierDetectionStats, outlierDetectionStats::documentId); + statsPersister.persistWithRetry(outlierDetectionStats, outlierDetectionStats::documentId); } ClassificationStats classificationStats = result.getClassificationStats(); if (classificationStats != null) { statsHolder.setAnalysisStats(classificationStats); - indexStatsResult(classificationStats, classificationStats::documentId); + statsPersister.persistWithRetry(classificationStats, classificationStats::documentId); } RegressionStats regressionStats = result.getRegressionStats(); if (regressionStats != null) { statsHolder.setAnalysisStats(regressionStats); - indexStatsResult(regressionStats, regressionStats::documentId); + statsPersister.persistWithRetry(regressionStats, regressionStats::documentId); } } @@ -275,23 +268,4 @@ private void setAndReportFailure(Exception e) { failure = "error processing results; " + e.getMessage(); auditor.error(analytics.getId(), "Error processing results; " + e.getMessage()); } - - private void indexStatsResult(ToXContentObject result, Function docIdSupplier) { - try { - resultsPersisterService.indexWithRetry(analytics.getId(), - MlStatsIndex.writeAlias(), - result, - new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), - WriteRequest.RefreshPolicy.IMMEDIATE, - docIdSupplier.apply(analytics.getId()), - () -> isCancelled == false, - errorMsg -> auditor.error(analytics.getId(), - "failed to persist result with id [" + docIdSupplier.apply(analytics.getId()) + "]; " + errorMsg) - ); - } catch (IOException ioe) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", analytics.getId()), ioe); - } catch (Exception e) { - LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", analytics.getId()), e); - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java index 5d12a2a81a607..fce602b28e2c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java @@ -10,5 +10,5 @@ */ public interface CrossValidationSplitter { - void process(String[] row); + void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java index 47c052dd0bf84..986633aaa3705 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java @@ -31,6 +31,6 @@ public CrossValidationSplitter create(DataFrameAnalysis analysis) { return new RandomCrossValidationSplitter( fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } - return row -> {}; + return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java index 0afc59628e7de..e4e343083ee25 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java @@ -40,22 +40,25 @@ private static int findDependentVariableIndex(List fieldNames, String de } @Override - public void process(String[] row) { - if (canBeUsedForTraining(row)) { - if (isFirstRow) { - // Let's make sure we have at least one training row - isFirstRow = false; - } else if (isRandomlyExcludedFromTraining()) { - row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; - } + public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { + if (canBeUsedForTraining(row) && isPickedForTraining()) { + incrementTrainingDocs.run(); + } else { + row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; + incrementTestDocs.run(); } } private boolean canBeUsedForTraining(String[] row) { - return row[dependentVariableIndex].length() > 0; + return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; } - private boolean isRandomlyExcludedFromTraining() { - return random.nextDouble() * 100 > trainingPercent; + private boolean isPickedForTraining() { + if (isFirstRow) { + // Let's make sure we have at least one training row + isFirstRow = false; + return true; + } + return random.nextDouble() * 100 <= trainingPercent; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java new file mode 100644 index 0000000000000..bed9f52b448cf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; + +public class DataCountsTracker { + + private volatile long trainingDocsCount; + private volatile long testDocsCount; + private volatile long skippedDocsCount; + + public void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + public void incrementTestDocsCount() { + testDocsCount++; + } + + public void incrementSkippedDocsCount() { + skippedDocsCount++; + } + + public DataCounts report(String jobId) { + return new DataCounts( + jobId, + trainingDocsCount, + testDocsCount, + skippedDocsCount + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index ff6b9ec7bcfef..d01eee6a3a3d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -19,11 +19,13 @@ public class StatsHolder { private final ProgressTracker progressTracker; private final AtomicReference memoryUsageHolder; private final AtomicReference analysisStatsHolder; + private final DataCountsTracker dataCountsTracker; public StatsHolder() { progressTracker = new ProgressTracker(); memoryUsageHolder = new AtomicReference<>(); analysisStatsHolder = new AtomicReference<>(); + dataCountsTracker = new DataCountsTracker(); } public ProgressTracker getProgressTracker() { @@ -45,4 +47,8 @@ public void setAnalysisStats(AnalysisStats analysisStats) { public AnalysisStats getAnalysisStats() { return analysisStatsHolder.get(); } + + public DataCountsTracker getDataCountsTracker() { + return dataCountsTracker; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java new file mode 100644 index 0000000000000..eeb8924928ce4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; + +import java.io.IOException; +import java.util.Collections; +import java.util.Objects; +import java.util.function.Function; + +public class StatsPersister { + + private static final Logger LOGGER = LogManager.getLogger(StatsPersister.class); + + private final String jobId; + private final ResultsPersisterService resultsPersisterService; + private final DataFrameAnalyticsAuditor auditor; + private volatile boolean isCancelled; + + public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) { + this.jobId = Objects.requireNonNull(jobId); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); + this.auditor = Objects.requireNonNull(auditor); + } + + public void persistWithRetry(ToXContentObject result, Function docIdSupplier) { + if (isCancelled) { + return; + } + + try { + resultsPersisterService.indexWithRetry(jobId, + MlStatsIndex.writeAlias(), + result, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), + WriteRequest.RefreshPolicy.IMMEDIATE, + docIdSupplier.apply(jobId), + () -> isCancelled == false, + errorMsg -> auditor.error(jobId, + "failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg) + ); + } catch (IOException ioe) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed serializing stats result", jobId), ioe); + } catch (Exception e) { + LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e); + } + } + + public void cancel() { + isCancelled = true; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index b75392e03c2bb..01661f6ec82a6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -324,9 +324,46 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}")); } - public void testMissingValues_GivenShouldNotInclude() throws IOException { + public void testCollectDataSummary_GivenAnalysisSupportsMissingFields() { + TestExtractor dataExtractor = createExtractor(true, true); + + // First and only batch + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); + + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + } + + public void testCollectDataSummary_GivenAnalysisDoesNotSupportMissingFields() { TestExtractor dataExtractor = createExtractor(true, false); + // First and only batch + SearchResponse response = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); + dataExtractor.setNextResponse(response); + + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + + assertThat(dataSummary.rows, equalTo(2L)); + assertThat(dataSummary.cols, equalTo(2)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString( + "\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"bool\":{\"filter\":" + + "[{\"exists\":{\"field\":\"field_1\",\"boost\":1.0}},{\"exists\":{\"field\":\"field_2\",\"boost\":1.0}}]," + + "\"boost\":1.0}}],\"boost\":1.0}")); + } + + public void testMissingValues_GivenSupported() throws IOException { + TestExtractor dataExtractor = createExtractor(true, true); + // First and only batch SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); dataExtractor.setNextResponse(response1); @@ -343,11 +380,12 @@ public void testMissingValues_GivenShouldNotInclude() throws IOException { assertThat(rows.get().size(), equalTo(3)); assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(rows.get().get(1).getValues(), is(nullValue())); + assertThat(rows.get().get(1).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); + assertThat(rows.get().get(1).getValues()[1], equalTo("22")); assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(true)); + assertThat(rows.get().get(1).shouldSkip(), is(false)); assertThat(rows.get().get(2).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -358,8 +396,8 @@ public void testMissingValues_GivenShouldNotInclude() throws IOException { assertThat(dataExtractor.hasNext(), is(false)); } - public void testMissingValues_GivenShouldInclude() throws IOException { - TestExtractor dataExtractor = createExtractor(true, true); + public void testMissingValues_GivenNotSupported() throws IOException { + TestExtractor dataExtractor = createExtractor(true, false); // First and only batch SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); @@ -377,12 +415,11 @@ public void testMissingValues_GivenShouldInclude() throws IOException { assertThat(rows.get().size(), equalTo(3)); assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(rows.get().get(1).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); - assertThat(rows.get().get(1).getValues()[1], equalTo("22")); + assertThat(rows.get().get(1).getValues(), is(nullValue())); assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(false)); + assertThat(rows.get().get(1).shouldSkip(), is(true)); assertThat(rows.get().get(2).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -424,9 +461,9 @@ public void testGetCategoricalFields() { containsInAnyOrder("field_keyword", "field_text", "field_boolean")); } - private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) { + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( - JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues); + JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues); return new TestExtractor(client, context); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index da7bbff71d5a7..ee2b399ef2d48 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -24,13 +24,13 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; @@ -66,7 +66,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private StatsHolder statsHolder = new StatsHolder(); private TrainedModelProvider trainedModelProvider; private DataFrameAnalyticsAuditor auditor; - private ResultsPersisterService resultsPersisterService; + private StatsPersister statsPersister; private DataFrameAnalyticsConfig analyticsConfig; @Before @@ -76,7 +76,7 @@ public void setUpMocks() { dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); trainedModelProvider = mock(TrainedModelProvider.class); auditor = mock(DataFrameAnalyticsAuditor.class); - resultsPersisterService = mock(ResultsPersisterService.class); + statsPersister = mock(StatsPersister.class); analyticsConfig = new DataFrameAnalyticsConfig.Builder() .setId(JOB_ID) .setDescription(JOB_DESCRIPTION) @@ -251,7 +251,7 @@ private AnalyticsResultProcessor createResultProcessor(List fiel statsHolder, trainedModelProvider, auditor, - resultsPersisterService, + statsPersister, fieldNames); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java index eea102e673893..0bbc9d75d8be5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java @@ -26,6 +26,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { private int dependentVariableIndex; private String dependentVariable; private long randomizeSeed; + private long trainingDocsCount; + private long testDocsCount; @Before public void setUpTests() { @@ -40,47 +42,48 @@ public void setUpTests() { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(50.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); + String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10); row[fieldIndex] = value; } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // As all these rows have no dependent variable value, they're not for training and should be unaffected assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(0L)); + assertThat(testDocsCount, equalTo(100L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 100.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(100.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { - String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10); - row[fieldIndex] = value; + row[fieldIndex] = randomAlphaOfLength(10); } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); // We should pick them all as training percent is 100 assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(100L)); + assertThat(testDocsCount, equalTo(0L)); } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, trainingPercent, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); int runCount = 20; int rowsCount = 1000; @@ -94,7 +97,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (fieldIndex != dependentVariableIndex) { @@ -126,8 +129,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( - fields, dependentVariable, 1.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = createSplitter(1.0); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -135,16 +137,30 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (i < 9 && fieldIndex == dependentVariableIndex) { - row[fieldIndex] = ""; + row[fieldIndex] = DataFrameDataExtractor.NULL_VALUE; } else { row[fieldIndex] = randomAlphaOfLength(10); } } String[] processedRow = Arrays.copyOf(row, row.length); - crossValidationSplitter.process(processedRow); + crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount); assertThat(Arrays.equals(processedRow, row), is(true)); } + assertThat(trainingDocsCount, equalTo(1L)); + assertThat(testDocsCount, equalTo(9L)); + } + + private RandomCrossValidationSplitter createSplitter(double trainingPercent) { + return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed); + } + + private void incrementTrainingDocsCount() { + trainingDocsCount++; + } + + private void incrementTestDocsCount() { + testDocsCount++; } }