Skip to content

Commit

Permalink
Remove usage of gRPC Context cancellation in the remote execution cli…
Browse files Browse the repository at this point in the history
…ent.

The gRPC remote execution client frequently "converts" gRPC calls into `ListenableFuture`s by setting a `SettableFuture` in the `onCompleted` or `onError` gRPC stub callbacks. If the future has direct executor callbacks, those callbacks will execute with the gRPC Context of the freshly completed call. That is problematic if the `Context` was canceled (canceling the call `Context` is good hygiene after completing a gRPC call), and the future callback goes to make further gRPC calls.

Therefore, this change removes all usage of gRPC `Context` cancellation. It would be nice if there was instead some way to avoid leaking `Context`s between calls instead of having totally forswear `Context` cancellation. However, I can't see a good way to enforce proper isolation.

Fixes #17298.

Closes #17426.

PiperOrigin-RevId: 507730469
Change-Id: Iea74acad4592952700e41d34672f6478de509d5e
  • Loading branch information
benjaminp authored and Copybara-Service committed Feb 7, 2023
1 parent 7a9a2f8 commit ba9e2f8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 97 deletions.
Expand Up @@ -42,8 +42,6 @@
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.remote.util.Utils;
import io.grpc.Channel;
import io.grpc.Context;
import io.grpc.Context.CancellableContext;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
Expand Down Expand Up @@ -231,7 +229,6 @@ private ListenableFuture<Void> startAsyncUpload(
ListenableFuture<Void> currUpload = newUpload.start();
currUpload.addListener(
() -> {
newUpload.cancel();
if (openedFilePermits != null) {
openedFilePermits.release();
}
Expand All @@ -249,7 +246,6 @@ private static final class AsyncUpload implements AsyncCallable<Long> {
private final String resourceName;
private final Chunker chunker;
private final ProgressiveBackoff progressiveBackoff;
private final CancellableContext grpcContext;

private long lastCommittedOffset = -1;

Expand All @@ -269,7 +265,6 @@ private static final class AsyncUpload implements AsyncCallable<Long> {
this.progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
this.resourceName = resourceName;
this.chunker = chunker;
this.grpcContext = Context.current().withCancellation();
}

ListenableFuture<Void> start() {
Expand Down Expand Up @@ -369,13 +364,11 @@ private ListenableFuture<Long> query() {
Futures.transform(
channel.withChannelFuture(
channel ->
grpcContext.call(
() ->
bsFutureStub(channel)
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder()
.setResourceName(resourceName)
.build()))),
bsFutureStub(channel)
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder()
.setResourceName(resourceName)
.build())),
QueryWriteStatusResponse::getCommittedSize,
MoreExecutors.directExecutor());
return Futures.catchingAsync(
Expand All @@ -397,18 +390,10 @@ private ListenableFuture<Long> upload(long pos) {
return channel.withChannelFuture(
channel -> {
SettableFuture<Long> uploadResult = SettableFuture.create();
grpcContext.run(
() ->
bsAsyncStub(channel)
.write(new Writer(resourceName, chunker, pos, uploadResult)));
bsAsyncStub(channel).write(new Writer(resourceName, chunker, pos, uploadResult));
return uploadResult;
});
}

void cancel() {
grpcContext.cancel(
Status.CANCELLED.withDescription("Cancelled by user").asRuntimeException());
}
}

private static final class Writer
Expand All @@ -432,6 +417,13 @@ private Writer(
@Override
public void beforeStart(ClientCallStreamObserver<WriteRequest> requestObserver) {
this.requestObserver = requestObserver;
uploadResult.addListener(
() -> {
if (uploadResult.isCancelled()) {
requestObserver.cancel("cancelled by user", null);
}
},
MoreExecutors.directExecutor());
requestObserver.setOnReadyHandler(this);
}

Expand Down
Expand Up @@ -34,6 +34,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ascii;
import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
Expand All @@ -58,11 +59,11 @@
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
Expand Down Expand Up @@ -371,81 +372,87 @@ private ListenableFuture<Long> requestRead(
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
Context.CancellableContext grpcContext = Context.current().withCancellation();
future.addListener(() -> grpcContext.cancel(null), MoreExecutors.directExecutor());
grpcContext.run(
() ->
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {
@Override
public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
}
// reset the stall backoff because we've made progress or been kept alive
progressiveBackoff.reset();
}

@Override
public void onError(Throwable t) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an
// error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
} else {
future.setException(t);
}
}

@Override
public void onCompleted() {
try {
try {
out.flush();
} finally {
releaseOut();
}
if (digestSupplier != null) {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
}
future.set(rawOut.getCount());
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(rawOut.getCount())
.build(),
new ClientResponseObserver<ReadRequest, ReadResponse>() {
@Override
public void beforeStart(ClientCallStreamObserver<ReadRequest> requestStream) {
future.addListener(
() -> {
if (future.isCancelled()) {
requestStream.cancel("canceled by user", null);
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log(
"failed to cleanly close output stream");
}
}
}
}));
},
MoreExecutors.directExecutor());
}

@Override
public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
} catch (IOException e) {
// Cancel the call.
throw new VerifyException(e);
}
// reset the stall backoff because we've made progress or been kept alive
progressiveBackoff.reset();
}

@Override
public void onError(Throwable t) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an
// error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
} else {
future.setException(t);
}
}

@Override
public void onCompleted() {
try {
try {
out.flush();
} finally {
releaseOut();
}
if (digestSupplier != null) {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
}
future.set(rawOut.getCount());
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
}
}
});
return future;
}

Expand Down

0 comments on commit ba9e2f8

Please sign in to comment.