Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add a model memory estimation endpoint for anomaly detection #53507

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.FindFileStructureRequest;
Expand Down Expand Up @@ -593,6 +594,17 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven
return new Request(HttpDelete.METHOD_NAME, endpoint);
}

static Request estimateModelMemory(EstimateModelMemoryRequest estimateModelMemoryRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
.addPathPartAsIs("anomaly_detectors")
.addPathPartAsIs("_estimate_model_memory")
.build();
Request request = new Request(HttpPost.METHOD_NAME, endpoint);
request.setEntity(createEntity(estimateModelMemoryRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
}

static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "data_frame", "analytics")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.CloseJobResponse;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
Expand Down Expand Up @@ -1951,6 +1953,48 @@ public Cancellable setUpgradeModeAsync(SetUpgradeModeRequest request, RequestOpt
Collections.emptySet());
}

/**
* Estimate the model memory an analysis config is likely to need given supplied field cardinalities
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
*
* @param request The {@link EstimateModelMemoryRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return {@link EstimateModelMemoryResponse} response object
*/
public EstimateModelMemoryResponse estimateModelMemory(EstimateModelMemoryRequest request,
RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::estimateModelMemory,
options,
EstimateModelMemoryResponse::fromXContent,
Collections.emptySet());
}

/**
* Estimate the model memory an analysis config is likely to need given supplied field cardinalities and notifies listener upon
* completion
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
*
* @param request The {@link EstimateModelMemoryRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable estimateModelMemoryAsync(EstimateModelMemoryRequest request,
RequestOptions options,
ActionListener<EstimateModelMemoryResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::estimateModelMemory,
options,
EstimateModelMemoryResponse::fromXContent,
listener,
Collections.emptySet());
}

/**
* Creates a new Data Frame Analytics config
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
* Request to estimate the model memory an analysis config is likely to need given supplied field cardinalities.
*/
public class EstimateModelMemoryRequest implements Validatable, ToXContentObject {

public static final String ANALYSIS_CONFIG = "analysis_config";
public static final String OVERALL_CARDINALITY = "overall_cardinality";
public static final String MAX_BUCKET_CARDINALITY = "max_bucket_cardinality";

private final AnalysisConfig analysisConfig;
private Map<String, Long> overallCardinality = Collections.emptyMap();
private Map<String, Long> maxBucketCardinality = Collections.emptyMap();

@Override
public Optional<ValidationException> validate() {
return Optional.empty();
}

public EstimateModelMemoryRequest(AnalysisConfig analysisConfig) {
this.analysisConfig = Objects.requireNonNull(analysisConfig);
}

public AnalysisConfig getAnalysisConfig() {
return analysisConfig;
}

public Map<String, Long> getOverallCardinality() {
return overallCardinality;
}

public void setOverallCardinality(Map<String, Long> overallCardinality) {
this.overallCardinality = Collections.unmodifiableMap(overallCardinality);
}

public Map<String, Long> getMaxBucketCardinality() {
return maxBucketCardinality;
}

public void setMaxBucketCardinality(Map<String, Long> maxBucketCardinality) {
this.maxBucketCardinality = Collections.unmodifiableMap(maxBucketCardinality);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ANALYSIS_CONFIG, analysisConfig);
if (overallCardinality.isEmpty() == false) {
builder.field(OVERALL_CARDINALITY, overallCardinality);
}
if (maxBucketCardinality.isEmpty() == false) {
builder.field(MAX_BUCKET_CARDINALITY, maxBucketCardinality);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(analysisConfig, overallCardinality, maxBucketCardinality);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}

if (other == null || getClass() != other.getClass()) {
return false;
}

EstimateModelMemoryRequest that = (EstimateModelMemoryRequest) other;
return Objects.equals(analysisConfig, that.analysisConfig) &&
Objects.equals(overallCardinality, that.overallCardinality) &&
Objects.equals(maxBucketCardinality, that.maxBucketCardinality);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;

import java.util.Objects;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;

public class EstimateModelMemoryResponse {

public static final ParseField MODEL_MEMORY_ESTIMATE = new ParseField("model_memory_estimate");

static final ConstructingObjectParser<EstimateModelMemoryResponse, Void> PARSER =
new ConstructingObjectParser<>(
"estimate_model_memory",
true,
args -> new EstimateModelMemoryResponse((String) args[0]));

static {
PARSER.declareString(constructorArg(), MODEL_MEMORY_ESTIMATE);
}

public static EstimateModelMemoryResponse fromXContent(final XContentParser parser) {
return PARSER.apply(parser, null);
}

private final ByteSizeValue modelMemoryEstimate;

public EstimateModelMemoryResponse(String modelMemoryEstimate) {
this.modelMemoryEstimate = ByteSizeValue.parseBytesSizeValue(modelMemoryEstimate, MODEL_MEMORY_ESTIMATE.getPreferredName());
}

/**
* @return An estimate of the model memory the supplied analysis config is likely to need given the supplied field cardinalities.
*/
public ByteSizeValue getModelMemoryEstimate() {
return modelMemoryEstimate;
}

@Override
public boolean equals(Object o) {

if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

EstimateModelMemoryResponse other = (EstimateModelMemoryResponse) o;
return Objects.equals(this.modelMemoryEstimate, other.modelMemoryEstimate);
}

@Override
public int hashCode() {
return Objects.hash(modelMemoryEstimate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
Expand Down Expand Up @@ -107,6 +108,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
Expand Down Expand Up @@ -695,6 +697,25 @@ public void testDeleteCalendarEvent() {
assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint());
}

public void testEstimateModelMemory() throws Exception {
String byFieldName = randomAlphaOfLength(10);
String influencerFieldName = randomAlphaOfLength(10);
AnalysisConfig analysisConfig = AnalysisConfig.builder(
Collections.singletonList(
Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
)).setInfluencers(Collections.singletonList(influencerFieldName)).build();
EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));
Request request = MLRequestConverters.estimateModelMemory(estimateModelMemoryRequest);
assertEquals(HttpPost.METHOD_NAME, request.getMethod());
assertEquals("/_ml/anomaly_detectors/_estimate_model_memory", request.getEndpoint());

XContentBuilder builder = JsonXContent.contentBuilder();
builder = estimateModelMemoryRequest.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertEquals(Strings.toString(builder), requestEntityToString(request));
}

public void testPutDataFrameAnalytics() throws IOException {
PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig());
Request request = MLRequestConverters.putDataFrameAnalytics(putRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.elasticsearch.client.ml.DeleteJobResponse;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
Expand Down Expand Up @@ -1244,6 +1246,27 @@ public void testDeleteCalendarEvent() throws IOException {
assertThat(remainingIds, not(hasItem(deletedEvent)));
}

public void testEstimateModelMemory() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();

String byFieldName = randomAlphaOfLength(10);
String influencerFieldName = randomAlphaOfLength(10);
AnalysisConfig analysisConfig = AnalysisConfig.builder(
Collections.singletonList(
Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
)).setInfluencers(Collections.singletonList(influencerFieldName)).build();
EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));

EstimateModelMemoryResponse estimateModelMemoryResponse = execute(
estimateModelMemoryRequest,
machineLearningClient::estimateModelMemory, machineLearningClient::estimateModelMemoryAsync);

ByteSizeValue modelMemoryEstimate = estimateModelMemoryResponse.getModelMemoryEstimate();
assertThat(modelMemoryEstimate.getBytes(), greaterThanOrEqualTo(10000000L));
}

public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "test-put-df-analytics-outlier-detection";
Expand Down
Loading