Skip to content

Commit

Permalink
[ML][Inference] don't return inflated definition when storing trained…
Browse files Browse the repository at this point in the history
… models (#52573) (#52583)

When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition.

These definitions can be large and returning the inflated definition causes undo work on the server and client side.
  • Loading branch information
benwtrent committed Feb 20, 2020
1 parent 1fe4fde commit 03884a7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/java-rest/high-level/ml/put-trained-model.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include::../execution.asciidoc[]
==== Response

The returned +{response}+ contains the newly created trained model.
The +{response}+ will omit the model definition as a precaution against
streaming large model definitions back to the client.

["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition);
} else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
Expand Down Expand Up @@ -370,6 +370,9 @@ public Builder(TrainedModelConfig config) {
this.tags = config.getTags();
this.metadata = config.getMetadata();
this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
}

public Builder setModelId(String modelId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ public void testToXContentWithParams() throws IOException {
"platinum");

BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"definition\""));
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));

reference = XContentHelper.toXContent(config,
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
false);
assertThat(reference.utf8ToString(), not(containsString("definition")));
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));

reference = XContentHelper.toXContent(config,
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")),
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
false);
assertThat(reference.utf8ToString(), not(containsString("\"definition\"")));
assertThat(reference.utf8ToString(), containsString("compressed_definition"));
assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
assertThat(reference.utf8ToString(), containsString("\"definition\""));
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
}

public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
Expand All @@ -179,7 +179,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();

objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString());
objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());

try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
XContentParser parser = XContentType.JSON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException {
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
assertThat(response, containsString("\"definition\""));
assertThat(response, not(containsString("\"compressed_definition\"")));
assertThat(response, containsString("\"count\":1"));

getModel = client().performRequest(new Request("GET",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ protected void masterOperation(Request request, ClusterState state, ActionListen

ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
bool -> {
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
},
listener::onFailure
)),
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand All @@ -19,6 +25,8 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.rest.RestRequest.Method.GET;
Expand All @@ -32,6 +40,8 @@ public RestGetTrainedModelsAction(RestController controller) {
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this);
}

private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
@Override
public String getName() {
return "ml_get_trained_models_action";
Expand All @@ -53,12 +63,33 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
request,
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
}

@Override
protected Set<String> responseParams() {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
}

private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
private final Map<String, String> defaultToXContentParamValues;

private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
super(channel);
this.defaultToXContentParamValues = defaultToXContentParamValues;
}

@Override
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
assert response.isFragment() == false; //would be nice if we could make default methods final
Map<String, String> params = new HashMap<>(channel.request().params());
defaultToXContentParamValues.forEach((k, v) ->
params.computeIfAbsent(k, defaultToXContentParamValues::get)
);
response.toXContent(builder, new ToXContent.MapParams(params));
return new BytesRestResponse(getStatus(response), builder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,53 @@ setup:
}
}
}
---
"Test put model":
- do:
ml.put_trained_model:
model_id: my-regression-model
body: >
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"ensemble": {
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
},
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
}
]
}
}
}
}
- match: { model_id: my-regression-model }
- match: { estimated_operations: 6 }
- is_false: definition
- is_false: compressed_definition
- is_true: license_level
- is_true: create_time
- is_true: version
- is_true: estimated_heap_memory_usage_bytes

0 comments on commit 03884a7

Please sign in to comment.