Skip to content

Commit

Permalink
Remote: Use parameters instead of thread-local storage to provide tra…
Browse files Browse the repository at this point in the history
…cing metadata. (Part 5)

Change MissingDigestsFinder#findMissingDigests and RemoteExecutionClient#executeRemotely to use RemoteActionExecutionContext.

Removed all the usages of io.grpc.Context in the client code.

Fixed the regression about NetworkTime introduced by bazelbuild@bc54c64.

PiperOrigin-RevId: 354479787
  • Loading branch information
Googler authored and philwo committed Mar 15, 2021
1 parent c378d9d commit 71e35b1
Show file tree
Hide file tree
Showing 37 changed files with 660 additions and 775 deletions.
Expand Up @@ -30,13 +30,10 @@
import com.google.devtools.build.lib.buildeventstream.PathConverter;
import com.google.devtools.build.lib.collect.ImmutableIterable;
import com.google.devtools.build.lib.remote.common.MissingDigestsFinder;
import com.google.devtools.build.lib.remote.common.NetworkTime;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContextImpl;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.vfs.Path;
import io.grpc.Context;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
import java.io.IOException;
Expand Down Expand Up @@ -161,7 +158,9 @@ private static List<PathMetadata> processQueryResult(
*/
private ListenableFuture<ImmutableIterable<PathMetadata>> queryRemoteCache(
ImmutableList<ListenableFuture<PathMetadata>> allPaths) throws Exception {
Context ctx = TracingMetadataUtils.contextWithMetadata(buildRequestId, commandId, "bes-upload");
RequestMetadata metadata =
TracingMetadataUtils.buildMetadata(buildRequestId, commandId, "bes-upload");
RemoteActionExecutionContext context = RemoteActionExecutionContext.create(metadata);

List<PathMetadata> knownRemotePaths = new ArrayList<>(allPaths.size());
List<PathMetadata> filesToQuery = new ArrayList<>();
Expand All @@ -181,7 +180,7 @@ private ListenableFuture<ImmutableIterable<PathMetadata>> queryRemoteCache(
return Futures.immediateFuture(ImmutableIterable.from(knownRemotePaths));
}
return Futures.transform(
ctx.call(() -> missingDigestsFinder.findMissingDigests(digestsToQuery)),
missingDigestsFinder.findMissingDigests(context, digestsToQuery),
(missingDigests) -> {
List<PathMetadata> filesToQueryUpdated = processQueryResult(missingDigests, filesToQuery);
return ImmutableIterable.from(Iterables.concat(knownRemotePaths, filesToQueryUpdated));
Expand All @@ -197,8 +196,7 @@ private ListenableFuture<List<PathMetadata>> uploadLocalFiles(
ImmutableIterable<PathMetadata> allPaths) {
RequestMetadata metadata =
TracingMetadataUtils.buildMetadata(buildRequestId, commandId, "bes-upload");
RemoteActionExecutionContext context =
new RemoteActionExecutionContextImpl(metadata, new NetworkTime());
RemoteActionExecutionContext context = RemoteActionExecutionContext.create(metadata);

ImmutableList.Builder<ListenableFuture<PathMetadata>> allPathsUploaded =
ImmutableList.builder();
Expand Down
Expand Up @@ -208,7 +208,6 @@ void shutdown() {
* boolean)} instead.
*/
@Deprecated
@VisibleForTesting
public ListenableFuture<Void> uploadBlobAsync(
RemoteActionExecutionContext context, HashCode hash, Chunker chunker, boolean forceUpload) {
Digest digest =
Expand Down
Expand Up @@ -19,13 +19,15 @@
import build.bazel.remote.execution.v2.ExecuteResponse;
import build.bazel.remote.execution.v2.ExecutionGrpc;
import build.bazel.remote.execution.v2.ExecutionGrpc.ExecutionBlockingStub;
import build.bazel.remote.execution.v2.RequestMetadata;
import build.bazel.remote.execution.v2.WaitExecutionRequest;
import com.google.common.base.Preconditions;
import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
import com.google.devtools.build.lib.remote.Retrier.Backoff;
import com.google.devtools.build.lib.remote.common.OperationObserver;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.common.RemoteExecutionClient;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
Expand Down Expand Up @@ -71,9 +73,9 @@ public ExperimentalGrpcRemoteExecutor(
this.retrier = retrier;
}

private ExecutionBlockingStub executionBlockingStub() {
private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata) {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
.withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(remoteOptions.remoteTimeout.getSeconds(), SECONDS);
}
Expand Down Expand Up @@ -310,11 +312,16 @@ static ExecuteResponse extractResponseOrThrowIfError(Operation operation) throws
}

@Override
public ExecuteResponse executeRemotely(ExecuteRequest request, OperationObserver observer)
public ExecuteResponse executeRemotely(
RemoteActionExecutionContext context, ExecuteRequest request, OperationObserver observer)
throws IOException, InterruptedException {
Execution execution =
new Execution(
request, observer, retrier, callCredentialsProvider, this::executionBlockingStub);
request,
observer,
retrier,
callCredentialsProvider,
() -> this.executionBlockingStub(context.getRequestMetadata()));
return execution.start();
}

Expand Down
Expand Up @@ -56,7 +56,6 @@
import com.google.devtools.build.lib.remote.util.Utils;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
Expand Down Expand Up @@ -122,9 +121,11 @@ private int computeMaxMissingBlobsDigestsPerMessage() {
return (options.maxOutboundMessageSize - overhead) / digestSize;
}

private ContentAddressableStorageFutureStub casFutureStub() {
private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionContext context) {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(
TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
new NetworkTimeInterceptor(context::getNetworkTime))
.withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}
Expand Down Expand Up @@ -176,7 +177,8 @@ public static boolean isRemoteCacheOptions(RemoteOptions options) {
}

@Override
public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(Iterable<Digest> digests) {
public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(
RemoteActionExecutionContext context, Iterable<Digest> digests) {
if (Iterables.isEmpty(digests)) {
return Futures.immediateFuture(ImmutableSet.of());
}
Expand All @@ -187,13 +189,13 @@ public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(Iterable<Digest
for (Digest digest : digests) {
requestBuilder.addBlobDigests(digest);
if (requestBuilder.getBlobDigestsCount() == maxMissingBlobsDigestsPerMessage) {
getMissingDigestCalls.add(getMissingDigests(requestBuilder.build()));
getMissingDigestCalls.add(getMissingDigests(context, requestBuilder.build()));
requestBuilder.clearBlobDigests();
}
}

if (requestBuilder.getBlobDigestsCount() > 0) {
getMissingDigestCalls.add(getMissingDigests(requestBuilder.build()));
getMissingDigestCalls.add(getMissingDigests(context, requestBuilder.build()));
}

ListenableFuture<ImmutableSet<Digest>> success =
Expand All @@ -209,7 +211,7 @@ public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(Iterable<Digest
},
MoreExecutors.directExecutor());

RequestMetadata requestMetadata = TracingMetadataUtils.fromCurrentContext();
RequestMetadata requestMetadata = context.getRequestMetadata();
return Futures.catchingAsync(
success,
RuntimeException.class,
Expand All @@ -226,10 +228,9 @@ public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(Iterable<Digest
}

private ListenableFuture<FindMissingBlobsResponse> getMissingDigests(
FindMissingBlobsRequest request) {
Context ctx = Context.current();
RemoteActionExecutionContext context, FindMissingBlobsRequest request) {
return Utils.refreshIfUnauthenticatedAsync(
() -> retrier.executeAsync(() -> ctx.call(() -> casFutureStub().findMissingBlobs(request))),
() -> retrier.executeAsync(() -> casFutureStub(context).findMissingBlobs(request)),
callCredentialsProvider);
}

Expand Down
Expand Up @@ -18,11 +18,13 @@
import build.bazel.remote.execution.v2.ExecuteResponse;
import build.bazel.remote.execution.v2.ExecutionGrpc;
import build.bazel.remote.execution.v2.ExecutionGrpc.ExecutionBlockingStub;
import build.bazel.remote.execution.v2.RequestMetadata;
import build.bazel.remote.execution.v2.WaitExecutionRequest;
import com.google.common.base.Preconditions;
import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.remote.common.OperationObserver;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.common.RemoteExecutionClient;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.remote.util.Utils;
Expand Down Expand Up @@ -55,9 +57,9 @@ public GrpcRemoteExecutor(
this.retrier = retrier;
}

private ExecutionBlockingStub execBlockingStub() {
private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata) {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
.withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
.withCallCredentials(callCredentialsProvider.getCallCredentials());
}

Expand Down Expand Up @@ -105,7 +107,8 @@ private ExecuteResponse getOperationResponse(Operation op) throws IOException {
* trigger a retry of the Execute call, resulting in a new Operation.
* */
@Override
public ExecuteResponse executeRemotely(ExecuteRequest request, OperationObserver observer)
public ExecuteResponse executeRemotely(
RemoteActionExecutionContext context, ExecuteRequest request, OperationObserver observer)
throws IOException, InterruptedException {
// Execute has two components: the Execute call and (optionally) the WaitExecution call.
// This is the simple flow without any errors:
Expand Down Expand Up @@ -149,9 +152,9 @@ public ExecuteResponse executeRemotely(ExecuteRequest request, OperationObserver
WaitExecutionRequest.newBuilder()
.setName(operation.get().getName())
.build();
replies = execBlockingStub().waitExecution(wr);
replies = execBlockingStub(context.getRequestMetadata()).waitExecution(wr);
} else {
replies = execBlockingStub().execute(request);
replies = execBlockingStub(context.getRequestMetadata()).execute(request);
}
try {
while (replies.hasNext()) {
Expand Down
Expand Up @@ -19,7 +19,6 @@
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Context;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
Expand All @@ -30,7 +29,6 @@
/** The ClientInterceptor used to track network time. */
public class NetworkTimeInterceptor implements ClientInterceptor {

public static final Context.Key<NetworkTime> CONTEXT_KEY = Context.key("remote-network-time");
private final Supplier<NetworkTime> networkTimeSupplier;

public NetworkTimeInterceptor(Supplier<NetworkTime> networkTimeSupplier) {
Expand Down
Expand Up @@ -33,15 +33,12 @@
import com.google.devtools.build.lib.profiler.ProfilerTask;
import com.google.devtools.build.lib.profiler.SilentCloseable;
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
import com.google.devtools.build.lib.remote.common.NetworkTime;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContextImpl;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.remote.util.Utils;
import com.google.devtools.build.lib.sandbox.SandboxHelpers;
import com.google.devtools.build.lib.vfs.Path;
import io.grpc.Context;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -167,48 +164,42 @@ private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue me
if (download == null) {
RequestMetadata requestMetadata =
TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId());
RemoteActionExecutionContext remoteActionExecutionContext =
new RemoteActionExecutionContextImpl(requestMetadata, new NetworkTime());
Context ctx = TracingMetadataUtils.contextWithMetadata(requestMetadata);
Context prevCtx = ctx.attach();
try {
Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
download = remoteCache.downloadFile(remoteActionExecutionContext, path, digest);
downloadsInProgress.put(path, download);
Futures.addCallback(
download,
new FutureCallback<Void>() {
@Override
public void onSuccess(Void v) {
synchronized (lock) {
downloadsInProgress.remove(path);
downloadedPaths.add(path);
}

try {
path.chmod(0755);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
}
RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata);

Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
download = remoteCache.downloadFile(context, path, digest);
downloadsInProgress.put(path, download);
Futures.addCallback(
download,
new FutureCallback<Void>() {
@Override
public void onSuccess(Void v) {
synchronized (lock) {
downloadsInProgress.remove(path);
downloadedPaths.add(path);
}

@Override
public void onFailure(Throwable t) {
synchronized (lock) {
downloadsInProgress.remove(path);
}
try {
path.delete();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"Failed to delete output file after incomplete download: %s", path);
}
try {
path.chmod(0755);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
}
},
MoreExecutors.directExecutor());
} finally {
ctx.detach(prevCtx);
}
}

@Override
public void onFailure(Throwable t) {
synchronized (lock) {
downloadsInProgress.remove(path);
}
try {
path.delete();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"Failed to delete output file after incomplete download: %s", path);
}
}
},
MoreExecutors.directExecutor());
}
return download;
}
Expand Down
Expand Up @@ -188,7 +188,8 @@ private void uploadOutputs(
digests.addAll(digestToFile.keySet());
digests.addAll(digestToBlobs.keySet());

ImmutableSet<Digest> digestsToUpload = getFromFuture(cacheProtocol.findMissingDigests(digests));
ImmutableSet<Digest> digestsToUpload =
getFromFuture(cacheProtocol.findMissingDigests(context, digests));
ImmutableList.Builder<ListenableFuture<Void>> uploads = ImmutableList.builder();
for (Digest digest : digestsToUpload) {
Path file = digestToFile.get(digest);
Expand Down
Expand Up @@ -62,7 +62,7 @@ public void ensureInputsPresent(
Iterable<Digest> allDigests =
Iterables.concat(merkleTree.getAllDigests(), additionalInputs.keySet());
ImmutableSet<Digest> missingDigests =
getFromFuture(cacheProtocol.findMissingDigests(allDigests));
getFromFuture(cacheProtocol.findMissingDigests(context, allDigests));

List<ListenableFuture<Void>> uploadFutures = new ArrayList<>();
for (Digest missingDigest : missingDigests) {
Expand Down

0 comments on commit 71e35b1

Please sign in to comment.