Skip to content

Commit

Permalink
[ML] Add throughput stats for Trained Model Deployments (#84628)
Browse files Browse the repository at this point in the history
Throughput is measured as the number of inference requests 
processed per minute. The node level stats peak_throughput_per_minute, 
throughput_last_minute and average_inference_time_ms_last_minute are 
added with a deployment level stat peak_throughput_per_minute which
 is the summed throughput of all nodes.
  • Loading branch information
davidkyle committed Mar 8, 2022
1 parent aaf66f9 commit 27ae821
Show file tree
Hide file tree
Showing 12 changed files with 573 additions and 30 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/84628.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 84628
summary: Add throughput stats for Trained Model Deployments
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ The deployment stats for each node that currently has the model allocated.
`average_inference_time_ms`:::
(double)
The average time for each inference call to complete on this node.
The average is calculated over the lifetime of the deployment.
`average_inference_time_ms_last_minute`:::
(double)
The average time for each inference call to complete on this node
in the last minute.
`error_count`:::
(integer)
Expand Down Expand Up @@ -198,6 +204,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-transport-address]
(integer)
The number of inference requests queued to be processed.
`peak_throughput_per_minute`:::
(integer)
The peak number of requests processed in a 1 minute period.
`routing_state`:::
(object)
The current routing state and reason for the current routing state for this allocation.
Expand Down Expand Up @@ -233,7 +243,16 @@ The epoch timestamp when the allocation started.
`timeout_count`:::
(integer)
The number of inference requests that timed out before being processed.
`throughput_last_minute`:::
(integer)
The number of requests processed in the last 1 minute.
======
`peak_throughput_per_minute`:::
(integer)
The peak number of requests processed in a 1 minute period for
all nodes in the deployment. This is calculated as the sum of
each node's `peak_throughput_per_minute` value.

`rejected_execution_count`:::
(integer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -38,6 +39,9 @@ public static class NodeStats implements ToXContentObject, Writeable {
private final Instant startTime;
private final Integer inferenceThreads;
private final Integer modelThreads;
private final long peakThroughput;
private final long throughputLastPeriod;
private final Double avgInferenceTimeLastPeriod;

public static AllocationStats.NodeStats forStartedState(
DiscoveryNode node,
Expand All @@ -50,7 +54,10 @@ public static AllocationStats.NodeStats forStartedState(
Instant lastAccess,
Instant startTime,
Integer inferenceThreads,
Integer modelThreads
Integer modelThreads,
long peakThroughput,
long throughputLastPeriod,
Double avgInferenceTimeLastPeriod
) {
return new AllocationStats.NodeStats(
node,
Expand All @@ -64,7 +71,10 @@ public static AllocationStats.NodeStats forStartedState(
new RoutingStateAndReason(RoutingState.STARTED, null),
Objects.requireNonNull(startTime),
inferenceThreads,
modelThreads
modelThreads,
peakThroughput,
throughputLastPeriod,
avgInferenceTimeLastPeriod
);
}

Expand All @@ -81,6 +91,9 @@ public static AllocationStats.NodeStats forNotStartedState(DiscoveryNode node, R
new RoutingStateAndReason(state, reason),
null,
null,
null,
0,
0,
null
);
}
Expand All @@ -97,7 +110,10 @@ public NodeStats(
RoutingStateAndReason routingState,
@Nullable Instant startTime,
@Nullable Integer inferenceThreads,
@Nullable Integer modelThreads
@Nullable Integer modelThreads,
long peakThroughput,
long throughputLastPeriod,
Double avgInferenceTimeLastPeriod
) {
this.node = node;
this.inferenceCount = inferenceCount;
Expand All @@ -111,6 +127,9 @@ public NodeStats(
this.startTime = startTime;
this.inferenceThreads = inferenceThreads;
this.modelThreads = modelThreads;
this.peakThroughput = peakThroughput;
this.throughputLastPeriod = throughputLastPeriod;
this.avgInferenceTimeLastPeriod = avgInferenceTimeLastPeriod;

// if lastAccess time is null there have been no inferences
assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
Expand All @@ -137,6 +156,15 @@ public NodeStats(StreamInput in) throws IOException {
this.rejectedExecutionCount = 0;
this.timeoutCount = 0;
}
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.peakThroughput = in.readVLong();
this.throughputLastPeriod = in.readVLong();
this.avgInferenceTimeLastPeriod = in.readOptionalDouble();
} else {
this.peakThroughput = 0;
this.throughputLastPeriod = 0;
this.avgInferenceTimeLastPeriod = null;
}
}

public DiscoveryNode getNode() {
Expand Down Expand Up @@ -179,6 +207,26 @@ public Instant getStartTime() {
return startTime;
}

public Integer getInferenceThreads() {
return inferenceThreads;
}

public Integer getModelThreads() {
return modelThreads;
}

public long getPeakThroughput() {
return peakThroughput;
}

public long getThroughputLastPeriod() {
return throughputLastPeriod;
}

public Double getAvgInferenceTimeLastPeriod() {
return avgInferenceTimeLastPeriod;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -219,6 +267,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelThreads != null) {
builder.field("model_threads", modelThreads);
}
builder.field("peak_throughput_per_minute", peakThroughput);
builder.field("throughput_last_minute", throughputLastPeriod);
if (avgInferenceTimeLastPeriod != null) {
builder.field("average_inference_time_ms_last_minute", avgInferenceTimeLastPeriod);
}

builder.endObject();
return builder;
}
Expand All @@ -239,6 +293,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(rejectedExecutionCount);
out.writeVInt(timeoutCount);
}
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeVLong(peakThroughput);
out.writeVLong(throughputLastPeriod);
out.writeOptionalDouble(avgInferenceTimeLastPeriod);
}
}

@Override
Expand All @@ -257,7 +316,10 @@ public boolean equals(Object o) {
&& Objects.equals(routingState, that.routingState)
&& Objects.equals(startTime, that.startTime)
&& Objects.equals(inferenceThreads, that.inferenceThreads)
&& Objects.equals(modelThreads, that.modelThreads);
&& Objects.equals(modelThreads, that.modelThreads)
&& Objects.equals(peakThroughput, that.peakThroughput)
&& Objects.equals(throughputLastPeriod, that.throughputLastPeriod)
&& Objects.equals(avgInferenceTimeLastPeriod, that.avgInferenceTimeLastPeriod);
}

@Override
Expand All @@ -274,7 +336,10 @@ public int hashCode() {
routingState,
startTime,
inferenceThreads,
modelThreads
modelThreads,
peakThroughput,
throughputLastPeriod,
avgInferenceTimeLastPeriod
);
}
}
Expand Down Expand Up @@ -403,6 +468,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.filter(n -> n.getInferenceCount().isPresent())
.mapToLong(n -> n.getInferenceCount().get())
.sum();
long peakThroughput = nodeStats.stream().mapToLong(NodeStats::getPeakThroughput).sum();

if (totalErrorCount > 0) {
builder.field("error_count", totalErrorCount);
Expand All @@ -416,6 +482,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (totalInferenceCount > 0) {
builder.field("inference_count", totalInferenceCount);
}
builder.field("peak_throughput_per_minute", peakThroughput);

builder.startArray("nodes");
for (AllocationStats.NodeStats nodeStat : nodeStats) {
Expand Down Expand Up @@ -459,4 +526,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(modelId, inferenceThreads, modelThreads, queueCapacity, startTime, nodeStats, state, reason, allocationStatus);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,61 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
null,
null,
0L,
0L,
null
)
)
.toList()
)
)
)
.collect(Collectors.toList()),
instance.getResources().count(),
RESULTS_FIELD
)
);
} else if (version.before(Version.V_8_2_0)) {
return new Response(
new QueryPage<>(
instance.getResources()
.results()
.stream()
.map(
stats -> new Response.TrainedModelStats(
stats.getModelId(),
stats.getModelSizeStats(),
stats.getIngestStats(),
stats.getPipelineCount(),
stats.getInferenceStats(),
stats.getDeploymentStats() == null
? null
: new AllocationStats(
stats.getDeploymentStats().getModelId(),
stats.getDeploymentStats().getInferenceThreads(),
stats.getDeploymentStats().getModelThreads(),
stats.getDeploymentStats().getQueueCapacity(),
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
.stream()
.map(
nodeStats -> new AllocationStats.NodeStats(
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
nodeStats.getRejectedExecutionCount(),
nodeStats.getTimeoutCount(),
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
nodeStats.getInferenceThreads(),
nodeStats.getModelThreads(),
0L,
0L,
null
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ public static AllocationStats.NodeStats randomNodeStats(DiscoveryNode node) {
Instant.now(),
Instant.now(),
randomIntBetween(1, 16),
randomIntBetween(1, 16)
randomIntBetween(1, 16),
randomIntBetween(0, 100),
randomIntBetween(0, 100),
randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,23 @@ protected void taskOperation(
List<AllocationStats.NodeStats> nodeStats = new ArrayList<>();

if (stats.isPresent()) {
var presentValue = stats.get();
nodeStats.add(
AllocationStats.NodeStats.forStartedState(
clusterService.localNode(),
stats.get().timingStats().getCount(),
stats.get().timingStats().getAverage(),
stats.get().pendingCount(),
stats.get().errorCount(),
stats.get().rejectedExecutionCount(),
stats.get().timeoutCount(),
stats.get().lastUsed(),
stats.get().startTime(),
stats.get().inferenceThreads(),
stats.get().modelThreads()
presentValue.timingStats().getCount(),
presentValue.timingStats().getAverage(),
presentValue.pendingCount(),
presentValue.errorCount(),
presentValue.rejectedExecutionCount(),
presentValue.timeoutCount(),
presentValue.lastUsed(),
presentValue.startTime(),
presentValue.inferenceThreads(),
presentValue.modelThreads(),
presentValue.peakThroughput(),
presentValue.throughputLastPeriod(),
presentValue.avgInferenceTimeLastPeriod()
)
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
processContext.rejectedExecutionCount.intValue(),
processContext.timeoutCount.intValue(),
processContext.inferenceThreads,
processContext.modelThreads
processContext.modelThreads,
stats.peakThroughput(),
stats.recentStats().requestsProcessed(),
stats.recentStats().avgInferenceTime()
);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ public record ModelStats(
int rejectedExecutionCount,
int timeoutCount,
Integer inferenceThreads,
Integer modelThreads
Integer modelThreads,
long peakThroughput,
long throughputLastPeriod,
Double avgInferenceTimeLastPeriod
) {}

0 comments on commit 27ae821

Please sign in to comment.