Skip to content

Commit

Permalink
[ML] Add queue_capacity to deployment stats (#79905)
Browse files Browse the repository at this point in the history
Adds the `queue_capacity` to the response of the
get trained model deployment stats API.
  • Loading branch information
dimitris-athanasiou committed Oct 27, 2021
1 parent 46b91f5 commit ee97aa3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -233,19 +233,22 @@ public int hashCode() {
@Nullable private final ByteSizeValue modelSize;
@Nullable private final Integer inferenceThreads;
@Nullable private final Integer modelThreads;
@Nullable private final Integer queueCapacity;
private final List<NodeStats> nodeStats;

public AllocationStats(
String modelId,
@Nullable ByteSizeValue modelSize,
@Nullable Integer inferenceThreads,
@Nullable Integer modelThreads,
@Nullable Integer queueCapacity,
List<NodeStats> nodeStats
) {
this.modelId = modelId;
this.modelSize = modelSize;
this.inferenceThreads = inferenceThreads;
this.modelThreads = modelThreads;
this.queueCapacity = queueCapacity;
this.nodeStats = nodeStats;
this.state = null;
this.reason = null;
Expand All @@ -256,6 +259,7 @@ public AllocationStats(StreamInput in) throws IOException {
modelSize = in.readOptionalWriteable(ByteSizeValue::new);
inferenceThreads = in.readOptionalVInt();
modelThreads = in.readOptionalVInt();
queueCapacity = in.readOptionalVInt();
nodeStats = in.readList(NodeStats::new);
state = in.readOptionalEnum(AllocationState.class);
reason = in.readOptionalString();
Expand All @@ -280,6 +284,11 @@ public Integer getModelThreads() {
return modelThreads;
}

@Nullable
public Integer getQueueCapacity() {
return queueCapacity;
}

public List<NodeStats> getNodeStats() {
return nodeStats;
}
Expand Down Expand Up @@ -320,6 +329,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelThreads != null) {
builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
}
if (queueCapacity != null) {
builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
}
if (state != null) {
builder.field("state", state);
}
Expand All @@ -344,6 +356,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(modelSize);
out.writeOptionalVInt(inferenceThreads);
out.writeOptionalVInt(modelThreads);
out.writeOptionalVInt(queueCapacity);
out.writeList(nodeStats);
out.writeOptionalEnum(state);
out.writeOptionalString(reason);
Expand All @@ -359,6 +372,7 @@ public boolean equals(Object o) {
Objects.equals(modelSize, that.modelSize) &&
Objects.equals(inferenceThreads, that.inferenceThreads) &&
Objects.equals(modelThreads, that.modelThreads) &&
Objects.equals(queueCapacity, that.queueCapacity) &&
Objects.equals(state, that.state) &&
Objects.equals(reason, that.reason) &&
Objects.equals(allocationStatus, that.allocationStatus) &&
Expand All @@ -367,7 +381,8 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, nodeStats, state, reason, allocationStatus);
return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, queueCapacity, nodeStats, state, reason,
allocationStatus);
}
}

Expand Down Expand Up @@ -482,6 +497,7 @@ public static GetDeploymentStatsAction.Response addFailedRoutes(
stat.getModelSize(),
stat.getInferenceThreads(),
stat.getModelThreads(),
stat.getQueueCapacity(),
updatedNodeStats
)
);
Expand Down Expand Up @@ -510,7 +526,7 @@ public static GetDeploymentStatsAction.Response addFailedRoutes(
nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));

updatedAllocationStats.add(new GetDeploymentStatsAction.Response.AllocationStats(
modelId, null, null, null, nodeStats)
modelId, null, null, null, null, nodeStats)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostExceptio
ByteSizeValue.ofBytes(randomNonNegativeLong()),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
nodeStatsList);

Map<String, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
Expand Down Expand Up @@ -145,6 +146,7 @@ public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostExce
ByteSizeValue.ofBytes(randomNonNegativeLong()),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
nodeStatsList);
var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(),
List.of(model1), 1);
Expand Down Expand Up @@ -202,6 +204,7 @@ node, randomFrom(RoutingState.values()), randomBoolean() ? null : "a good reason
ByteSizeValue.ofBytes(randomNonNegativeLong()),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 8),
randomBoolean() ? null : randomIntBetween(1, 10000),
nodeStatsList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ protected void taskOperation(GetDeploymentStatsAction.Request request, TrainedMo
ByteSizeValue.ofBytes(task.getParams().getModelBytes()),
task.getParams().getInferenceThreads(),
task.getParams().getModelThreads(),
task.getParams().getQueueCapacity(),
nodeStats)
);
}
Expand Down

0 comments on commit ee97aa3

Please sign in to comment.