Skip to content

Commit

Permalink
Remote: Don't blocking-get when acquiring gRPC connections. (#14420)
Browse files Browse the repository at this point in the history
With recent change to limit the max number of gRPC connections by default, acquiring a connection could suspend a thread if there is no available connection.

gRPC calls are scheduled to a dedicated background thread pool. Workers in the thread pool are responsible to acquire the connection before starting the RPC call.

There could be a race condition that a worker thread handles some gRPC calls and then switches to a new call which will acquire new connections. If the number of connections reaches the max, the worker thread is suspended and doesn't have a chance to switch to previous calls. The connections held by previous calls are, hence, never released.

This PR changes to not use blocking get when acquiring gRPC connections.

Fixes #14363.

Closes #14416.

PiperOrigin-RevId: 416282883
  • Loading branch information
coeuvre committed Dec 14, 2021
1 parent bfc2413 commit 5aef53a
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 166 deletions.
3 changes: 2 additions & 1 deletion src/main/java/com/google/devtools/build/lib/remote/BUILD
Expand Up @@ -138,9 +138,10 @@ java_library(
],
deps = [
"//src/main/java/com/google/devtools/build/lib/remote/grpc",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//third_party:guava",
"//third_party:jsr305",
"//third_party:netty",
"//third_party:rxjava3",
"//third_party/grpc:grpc-jar",
],
)
Expand Down
Expand Up @@ -24,6 +24,7 @@
import com.google.bytestream.ByteStreamGrpc;
import com.google.bytestream.ByteStreamGrpc.ByteStreamFutureStub;
import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest;
import com.google.bytestream.ByteStreamProto.QueryWriteStatusResponse;
import com.google.bytestream.ByteStreamProto.WriteRequest;
import com.google.bytestream.ByteStreamProto.WriteResponse;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -374,7 +375,7 @@ public ReferenceCounted touch(Object o) {
private static class AsyncUpload {

private final RemoteActionExecutionContext context;
private final Channel channel;
private final ReferenceCountedChannel channel;
private final CallCredentialsProvider callCredentialsProvider;
private final long callTimeoutSecs;
private final Retrier retrier;
Expand All @@ -385,7 +386,7 @@ private static class AsyncUpload {

AsyncUpload(
RemoteActionExecutionContext context,
Channel channel,
ReferenceCountedChannel channel,
CallCredentialsProvider callCredentialsProvider,
long callTimeoutSecs,
Retrier retrier,
Expand Down Expand Up @@ -452,7 +453,7 @@ ListenableFuture<Void> start() {
MoreExecutors.directExecutor());
}

private ByteStreamFutureStub bsFutureStub() {
private ByteStreamFutureStub bsFutureStub(Channel channel) {
return ByteStreamGrpc.newFutureStub(channel)
.withInterceptors(
TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()))
Expand All @@ -463,7 +464,10 @@ private ByteStreamFutureStub bsFutureStub() {
private ListenableFuture<Void> callAndQueryOnFailure(
AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
return Futures.catchingAsync(
call(committedOffset),
Futures.transform(
channel.withChannelFuture(channel -> call(committedOffset, channel)),
written -> null,
MoreExecutors.directExecutor()),
Exception.class,
(e) -> guardQueryWithSuppression(e, committedOffset, progressiveBackoff),
MoreExecutors.directExecutor());
Expand Down Expand Up @@ -500,10 +504,14 @@ private ListenableFuture<Void> query(
AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
ListenableFuture<Long> committedSizeFuture =
Futures.transform(
bsFutureStub()
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder().setResourceName(resourceName).build()),
(response) -> response.getCommittedSize(),
channel.withChannelFuture(
channel ->
bsFutureStub(channel)
.queryWriteStatus(
QueryWriteStatusRequest.newBuilder()
.setResourceName(resourceName)
.build())),
QueryWriteStatusResponse::getCommittedSize,
MoreExecutors.directExecutor());
ListenableFuture<Long> guardedCommittedSizeFuture =
Futures.catchingAsync(
Expand Down Expand Up @@ -533,14 +541,14 @@ private ListenableFuture<Void> query(
MoreExecutors.directExecutor());
}

private ListenableFuture<Void> call(AtomicLong committedOffset) {
private ListenableFuture<Long> call(AtomicLong committedOffset, Channel channel) {
CallOptions callOptions =
CallOptions.DEFAULT
.withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(callTimeoutSecs, SECONDS);
call = channel.newCall(ByteStreamGrpc.getWriteMethod(), callOptions);

SettableFuture<Void> uploadResult = SettableFuture.create();
SettableFuture<Long> uploadResult = SettableFuture.create();
ClientCall.Listener<WriteResponse> callListener =
new ClientCall.Listener<WriteResponse>() {

Expand Down Expand Up @@ -568,7 +576,7 @@ public void onMessage(WriteResponse response) {
@Override
public void onClose(Status status, Metadata trailers) {
if (status.isOk()) {
uploadResult.set(null);
uploadResult.set(committedOffset.get());
} else {
uploadResult.setException(status.asRuntimeException());
}
Expand Down
Expand Up @@ -35,12 +35,13 @@
import com.google.longrunning.Operation;
import com.google.longrunning.Operation.ResultCase;
import com.google.rpc.Status;
import io.grpc.Channel;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.reactivex.rxjava3.functions.Function;
import java.io.IOException;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -73,7 +74,7 @@ public ExperimentalGrpcRemoteExecutor(
this.retrier = retrier;
}

private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata) {
private ExecutionBlockingStub executionBlockingStub(RequestMetadata metadata, Channel channel) {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
.withCallCredentials(callCredentialsProvider.getCallCredentials())
Expand All @@ -90,7 +91,8 @@ private static class Execution {
// Count retry times for WaitExecution() calls and is reset when we receive any response from
// the server that is not an error.
private final ProgressiveBackoff waitExecutionBackoff;
private final Supplier<ExecutionBlockingStub> executionBlockingStubSupplier;
private final Function<ExecuteRequest, Iterator<Operation>> executeFunction;
private final Function<WaitExecutionRequest, Iterator<Operation>> waitExecutionFunction;

// Last response (without error) we received from server.
private Operation lastOperation;
Expand All @@ -100,14 +102,16 @@ private static class Execution {
OperationObserver observer,
RemoteRetrier retrier,
CallCredentialsProvider callCredentialsProvider,
Supplier<ExecutionBlockingStub> executionBlockingStubSupplier) {
Function<ExecuteRequest, Iterator<Operation>> executeFunction,
Function<WaitExecutionRequest, Iterator<Operation>> waitExecutionFunction) {
this.request = request;
this.observer = observer;
this.retrier = retrier;
this.callCredentialsProvider = callCredentialsProvider;
this.executeBackoff = this.retrier.newBackoff();
this.waitExecutionBackoff = new ProgressiveBackoff(this.retrier::newBackoff);
this.executionBlockingStubSupplier = executionBlockingStubSupplier;
this.executeFunction = executeFunction;
this.waitExecutionFunction = waitExecutionFunction;
}

ExecuteResponse start() throws IOException, InterruptedException {
Expand Down Expand Up @@ -168,9 +172,9 @@ ExecuteResponse execute() throws IOException {
Preconditions.checkState(lastOperation == null);

try {
Iterator<Operation> operationStream = executionBlockingStubSupplier.get().execute(request);
Iterator<Operation> operationStream = executeFunction.apply(request);
return handleOperationStream(operationStream);
} catch (StatusRuntimeException e) {
} catch (Throwable e) {
// If lastOperation is not null, we know the execution request is accepted by the server. In
// this case, we will fallback to WaitExecution() loop when the stream is broken.
if (lastOperation != null) {
Expand All @@ -188,17 +192,20 @@ ExecuteResponse waitExecution() throws IOException {
WaitExecutionRequest request =
WaitExecutionRequest.newBuilder().setName(lastOperation.getName()).build();
try {
Iterator<Operation> operationStream =
executionBlockingStubSupplier.get().waitExecution(request);
Iterator<Operation> operationStream = waitExecutionFunction.apply(request);
return handleOperationStream(operationStream);
} catch (StatusRuntimeException e) {
} catch (Throwable e) {
// A NOT_FOUND error means Operation was lost on the server, retry Execute().
//
// However, we only retry Execute() if executeBackoff should retry. Also increase the retry
// counter at the same time (done by nextDelayMillis()).
if (e.getStatus().getCode() == Code.NOT_FOUND && executeBackoff.nextDelayMillis(e) >= 0) {
lastOperation = null;
return null;
if (e instanceof StatusRuntimeException) {
StatusRuntimeException sre = (StatusRuntimeException) e;
if (sre.getStatus().getCode() == Code.NOT_FOUND
&& executeBackoff.nextDelayMillis(sre) >= 0) {
lastOperation = null;
return null;
}
}
throw new IOException(e);
}
Expand Down Expand Up @@ -321,7 +328,16 @@ public ExecuteResponse executeRemotely(
observer,
retrier,
callCredentialsProvider,
() -> this.executionBlockingStub(context.getRequestMetadata()));
(req) ->
channel.withChannelBlocking(
channel ->
this.executionBlockingStub(context.getRequestMetadata(), channel)
.execute(req)),
(req) ->
channel.withChannelBlocking(
channel ->
this.executionBlockingStub(context.getRequestMetadata(), channel)
.waitExecution(req)));
return execution.start();
}

Expand Down
Expand Up @@ -56,6 +56,7 @@
import com.google.devtools.build.lib.remote.zstd.ZstdDecompressingOutputStream;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
Expand Down Expand Up @@ -122,7 +123,8 @@ private int computeMaxMissingBlobsDigestsPerMessage() {
return (options.maxOutboundMessageSize - overhead) / digestSize;
}

private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionContext context) {
private ContentAddressableStorageFutureStub casFutureStub(
RemoteActionExecutionContext context, Channel channel) {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(
TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
Expand All @@ -131,7 +133,7 @@ private ContentAddressableStorageFutureStub casFutureStub(RemoteActionExecutionC
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}

private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) {
private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context, Channel channel) {
return ByteStreamGrpc.newStub(channel)
.withInterceptors(
TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
Expand All @@ -140,7 +142,8 @@ private ByteStreamStub bsAsyncStub(RemoteActionExecutionContext context) {
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}

private ActionCacheFutureStub acFutureStub(RemoteActionExecutionContext context) {
private ActionCacheFutureStub acFutureStub(
RemoteActionExecutionContext context, Channel channel) {
return ActionCacheGrpc.newFutureStub(channel)
.withInterceptors(
TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()),
Expand Down Expand Up @@ -222,7 +225,11 @@ public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(
private ListenableFuture<FindMissingBlobsResponse> getMissingDigests(
RemoteActionExecutionContext context, FindMissingBlobsRequest request) {
return Utils.refreshIfUnauthenticatedAsync(
() -> retrier.executeAsync(() -> casFutureStub(context).findMissingBlobs(request)),
() ->
retrier.executeAsync(
() ->
channel.withChannelFuture(
channel -> casFutureStub(context, channel).findMissingBlobs(request))),
callCredentialsProvider);
}

Expand Down Expand Up @@ -254,7 +261,10 @@ public ListenableFuture<CachedActionResult> downloadActionResult(
return Utils.refreshIfUnauthenticatedAsync(
() ->
retrier.executeAsync(
() -> handleStatus(acFutureStub(context).getActionResult(request))),
() ->
handleStatus(
channel.withChannelFuture(
channel -> acFutureStub(context, channel).getActionResult(request)))),
callCredentialsProvider);
}

Expand All @@ -267,13 +277,15 @@ public ListenableFuture<Void> uploadActionResult(
retrier.executeAsync(
() ->
Futures.catchingAsync(
acFutureStub(context)
.updateActionResult(
UpdateActionResultRequest.newBuilder()
.setInstanceName(options.remoteInstanceName)
.setActionDigest(actionKey.getDigest())
.setActionResult(actionResult)
.build()),
channel.withChannelFuture(
channel ->
acFutureStub(context, channel)
.updateActionResult(
UpdateActionResultRequest.newBuilder()
.setInstanceName(options.remoteInstanceName)
.setActionDigest(actionKey.getDigest())
.setActionResult(actionResult)
.build())),
StatusRuntimeException.class,
(sre) -> Futures.immediateFailedFuture(new IOException(sre)),
MoreExecutors.directExecutor())),
Expand Down Expand Up @@ -317,18 +329,26 @@ private ListenableFuture<Void> downloadBlob(
@Nullable Supplier<Digest> digestSupplier) {
AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
ListenableFuture<Void> downloadFuture =
ListenableFuture<Long> downloadFuture =
Utils.refreshIfUnauthenticatedAsync(
() ->
retrier.executeAsync(
() ->
requestRead(
context, offset, progressiveBackoff, digest, out, digestSupplier),
channel.withChannelFuture(
channel ->
requestRead(
context,
offset,
progressiveBackoff,
digest,
out,
digestSupplier,
channel)),
progressiveBackoff),
callCredentialsProvider);

return Futures.catchingAsync(
downloadFuture,
Futures.transform(downloadFuture, bytesWritten -> null, MoreExecutors.directExecutor()),
StatusRuntimeException.class,
(e) -> Futures.immediateFailedFuture(new IOException(e)),
MoreExecutors.directExecutor());
Expand All @@ -343,17 +363,18 @@ public static String getResourceName(String instanceName, Digest digest, boolean
return resourceName + DigestUtil.toString(digest);
}

private ListenableFuture<Void> requestRead(
private ListenableFuture<Long> requestRead(
RemoteActionExecutionContext context,
AtomicLong offset,
ProgressiveBackoff progressiveBackoff,
Digest digest,
CountingOutputStream out,
@Nullable Supplier<Digest> digestSupplier) {
@Nullable Supplier<Digest> digestSupplier,
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
SettableFuture<Void> future = SettableFuture.create();
bsAsyncStub(context)
SettableFuture<Long> future = SettableFuture.create();
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
Expand Down Expand Up @@ -400,7 +421,7 @@ public void onCompleted() {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
out.flush();
future.set(null);
future.set(offset.get());
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
Expand Down
Expand Up @@ -30,6 +30,7 @@
import com.google.devtools.build.lib.remote.util.Utils;
import com.google.longrunning.Operation;
import com.google.rpc.Status;
import io.grpc.Channel;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
Expand Down Expand Up @@ -57,7 +58,7 @@ public GrpcRemoteExecutor(
this.retrier = retrier;
}

private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata) {
private ExecutionBlockingStub execBlockingStub(RequestMetadata metadata, Channel channel) {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataInterceptor(metadata))
.withCallCredentials(callCredentialsProvider.getCallCredentials());
Expand Down Expand Up @@ -152,9 +153,17 @@ public ExecuteResponse executeRemotely(
WaitExecutionRequest.newBuilder()
.setName(operation.get().getName())
.build();
replies = execBlockingStub(context.getRequestMetadata()).waitExecution(wr);
replies =
channel.withChannelBlocking(
channel ->
execBlockingStub(context.getRequestMetadata(), channel)
.waitExecution(wr));
} else {
replies = execBlockingStub(context.getRequestMetadata()).execute(request);
replies =
channel.withChannelBlocking(
channel ->
execBlockingStub(context.getRequestMetadata(), channel)
.execute(request));
}
try {
while (replies.hasNext()) {
Expand Down

0 comments on commit 5aef53a

Please sign in to comment.