diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
index f6f70a643764c4..3ea14306e94f82 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
@@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableSet;
import io.reactivex.rxjava3.annotations.NonNull;
import io.reactivex.rxjava3.core.Completable;
+import io.reactivex.rxjava3.core.CompletableEmitter;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleObserver;
import io.reactivex.rxjava3.disposables.Disposable;
@@ -27,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
@@ -45,26 +47,38 @@
* re-execute a finished task.
*
*
Dispose the {@link Single} to cancel to task execution.
+ *
+ *
Use {@link #shutdown} to shuts the cache down. Any in progress tasks will continue running
+ * while new tasks will be injected with {@link CancellationException}. Use {@link
+ * #awaitTermination()} after {@link #shutdown} to wait for the in progress tasks finished.
+ *
+ *
Use {@link #shutdownNow} to cancel all in progress and new tasks with exception {@link
+ * CancellationException}.
*/
@ThreadSafe
public final class AsyncTaskCache {
private final Object lock = new Object();
+ private static final int STATE_ACTIVE = 0;
+ private static final int STATE_PENDING_SHUTDOWN = 1;
+ private static final int STATE_SHUTDOWN = 2;
+
+ @GuardedBy("lock")
+ private int state = STATE_ACTIVE;
+
+ @GuardedBy("lock")
+ private final List terminationSubscriber = new ArrayList<>();
+
@GuardedBy("lock")
- private final Map finished;
+ private final Map finished = new HashMap<>();
@GuardedBy("lock")
- private final Map inProgress;
+ private final Map inProgress = new HashMap<>();
public static AsyncTaskCache create() {
return new AsyncTaskCache<>();
}
- private AsyncTaskCache() {
- this.finished = new HashMap<>();
- this.inProgress = new HashMap<>();
- }
-
/** Returns a set of keys for tasks which is finished. */
public ImmutableSet getFinishedTasks() {
synchronized (lock) {
@@ -165,6 +179,8 @@ public void onSuccess(@NonNull ValueT value) {
for (SingleObserver super ValueT> observer : ImmutableList.copyOf(observers)) {
observer.onSuccess(value);
}
+
+ maybeNotifyTermination();
}
}
}
@@ -179,6 +195,8 @@ public void onError(@NonNull Throwable error) {
for (SingleObserver super ValueT> observer : ImmutableList.copyOf(observers)) {
observer.onError(error);
}
+
+ maybeNotifyTermination();
}
}
}
@@ -197,6 +215,18 @@ void remove(SingleObserver super ValueT> observer) {
}
}
}
+
+ void cancel() {
+ synchronized (lock) {
+ if (!terminated) {
+ if (upstreamDisposable != null) {
+ upstreamDisposable.dispose();
+ }
+
+ onError(new CancellationException("cancelled"));
+ }
+ }
+ }
}
class ExecutionDisposable implements Disposable {
@@ -225,6 +255,8 @@ public boolean isDisposed() {
/**
* Executes a task.
*
+ * If the cache is already shutdown, a {@link CancellationException} will be emitted.
+ *
* @param key identifies the task.
* @param force re-execute a finished task if set to {@code true}.
* @return a {@link Single} which turns to completed once the task is finished or propagates the
@@ -234,6 +266,11 @@ public Single execute(KeyT key, Single task, boolean force) {
return Single.create(
emitter -> {
synchronized (lock) {
+ if (state != STATE_ACTIVE) {
+ emitter.onError(new CancellationException("already shutdown"));
+ return;
+ }
+
if (!force && finished.containsKey(key)) {
emitter.onSuccess(finished.get(key));
return;
@@ -273,6 +310,72 @@ public void onError(@NonNull Throwable e) {
});
}
+ /**
+ * Shuts the cache down. Any in progress tasks will continue running while new tasks will be
+ * injected with {@link CancellationException}.
+ */
+ public void shutdown() {
+ synchronized (lock) {
+ if (state == STATE_ACTIVE) {
+ state = STATE_PENDING_SHUTDOWN;
+ maybeNotifyTermination();
+ }
+ }
+ }
+
+ /** Returns a {@link Completable} which will complete once all the in progress tasks finished. */
+ public Completable awaitTermination() {
+ return Completable.create(
+ emitter -> {
+ synchronized (lock) {
+ if (state == STATE_SHUTDOWN) {
+ emitter.onComplete();
+ } else {
+ terminationSubscriber.add(emitter);
+
+ emitter.setCancellable(
+ () -> {
+ synchronized (lock) {
+ if (state != STATE_SHUTDOWN) {
+ terminationSubscriber.remove(emitter);
+ }
+ }
+ });
+ }
+ }
+ });
+ }
+
+ /**
+ * Shuts the cache down. All in progress and new tasks will be cancelled with {@link
+ * CancellationException}.
+ */
+ public void shutdownNow() {
+ shutdown();
+
+ synchronized (lock) {
+ if (state == STATE_PENDING_SHUTDOWN) {
+ for (Execution execution : ImmutableList.copyOf(inProgress.values())) {
+ execution.cancel();
+ }
+ }
+ }
+
+ awaitTermination().blockingAwait();
+ }
+
+ @GuardedBy("lock")
+ private void maybeNotifyTermination() {
+ if (state == STATE_PENDING_SHUTDOWN && inProgress.isEmpty()) {
+ state = STATE_SHUTDOWN;
+
+ for (CompletableEmitter emitter : terminationSubscriber) {
+ emitter.onComplete();
+ }
+ terminationSubscriber.clear();
+ }
+ }
+
/** An {@link AsyncTaskCache} without result. */
public static final class NoResult {
private final AsyncTaskCache> cache;
@@ -311,5 +414,28 @@ public ImmutableSet getInProgressTasks() {
public int getSubscriberCount(KeyT key) {
return cache.getSubscriberCount(key);
}
+
+ /**
+ * Shuts the cache down. Any in progress tasks will continue running while new tasks will be
+ * injected with {@link CancellationException}.
+ */
+ public void shutdown() {
+ cache.shutdown();
+ }
+
+ /**
+ * Returns a {@link Completable} which will complete once all the in progress tasks finished.
+ */
+ public Completable awaitTermination() {
+ return cache.awaitTermination();
+ }
+
+ /**
+ * Shuts the cache down. All in progress and active tasks will be cancelled with {@link
+ * CancellationException}.
+ */
+ public void shutdownNow() {
+ cache.shutdownNow();
+ }
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java
index 279c342baa7f94..b40f6f28be22ef 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java
@@ -22,6 +22,7 @@
import io.reactivex.rxjava3.observers.TestObserver;
import java.io.IOException;
import java.util.Random;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@@ -386,4 +387,82 @@ public void execute_executeWithFutureAndCancelLoop_noErrors() throws Throwable {
throw error.get();
}
}
+
+ @Test
+ public void execute_pendingShutdown_getCancellationError() {
+ AsyncTaskCache cache = AsyncTaskCache.create();
+ cache
+ .executeIfNot(
+ "key1",
+ Single.create(
+ emitter -> {
+ // never complete
+ }))
+ .test()
+ .assertNotComplete();
+ cache.shutdown();
+ cache.awaitTermination().test().assertNotComplete();
+
+ TestObserver ob = cache.executeIfNot("key2", Single.just("value2")).test();
+
+ ob.assertError(e -> e instanceof CancellationException);
+ }
+
+ @Test
+ public void execute_afterShutdown_getCancellationError() {
+ AsyncTaskCache cache = AsyncTaskCache.create();
+ cache.shutdown();
+ cache.awaitTermination().blockingAwait();
+
+ TestObserver ob = cache.executeIfNot("key", Single.just("value")).test();
+
+ ob.assertError(e -> e instanceof CancellationException);
+ }
+
+ @Test
+ public void shutdownNow_cancelInProgressTasks() {
+ AsyncTaskCache cache = AsyncTaskCache.create();
+ TestObserver ob =
+ cache
+ .executeIfNot(
+ "key",
+ Single.create(
+ emitter -> {
+ // never complete
+ }))
+ .test();
+ cache.shutdown();
+ cache.awaitTermination().test().assertNotComplete();
+ ob.assertNotComplete();
+
+ cache.shutdownNow();
+
+ ob.assertError(e -> e instanceof CancellationException);
+ cache.awaitTermination().test().assertComplete();
+ }
+
+ @Test
+ public void awaitTermination_pendingShutdown_completeAfterTaskFinished() {
+ AsyncTaskCache cache = AsyncTaskCache.create();
+ AtomicReference> emitterRef = new AtomicReference<>(null);
+ cache.executeIfNot("key", Single.create(emitterRef::set)).test().assertNotComplete();
+ assertThat(emitterRef.get()).isNotNull();
+ cache.shutdown();
+
+ TestObserver ob = cache.awaitTermination().test();
+ ob.assertNotComplete();
+ emitterRef.get().onSuccess("value");
+
+ ob.assertComplete();
+ }
+
+ @Test
+ public void awaitTermination_afterShutdown_complete() {
+ AsyncTaskCache cache = AsyncTaskCache.create();
+ cache.shutdownNow();
+
+ TestObserver ob = cache.awaitTermination().test();
+
+ ob.assertComplete();
+ }
}