Skip to content

Commit

Permalink
[8.0][ML] Improve reporting of trained model size stats (#82000) (#82041
Browse files Browse the repository at this point in the history
)

This improves reporting of trained model size in the response of the stats API.

In particular, it removes the `model_size_bytes` from the `deployment_stats` section and
replaces it with a top-level `model_size_stats` object that contains:

- `model_size_bytes`: the actual model size
- `required_native_memory_bytes`: the amount of memory required to load a model

In addition, these are now reported for PyTorch models regardless of their deployment state.

Backport of #82000
  • Loading branch information
dimitris-athanasiou committed Dec 22, 2021
1 parent 3c21d01 commit c73d651
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ The desired number of nodes for model allocation.
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

`model_size`:::
(<<byte-units,byte value>>)
The size of the loaded model in bytes.

`nodes`:::
(array of objects)
The deployment stats for each node that currently has the model allocated.
Expand Down Expand Up @@ -249,6 +245,23 @@ section in <<cluster-nodes-stats>>.
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
`model_size_stats`:::
(object)
A collection of model size stats fields.
+
.Properties of model size stats
[%collapsible%open]
=====

`model_size_bytes`:::
(integer)
The size of the model in bytes.

`required_native_memory_bytes`:::
(integer)
The amount of memory required to load the model in bytes.
=====
`pipeline_count`:::
(integer)
The number of ingest pipelines that currently refer to the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -41,6 +42,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField MODEL_SIZE_STATS = new ParseField("model_size_stats");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
public static final ParseField DEPLOYMENT_STATS = new ParseField("deployment_stats");
Expand Down Expand Up @@ -77,6 +79,7 @@ public static class Response extends AbstractGetResourcesResponse<Response.Train

public static class TrainedModelStats implements ToXContentObject, Writeable {
private final String modelId;
private final TrainedModelSizeStats modelSizeStats;
private final IngestStats ingestStats;
private final InferenceStats inferenceStats;
private final AllocationStats deploymentStats;
Expand All @@ -90,12 +93,14 @@ public static class TrainedModelStats implements ToXContentObject, Writeable {

public TrainedModelStats(
String modelId,
TrainedModelSizeStats modelSizeStats,
IngestStats ingestStats,
int pipelineCount,
InferenceStats inferenceStats,
AllocationStats deploymentStats
) {
this.modelId = Objects.requireNonNull(modelId);
this.modelSizeStats = modelSizeStats;
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
if (pipelineCount < 0) {
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
Expand All @@ -107,6 +112,11 @@ public TrainedModelStats(

public TrainedModelStats(StreamInput in) throws IOException {
modelId = in.readString();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
modelSizeStats = in.readOptionalWriteable(TrainedModelSizeStats::new);
} else {
modelSizeStats = null;
}
ingestStats = new IngestStats(in);
pipelineCount = in.readVInt();
inferenceStats = in.readOptionalWriteable(InferenceStats::new);
Expand All @@ -121,6 +131,10 @@ public String getModelId() {
return modelId;
}

public TrainedModelSizeStats getModelSizeStats() {
return modelSizeStats;
}

public IngestStats getIngestStats() {
return ingestStats;
}
Expand All @@ -141,6 +155,9 @@ public AllocationStats getDeploymentStats() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
if (modelSizeStats != null) {
builder.field(MODEL_SIZE_STATS.getPreferredName(), modelSizeStats);
}
builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
if (pipelineCount > 0) {
// Ingest stats is a fragment
Expand All @@ -159,6 +176,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeOptionalWriteable(modelSizeStats);
}
ingestStats.writeTo(out);
out.writeVInt(pipelineCount);
out.writeOptionalWriteable(inferenceStats);
Expand All @@ -169,7 +189,7 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats, deploymentStats);
return Objects.hash(modelId, modelSizeStats, ingestStats, pipelineCount, inferenceStats, deploymentStats);
}

@Override
Expand All @@ -182,6 +202,7 @@ public boolean equals(Object obj) {
}
TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.modelSizeStats, other.modelSizeStats)
&& Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount)
&& Objects.equals(this.deploymentStats, other.deploymentStats)
Expand All @@ -208,6 +229,7 @@ public static class Builder {

private long totalModelCount;
private Map<String, Set<String>> expandedIdsWithAliases;
private Map<String, TrainedModelSizeStats> modelSizeStatsMap;
private Map<String, IngestStats> ingestStatsMap;
private Map<String, InferenceStats> inferenceStatsMap;
private Map<String, AllocationStats> allocationStatsMap;
Expand All @@ -226,6 +248,11 @@ public Map<String, Set<String>> getExpandedIdsWithAliases() {
return this.expandedIdsWithAliases;
}

public Builder setModelSizeStatsByModelId(Map<String, TrainedModelSizeStats> modelSizeStatsByModelId) {
this.modelSizeStatsMap = modelSizeStatsByModelId;
return this;
}

public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
this.ingestStatsMap = ingestStatsByModelId;
return this;
Expand All @@ -244,12 +271,14 @@ public Builder setDeploymentStatsByModelId(Map<String, AllocationStats> allocati
public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
expandedIdsWithAliases.keySet().forEach(id -> {
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(id);
IngestStats ingestStats = ingestStatsMap.get(id);
InferenceStats inferenceStats = inferenceStatsMap.get(id);
AllocationStats allocationStats = allocationStatsMap.get(id);
trainedModelStats.add(
new TrainedModelStats(
id,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM

public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS);

/**
* This has been found to be approximately 300MB on linux by manual testing.
* We also subtract 30MB that we always add as overhead (see MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD).
* TODO Check if it is substantially different in other platforms.
*/
private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);

public StartTrainedModelDeploymentAction() {
super(NAME, CreateTrainedModelAllocationAction.Response::new);
}
Expand Down Expand Up @@ -265,13 +272,6 @@ public static TaskParams fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

/**
* This has been found to be approximately 300MB on linux by manual testing.
* We also subtract 30MB that we always add as overhead (see MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD).
* TODO Check if it is substantially different in other platforms.
*/
private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);

private final String modelId;
private final long modelBytes;
// How many threads are used by the model during inference. Used to increase inference speed.
Expand Down Expand Up @@ -301,8 +301,7 @@ public String getModelId() {
}

public long estimateMemoryUsageBytes() {
// While loading the model in the process we need twice the model size.
return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
}

public Version getMinimalSupportedVersion() {
Expand Down Expand Up @@ -388,4 +387,9 @@ static boolean match(Task task, String expectedId) {
return false;
}
}

public static long estimateMemoryUsageBytes(long totalDefinitionLength) {
// While loading the model in the process we need twice the model size.
return MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -168,8 +167,6 @@ public int hashCode() {
private AllocationStatus allocationStatus;
private String reason;
@Nullable
private final ByteSizeValue modelSize;
@Nullable
private final Integer inferenceThreads;
@Nullable
private final Integer modelThreads;
Expand All @@ -180,15 +177,13 @@ public int hashCode() {

public AllocationStats(
String modelId,
@Nullable ByteSizeValue modelSize,
@Nullable Integer inferenceThreads,
@Nullable Integer modelThreads,
@Nullable Integer queueCapacity,
Instant startTime,
List<AllocationStats.NodeStats> nodeStats
) {
this.modelId = modelId;
this.modelSize = modelSize;
this.inferenceThreads = inferenceThreads;
this.modelThreads = modelThreads;
this.queueCapacity = queueCapacity;
Expand All @@ -200,7 +195,6 @@ public AllocationStats(

public AllocationStats(StreamInput in) throws IOException {
modelId = in.readString();
modelSize = in.readOptionalWriteable(ByteSizeValue::new);
inferenceThreads = in.readOptionalVInt();
modelThreads = in.readOptionalVInt();
queueCapacity = in.readOptionalVInt();
Expand All @@ -215,10 +209,6 @@ public String getModelId() {
return modelId;
}

public ByteSizeValue getModelSize() {
return modelSize;
}

@Nullable
public Integer getInferenceThreads() {
return inferenceThreads;
Expand Down Expand Up @@ -269,9 +259,6 @@ public AllocationStats setReason(String reason) {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("model_id", modelId);
if (modelSize != null) {
builder.humanReadableField("model_size_bytes", "model_size", modelSize);
}
if (inferenceThreads != null) {
builder.field(StartTrainedModelDeploymentAction.TaskParams.INFERENCE_THREADS.getPreferredName(), inferenceThreads);
}
Expand Down Expand Up @@ -303,7 +290,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalWriteable(modelSize);
out.writeOptionalVInt(inferenceThreads);
out.writeOptionalVInt(modelThreads);
out.writeOptionalVInt(queueCapacity);
Expand All @@ -320,7 +306,6 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
AllocationStats that = (AllocationStats) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(modelSize, that.modelSize)
&& Objects.equals(inferenceThreads, that.inferenceThreads)
&& Objects.equals(modelThreads, that.modelThreads)
&& Objects.equals(queueCapacity, that.queueCapacity)
Expand All @@ -333,17 +318,6 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(
modelId,
modelSize,
inferenceThreads,
modelThreads,
queueCapacity,
startTime,
nodeStats,
state,
reason,
allocationStatus
);
return Objects.hash(modelId, inferenceThreads, modelThreads, queueCapacity, startTime, nodeStats, state, reason, allocationStatus);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class TrainedModelSizeStats implements ToXContentObject, Writeable {

private static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes");
private static final ParseField REQUIRED_NATIVE_MEMORY_BYTES = new ParseField("required_native_memory_bytes");

private final long modelSizeBytes;
private final long requiredNativeMemoryBytes;

public TrainedModelSizeStats(long modelSizeBytes, long requiredNativeMemoryBytes) {
this.modelSizeBytes = modelSizeBytes;
this.requiredNativeMemoryBytes = requiredNativeMemoryBytes;
}

public TrainedModelSizeStats(StreamInput in) throws IOException {
modelSizeBytes = in.readLong();
requiredNativeMemoryBytes = in.readLong();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeLong(modelSizeBytes);
out.writeLong(requiredNativeMemoryBytes);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.humanReadableField(MODEL_SIZE_BYTES.getPreferredName(), "model_size", ByteSizeValue.ofBytes(modelSizeBytes));
builder.humanReadableField(
REQUIRED_NATIVE_MEMORY_BYTES.getPreferredName(),
"required_native_memory",
ByteSizeValue.ofBytes(requiredNativeMemoryBytes)
);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelSizeStats that = (TrainedModelSizeStats) o;
return modelSizeBytes == that.modelSizeBytes && requiredNativeMemoryBytes == that.requiredNativeMemoryBytes;
}

@Override
public int hashCode() {
return Objects.hash(modelSizeBytes, requiredNativeMemoryBytes);
}
}

0 comments on commit c73d651

Please sign in to comment.