Skip to content

Commit

Permalink
[ML] Fix missing deployment stats peak throughput field #85436
Browse files Browse the repository at this point in the history
In an edge case peak_throughput_per_minute was not being returned 
even if the stat could be calculated for the last bucket
  • Loading branch information
davidkyle committed Mar 29, 2022
1 parent 4887756 commit 7a22f39
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,20 @@ public void testLiveDeploymentStats() throws IOException {
assertThat(nodes, hasSize(2));
for (var node : nodes) {
assertThat(node.get("number_of_pending_requests"), notNullValue());
// last_access and average_inference_time_ms may be null if inference wasn't performed on this node
}
// last_access and average_inference_time_ms may be null if inference wasn't performed on this node
assertAtLeastOneOfTheseIsNotNull("last_access", nodes);
assertAtLeastOneOfTheseIsNotNull("average_inference_time_ms", nodes);

int inferenceCount = sumInferenceCountOnNodes(nodes);
assertThat(inferenceCount, equalTo(2));
}
}

private void assertAtLeastOneOfTheseIsNotNull(String name, List<Map<String, Object>> nodes) {
assertTrue("all nodes have null value for [" + name + "]", nodes.stream().anyMatch(n -> n.get(name) != null));
}

@SuppressWarnings("unchecked")
public void testGetDeploymentStats_WithWildcard() throws IOException {
String modelFoo = "foo";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public synchronized ResultStats getResultStats() {
// in this period to close off the last period stats.
// The stats are valid return them here
rs = new RecentStats(lastPeriodSummaryStats.getCount(), lastPeriodSummaryStats.getAverage());
peakThroughput = Math.max(peakThroughput, lastPeriodSummaryStats.getCount());
}

if (rs == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.hamcrest.Matchers.comparesEqualTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -150,7 +151,7 @@ public void testsStats() {
assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L));
}

public void testsRecentStats() {
public void testsTimeDependentStats() {

long start = System.currentTimeMillis();
// the first value is used in the ctor to set the start time.
Expand Down Expand Up @@ -211,12 +212,14 @@ public void testsRecentStats() {
// first call has no results as is in the same period
var stats = processor.getResultStats();
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
assertThat(stats.recentStats().avgInferenceTime(), nullValue());
// 2nd time in the next period
stats = processor.getResultStats();
assertNotNull(stats.recentStats());
assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));
assertThat(stats.recentStats().avgInferenceTime(), closeTo(200.0, 0.00001));
assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[3])));
assertThat(stats.peakThroughput(), equalTo(3L));

// 2nd period
processor.processInferenceResult(new PyTorchInferenceResult("foo", null, 100L, null));
Expand All @@ -225,6 +228,7 @@ public void testsRecentStats() {
assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
assertThat(stats.recentStats().avgInferenceTime(), closeTo(100.0, 0.00001));
assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[6])));
assertThat(stats.peakThroughput(), equalTo(3L));

stats = processor.getResultStats();
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
Expand All @@ -242,6 +246,7 @@ public void testsRecentStats() {
processor.processInferenceResult(new PyTorchInferenceResult("foo", null, 390L, null));
stats = processor.getResultStats();
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
assertThat(stats.recentStats().avgInferenceTime(), nullValue());
stats = processor.getResultStats(); // called in the next period
assertNotNull(stats.recentStats());
assertThat(stats.recentStats().requestsProcessed(), equalTo(2L));
Expand All @@ -257,6 +262,7 @@ public void testsRecentStats() {
assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));
assertThat(stats.recentStats().avgInferenceTime(), closeTo(500.0, 0.00001));
assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[17])));
assertThat(stats.peakThroughput(), equalTo(3L));
}

private static class TimeSupplier implements LongSupplier {
Expand Down

0 comments on commit 7a22f39

Please sign in to comment.