Skip to content

Commit

Permalink
[ML] Get stats by deployment or model id (#95440)
Browse files Browse the repository at this point in the history
Trained model stats may be requested by deployment id or model id.
If model id is used and the model has multiple deployments stats for
each deployment are retuned.
  • Loading branch information
davidkyle committed Apr 21, 2023
1 parent 9313830 commit ed955d5
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 93 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/95440.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 95440
summary: "[ML] Get trained model stats by deployment id or model id"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.ingest.IngestStats;
Expand All @@ -30,7 +29,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -236,7 +234,7 @@ protected Reader<Response.TrainedModelStats> getReader() {
public static class Builder {

private long totalModelCount;
private Map<String, Set<String>> expandedIdsWithAliases;
private Map<String, Set<String>> expandedModelIdsWithAliases;
private Map<String, TrainedModelSizeStats> modelSizeStatsMap;
private Map<String, IngestStats> ingestStatsMap;
private Map<String, InferenceStats> inferenceStatsMap;
Expand All @@ -247,13 +245,13 @@ public Builder setTotalModelCount(long totalModelCount) {
return this;
}

public Builder setExpandedIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
this.expandedIdsWithAliases = expandedIdsWithAliases;
public Builder setExpandedModelIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
this.expandedModelIdsWithAliases = expandedIdsWithAliases;
return this;
}

public Map<String, Set<String>> getExpandedIdsWithAliases() {
return this.expandedIdsWithAliases;
public Map<String, Set<String>> getExpandedModelIdsWithAliases() {
return this.expandedModelIdsWithAliases;
}

public Builder setModelSizeStatsByModelId(Map<String, TrainedModelSizeStats> modelSizeStatsByModelId) {
Expand All @@ -276,36 +274,86 @@ public Builder setInferenceStatsByModelId(Map<String, InferenceStats> inferenceS
* @param assignmentStatsMap map of model_id to assignment stats
* @return the builder with inference stats map updated and assignment stats map set
*/
public Builder setDeploymentStatsByModelId(Map<String, AssignmentStats> assignmentStatsMap) {
public Builder setDeploymentStatsByDeploymentId(Map<String, AssignmentStats> assignmentStatsMap) {
this.assignmentStatsMap = assignmentStatsMap;
if (inferenceStatsMap == null) {
inferenceStatsMap = Maps.newHashMapWithExpectedSize(assignmentStatsMap.size());
}
assignmentStatsMap.forEach(
(modelId, assignmentStats) -> inferenceStatsMap.put(modelId, assignmentStats.getOverallInferenceStats())
);
return this;
}

public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
expandedIdsWithAliases.keySet().forEach(id -> {
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(id);
IngestStats ingestStats = ingestStatsMap.get(id);
InferenceStats inferenceStats = inferenceStatsMap.get(id);
AssignmentStats assignmentStats = assignmentStatsMap.get(id);
trainedModelStats.add(
new TrainedModelStats(
id,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
assignmentStats
)
);
public Response build(Map<String, Set<String>> modelToDeploymentIds) {
int numResponses = expandedModelIdsWithAliases.size();
// plus an extra response for every deployment after
// the first per model
for (var entry : modelToDeploymentIds.entrySet()) {
assert expandedModelIdsWithAliases.containsKey(entry.getKey()); // model id
assert entry.getValue().size() > 0; // must have a deployment
numResponses += entry.getValue().size() - 1;
}

if (inferenceStatsMap == null) {
inferenceStatsMap = Collections.emptyMap();
}

List<TrainedModelStats> trainedModelStats = new ArrayList<>(numResponses);
expandedModelIdsWithAliases.keySet().forEach(modelId -> {
if (modelToDeploymentIds.containsKey(modelId) == false) { // not deployed
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId);
IngestStats ingestStats = ingestStatsMap.get(modelId);
InferenceStats inferenceStats = inferenceStatsMap.get(modelId);
trainedModelStats.add(
new TrainedModelStats(
modelId,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
null // no assignment stats for undeployed models
)
);
} else {
for (var deploymentId : modelToDeploymentIds.get(modelId)) {
AssignmentStats assignmentStats = assignmentStatsMap.get(deploymentId);
if (assignmentStats == null) {
continue;
}
InferenceStats inferenceStats = assignmentStats.getOverallInferenceStats();
IngestStats ingestStats = ingestStatsMap.get(deploymentId);
if (ingestStats == null) {
// look up by model id
ingestStats = ingestStatsMap.get(modelId);
}
TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(modelId);
trainedModelStats.add(
new TrainedModelStats(
modelId,
modelSizeStats,
ingestStats,
ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
inferenceStats,
assignmentStats
)
);
}
}
});

// Sort first by model id then by deployment id
trainedModelStats.sort((modelStats1, modelStats2) -> {
var comparison = modelStats1.getModelId().compareTo(modelStats2.getModelId());
if (comparison == 0) {
var deploymentId1 = modelStats1.getDeploymentStats() == null
? null
: modelStats1.getDeploymentStats().getDeploymentId();
var deploymentId2 = modelStats2.getDeploymentStats() == null
? null
: modelStats1.getDeploymentStats().getDeploymentId();

assert deploymentId1 != null && deploymentId2 != null
: "2 results for model " + modelStats1.getModelId() + " both should have deployment stats";

comparison = deploymentId1.compareTo(deploymentId2);
}
return comparison;
});
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.client.Response;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;

public class MultipleDeploymentsIT extends PyTorchModelRestTestCase {

@SuppressWarnings("unchecked")
public void testDeployModelMultipleTimes() throws IOException {
String baseModelId = "base-model";
createPassThroughModel(baseModelId);
putModelDefinition(baseModelId);
putVocabulary(List.of("these", "are", "my", "words"), baseModelId);
putAllModelParts(baseModelId);

String forSearch = "for-search";
startWithDeploymentId(baseModelId, forSearch);
Expand All @@ -35,12 +41,141 @@ public void testDeployModelMultipleTimes() throws IOException {
inference = infer("my words", forIngest);
assertOK(inference);

// TODO
// assertInferenceCount(1, forSearch);
// assertInferenceCount(2, forIngest);
assertInferenceCount(1, forSearch);
assertInferenceCount(2, forIngest);

stopDeployment(forSearch);
stopDeployment(forIngest);

Response statsResponse = getTrainedModelStats("_all");
Map<String, Object> stats = entityAsMap(statsResponse);
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
assertThat(stats.toString(), trainedModelStats, hasSize(2));

for (var statsMap : trainedModelStats) {
// no deployment stats when the deployment is stopped
assertNull(stats.toString(), statsMap.get("deployment_stats"));
}
}

@SuppressWarnings("unchecked")
public void testGetStats() throws IOException {
String undeployedModel1 = "undeployed_1";
putAllModelParts(undeployedModel1);
String undeployedModel2 = "undeployed_2";
putAllModelParts(undeployedModel2);

String modelWith1Deployment = "model-with-1-deployment";
putAllModelParts(modelWith1Deployment);

String modelWith2Deployments = "model-with-2-deployments";
putAllModelParts(modelWith2Deployments);
String forSearchDeployment = "for-search";
startWithDeploymentId(modelWith2Deployments, forSearchDeployment);
String forIngestDeployment = "for-ingest";
startWithDeploymentId(modelWith2Deployments, forIngestDeployment);

// deployment Id is the same as model
startDeployment(modelWith1Deployment);

{
Map<String, Object> stats = entityAsMap(getTrainedModelStats("_all"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(
new Tuple<>(undeployedModel1, null),
new Tuple<>(undeployedModel2, null),
new Tuple<>(modelWith1Deployment, modelWith1Deployment),
new Tuple<>(modelWith2Deployments, forSearchDeployment),
new Tuple<>(modelWith2Deployments, forIngestDeployment)
),
trainedModelStats,
true
);

// check the sorted order
assertEquals(trainedModelStats.get(0).get("model_id"), "lang_ident_model_1");
assertEquals(trainedModelStats.get(1).get("model_id"), modelWith1Deployment);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(1)), modelWith1Deployment);
assertEquals(trainedModelStats.get(2).get("model_id"), modelWith2Deployments);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(2)), forIngestDeployment);
assertEquals(trainedModelStats.get(3).get("model_id"), modelWith2Deployments);
assertEquals(MapHelper.dig("deployment_stats.deployment_id", trainedModelStats.get(3)), forSearchDeployment);
assertEquals(trainedModelStats.get(4).get("model_id"), undeployedModel1);
assertEquals(trainedModelStats.get(5).get("model_id"), undeployedModel2);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(modelWith1Deployment));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(List.of(new Tuple<>(modelWith1Deployment, modelWith1Deployment)), trainedModelStats);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(modelWith2Deployments));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)),
trainedModelStats
);
}
{
Map<String, Object> stats = entityAsMap(getTrainedModelStats(forIngestDeployment));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(List.of(new Tuple<>(modelWith2Deployments, forIngestDeployment)), trainedModelStats);
}
{
// wildcard model id matching
Map<String, Object> stats = entityAsMap(getTrainedModelStats("model-with-*"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(
new Tuple<>(modelWith1Deployment, modelWith1Deployment),
new Tuple<>(modelWith2Deployments, forSearchDeployment),
new Tuple<>(modelWith2Deployments, forIngestDeployment)
),
trainedModelStats
);
}
{
// wildcard deployment id matching
Map<String, Object> stats = entityAsMap(getTrainedModelStats("for-*"));
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");
checkExpectedStats(
List.of(new Tuple<>(modelWith2Deployments, forSearchDeployment), new Tuple<>(modelWith2Deployments, forIngestDeployment)),
trainedModelStats
);
}
}

private void checkExpectedStats(List<Tuple<String, String>> modelDeploymentPairs, List<Map<String, Object>> trainedModelStats) {
checkExpectedStats(modelDeploymentPairs, trainedModelStats, false);
}

private void checkExpectedStats(
List<Tuple<String, String>> modelDeploymentPairs,
List<Map<String, Object>> trainedModelStats,
boolean plusOneForLangIdent
) {
var concatenatedIds = new HashSet<String>();
modelDeploymentPairs.forEach(t -> concatenatedIds.add(t.v1() + t.v2()));

int expectedSize = modelDeploymentPairs.size();
if (plusOneForLangIdent) {
expectedSize++;
}
assertEquals(trainedModelStats.toString(), trainedModelStats.size(), expectedSize);
for (var tmStats : trainedModelStats) {
String modelId = (String) tmStats.get("model_id");
String deploymentId = (String) XContentMapValues.extractValue("deployment_stats.deployment_id", tmStats);
concatenatedIds.remove(modelId + deploymentId);
}

assertThat("Missing stats for " + concatenatedIds, concatenatedIds, empty());
}

private void putAllModelParts(String modelId) throws IOException {
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
}

private void putModelDefinition(String modelId) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public void setLogging() throws IOException {
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.action" : "DEBUG"
}}""");
client().performRequest(loggingSettings);
}
Expand Down Expand Up @@ -122,15 +123,24 @@ protected void assertAllocationCount(String modelId, int expectedAllocationCount

@SuppressWarnings("unchecked")
protected void assertInferenceCount(int expectedCount, String deploymentId) throws IOException {
Response noInferenceCallsStatsResponse = getTrainedModelStats(deploymentId);
Map<String, Object> stats = entityAsMap(noInferenceCallsStatsResponse);
Response statsResponse = getTrainedModelStats(deploymentId);
Map<String, Object> stats = entityAsMap(statsResponse);
List<Map<String, Object>> trainedModelStats = (List<Map<String, Object>>) stats.get("trained_model_stats");

boolean deploymentFound = false;
for (var statsMap : trainedModelStats) {
var deploymentStats = (Map<String, Object>) XContentMapValues.extractValue("deployment_stats", statsMap);
// find the matching deployment
if (deploymentId.equals(deploymentStats.get("deployment_id"))) {
List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue("nodes", deploymentStats);
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertEquals(stats.toString(), expectedCount, inferenceCount);
deploymentFound = true;
break;
}
}

List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
"trained_model_stats.0.deployment_stats.nodes",
stats
);
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertEquals(expectedCount, inferenceCount);
assertTrue("No deployment stats found for deployment [" + deploymentId + "]", deploymentFound);
}

protected int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public void testStoreModelViaChunkedPersisterWithNodeInfo() throws IOException {
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
Collections.emptySet(),
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
Expand Down Expand Up @@ -184,6 +185,7 @@ public void testStoreModelViaChunkedPersisterWithoutNodeInfo() throws IOExceptio
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
Collections.emptySet(),
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
Expand Down

0 comments on commit ed955d5

Please sign in to comment.