Skip to content

Commit

Permalink
[ML] Add error counts to trained model stats (#82705)
Browse files Browse the repository at this point in the history
Adds inference_count, timeout_count, rejected_execution_count
and error_count fields to trained model stats.
  • Loading branch information
davidkyle committed Jan 27, 2022
1 parent 1504c93 commit c1fbf87
Show file tree
Hide file tree
Showing 12 changed files with 301 additions and 43 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/82705.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 82705
summary: Add error counts to trained model stats
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,26 @@ The detailed allocation state related to the nodes.
The desired number of nodes for model allocation.
======

`error_count`:::
(integer)
The sum of `error_count` for all nodes in the deployment.

`inference_count`:::
(integer)
The sum of `inference_count` for all nodes in the deployment.

`inference_threads`:::
(integer)
The number of threads used by the inference process.

`model_id`:::
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

`model_threads`:::
(integer)
The number of threads used when sending inference requests to the model.

`nodes`:::
(array of objects)
The deployment stats for each node that currently has the model allocated.
Expand All @@ -127,14 +143,30 @@ The deployment stats for each node that currently has the model allocated.
(double)
The average time for each inference call to complete on this node.
`error_count`:::
(integer)
The number of errors when evaluating the trained model.
`inference_count`:::
(integer)
The total number of inference calls made against this node for this model.
`inference_threads`:::
(integer)
The number of threads used by the inference process.
This value is limited by the number of hardware threads on the node;
it might therefore differ from the `inference_threads` value in the <<start-trained-model-deployment>> API.
`last_access`:::
(long)
The epoch time stamp of the last inference call for the model on this node.
`model_threads`:::
(integer)
The number of threads used when sending inference requests to the model.
This value is limited by the number of hardware threads on the node;
it might therefore differ from the `model_threads` value in the <<start-trained-model-deployment>> API.
`node`:::
(object)
Information pertaining to the node.
Expand Down Expand Up @@ -162,28 +194,59 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-id]
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-transport-address]
========
`reason`:::
(string)
The reason for the current state. Usually only populated when the `routing_state` is `failed`.
`number_of_pending_requests`:::
(integer)
The number of inference requests queued to be processed.
`routing_state`:::
(object)
The current routing state and reason for the current routing state for this allocation.
+
.Properties of routing_state
[%collapsible%open]
========
`reason`:::
(string)
The reason for the current state. Usually only populated when the `routing_state` is `failed`.

`routing_state`:::
(string)
The current routing state.
--
* `starting`: The model is attempting to allocate on this model, inference calls are not yet accepted.
* `started`: The model is allocated and ready to accept inference requests.
* `stopping`: The model is being deallocated from this node.
* `stopped`: The model is fully deallocated from this node.
* `failed`: The allocation attempt failed, see `reason` field for the potential cause.
--
========
`rejected_execution_count`:::
(integer)
The number of inference requests that were not processed because the
queue was full.
`start_time`:::
(long)
The epoch timestamp when the allocation started.
`timeout_count`:::
(integer)
The number of inference requests that timed out before being processed.
======

`rejected_execution_count`:::
(integer)
The sum of `rejected_execution_count` for all nodes in the deployment.
Individual nodes reject an inference request if the inference queue is full.
The queue size is controlled by the `queue_capacity` setting in the
<<start-trained-model-deployment>> API.

`reason`:::
(string)
The reason for the current deployment state.
Usually only populated when the model is not deployed to a node.

`start_time`:::
(long)
The epoch timestamp when the deployment started.
Expand All @@ -198,6 +261,15 @@ The overall state of the deployment. The values may be:
* `stopping`: The deployment is preparing to stop and deallocate the model from the relevant nodes.
--

`timeout_count`:::
(integer)
The sum of `timeout_count` for all nodes in the deployment.

`queue_capacity`:::
(integer)
The number of inference requests that may be queued before new requests are
rejected.

=====
`inference_stats`:::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public static class NodeStats implements ToXContentObject, Writeable {
private final Double avgInferenceTime;
private final Instant lastAccess;
private final Integer pendingCount;
private final int errorCount;
private final int rejectedExecutionCount;
private final int timeoutCount;
private final RoutingStateAndReason routingState;
private final Instant startTime;
private final Integer inferenceThreads;
Expand All @@ -41,6 +44,9 @@ public static AllocationStats.NodeStats forStartedState(
long inferenceCount,
Double avgInferenceTime,
int pendingCount,
int errorCount,
int rejectedExecutionCount,
int timeoutCount,
Instant lastAccess,
Instant startTime,
Integer inferenceThreads,
Expand All @@ -52,6 +58,9 @@ public static AllocationStats.NodeStats forStartedState(
avgInferenceTime,
lastAccess,
pendingCount,
errorCount,
rejectedExecutionCount,
timeoutCount,
new RoutingStateAndReason(RoutingState.STARTED, null),
Objects.requireNonNull(startTime),
inferenceThreads,
Expand All @@ -60,7 +69,20 @@ public static AllocationStats.NodeStats forStartedState(
}

public static AllocationStats.NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
return new AllocationStats.NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null, null, null);
return new AllocationStats.NodeStats(
node,
null,
null,
null,
null,
0,
0,
0,
new RoutingStateAndReason(state, reason),
null,
null,
null
);
}

public NodeStats(
Expand All @@ -69,6 +91,9 @@ public NodeStats(
Double avgInferenceTime,
Instant lastAccess,
Integer pendingCount,
int errorCount,
int rejectedExecutionCount,
int timeoutCount,
RoutingStateAndReason routingState,
@Nullable Instant startTime,
@Nullable Integer inferenceThreads,
Expand All @@ -79,6 +104,9 @@ public NodeStats(
this.avgInferenceTime = avgInferenceTime;
this.lastAccess = lastAccess;
this.pendingCount = pendingCount;
this.errorCount = errorCount;
this.rejectedExecutionCount = rejectedExecutionCount;
this.timeoutCount = timeoutCount;
this.routingState = routingState;
this.startTime = startTime;
this.inferenceThreads = inferenceThreads;
Expand All @@ -96,13 +124,18 @@ public NodeStats(StreamInput in) throws IOException {
this.pendingCount = in.readOptionalVInt();
this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
this.startTime = in.readOptionalInstant();

if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
this.inferenceThreads = in.readOptionalVInt();
this.modelThreads = in.readOptionalVInt();
this.errorCount = in.readVInt();
this.rejectedExecutionCount = in.readVInt();
this.timeoutCount = in.readVInt();
} else {
this.inferenceThreads = null;
this.modelThreads = null;
this.errorCount = 0;
this.rejectedExecutionCount = 0;
this.timeoutCount = 0;
}
}

Expand Down Expand Up @@ -130,6 +163,18 @@ public Integer getPendingCount() {
return pendingCount;
}

public int getErrorCount() {
return errorCount;
}

public int getRejectedExecutionCount() {
return rejectedExecutionCount;
}

public int getTimeoutCount() {
return timeoutCount;
}

public Instant getStartTime() {
return startTime;
}
Expand All @@ -146,7 +191,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (inferenceCount != null) {
builder.field("inference_count", inferenceCount);
}
if (avgInferenceTime != null) {
// avoid reporting the average time as 0 if count < 1
if (avgInferenceTime != null && (inferenceCount != null && inferenceCount > 0)) {
builder.field("average_inference_time_ms", avgInferenceTime);
}
if (lastAccess != null) {
Expand All @@ -155,6 +201,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (pendingCount != null) {
builder.field("number_of_pending_requests", pendingCount);
}
if (errorCount > 0) {
builder.field("error_count", errorCount);
}
if (rejectedExecutionCount > 0) {
builder.field("rejected_execution_count", rejectedExecutionCount);
}
if (timeoutCount > 0) {
builder.field("timeout_count", timeoutCount);
}
if (startTime != null) {
builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
}
Expand All @@ -180,6 +235,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
out.writeOptionalVInt(inferenceThreads);
out.writeOptionalVInt(modelThreads);
out.writeVInt(errorCount);
out.writeVInt(rejectedExecutionCount);
out.writeVInt(timeoutCount);
}
}

Expand All @@ -193,6 +251,9 @@ public boolean equals(Object o) {
&& Objects.equals(node, that.node)
&& Objects.equals(lastAccess, that.lastAccess)
&& Objects.equals(pendingCount, that.pendingCount)
&& Objects.equals(errorCount, that.errorCount)
&& Objects.equals(rejectedExecutionCount, that.rejectedExecutionCount)
&& Objects.equals(timeoutCount, that.timeoutCount)
&& Objects.equals(routingState, that.routingState)
&& Objects.equals(startTime, that.startTime)
&& Objects.equals(inferenceThreads, that.inferenceThreads)
Expand All @@ -207,6 +268,9 @@ public int hashCode() {
avgInferenceTime,
lastAccess,
pendingCount,
errorCount,
rejectedExecutionCount,
timeoutCount,
routingState,
startTime,
inferenceThreads,
Expand Down Expand Up @@ -331,6 +395,28 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("allocation_status", allocationStatus);
}
builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());

int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
int totalRejectedExecutionCount = nodeStats.stream().mapToInt(NodeStats::getRejectedExecutionCount).sum();
int totalTimeoutCount = nodeStats.stream().mapToInt(NodeStats::getTimeoutCount).sum();
long totalInferenceCount = nodeStats.stream()
.filter(n -> n.getInferenceCount().isPresent())
.mapToLong(n -> n.getInferenceCount().get())
.sum();

if (totalErrorCount > 0) {
builder.field("error_count", totalErrorCount);
}
if (totalRejectedExecutionCount > 0) {
builder.field("rejected_execution_count", totalRejectedExecutionCount);
}
if (totalTimeoutCount > 0) {
builder.field("timeout_count", totalTimeoutCount);
}
if (totalInferenceCount > 0) {
builder.field("inference_count", totalInferenceCount);
}

builder.startArray("nodes");
for (AllocationStats.NodeStats nodeStat : nodeStats) {
nodeStat.toXContent(builder, params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
nodeStats.getAvgInferenceTime().orElse(null),
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
0,
0,
0,
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public static AllocationStats.NodeStats randomNodeStats(DiscoveryNode node) {
randomNonNegativeLong(),
randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
randomIntBetween(0, 100),
randomIntBetween(0, 100),
randomIntBetween(0, 100),
randomIntBetween(0, 100),
Instant.now(),
Instant.now(),
randomIntBetween(1, 16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,11 @@ protected void taskOperation(
AllocationStats.NodeStats.forStartedState(
clusterService.localNode(),
stats.get().timingStats().getCount(),
// avoid reporting the average time as 0 if count < 1
(stats.get().timingStats().getCount() > 0) ? stats.get().timingStats().getAverage() : null,
stats.get().timingStats().getAverage(),
stats.get().pendingCount(),
stats.get().errorCount(),
stats.get().rejectedExecutionCount(),
stats.get().timeoutCount(),
stats.get().lastUsed(),
stats.get().startTime(),
stats.get().inferenceThreads(),
Expand Down

0 comments on commit c1fbf87

Please sign in to comment.