Skip to content

Commit

Permalink
Add channel pool for remote execution to overcome gRPC connections li…
Browse files Browse the repository at this point in the history
…mitation.

This PR add a `ReferenceCountedChannelPool` which will create `poolSize` number of channels and round-robin across them for gRPC requests.

The `poolSize` is calculated as `jobs / 100`.

Fixes #11801.

Closes #11937.

PiperOrigin-RevId: 326619592
  • Loading branch information
coeuvre authored and Copybara-Service committed Aug 14, 2020
1 parent d65d09a commit 8d656cf
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 17 deletions.
8 changes: 7 additions & 1 deletion src/main/java/com/google/devtools/build/lib/remote/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ java_library(
exclude = [
"ExecutionStatusException.java",
"ReferenceCountedChannel.java",
"ReferenceCountedChannelPool.java",
"RemoteRetrier.java",
"RemoteRetrierUtils.java",
"Retrier.java",
Expand All @@ -40,6 +41,7 @@ java_library(
":ExecutionStatusException",
":ReferenceCountedChannel",
":Retrier",
"//src/main/java/com/google/devtools/build/lib:build-request-options",
"//src/main/java/com/google/devtools/build/lib:runtime",
"//src/main/java/com/google/devtools/build/lib/actions",
"//src/main/java/com/google/devtools/build/lib/actions:artifacts",
Expand Down Expand Up @@ -121,8 +123,12 @@ java_library(

java_library(
name = "ReferenceCountedChannel",
srcs = ["ReferenceCountedChannel.java"],
srcs = [
"ReferenceCountedChannel.java",
"ReferenceCountedChannelPool.java",
],
deps = [
"//third_party:guava",
"//third_party:netty",
"//third_party/grpc:grpc-jar",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,28 @@
public class ReferenceCountedChannel extends ManagedChannel implements ReferenceCounted {

private final ManagedChannel channel;
private final AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() {
@Override
protected void deallocate() {
channel.shutdown();
}

@Override
public ReferenceCounted touch(Object o) {
return this;
}
};
private final AbstractReferenceCounted referenceCounted;

public ReferenceCountedChannel(ManagedChannel channel) {
this(
channel,
new AbstractReferenceCounted() {
@Override
protected void deallocate() {
channel.shutdown();
}

@Override
public ReferenceCounted touch(Object o) {
return this;
}
});
}

protected ReferenceCountedChannel(
ManagedChannel channel, AbstractReferenceCounted referenceCounted) {
this.channel = channel;
this.referenceCounted = referenceCounted;
}

@Override
Expand All @@ -70,8 +78,8 @@ public ManagedChannel shutdownNow() {
}

@Override
public boolean awaitTermination(long l, TimeUnit timeUnit) throws InterruptedException {
return channel.awaitTermination(l, timeUnit);
public boolean awaitTermination(long timeout, TimeUnit timeUnit) throws InterruptedException {
return channel.awaitTermination(timeout, timeUnit);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote;

import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A wrapper around a {@link io.grpc.ManagedChannel} exposing a reference count and performing a
* round-robin load balance across a list of channels. When instantiated the reference count is 1.
* {@link ManagedChannel#shutdown()} will be called on the wrapped channel when the reference count
* reaches 0.
*
* <p>See {@link ReferenceCounted} for more information about reference counting.
*/
public class ReferenceCountedChannelPool extends ReferenceCountedChannel {

private final AtomicInteger indexTicker = new AtomicInteger();
private final ImmutableList<ManagedChannel> channels;

public ReferenceCountedChannelPool(ImmutableList<ManagedChannel> channels) {
super(
channels.get(0),
new AbstractReferenceCounted() {
@Override
protected void deallocate() {
for (ManagedChannel channel : channels) {
channel.shutdown();
}
}

@Override
public ReferenceCounted touch(Object o) {
return null;
}
});
this.channels = channels;
}

@Override
public boolean isShutdown() {
for (ManagedChannel channel : channels) {
if (!channel.isShutdown()) {
return false;
}
}
return true;
}

@Override
public boolean isTerminated() {
for (ManagedChannel channel : channels) {
if (!channel.isTerminated()) {
return false;
}
}
return true;
}

@Override
public boolean awaitTermination(long timeout, TimeUnit timeUnit) throws InterruptedException {
long endTimeNanos = System.nanoTime() + timeUnit.toNanos(timeout);
for (ManagedChannel channel : channels) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
if (awaitTimeNanos <= 0) {
break;
}
channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}

@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return getNextChannel().newCall(methodDescriptor, callOptions);
}

@Override
public String authority() {
// Assume all channels have the same authority.
return channels.get(0).authority();
}

/**
* Performs a simple round robin on the list of {@link ManagedChannel}s in the {@code channels}
* list.
*
* @see <a href="https://github.com/grpc/grpc/issues/21386#issuecomment-564742173">Suggestion from
* gRPC team.</a>
* @return A {@link ManagedChannel} that can be used for a single RPC call.
*/
private ManagedChannel getNextChannel() {
return getChannel(indexTicker.getAndIncrement());
}

private ManagedChannel getChannel(int affinity) {
int index = affinity % channels.size();
index = Math.abs(index);
// If index is the most negative int, abs(index) is still negative.

This comment has been minimized.

Copy link
@ulfjack

ulfjack Aug 15, 2020

Contributor

Nitpick: while it's true that Math.abs can return a negative result, index cannot be the most negative int here because you first perform a modulo operation two lines above. (Or can it? Maybe consider extracting this to a static method and adding some unit tests?)

if (index < 0) {
index = 0;
}
return channels.get(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.PathFragment;
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.netty.channel.unix.DomainSocketAddress;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -66,6 +68,20 @@ public static ReferenceCountedChannel createGrpcChannel(
GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors));
}

public static ReferenceCountedChannel createGrpcChannelPool(
int poolSize,
String target,
String proxyUri,
AuthAndTLSOptions authOptions,
@Nullable List<ClientInterceptor> interceptors)
throws IOException {
List<ManagedChannel> channels = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
channels.add(GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors));
}
return new ReferenceCountedChannelPool(ImmutableList.copyOf(channels));
}

public static RemoteCacheClient create(
RemoteOptions options,
@Nullable Credentials creds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploader;
import com.google.devtools.build.lib.buildeventstream.LocalFilesArtifactUploader;
import com.google.devtools.build.lib.buildtool.BuildRequest;
import com.google.devtools.build.lib.buildtool.BuildRequestOptions;
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.Reporter;
Expand Down Expand Up @@ -255,6 +256,21 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
ReferenceCountedChannel execChannel = null;
ReferenceCountedChannel cacheChannel = null;
ReferenceCountedChannel downloaderChannel = null;

int poolSize = 1;
BuildRequestOptions buildRequestOptions =
env.getOptions().getOptions(BuildRequestOptions.class);
if (buildRequestOptions != null) {
// The following calculation is based on the suggestion from comment
// https://github.com/bazelbuild/bazel/issues/11801#issuecomment-672973245
//
// The number of concurrent requests for one connection to a gRPC server is limited by
// MAX_CONCURRENT_STREAMS which is normally being 100+. We assume 50 concurrent requests for
// each connection should be fairly well. The number of connections opened by one channel is
// based on the resolved IPs of that server. We assume servers normally have 2 IPs. So the
// number of required channels is calculated as: ceil(jobs / 100).
poolSize = (int) Math.ceil((double) buildRequestOptions.jobs / 100.0);
}
if (enableRemoteExecution) {
ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
interceptors.add(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions));
Expand All @@ -264,7 +280,8 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
interceptors.add(new NetworkTime.Interceptor());
try {
execChannel =
RemoteCacheClientFactory.createGrpcChannel(
RemoteCacheClientFactory.createGrpcChannelPool(
poolSize,
remoteOptions.remoteExecutor,
remoteOptions.remoteProxy,
authAndTlsOptions,
Expand All @@ -290,7 +307,8 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
interceptors.add(new NetworkTime.Interceptor());
try {
cacheChannel =
RemoteCacheClientFactory.createGrpcChannel(
RemoteCacheClientFactory.createGrpcChannelPool(
poolSize,
remoteOptions.remoteCache,
remoteOptions.remoteProxy,
authAndTlsOptions,
Expand All @@ -313,7 +331,8 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
}
try {
downloaderChannel =
RemoteCacheClientFactory.createGrpcChannel(
RemoteCacheClientFactory.createGrpcChannelPool(
poolSize,
remoteOptions.remoteDownloader,
remoteOptions.remoteProxy,
authAndTlsOptions,
Expand Down

0 comments on commit 8d656cf

Please sign in to comment.