Skip to content

Commit

Permalink
Remote: Use AsyncTaskCache inside RemoteActionInputFetcher.
Browse files Browse the repository at this point in the history
When using the dynamic scheduler, local actions may get interrupted or
cancelled when the remote strategy is faster (e.g., remote cache hit).
Ordinarily this isn't a problem, except when the local action is sharing
a file download future with another local action. The interrupted thread
of the local action cancels the future, and causes a CancellationExeception
when the other local action thread tries to retrieve it.

This resolves that problem by not letting threads/callers share the same
future instance. The shared download future is only cancelled if all
callers have requested cancellation.

Fixes bazelbuild#12927.

PiperOrigin-RevId: 362009791
  • Loading branch information
coeuvre authored and philwo committed Mar 15, 2021
1 parent ccad56c commit 9d0c732
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 114 deletions.
Expand Up @@ -19,8 +19,6 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.GoogleLogger;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.actions.ActionInput;
Expand All @@ -34,17 +32,17 @@
import com.google.devtools.build.lib.profiler.SilentCloseable;
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.util.AsyncTaskCache;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.RxFutures;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.remote.util.Utils;
import com.google.devtools.build.lib.sandbox.SandboxHelpers;
import com.google.devtools.build.lib.vfs.Path;
import io.reactivex.rxjava3.core.Completable;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.concurrent.GuardedBy;

/**
* Stages output files that are stored remotely to the local filesystem.
Expand All @@ -55,17 +53,10 @@
class RemoteActionInputFetcher implements ActionInputPrefetcher {

private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
private final AsyncTaskCache.NoResult<Path> downloadCache = AsyncTaskCache.NoResult.create();

private final Object lock = new Object();

/** Set of successfully downloaded output files. */
@GuardedBy("lock")
private final Set<Path> downloadedPaths = new HashSet<>();

@VisibleForTesting
@GuardedBy("lock")
final Map<Path, ListenableFuture<Void>> downloadsInProgress = new HashMap<>();

private final String buildRequestId;
private final String commandId;
private final RemoteCache remoteCache;
Expand Down Expand Up @@ -110,11 +101,8 @@ public void prefetchFiles(

Path path = execRoot.getRelative(input.getExecPath());
synchronized (lock) {
if (downloadedPaths.contains(path)) {
continue;
}
ListenableFuture<Void> download = downloadFileAsync(path, metadata);
downloadsToWaitFor.putIfAbsent(path, download);
downloadsToWaitFor.computeIfAbsent(
path, key -> RxFutures.toListenableFuture(downloadFileAsync(path, metadata)));
}
}
}
Expand Down Expand Up @@ -143,65 +131,59 @@ public void prefetchFiles(
}

ImmutableSet<Path> downloadedFiles() {
synchronized (lock) {
return ImmutableSet.copyOf(downloadedPaths);
}
return downloadCache.getFinishedTasks();
}

ImmutableSet<Path> downloadsInProgress() {
return downloadCache.getInProgressTasks();
}

@VisibleForTesting
AsyncTaskCache.NoResult<Path> getDownloadCache() {
return downloadCache;
}

void downloadFile(Path path, FileArtifactValue metadata)
throws IOException, InterruptedException {
Utils.getFromFuture(downloadFileAsync(path, metadata));
Utils.getFromFuture(RxFutures.toListenableFuture(downloadFileAsync(path, metadata)));
}

private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
throws IOException {
synchronized (lock) {
if (downloadedPaths.contains(path)) {
return Futures.immediateFuture(null);
}
private Completable downloadFileAsync(Path path, FileArtifactValue metadata) {
Completable download =
RxFutures.toCompletable(
() -> {
RequestMetadata requestMetadata =
TracingMetadataUtils.buildMetadata(
buildRequestId, commandId, metadata.getActionId());
RemoteActionExecutionContext context =
RemoteActionExecutionContext.create(requestMetadata);

Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());

return remoteCache.downloadFile(context, path, digest);
},
MoreExecutors.directExecutor())
.doOnComplete(() -> finalizeDownload(path))
.doOnError(error -> deletePartialDownload(path))
.doOnDispose(() -> deletePartialDownload(path));

return downloadCache.executeIfNot(path, download);
}

ListenableFuture<Void> download = downloadsInProgress.get(path);
if (download == null) {
RequestMetadata requestMetadata =
TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId());
RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata);

Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
download = remoteCache.downloadFile(context, path, digest);
downloadsInProgress.put(path, download);
Futures.addCallback(
download,
new FutureCallback<Void>() {
@Override
public void onSuccess(Void v) {
synchronized (lock) {
downloadsInProgress.remove(path);
downloadedPaths.add(path);
}

try {
path.chmod(0755);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
}
}

@Override
public void onFailure(Throwable t) {
synchronized (lock) {
downloadsInProgress.remove(path);
}
try {
path.delete();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"Failed to delete output file after incomplete download: %s", path);
}
}
},
MoreExecutors.directExecutor());
}
return download;
private void finalizeDownload(Path path) {
try {
path.chmod(0755);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
}
}

private void deletePartialDownload(Path path) {
try {
path.delete();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"Failed to delete output file after incomplete download: %s", path);
}
}
}
Expand Up @@ -13,15 +13,21 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.util;

import com.google.common.base.Preconditions;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.collect.ImmutableSet;
import io.reactivex.rxjava3.annotations.NonNull;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleObserver;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.subjects.AsyncSubject;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

Expand All @@ -42,11 +48,13 @@
*/
@ThreadSafe
public final class AsyncTaskCache<KeyT, ValueT> {
@GuardedBy("this")
private final Object lock = new Object();

@GuardedBy("lock")
private final Map<KeyT, ValueT> finished;

@GuardedBy("this")
private final Map<KeyT, Observable<ValueT>> inProgress;
@GuardedBy("lock")
private final Map<KeyT, Execution> inProgress;

public static <KeyT, ValueT> AsyncTaskCache<KeyT, ValueT> create() {
return new AsyncTaskCache<>();
Expand All @@ -59,14 +67,14 @@ private AsyncTaskCache() {

/** Returns a set of keys for tasks which is finished. */
public ImmutableSet<KeyT> getFinishedTasks() {
synchronized (this) {
synchronized (lock) {
return ImmutableSet.copyOf(finished.keySet());
}
}

/** Returns a set of keys for tasks which is still executing. */
public ImmutableSet<KeyT> getInProgressTasks() {
synchronized (this) {
synchronized (lock) {
return ImmutableSet.copyOf(inProgress.keySet());
}
}
Expand All @@ -82,6 +90,65 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
return execute(key, task, false);
}

private class Execution {
private final Single<ValueT> task;
private final AsyncSubject<ValueT> asyncSubject = AsyncSubject.create();
private final AtomicInteger subscriberCount = new AtomicInteger(0);
private final AtomicReference<Disposable> taskDisposable = new AtomicReference<>(null);

Execution(Single<ValueT> task) {
this.task = task;
}

public Single<ValueT> start() {
if (taskDisposable.get() == null) {
task.subscribe(
new SingleObserver<ValueT>() {
@Override
public void onSubscribe(@NonNull Disposable d) {
taskDisposable.compareAndSet(null, d);
}

@Override
public void onSuccess(@NonNull ValueT value) {
asyncSubject.onNext(value);
asyncSubject.onComplete();
}

@Override
public void onError(@NonNull Throwable e) {
asyncSubject.onError(e);
}
});
}

return Single.fromObservable(asyncSubject)
.doOnSubscribe(d -> subscriberCount.incrementAndGet())
.doOnDispose(
() -> {
if (subscriberCount.decrementAndGet() == 0) {
Disposable d = taskDisposable.get();
if (d != null) {
d.dispose();
}
asyncSubject.onError(new CancellationException("disposed"));
}
});
}
}

/** Returns count of subscribers for a task. */
public int getSubscriberCount(KeyT key) {
synchronized (lock) {
Execution execution = inProgress.get(key);
if (execution != null) {
return execution.subscriberCount.get();
}
}

return 0;
}

/**
* Executes a task.
*
Expand All @@ -93,50 +160,47 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
return Single.defer(
() -> {
synchronized (this) {
synchronized (lock) {
if (!force && finished.containsKey(key)) {
return Single.just(finished.get(key));
}

finished.remove(key);

Observable<ValueT> execution =
Execution execution =
inProgress.computeIfAbsent(
key,
missingKey -> {
AtomicInteger subscribeTimes = new AtomicInteger(0);
return Single.defer(
() -> {
int times = subscribeTimes.incrementAndGet();
Preconditions.checkState(
times == 1, "Subscribed more than once to the task");
return task;
})
.doOnSuccess(
value -> {
synchronized (this) {
finished.put(key, value);
inProgress.remove(key);
}
})
.doOnError(
error -> {
synchronized (this) {
inProgress.remove(key);
}
})
.doOnDispose(
() -> {
synchronized (this) {
inProgress.remove(key);
}
})
.toObservable()
.publish()
.refCount();
return new Execution(
Single.defer(
() -> {
int times = subscribeTimes.incrementAndGet();
checkState(times == 1, "Subscribed more than once to the task");
return task;
})
.doOnSuccess(
value -> {
synchronized (lock) {
finished.put(key, value);
inProgress.remove(key);
}
})
.doOnError(
error -> {
synchronized (lock) {
inProgress.remove(key);
}
})
.doOnDispose(
() -> {
synchronized (lock) {
inProgress.remove(key);
}
}));
});

return Single.fromObservable(execution);
return execution.start();
}
});
}
Expand Down Expand Up @@ -174,5 +238,10 @@ public ImmutableSet<KeyT> getFinishedTasks() {
public ImmutableSet<KeyT> getInProgressTasks() {
return cache.getInProgressTasks();
}

/** Returns count of subscribers for a task. */
public int getSubscriberCount(KeyT key) {
return cache.getSubscriberCount(key);
}
}
}

0 comments on commit 9d0c732

Please sign in to comment.