Skip to content

Commit

Permalink
[ML] add new cache_size parameter to trained_model deployments API (#…
Browse files Browse the repository at this point in the history
…88450)

With: elastic/ml-cpp#2305 we now support caching pytorch inference responses per node per model.

By default, the cache will be the same size has the model on disk size. This is because our current best estimate for memory used (for deploying) is 2*model_size + constant_overhead. 

This is due to the model having to be loaded in memory twice when serializing to the native process. 

But, once the model is in memory and accepting requests, its actual memory usage is reduced vs. what we have "reserved" for it within the node.

Consequently, having a cache layer that takes advantage of that unused (but reserved) memory is effectively free. When used in production, especially in search scenarios, caching inference results is critical for decreasing latency.
  • Loading branch information
benwtrent committed Jul 18, 2022
1 parent 5c11a81 commit afa28d4
Show file tree
Hide file tree
Showing 28 changed files with 376 additions and 32 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/88450.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 88450
summary: Add new `cache_size` parameter to `trained_model` deployments API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration.
(integer)
The current number of nodes where the model is allocated.
`cache_size`:::
(<<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
`state`:::
(string)
The detailed allocation state related to the nodes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Increasing `threads_per_allocation` means more threads are used when
an inference request is processed on a node. This can improve inference speed
for certain models. It may also result in improvement to throughput.

Increasing `number_of_allocations` means more threads are used to
Increasing `number_of_allocations` means more threads are used to
process multiple inference requests in parallel resulting in throughput
improvement. Each model allocation uses a number of threads defined by
`threads_per_allocation`.
Expand All @@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
[[start-trained-model-deployment-query-params]]
== {api-query-parms-title}

`cache_size`::
(Optional, <<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the model.
The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided.

`number_of_allocations`::
(Optional, integer)
The total number of allocations this model is assigned across {ml} nodes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
]
},
"params":{
"cache_size": {
"type": "string",
"description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
"required": false
},
"number_of_allocations":{
"type":"int",
"description": "The number of model allocations on each node where the model is deployed.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ConstructingObjectParser;
Expand All @@ -34,8 +35,10 @@

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;

public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
Expand Down Expand Up @@ -75,6 +78,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;

public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

Expand All @@ -85,6 +89,12 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
PARSER.declareField(
Request::setCacheSize,
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
}

public static Request parseRequest(String modelId, XContentParser parser) {
Expand All @@ -102,6 +112,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
private String modelId;
private TimeValue timeout = DEFAULT_TIMEOUT;
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private ByteSizeValue cacheSize;
private int numberOfAllocations = 1;
private int threadsPerAllocation = 1;
private int queueCapacity = 1024;
Expand All @@ -120,6 +131,9 @@ public Request(StreamInput in) throws IOException {
numberOfAllocations = in.readVInt();
threadsPerAllocation = in.readVInt();
queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
}
}

public final void setModelId(String modelId) {
Expand Down Expand Up @@ -171,6 +185,14 @@ public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}

public ByteSizeValue getCacheSize() {
return cacheSize;
}

public void setCacheSize(ByteSizeValue cacheSize) {
this.cacheSize = cacheSize;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -180,6 +202,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(numberOfAllocations);
out.writeVInt(threadsPerAllocation);
out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
}

@Override
Expand All @@ -191,6 +216,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -229,7 +257,7 @@ private static boolean isPowerOf2(int value) {

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
}

@Override
Expand All @@ -244,6 +272,7 @@ public boolean equals(Object obj) {
return Objects.equals(modelId, other.modelId)
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& Objects.equals(cacheSize, other.cacheSize)
&& numberOfAllocations == other.numberOfAllocations
&& threadsPerAllocation == other.threadsPerAllocation
&& queueCapacity == other.queueCapacity;
Expand Down Expand Up @@ -273,11 +302,21 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
// threads_per_allocation was previously named inference_threads
public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
public static final ParseField CACHE_SIZE = new ParseField("cache_size");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
a -> new TaskParams(
(String) a[0],
(Long) a[1],
(Integer) a[2],
(Integer) a[3],
(int) a[4],
(ByteSizeValue) a[5],
(Integer) a[6],
(Integer) a[7]
)
);

static {
Expand All @@ -286,6 +325,12 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
CACHE_SIZE,
ObjectParser.ValueType.VALUE
);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
}
Expand All @@ -295,6 +340,7 @@ public static TaskParams fromXContent(XContentParser parser) {
}

private final String modelId;
private final ByteSizeValue cacheSize;
private final long modelBytes;
// How many threads are used by the model during inference. Used to increase inference speed.
private final int threadsPerAllocation;
Expand All @@ -308,6 +354,7 @@ private TaskParams(
Integer numberOfAllocations,
Integer threadsPerAllocation,
int queueCapacity,
ByteSizeValue cacheSizeValue,
Integer legacyModelThreads,
Integer legacyInferenceThreads
) {
Expand All @@ -316,16 +363,25 @@ private TaskParams(
modelBytes,
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
queueCapacity
queueCapacity,
cacheSizeValue
);
}

public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
public TaskParams(
String modelId,
long modelBytes,
int threadsPerAllocation,
int numberOfAllocations,
int queueCapacity,
@Nullable ByteSizeValue cacheSize
) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
this.threadsPerAllocation = threadsPerAllocation;
this.numberOfAllocations = numberOfAllocations;
this.queueCapacity = queueCapacity;
this.cacheSize = cacheSize;
}

public TaskParams(StreamInput in) throws IOException {
Expand All @@ -334,13 +390,23 @@ public TaskParams(StreamInput in) throws IOException {
this.threadsPerAllocation = in.readVInt();
this.numberOfAllocations = in.readVInt();
this.queueCapacity = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
} else {
this.cacheSize = null;
}
}

public String getModelId() {
return modelId;
}

public long estimateMemoryUsageBytes() {
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
// we need to take it into account when returning the estimate.
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
}
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
}

Expand All @@ -355,6 +421,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(threadsPerAllocation);
out.writeVInt(numberOfAllocations);
out.writeVInt(queueCapacity);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(cacheSize);
}
}

@Override
Expand All @@ -365,13 +434,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
if (cacheSize != null) {
builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
}

@Override
Expand All @@ -384,6 +456,7 @@ public boolean equals(Object o) {
&& modelBytes == other.modelBytes
&& threadsPerAllocation == other.threadsPerAllocation
&& numberOfAllocations == other.numberOfAllocations
&& Objects.equals(cacheSize, other.cacheSize)
&& queueCapacity == other.queueCapacity;
}

Expand All @@ -408,6 +481,14 @@ public int getQueueCapacity() {
return queueCapacity;
}

public Optional<ByteSizeValue> getCacheSize() {
return Optional.ofNullable(cacheSize);
}

public long getCacheSizeBytes() {
return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
}

@Override
public String toString() {
return Strings.toString(this);
Expand Down

0 comments on commit afa28d4

Please sign in to comment.