Skip to content

Commit

Permalink
[ML] PyTorchModelIT: investigate single processor mode failures (#91547)
Browse files Browse the repository at this point in the history
Start the priority process worker after the model has been loaded. 
Adds debug logging and some renaming for clarity.
  • Loading branch information
davidkyle committed Nov 14, 2022
1 parent fe0e5c4 commit fb10f12
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ public void setLogging() throws IOException {
{"persistent" : {
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"
}}""");
client().performRequest(loggingSettings);
Expand All @@ -139,6 +140,7 @@ public void cleanup() throws Exception {
"logger.org.elasticsearch.xpack.ml.inference.assignment": null,
"logger.org.elasticsearch.xpack.ml.inference.deployment" : null,
"logger.org.elasticsearch.xpack.ml.process.logging" : null,
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : null,
"xpack.ml.max_lazy_ml_nodes": null
}}""");
client().performRequest(loggingSettings);
Expand Down Expand Up @@ -293,14 +295,14 @@ public void testDeploymentStats() throws IOException {

@SuppressWarnings("unchecked")
public void testLiveDeploymentStats() throws IOException {
String modelA = "model_a";
String modelId = "live_deployment_stats";

createTrainedModel(modelA);
putVocabulary(List.of("once", "twice"), modelA);
putModelDefinition(modelA);
startDeployment(modelA, AllocationStatus.State.FULLY_ALLOCATED.toString());
createTrainedModel(modelId);
putVocabulary(List.of("once", "twice"), modelId);
putModelDefinition(modelId);
startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
{
Response noInferenceCallsStatsResponse = getTrainedModelStats(modelA);
Response noInferenceCallsStatsResponse = getTrainedModelStats(modelId);
List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(noInferenceCallsStatsResponse).get(
"trained_model_stats"
);
Expand All @@ -321,17 +323,17 @@ public void testLiveDeploymentStats() throws IOException {
}
}

infer("once", modelA);
infer("twice", modelA);
infer("once", modelId);
infer("twice", modelId);
// By making this request 3 times at least one of the responses must come from the cache because the cluster has 2 ML nodes
infer("three times", modelA);
infer("three times", modelA);
infer("three times", modelA);
infer("three times", modelId);
infer("three times", modelId);
infer("three times", modelId);
{
Response postInferStatsResponse = getTrainedModelStats(modelA);
Response postInferStatsResponse = getTrainedModelStats(modelId);
List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(postInferStatsResponse).get("trained_model_stats");
assertThat(stats, hasSize(1));
assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelA));
assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelId));
assertThat(XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0)), equalTo((int) RAW_MODEL_SIZE));
List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
"deployment_stats.nodes",
Expand Down Expand Up @@ -748,12 +750,12 @@ public void testStoppingDeploymentShouldTriggerRebalance() throws Exception {
}}""");
client().performRequest(loggingSettings);

String modelId1 = "model_1";
String modelId1 = "stopping_triggers_rebalance_1";
createTrainedModel(modelId1);
putModelDefinition(modelId1);
putVocabulary(List.of("these", "are", "my", "words"), modelId1);

String modelId2 = "model_2";
String modelId2 = "stopping_triggers_rebalance_2";
createTrainedModel(modelId2);
putModelDefinition(modelId2);
putVocabulary(List.of("these", "are", "my", "words"), modelId2);
Expand Down Expand Up @@ -826,12 +828,12 @@ public void testStartDeployment_GivenNoProcessorsLeft_AndLazyStartEnabled() thro
}}""");
client().performRequest(loggingSettings);

String modelId1 = "model_1";
String modelId1 = "start_no_processors_left_lazy_start_1";
createTrainedModel(modelId1);
putModelDefinition(modelId1);
putVocabulary(List.of("these", "are", "my", "words"), modelId1);

String modelId2 = "model_2";
String modelId2 = "start_no_processors_left_lazy_start_2";
createTrainedModel(modelId2);
putModelDefinition(modelId2);
putVocabulary(List.of("these", "are", "my", "words"), modelId2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
stats.timingStats().getAverage(),
stats.timingStatsExcludingCacheHits().getAverage(),
stats.lastUsed(),
processContext.executorService.queueSize() + stats.numberOfPendingResults(),
processContext.priorityProcessWorker.queueSize() + stats.numberOfPendingResults(),
stats.errorCount(),
stats.cacheHitCount(),
processContext.rejectedExecutionCount.intValue(),
Expand All @@ -130,7 +130,7 @@ ProcessContext addProcessContext(Long id, ProcessContext processContext) {
private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
logger.info("[{}] Starting model deployment", task.getModelId());

ProcessContext processContext = new ProcessContext(task, executorServiceForProcess);
ProcessContext processContext = new ProcessContext(task);
if (addProcessContext(task.getId(), processContext) != null) {
finalListener.onFailure(
ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", task.getModelId())
Expand Down Expand Up @@ -232,7 +232,10 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
private void startAndLoad(ProcessContext processContext, TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
try {
processContext.startProcess();
processContext.loadModel(modelLocation, loadedListener);
processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
processContext.startPriorityProcessWorker();
loadedListener.onResponse(success);
}, loadedListener::onFailure));
} catch (Exception e) {
loadedListener.onFailure(e);
}
Expand Down Expand Up @@ -332,13 +335,13 @@ public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, Actio
executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
}

public void executePyTorchAction(
void executePyTorchAction(
ProcessContext processContext,
PriorityProcessWorkerExecutorService.RequestPriority priority,
AbstractPyTorchAction<?> action
) {
try {
processContext.getExecutorService().executeWithPriority(action, priority, action.getRequestId());
processContext.getPriorityProcessWorker().executeWithPriority(action, priority, action.getRequestId());
} catch (EsRejectedExecutionException e) {
processContext.getRejectedExecutionCount().incrementAndGet();
action.onFailure(e);
Expand Down Expand Up @@ -376,21 +379,21 @@ class ProcessContext {
private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
private final PyTorchResultProcessor resultProcessor;
private final PyTorchStateStreamer stateStreamer;
private final PriorityProcessWorkerExecutorService executorService;
private final PriorityProcessWorkerExecutorService priorityProcessWorker;
private volatile Instant startTime;
private volatile Integer numThreadsPerAllocation;
private volatile Integer numAllocations;
private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
private final AtomicInteger timeoutCount = new AtomicInteger();

ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
ProcessContext(TrainedModelDeploymentTask task) {
this.task = Objects.requireNonNull(task);
resultProcessor = new PyTorchResultProcessor(task.getModelId(), threadSettings -> {
this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
this.numAllocations = threadSettings.numAllocations();
});
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new PriorityProcessWorkerExecutorService(
this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForProcess, xContentRegistry);
this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(
threadPool.getThreadContext(),
"inference process",
task.getParams().getQueueCapacity()
Expand All @@ -404,12 +407,15 @@ PyTorchResultProcessor getResultProcessor() {
synchronized void startProcess() {
process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
startTime = Instant.now();
executorServiceForProcess.submit(executorService::start);
}

void startPriorityProcessWorker() {
executorServiceForProcess.submit(priorityProcessWorker::start);
}

synchronized void stopProcess() {
resultProcessor.stop();
executorService.shutdown();
priorityProcessWorker.shutdown();
try {
if (process.get() == null) {
return;
Expand All @@ -430,7 +436,7 @@ private Consumer<String> onProcessCrash() {
return reason -> {
logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
resultProcessor.stop();
executorService.shutdownWithError(new IllegalStateException(reason));
priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
processContextByAllocation.remove(task.getId());
if (nlpTaskProcessor.get() != null) {
nlpTaskProcessor.get().close();
Expand All @@ -441,6 +447,7 @@ private Consumer<String> onProcessCrash() {

void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
if (modelLocation instanceof IndexLocation indexLocation) {
logger.debug("[{}] loading model state", task.getModelId());
process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
} else {
listener.onFailure(
Expand All @@ -455,8 +462,8 @@ AtomicInteger getTimeoutCount() {
}

// accessor used for mocking in tests
PriorityProcessWorkerExecutorService getExecutorService() {
return executorService;
PriorityProcessWorkerExecutorService getPriorityProcessWorker() {
return priorityProcessWorker;
}

// accessor used for mocking in tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ public void testRejectedExecution() {
mock(PyTorchProcessFactory.class)
);

PriorityProcessWorkerExecutorService executorService = new PriorityProcessWorkerExecutorService(
PriorityProcessWorkerExecutorService priorityExecutorService = new PriorityProcessWorkerExecutorService(
tp.getThreadContext(),
"test reject",
10
);
executorService.shutdown();
priorityExecutorService.shutdown();

AtomicInteger rejectedCount = new AtomicInteger();

DeploymentManager.ProcessContext context = mock(DeploymentManager.ProcessContext.class);
PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
when(context.getResultProcessor()).thenReturn(resultProcessor);
when(context.getExecutorService()).thenReturn(executorService);
when(context.getPriorityProcessWorker()).thenReturn(priorityExecutorService);
when(context.getRejectedExecutionCount()).thenReturn(rejectedCount);

deploymentManager.addProcessContext(taskId, context);
Expand Down

0 comments on commit fb10f12

Please sign in to comment.