From 5f5ededb0faa1941c757b1ad18975ae604bd9ce5 Mon Sep 17 00:00:00 2001 From: donalevans Date: Fri, 3 Oct 2025 12:24:32 -0700 Subject: [PATCH 1/2] Ensure queued AbstractRunnables are notified when executor stops AbstractProcessWorkerExecutorService.notifyQueueRunnables() was making an incorrect assumption that all AbstractRunnables that were submitted for execution would be queued as AbstractRunnables. However, PriorityProcessWorkerExecutorService wraps AbstractRunnables in OrderedRunnable before queueing them, and since OrderedRunnable is not an AbstractRunnable, these were skipped when notifyQueueRunnables() drained the queue, leading to potential hangs. - Refactor notifyQueueRunnables() to allow PriorityProcessWorkerExecutorService to notify the AbstractRunnable contained within queued OrderedRunnables - Ensure that notifyQueueRunnables() is called and the executor marked as shut down if an exception is thrown from start() - Add unit tests Closes #134651 --- .../PriorityProcessWorkerExecutorService.java | 13 ++- .../AbstractProcessWorkerExecutorService.java | 29 ++--- .../process/ProcessWorkerExecutorService.java | 7 ++ ...rityProcessWorkerExecutorServiceTests.java | 73 +++++++++++++ .../ProcessWorkerExecutorServiceTests.java | 103 ++++++++++++++++++ 5 files changed, 209 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java index c309aba299de7..c7089a74ff8c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.pytorch; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.SuppressForbidden; @@ -35,7 +36,7 @@ public enum RequestPriority { * A Runnable sorted first by RequestPriority then a tie breaker which in * most cases will be the insertion order */ - public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable) + protected record OrderedRunnable(RequestPriority priority, long tieBreaker, AbstractRunnable runnable) implements Comparable, Runnable { @@ -53,7 +54,7 @@ public int compareTo(OrderedRunnable o) { public void run() { runnable.run(); } - }; + } private final int queueCapacity; @@ -93,7 +94,7 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma } // PriorityBlockingQueue::offer always returns true - queue.offer(new OrderedRunnable(priority, tieBreaker, contextHolder.preserveContext(command))); + queue.offer(new OrderedRunnable(priority, tieBreaker, (AbstractRunnable) contextHolder.preserveContext(command))); if (isShutdown()) { // the worker shutdown during this function notifyQueueRunnables(); @@ -104,4 +105,10 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma public synchronized void execute(Runnable command) { throw new UnsupportedOperationException("use executeWithPriority"); } + + @Override + protected void notifyIfAbstractRunnable(OrderedRunnable orderedRunnable, Exception ex, String msg) { + // The runnable contained within OrderedRunnable is always an AbstractRunnable, so no need to check the type + notifyAbstractRunnable(ex, msg, orderedRunnable.runnable()); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java index 66a39bde0fe6a..c26eccdce3269 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java @@ -125,11 +125,12 @@ public void start() { running.set(false); } } - - notifyQueueRunnables(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { + // If we're throwing an exception, shutdown() may not have been called, so call it here + shutdown(); + notifyQueueRunnables(); Runnable onComplete = onCompletion.get(); if (onComplete != null) { onComplete.run(); @@ -155,20 +156,22 @@ public synchronized void notifyQueueRunnables() { format("[%s] notifying [%d] queued requests that have not been processed before shutdown", processName, queue.size()) ); - List notExecuted = new ArrayList<>(); + List notExecuted = new ArrayList<>(); queue.drainTo(notExecuted); - String msg = "unable to process as " + processName + " worker service has shutdown"; - Exception ex = error.get(); - for (Runnable runnable : notExecuted) { - if (runnable instanceof AbstractRunnable ar) { - if (ex != null) { - ar.onFailure(ex); - } else { - ar.onRejection(new EsRejectedExecutionException(msg, true)); - } - } + for (T runnable : notExecuted) { + notifyIfAbstractRunnable(runnable, error.get(), "unable to process as " + processName + " worker service has shutdown"); } } } + + protected static void notifyAbstractRunnable(Exception ex, String msg, AbstractRunnable ar) { + if (ex != null) { + ar.onFailure(ex); + } else { + ar.onRejection(new EsRejectedExecutionException(msg, true)); + } + } + + protected abstract void notifyIfAbstractRunnable(T runnable, Exception ex, String msg); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java index 12eddf98dec09..2194fd2a33252 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java @@ -51,4 +51,11 @@ public synchronized void execute(Runnable command) { throw new EsRejectedExecutionException(processName + " queue is full. Unable to execute command", false); } } + + @Override + protected void notifyIfAbstractRunnable(Runnable runnable, Exception ex, String msg) { + if (runnable instanceof AbstractRunnable ar) { + notifyAbstractRunnable(ex, msg, ar); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java index d1923ca999063..c1dd36acd4705 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java @@ -22,7 +22,9 @@ import static org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService.RequestPriority; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.not; public class PriorityProcessWorkerExecutorServiceTests extends ESTestCase { @@ -177,6 +179,49 @@ public void testOrderedRunnables_MixedPriorities() { } } + public void testNotifyQueueRunnables_notifiesAllQueuedRunnables() throws InterruptedException { + notifyQueueRunnables(false); + } + + public void testNotifyQueueRunnables_notifiesAllQueuedRunnables_withError() throws InterruptedException { + notifyQueueRunnables(true); + } + + private void notifyQueueRunnables(boolean withError) { + int queueSize = 10; + var executor = createProcessWorkerExecutorService(queueSize); + + List runnables = new ArrayList<>(queueSize); + // First fill the queue + for (int i = 0; i < queueSize; ++i) { + QueueDrainingRunnable runnable = new QueueDrainingRunnable(); + runnables.add(runnable); + executor.executeWithPriority(runnable, RequestPriority.NORMAL, i); + } + + assertThat(executor.queueSize(), is(queueSize)); + + // Set the executor to be stopped + if (withError) { + executor.shutdownNowWithError(new Exception()); + } else { + executor.shutdownNow(); + } + + // Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped + executor.start(); + + // Confirm that all the runnables were notified + for (QueueDrainingRunnable runnable : runnables) { + assertThat(runnable.initialized, is(true)); + assertThat(runnable.hasBeenRun, is(false)); + assertThat(runnable.hasBeenRejected, not(withError)); + assertThat(runnable.hasBeenFailed, is(withError)); + } + + assertThat(executor.queueSize(), is(0)); + } + private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(int queueSize) { return new PriorityProcessWorkerExecutorService( threadPool.getThreadContext(), @@ -244,4 +289,32 @@ public void init() { // do nothing } } + + private static class QueueDrainingRunnable extends AbstractInitializableRunnable { + + private boolean initialized = false; + private boolean hasBeenRun = false; + private boolean hasBeenRejected = false; + private boolean hasBeenFailed = false; + + @Override + public void init() { + initialized = true; + } + + @Override + public void onRejection(Exception e) { + hasBeenRejected = true; + } + + @Override + public void onFailure(Exception e) { + hasBeenFailed = true; + } + + @Override + protected void doRun() { + hasBeenRun = true; + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java index 096d0b7105ce5..71b17641b1d7d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java @@ -14,15 +14,20 @@ import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.not; public class ProcessWorkerExecutorServiceTests extends ESTestCase { @@ -137,7 +142,105 @@ public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() { assertThat(e.getMessage(), containsString("future error")); } + public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables() throws InterruptedException { + notifyQueueRunnables(false); + } + + public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables_withError() throws InterruptedException { + notifyQueueRunnables(true); + } + + private void notifyQueueRunnables(boolean withError) { + int entries = 10; + var executor = createExecutorService(); + + List abstractRunnables = new ArrayList<>(); + // First fill the queue with both AbstractRunnable and Runnable + for (int i = 0; i < entries; ++i) { + QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable(); + abstractRunnables.add(abstractRunnable); + executor.execute(abstractRunnable); + Runnable runnable = () -> fail("Should not be invoked"); + executor.execute(runnable); + } + + assertThat(executor.queueSize(), is(entries * 2)); + + // Set the executor to be stopped + if (withError) { + executor.shutdownNowWithError(new Exception()); + } else { + executor.shutdownNow(); + } + + // Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped + executor.start(); + + // Confirm that all the abstract runnables were notified + for (QueueDrainingRunnable runnable : abstractRunnables) { + assertThat(runnable.initialized, is(true)); + assertThat(runnable.hasBeenRun, is(false)); + assertThat(runnable.hasBeenRejected, not(withError)); + assertThat(runnable.hasBeenFailed, is(withError)); + } + + assertThat(executor.queueSize(), is(0)); + } + + public void testQueuedAbstractRunnablesAreNotified_whenRunnableFutureEncountersError() { + var executor = createExecutorService(); + + // First queue a RunnableFuture that will stop the executor then throw an Exception wrapping an error when it completes + Error expectedError = new Error("Expected"); + FutureTask runnableFuture = new FutureTask<>(() -> { throw new Exception(expectedError); }); + executor.execute(runnableFuture); + + // Then queue an AbstractRunnable that should be notified if it's still in the queue when the start() method returns + QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable(); + executor.execute(abstractRunnable); + + // Start the executor and expect the error to be thrown and the executor to be marked as shut down + Error error = assertThrows(Error.class, executor::start); + assertThat(error, is(expectedError)); + assertThat(executor.isShutdown(), is(true)); + assertThat(executor.isTerminated(), is(true)); + + // Confirm that all the abstract runnable was notified + assertThat(abstractRunnable.initialized, is(true)); + assertThat(abstractRunnable.hasBeenRun, is(false)); + assertThat(abstractRunnable.hasBeenRejected, is(true)); + assertThat(abstractRunnable.hasBeenFailed, is(false)); + } + private ProcessWorkerExecutorService createExecutorService() { return new ProcessWorkerExecutorService(threadPool.getThreadContext(), TEST_PROCESS, QUEUE_SIZE); } + + private static class QueueDrainingRunnable extends AbstractInitializableRunnable { + + private boolean initialized = false; + private boolean hasBeenRun = false; + private boolean hasBeenRejected = false; + private boolean hasBeenFailed = false; + + @Override + public void init() { + initialized = true; + } + + @Override + public void onRejection(Exception e) { + hasBeenRejected = true; + } + + @Override + public void onFailure(Exception e) { + hasBeenFailed = true; + } + + @Override + protected void doRun() { + hasBeenRun = true; + } + } } From eb9f19b9e2d9ea87de8bb1a223d7ac4222c55bfa Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Fri, 3 Oct 2025 12:35:53 -0700 Subject: [PATCH 2/2] Update docs/changelog/135966.yaml --- docs/changelog/135966.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/135966.yaml diff --git a/docs/changelog/135966.yaml b/docs/changelog/135966.yaml new file mode 100644 index 0000000000000..455e41864d339 --- /dev/null +++ b/docs/changelog/135966.yaml @@ -0,0 +1,6 @@ +pr: 135966 +summary: Ensure queued `AbstractRunnables` are notified when executor stops +area: Machine Learning +type: bug +issues: + - 134651