diff --git a/docs/changelog/135966.yaml b/docs/changelog/135966.yaml new file mode 100644 index 0000000000000..455e41864d339 --- /dev/null +++ b/docs/changelog/135966.yaml @@ -0,0 +1,6 @@ +pr: 135966 +summary: Ensure queued `AbstractRunnables` are notified when executor stops +area: Machine Learning +type: bug +issues: + - 134651 diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java index c309aba299de7..c7089a74ff8c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.pytorch; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.SuppressForbidden; @@ -35,7 +36,7 @@ public enum RequestPriority { * A Runnable sorted first by RequestPriority then a tie breaker which in * most cases will be the insertion order */ - public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable) + protected record OrderedRunnable(RequestPriority priority, long tieBreaker, AbstractRunnable runnable) implements Comparable, Runnable { @@ -53,7 +54,7 @@ public int compareTo(OrderedRunnable o) { public void run() { runnable.run(); } - }; + } private final int queueCapacity; @@ -93,7 +94,7 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma } // PriorityBlockingQueue::offer always returns true - queue.offer(new OrderedRunnable(priority, tieBreaker, contextHolder.preserveContext(command))); + queue.offer(new OrderedRunnable(priority, tieBreaker, (AbstractRunnable) contextHolder.preserveContext(command))); if (isShutdown()) { // the worker shutdown during this function notifyQueueRunnables(); @@ -104,4 +105,10 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma public synchronized void execute(Runnable command) { throw new UnsupportedOperationException("use executeWithPriority"); } + + @Override + protected void notifyIfAbstractRunnable(OrderedRunnable orderedRunnable, Exception ex, String msg) { + // The runnable contained within OrderedRunnable is always an AbstractRunnable, so no need to check the type + notifyAbstractRunnable(ex, msg, orderedRunnable.runnable()); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java index 66a39bde0fe6a..c26eccdce3269 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java @@ -125,11 +125,12 @@ public void start() { running.set(false); } } - - notifyQueueRunnables(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { + // If we're throwing an exception, shutdown() may not have been called, so call it here + shutdown(); + notifyQueueRunnables(); Runnable onComplete = onCompletion.get(); if (onComplete != null) { onComplete.run(); @@ -155,20 +156,22 @@ public synchronized void notifyQueueRunnables() { format("[%s] notifying [%d] queued requests that have not been processed before shutdown", processName, queue.size()) ); - List notExecuted = new ArrayList<>(); + List notExecuted = new ArrayList<>(); queue.drainTo(notExecuted); - String msg = "unable to process as " + processName + " worker service has shutdown"; - Exception ex = error.get(); - for (Runnable runnable : notExecuted) { - if (runnable instanceof AbstractRunnable ar) { - if (ex != null) { - ar.onFailure(ex); - } else { - ar.onRejection(new EsRejectedExecutionException(msg, true)); - } - } + for (T runnable : notExecuted) { + notifyIfAbstractRunnable(runnable, error.get(), "unable to process as " + processName + " worker service has shutdown"); } } } + + protected static void notifyAbstractRunnable(Exception ex, String msg, AbstractRunnable ar) { + if (ex != null) { + ar.onFailure(ex); + } else { + ar.onRejection(new EsRejectedExecutionException(msg, true)); + } + } + + protected abstract void notifyIfAbstractRunnable(T runnable, Exception ex, String msg); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java index 12eddf98dec09..2194fd2a33252 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java @@ -51,4 +51,11 @@ public synchronized void execute(Runnable command) { throw new EsRejectedExecutionException(processName + " queue is full. Unable to execute command", false); } } + + @Override + protected void notifyIfAbstractRunnable(Runnable runnable, Exception ex, String msg) { + if (runnable instanceof AbstractRunnable ar) { + notifyAbstractRunnable(ex, msg, ar); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java index d1923ca999063..c1dd36acd4705 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java @@ -22,7 +22,9 @@ import static org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService.RequestPriority; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.not; public class PriorityProcessWorkerExecutorServiceTests extends ESTestCase { @@ -177,6 +179,49 @@ public void testOrderedRunnables_MixedPriorities() { } } + public void testNotifyQueueRunnables_notifiesAllQueuedRunnables() throws InterruptedException { + notifyQueueRunnables(false); + } + + public void testNotifyQueueRunnables_notifiesAllQueuedRunnables_withError() throws InterruptedException { + notifyQueueRunnables(true); + } + + private void notifyQueueRunnables(boolean withError) { + int queueSize = 10; + var executor = createProcessWorkerExecutorService(queueSize); + + List runnables = new ArrayList<>(queueSize); + // First fill the queue + for (int i = 0; i < queueSize; ++i) { + QueueDrainingRunnable runnable = new QueueDrainingRunnable(); + runnables.add(runnable); + executor.executeWithPriority(runnable, RequestPriority.NORMAL, i); + } + + assertThat(executor.queueSize(), is(queueSize)); + + // Set the executor to be stopped + if (withError) { + executor.shutdownNowWithError(new Exception()); + } else { + executor.shutdownNow(); + } + + // Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped + executor.start(); + + // Confirm that all the runnables were notified + for (QueueDrainingRunnable runnable : runnables) { + assertThat(runnable.initialized, is(true)); + assertThat(runnable.hasBeenRun, is(false)); + assertThat(runnable.hasBeenRejected, not(withError)); + assertThat(runnable.hasBeenFailed, is(withError)); + } + + assertThat(executor.queueSize(), is(0)); + } + private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(int queueSize) { return new PriorityProcessWorkerExecutorService( threadPool.getThreadContext(), @@ -244,4 +289,32 @@ public void init() { // do nothing } } + + private static class QueueDrainingRunnable extends AbstractInitializableRunnable { + + private boolean initialized = false; + private boolean hasBeenRun = false; + private boolean hasBeenRejected = false; + private boolean hasBeenFailed = false; + + @Override + public void init() { + initialized = true; + } + + @Override + public void onRejection(Exception e) { + hasBeenRejected = true; + } + + @Override + public void onFailure(Exception e) { + hasBeenFailed = true; + } + + @Override + protected void doRun() { + hasBeenRun = true; + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java index 096d0b7105ce5..71b17641b1d7d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java @@ -14,15 +14,20 @@ import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.not; public class ProcessWorkerExecutorServiceTests extends ESTestCase { @@ -137,7 +142,105 @@ public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() { assertThat(e.getMessage(), containsString("future error")); } + public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables() throws InterruptedException { + notifyQueueRunnables(false); + } + + public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables_withError() throws InterruptedException { + notifyQueueRunnables(true); + } + + private void notifyQueueRunnables(boolean withError) { + int entries = 10; + var executor = createExecutorService(); + + List abstractRunnables = new ArrayList<>(); + // First fill the queue with both AbstractRunnable and Runnable + for (int i = 0; i < entries; ++i) { + QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable(); + abstractRunnables.add(abstractRunnable); + executor.execute(abstractRunnable); + Runnable runnable = () -> fail("Should not be invoked"); + executor.execute(runnable); + } + + assertThat(executor.queueSize(), is(entries * 2)); + + // Set the executor to be stopped + if (withError) { + executor.shutdownNowWithError(new Exception()); + } else { + executor.shutdownNow(); + } + + // Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped + executor.start(); + + // Confirm that all the abstract runnables were notified + for (QueueDrainingRunnable runnable : abstractRunnables) { + assertThat(runnable.initialized, is(true)); + assertThat(runnable.hasBeenRun, is(false)); + assertThat(runnable.hasBeenRejected, not(withError)); + assertThat(runnable.hasBeenFailed, is(withError)); + } + + assertThat(executor.queueSize(), is(0)); + } + + public void testQueuedAbstractRunnablesAreNotified_whenRunnableFutureEncountersError() { + var executor = createExecutorService(); + + // First queue a RunnableFuture that will stop the executor then throw an Exception wrapping an error when it completes + Error expectedError = new Error("Expected"); + FutureTask runnableFuture = new FutureTask<>(() -> { throw new Exception(expectedError); }); + executor.execute(runnableFuture); + + // Then queue an AbstractRunnable that should be notified if it's still in the queue when the start() method returns + QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable(); + executor.execute(abstractRunnable); + + // Start the executor and expect the error to be thrown and the executor to be marked as shut down + Error error = assertThrows(Error.class, executor::start); + assertThat(error, is(expectedError)); + assertThat(executor.isShutdown(), is(true)); + assertThat(executor.isTerminated(), is(true)); + + // Confirm that all the abstract runnable was notified + assertThat(abstractRunnable.initialized, is(true)); + assertThat(abstractRunnable.hasBeenRun, is(false)); + assertThat(abstractRunnable.hasBeenRejected, is(true)); + assertThat(abstractRunnable.hasBeenFailed, is(false)); + } + private ProcessWorkerExecutorService createExecutorService() { return new ProcessWorkerExecutorService(threadPool.getThreadContext(), TEST_PROCESS, QUEUE_SIZE); } + + private static class QueueDrainingRunnable extends AbstractInitializableRunnable { + + private boolean initialized = false; + private boolean hasBeenRun = false; + private boolean hasBeenRejected = false; + private boolean hasBeenFailed = false; + + @Override + public void init() { + initialized = true; + } + + @Override + public void onRejection(Exception e) { + hasBeenRejected = true; + } + + @Override + public void onFailure(Exception e) { + hasBeenFailed = true; + } + + @Override + protected void doRun() { + hasBeenRun = true; + } + } }