Skip to content

Commit

Permalink
[ML] track inference model feature usage per node (#79752)
Browse files Browse the repository at this point in the history
This adds feature usage tracking for deployed inference models. The
models are tracked under the existing, inference feature and contain
context related to the model ID. I decided to track the feature via the
allocation task to keep the logic similar between allocation tasks and
licensed persistent tasks. closes:
#76452
  • Loading branch information
benwtrent committed Oct 26, 2021
1 parent 5f350af commit 5ffc4b5
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
"model-inference",
License.OperationMode.PLATINUM
);
public static final LicensedFeature.Persistent ML_PYTORCH_MODEL_INFERENCE_FEATURE = LicensedFeature.persistent(
MachineLearningField.ML_FEATURE_FAMILY,
"pytorch-model-inference",
License.OperationMode.PLATINUM
);

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand All @@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction(
ClusterService clusterService,
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver
IndexNameExpressionResolver indexNameExpressionResolver,
XPackLicenseState licenseState
) {
super(
CreateTrainedModelAllocationAction.NAME,
Expand All @@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction(
clusterService,
deploymentManager,
transportService.getTaskManager(),
threadPool
threadPool,
licenseState
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
Expand Down Expand Up @@ -52,6 +53,7 @@

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.elasticsearch.xpack.ml.MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE;

public class TrainedModelAllocationNodeService implements ClusterStateListener {

Expand All @@ -65,6 +67,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
private final Map<String, TrainedModelDeploymentTask> modelIdToTask;
private final ThreadPool threadPool;
private final Deque<TrainedModelDeploymentTask> loadingModels;
private final XPackLicenseState licenseState;
private volatile Scheduler.Cancellable scheduledFuture;
private volatile boolean stopped;
private volatile String nodeId;
Expand All @@ -74,14 +77,16 @@ public TrainedModelAllocationNodeService(
ClusterService clusterService,
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool
ThreadPool threadPool,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
this.taskManager = taskManager;
this.modelIdToTask = new ConcurrentHashMap<>();
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand All @@ -102,7 +107,8 @@ public void beforeStop() {
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool,
String nodeId
String nodeId,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
Expand All @@ -111,6 +117,7 @@ public void beforeStop() {
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.nodeId = nodeId;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand Down Expand Up @@ -265,7 +272,17 @@ public TaskId getParentTask() {

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService);
return new TrainedModelDeploymentTask(
id,
type,
action,
parentTaskId,
headers,
params,
trainedModelAllocationNodeService,
licenseState,
ML_PYTORCH_MODEL_INFERENCE_FEATURE
);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.MlTasks;
Expand All @@ -26,6 +28,7 @@
import java.util.Map;
import java.util.Optional;


public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher {

private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class);
Expand All @@ -35,6 +38,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
private volatile boolean stopped;
private final SetOnce<String> stoppedReason = new SetOnce<>();
private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
private final XPackLicenseState licenseState;
private final LicensedFeature.Persistent licensedFeature;

public TrainedModelDeploymentTask(
long id,
Expand All @@ -43,18 +48,23 @@ public TrainedModelDeploymentTask(
TaskId parentTask,
Map<String, String> headers,
TaskParams taskParams,
TrainedModelAllocationNodeService trainedModelAllocationNodeService
TrainedModelAllocationNodeService trainedModelAllocationNodeService,
XPackLicenseState licenseState,
LicensedFeature.Persistent licensedFeature
) {
super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers);
this.params = taskParams;
this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull(
trainedModelAllocationNodeService,
"trainedModelAllocationNodeService"
);
this.licenseState = licenseState;
this.licensedFeature = licensedFeature;
}

void init(InferenceConfig inferenceConfig) {
this.inferenceConfig.set(inferenceConfig);
licensedFeature.startTracking(licenseState, "model-" + params.getModelId());
}

public String getModelId() {
Expand All @@ -71,12 +81,14 @@ public TaskParams getParams() {

public void stop(String reason) {
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
stopped = true;
stoppedReason.trySet(reason);
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
}

public void stopWithoutNotification(String reason) {
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
stoppedReason.trySet(reason);
stopped = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
Expand Down Expand Up @@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() {
deploymentManager,
taskManager,
threadPool,
NODE_ID
NODE_ID,
mock(XPackLicenseState.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;

import java.util.Map;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class TrainedModelDeploymentTaskTests extends ESTestCase {

void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
0,
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId,
TaskId.EMPTY_TASK_ID,
Map.of(),
new StartTrainedModelDeploymentAction.TaskParams(
modelId,
randomLongBetween(1, Long.MAX_VALUE),
randomInt(5),
randomInt(5),
randomInt(5)
),
mock(TrainedModelAllocationNodeService.class),
licenseState,
feature
);

task.init(new PassThroughConfig(null, null, null));
verify(feature, times(1)).startTracking(licenseState, "model-" + modelId);
method.accept(task);
verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
}

public void testOnStopWithoutNotification() {
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
}

public void testOnStop() {
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
}

public void testCancelled() {
assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10));
}

}

0 comments on commit 5ffc4b5

Please sign in to comment.