Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/135966.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 135966
summary: Ensure queued `AbstractRunnables` are notified when executor stops
area: Machine Learning
type: bug
issues:
- 134651
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.ml.inference.pytorch;

import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.SuppressForbidden;
Expand Down Expand Up @@ -35,7 +36,7 @@ public enum RequestPriority {
* A Runnable sorted first by RequestPriority then a tie breaker which in
* most cases will be the insertion order
*/
public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
protected record OrderedRunnable(RequestPriority priority, long tieBreaker, AbstractRunnable runnable)
implements
Comparable<OrderedRunnable>,
Runnable {
Expand All @@ -53,7 +54,7 @@ public int compareTo(OrderedRunnable o) {
public void run() {
runnable.run();
}
};
}

private final int queueCapacity;

Expand Down Expand Up @@ -93,7 +94,7 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma
}

// PriorityBlockingQueue::offer always returns true
queue.offer(new OrderedRunnable(priority, tieBreaker, contextHolder.preserveContext(command)));
queue.offer(new OrderedRunnable(priority, tieBreaker, (AbstractRunnable) contextHolder.preserveContext(command)));
if (isShutdown()) {
// the worker shutdown during this function
notifyQueueRunnables();
Expand All @@ -104,4 +105,10 @@ public synchronized void executeWithPriority(AbstractInitializableRunnable comma
public synchronized void execute(Runnable command) {
throw new UnsupportedOperationException("use executeWithPriority");
}

@Override
protected void notifyIfAbstractRunnable(OrderedRunnable orderedRunnable, Exception ex, String msg) {
// The runnable contained within OrderedRunnable is always an AbstractRunnable, so no need to check the type
notifyAbstractRunnable(ex, msg, orderedRunnable.runnable());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@ public void start() {
running.set(false);
}
}

notifyQueueRunnables();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
// If we're throwing an exception, shutdown() may not have been called, so call it here
shutdown();
notifyQueueRunnables();
Runnable onComplete = onCompletion.get();
if (onComplete != null) {
onComplete.run();
Expand All @@ -155,20 +156,22 @@ public synchronized void notifyQueueRunnables() {
format("[%s] notifying [%d] queued requests that have not been processed before shutdown", processName, queue.size())
);

List<Runnable> notExecuted = new ArrayList<>();
List<T> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);

String msg = "unable to process as " + processName + " worker service has shutdown";
Exception ex = error.get();
for (Runnable runnable : notExecuted) {
if (runnable instanceof AbstractRunnable ar) {
if (ex != null) {
ar.onFailure(ex);
} else {
ar.onRejection(new EsRejectedExecutionException(msg, true));
}
}
for (T runnable : notExecuted) {
notifyIfAbstractRunnable(runnable, error.get(), "unable to process as " + processName + " worker service has shutdown");
}
}
}

protected static void notifyAbstractRunnable(Exception ex, String msg, AbstractRunnable ar) {
if (ex != null) {
ar.onFailure(ex);
} else {
ar.onRejection(new EsRejectedExecutionException(msg, true));
}
}

protected abstract void notifyIfAbstractRunnable(T runnable, Exception ex, String msg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about having AbstractProcessWorkerExecutorService contain the logic for notifyIfAbstractRunnable and relying on an abstract method like getAsAbstractRunnable which either returns the AbstractRunnable or null. That way child classes don't need to call back into the parent to do the notification, they would only return the abstract runnable if it was one.

Something like:

    protected abstract AbstractRunnable getAsAbstractRunnable(T runnable);

    private void notifyIfAbstractRunnable(T runnable, Exception ex, String msg) {
        var abstractRunnable = getAsAbstractRunnable(runnable);
        if (abstractRunnable != null) {
            notifyAbstractRunnable(ex, msg, abstractRunnable);
        }
    }

Then PriorityProcessWorkerExecutorService would have something like:

    @Override
    protected AbstractRunnable getAsAbstractRunnable(OrderedRunnable orderedRunnable, Exception ex, String msg) {
        // The runnable contained within OrderedRunnable is always an AbstractRunnable, so no need to check the type
        return orderedRunnable.runnable();
    }

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,11 @@ public synchronized void execute(Runnable command) {
throw new EsRejectedExecutionException(processName + " queue is full. Unable to execute command", false);
}
}

@Override
protected void notifyIfAbstractRunnable(Runnable runnable, Exception ex, String msg) {
if (runnable instanceof AbstractRunnable ar) {
notifyAbstractRunnable(ex, msg, ar);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

import static org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService.RequestPriority;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;

public class PriorityProcessWorkerExecutorServiceTests extends ESTestCase {

Expand Down Expand Up @@ -177,6 +179,49 @@ public void testOrderedRunnables_MixedPriorities() {
}
}

public void testNotifyQueueRunnables_notifiesAllQueuedRunnables() throws InterruptedException {
notifyQueueRunnables(false);
}

public void testNotifyQueueRunnables_notifiesAllQueuedRunnables_withError() throws InterruptedException {
notifyQueueRunnables(true);
}

private void notifyQueueRunnables(boolean withError) {
int queueSize = 10;
var executor = createProcessWorkerExecutorService(queueSize);

List<QueueDrainingRunnable> runnables = new ArrayList<>(queueSize);
// First fill the queue
for (int i = 0; i < queueSize; ++i) {
QueueDrainingRunnable runnable = new QueueDrainingRunnable();
runnables.add(runnable);
executor.executeWithPriority(runnable, RequestPriority.NORMAL, i);
}

assertThat(executor.queueSize(), is(queueSize));

// Set the executor to be stopped
if (withError) {
executor.shutdownNowWithError(new Exception());
} else {
executor.shutdownNow();
}

// Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped
executor.start();

// Confirm that all the runnables were notified
for (QueueDrainingRunnable runnable : runnables) {
assertThat(runnable.initialized, is(true));
assertThat(runnable.hasBeenRun, is(false));
assertThat(runnable.hasBeenRejected, not(withError));
assertThat(runnable.hasBeenFailed, is(withError));
}

assertThat(executor.queueSize(), is(0));
}

private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(int queueSize) {
return new PriorityProcessWorkerExecutorService(
threadPool.getThreadContext(),
Expand Down Expand Up @@ -244,4 +289,32 @@ public void init() {
// do nothing
}
}

private static class QueueDrainingRunnable extends AbstractInitializableRunnable {

private boolean initialized = false;
private boolean hasBeenRun = false;
private boolean hasBeenRejected = false;
private boolean hasBeenFailed = false;

@Override
public void init() {
initialized = true;
}

@Override
public void onRejection(Exception e) {
hasBeenRejected = true;
}

@Override
public void onFailure(Exception e) {
hasBeenFailed = true;
}

@Override
protected void doRun() {
hasBeenRun = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.not;

public class ProcessWorkerExecutorServiceTests extends ESTestCase {

Expand Down Expand Up @@ -137,7 +142,105 @@ public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
assertThat(e.getMessage(), containsString("future error"));
}

public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables() throws InterruptedException {
notifyQueueRunnables(false);
}

public void testNotifyQueueRunnables_notifiesAllQueuedAbstractRunnables_withError() throws InterruptedException {
notifyQueueRunnables(true);
}

private void notifyQueueRunnables(boolean withError) {
int entries = 10;
var executor = createExecutorService();

List<QueueDrainingRunnable> abstractRunnables = new ArrayList<>();
// First fill the queue with both AbstractRunnable and Runnable
for (int i = 0; i < entries; ++i) {
QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable();
abstractRunnables.add(abstractRunnable);
executor.execute(abstractRunnable);
Runnable runnable = () -> fail("Should not be invoked");
executor.execute(runnable);
}

assertThat(executor.queueSize(), is(entries * 2));

// Set the executor to be stopped
if (withError) {
executor.shutdownNowWithError(new Exception());
} else {
executor.shutdownNow();
}

// Start the executor, which will cause notifyQueueRunnables() to be called immediately since the executor is already stopped
executor.start();

// Confirm that all the abstract runnables were notified
for (QueueDrainingRunnable runnable : abstractRunnables) {
assertThat(runnable.initialized, is(true));
assertThat(runnable.hasBeenRun, is(false));
assertThat(runnable.hasBeenRejected, not(withError));
assertThat(runnable.hasBeenFailed, is(withError));
}

assertThat(executor.queueSize(), is(0));
}

public void testQueuedAbstractRunnablesAreNotified_whenRunnableFutureEncountersError() {
var executor = createExecutorService();

// First queue a RunnableFuture that will stop the executor then throw an Exception wrapping an error when it completes
Error expectedError = new Error("Expected");
FutureTask<Void> runnableFuture = new FutureTask<>(() -> { throw new Exception(expectedError); });
executor.execute(runnableFuture);

// Then queue an AbstractRunnable that should be notified if it's still in the queue when the start() method returns
QueueDrainingRunnable abstractRunnable = new QueueDrainingRunnable();
executor.execute(abstractRunnable);

// Start the executor and expect the error to be thrown and the executor to be marked as shut down
Error error = assertThrows(Error.class, executor::start);
assertThat(error, is(expectedError));
assertThat(executor.isShutdown(), is(true));
assertThat(executor.isTerminated(), is(true));

// Confirm that all the abstract runnable was notified
assertThat(abstractRunnable.initialized, is(true));
assertThat(abstractRunnable.hasBeenRun, is(false));
assertThat(abstractRunnable.hasBeenRejected, is(true));
assertThat(abstractRunnable.hasBeenFailed, is(false));
}

private ProcessWorkerExecutorService createExecutorService() {
return new ProcessWorkerExecutorService(threadPool.getThreadContext(), TEST_PROCESS, QUEUE_SIZE);
}

private static class QueueDrainingRunnable extends AbstractInitializableRunnable {

private boolean initialized = false;
private boolean hasBeenRun = false;
private boolean hasBeenRejected = false;
private boolean hasBeenFailed = false;

@Override
public void init() {
initialized = true;
}

@Override
public void onRejection(Exception e) {
hasBeenRejected = true;
}

@Override
public void onFailure(Exception e) {
hasBeenFailed = true;
}

@Override
protected void doRun() {
hasBeenRun = true;
}
}
}