diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index f6f70a643764c4..3ea14306e94f82 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableSet; import io.reactivex.rxjava3.annotations.NonNull; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableEmitter; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.disposables.Disposable; @@ -27,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -45,26 +47,38 @@ * re-execute a finished task. * *

Dispose the {@link Single} to cancel to task execution. + * + *

Use {@link #shutdown} to shuts the cache down. Any in progress tasks will continue running + * while new tasks will be injected with {@link CancellationException}. Use {@link + * #awaitTermination()} after {@link #shutdown} to wait for the in progress tasks finished. + * + *

Use {@link #shutdownNow} to cancel all in progress and new tasks with exception {@link + * CancellationException}. */ @ThreadSafe public final class AsyncTaskCache { private final Object lock = new Object(); + private static final int STATE_ACTIVE = 0; + private static final int STATE_PENDING_SHUTDOWN = 1; + private static final int STATE_SHUTDOWN = 2; + + @GuardedBy("lock") + private int state = STATE_ACTIVE; + + @GuardedBy("lock") + private final List terminationSubscriber = new ArrayList<>(); + @GuardedBy("lock") - private final Map finished; + private final Map finished = new HashMap<>(); @GuardedBy("lock") - private final Map inProgress; + private final Map inProgress = new HashMap<>(); public static AsyncTaskCache create() { return new AsyncTaskCache<>(); } - private AsyncTaskCache() { - this.finished = new HashMap<>(); - this.inProgress = new HashMap<>(); - } - /** Returns a set of keys for tasks which is finished. */ public ImmutableSet getFinishedTasks() { synchronized (lock) { @@ -165,6 +179,8 @@ public void onSuccess(@NonNull ValueT value) { for (SingleObserver observer : ImmutableList.copyOf(observers)) { observer.onSuccess(value); } + + maybeNotifyTermination(); } } } @@ -179,6 +195,8 @@ public void onError(@NonNull Throwable error) { for (SingleObserver observer : ImmutableList.copyOf(observers)) { observer.onError(error); } + + maybeNotifyTermination(); } } } @@ -197,6 +215,18 @@ void remove(SingleObserver observer) { } } } + + void cancel() { + synchronized (lock) { + if (!terminated) { + if (upstreamDisposable != null) { + upstreamDisposable.dispose(); + } + + onError(new CancellationException("cancelled")); + } + } + } } class ExecutionDisposable implements Disposable { @@ -225,6 +255,8 @@ public boolean isDisposed() { /** * Executes a task. * + *

If the cache is already shutdown, a {@link CancellationException} will be emitted. + * * @param key identifies the task. * @param force re-execute a finished task if set to {@code true}. * @return a {@link Single} which turns to completed once the task is finished or propagates the @@ -234,6 +266,11 @@ public Single execute(KeyT key, Single task, boolean force) { return Single.create( emitter -> { synchronized (lock) { + if (state != STATE_ACTIVE) { + emitter.onError(new CancellationException("already shutdown")); + return; + } + if (!force && finished.containsKey(key)) { emitter.onSuccess(finished.get(key)); return; @@ -273,6 +310,72 @@ public void onError(@NonNull Throwable e) { }); } + /** + * Shuts the cache down. Any in progress tasks will continue running while new tasks will be + * injected with {@link CancellationException}. + */ + public void shutdown() { + synchronized (lock) { + if (state == STATE_ACTIVE) { + state = STATE_PENDING_SHUTDOWN; + maybeNotifyTermination(); + } + } + } + + /** Returns a {@link Completable} which will complete once all the in progress tasks finished. */ + public Completable awaitTermination() { + return Completable.create( + emitter -> { + synchronized (lock) { + if (state == STATE_SHUTDOWN) { + emitter.onComplete(); + } else { + terminationSubscriber.add(emitter); + + emitter.setCancellable( + () -> { + synchronized (lock) { + if (state != STATE_SHUTDOWN) { + terminationSubscriber.remove(emitter); + } + } + }); + } + } + }); + } + + /** + * Shuts the cache down. All in progress and new tasks will be cancelled with {@link + * CancellationException}. + */ + public void shutdownNow() { + shutdown(); + + synchronized (lock) { + if (state == STATE_PENDING_SHUTDOWN) { + for (Execution execution : ImmutableList.copyOf(inProgress.values())) { + execution.cancel(); + } + } + } + + awaitTermination().blockingAwait(); + } + + @GuardedBy("lock") + private void maybeNotifyTermination() { + if (state == STATE_PENDING_SHUTDOWN && inProgress.isEmpty()) { + state = STATE_SHUTDOWN; + + for (CompletableEmitter emitter : terminationSubscriber) { + emitter.onComplete(); + } + terminationSubscriber.clear(); + } + } + /** An {@link AsyncTaskCache} without result. */ public static final class NoResult { private final AsyncTaskCache> cache; @@ -311,5 +414,28 @@ public ImmutableSet getInProgressTasks() { public int getSubscriberCount(KeyT key) { return cache.getSubscriberCount(key); } + + /** + * Shuts the cache down. Any in progress tasks will continue running while new tasks will be + * injected with {@link CancellationException}. + */ + public void shutdown() { + cache.shutdown(); + } + + /** + * Returns a {@link Completable} which will complete once all the in progress tasks finished. + */ + public Completable awaitTermination() { + return cache.awaitTermination(); + } + + /** + * Shuts the cache down. All in progress and active tasks will be cancelled with {@link + * CancellationException}. + */ + public void shutdownNow() { + cache.shutdownNow(); + } } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java index 279c342baa7f94..b40f6f28be22ef 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java @@ -22,6 +22,7 @@ import io.reactivex.rxjava3.observers.TestObserver; import java.io.IOException; import java.util.Random; +import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -386,4 +387,82 @@ public void execute_executeWithFutureAndCancelLoop_noErrors() throws Throwable { throw error.get(); } } + + @Test + public void execute_pendingShutdown_getCancellationError() { + AsyncTaskCache cache = AsyncTaskCache.create(); + cache + .executeIfNot( + "key1", + Single.create( + emitter -> { + // never complete + })) + .test() + .assertNotComplete(); + cache.shutdown(); + cache.awaitTermination().test().assertNotComplete(); + + TestObserver ob = cache.executeIfNot("key2", Single.just("value2")).test(); + + ob.assertError(e -> e instanceof CancellationException); + } + + @Test + public void execute_afterShutdown_getCancellationError() { + AsyncTaskCache cache = AsyncTaskCache.create(); + cache.shutdown(); + cache.awaitTermination().blockingAwait(); + + TestObserver ob = cache.executeIfNot("key", Single.just("value")).test(); + + ob.assertError(e -> e instanceof CancellationException); + } + + @Test + public void shutdownNow_cancelInProgressTasks() { + AsyncTaskCache cache = AsyncTaskCache.create(); + TestObserver ob = + cache + .executeIfNot( + "key", + Single.create( + emitter -> { + // never complete + })) + .test(); + cache.shutdown(); + cache.awaitTermination().test().assertNotComplete(); + ob.assertNotComplete(); + + cache.shutdownNow(); + + ob.assertError(e -> e instanceof CancellationException); + cache.awaitTermination().test().assertComplete(); + } + + @Test + public void awaitTermination_pendingShutdown_completeAfterTaskFinished() { + AsyncTaskCache cache = AsyncTaskCache.create(); + AtomicReference> emitterRef = new AtomicReference<>(null); + cache.executeIfNot("key", Single.create(emitterRef::set)).test().assertNotComplete(); + assertThat(emitterRef.get()).isNotNull(); + cache.shutdown(); + + TestObserver ob = cache.awaitTermination().test(); + ob.assertNotComplete(); + emitterRef.get().onSuccess("value"); + + ob.assertComplete(); + } + + @Test + public void awaitTermination_afterShutdown_complete() { + AsyncTaskCache cache = AsyncTaskCache.create(); + cache.shutdownNow(); + + TestObserver ob = cache.awaitTermination().test(); + + ob.assertComplete(); + } }