Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a round robin grpc balancer #554

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/pkg/balancer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
"gcp_balancer.go",
"gcp_interceptor.go",
"gcp_picker.go",
"roundrobin.go",
],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer",
visibility = ["//visibility:public"],
Expand Down
64 changes: 64 additions & 0 deletions go/pkg/balancer/roundrobin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package balancer

import (
"context"
"errors"
"io"
"sync/atomic"

"google.golang.org/grpc"
)

// RRConnPool is a pool of *grpc.ClientConn that are selected in a round-robin fashion.
type RRConnPool struct {
grpc.ClientConnInterface
io.Closer

conns []*grpc.ClientConn
idx uint32 // access via sync/atomic
}

// Conn picks the next connection from the pool in a round-robin fasion.
func (p *RRConnPool) Conn() *grpc.ClientConn {
i := atomic.AddUint32(&p.idx, 1)
return p.conns[i%uint32(len(p.conns))]
}

// Close closes all connections in the bool.
func (p *RRConnPool) Close() error {
var errs error
for _, conn := range p.conns {
if err := conn.Close(); err != nil {
errs = errors.Join(errs, err)
}
}
return errs

Check failure on line 35 in go/pkg/balancer/roundrobin.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func errors.Join(errs ...error) error (wrapcheck)
}

// Invoke picks up a connection from the pool and delegates the call to it.
func (p *RRConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
return p.Conn().Invoke(ctx, method, args, reply, opts...)

Check failure on line 40 in go/pkg/balancer/roundrobin.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...google.golang.org/grpc.CallOption) error (wrapcheck)
}

// NewStream picks up a connection from the pool and delegates the call to it.
func (p *RRConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return p.Conn().NewStream(ctx, desc, method, opts...)

Check failure on line 45 in go/pkg/balancer/roundrobin.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).NewStream(ctx context.Context, desc *google.golang.org/grpc.StreamDesc, method string, opts ...google.golang.org/grpc.CallOption) (google.golang.org/grpc.ClientStream, error) (wrapcheck)
}

// DialFunc defines the dial function used in creating the pool.
type DialFunc func(ctx context.Context) (*grpc.ClientConn, error)

// NewRRConnPool makes a new instance of the round-robin connection pool and dials as many as poolSize connections
// using the provided dialFn.
func NewRRConnPool(ctx context.Context, poolSize int, dialFn DialFunc) (*RRConnPool, error) {
pool := &RRConnPool{}
for i := 0; i < poolSize; i++ {
conn, err := dialFn(ctx)
if err != nil {
defer pool.Close()
return nil, err
}
pool.conns = append(pool.conns, conn)
}
return pool, nil
}
6 changes: 3 additions & 3 deletions go/pkg/client/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ func (c *Client) GetCapabilities(ctx context.Context) (res *repb.ServerCapabilit
// be determined from that; ExecutionCapabilities will always come from the main URL.
func (c *Client) GetCapabilitiesForInstance(ctx context.Context, instance string) (res *repb.ServerCapabilities, err error) {
req := &repb.GetCapabilitiesRequest{InstanceName: instance}
caps, err := c.GetBackendCapabilities(ctx, c.Connection, req)
caps, err := c.GetBackendCapabilities(ctx, c.Connection(), req)
if err != nil {
return nil, err
}
if c.CASConnection != c.Connection {
casCaps, err := c.GetBackendCapabilities(ctx, c.CASConnection, req)
if c.CASConnection() != c.Connection() {
casCaps, err := c.GetBackendCapabilities(ctx, c.CASConnection(), req)
if err != nil {
return nil, err
}
Expand Down
108 changes: 67 additions & 41 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"os"
"os/user"
Expand Down Expand Up @@ -107,6 +108,13 @@
return fmt.Sprintf("%v, authentication type (identity) used=%q", ce.Err.Error(), ce.AuthUsed)
}

// Temporary interface definition until the gcp balancer is removed in favour of
// the round-robin balancer.
type grpcClientConn interface {
grpc.ClientConnInterface
io.Closer
}

// Client is a client to several services, including remote execution and services used in
// conjunction with remote execution. A Client must be constructed by calling Dial() or NewClient()
// rather than attempting to assemble it directly.
Expand All @@ -128,8 +136,8 @@
//
// These fields are logically "protected" and are intended for use by extensions of Client.
Retrier *Retrier
Connection *grpc.ClientConn
CASConnection *grpc.ClientConn // Can be different from Connection a separate CAS endpoint is provided.
connection grpcClientConn
casConnection grpcClientConn
// StartupCapabilities denotes whether to load ServerCapabilities on startup.
StartupCapabilities StartupCapabilities
// LegacyExecRootRelativeOutputs denotes whether outputs are relative to the exec root.
Expand Down Expand Up @@ -208,17 +216,31 @@
DefaultRegularMode = 0644
)

func (c *Client) Connection() *grpc.ClientConn {
if conn, ok := c.connection.(*grpc.ClientConn); ok {
return conn
}
return c.connection.(*balancer.RRConnPool).Conn()
}

func (c *Client) CASConnection() *grpc.ClientConn {
if conn, ok := c.casConnection.(*grpc.ClientConn); ok {
return conn
}
return c.casConnection.(*balancer.RRConnPool).Conn()
}

// Close closes the underlying gRPC connection(s).
func (c *Client) Close() error {
// Close the channels & stop background operations.
UnifiedUploads(false).Apply(c)
UnifiedDownloads(false).Apply(c)
err := c.Connection.Close()
err := c.connection.Close()
if err != nil {
return err
}
if c.CASConnection != c.Connection {
return c.CASConnection.Close()
if c.casConnection != c.connection {
return c.casConnection.Close()

Check failure on line 243 in go/pkg/client/client.go

View workflow job for this annotation

GitHub Actions / lint

error returned from interface method should be wrapped: sig: func (io.Closer).Close() error (wrapcheck)
}
return nil
}
Expand Down Expand Up @@ -537,6 +559,12 @@
//
// If this is specified, TLSClientAuthCert must also be specified.
TLSClientAuthKey string

// RoundRobinBalancer enables the simplified gRPC balancer instead of the default one.
RoundRobinBalancer bool

// RoundRobinPoolSize specifies the pool size for the round-robin load balancer.
RoundRobinPoolSize int
}

func createGRPCInterceptor(p DialParams) *balancer.GCPInterceptor {
Expand Down Expand Up @@ -592,8 +620,8 @@
return c, nil
}

// Dial dials a given endpoint and returns the grpc connection that is established.
func Dial(ctx context.Context, endpoint string, params DialParams) (*grpc.ClientConn, AuthType, error) {
// OptsFromParams prepares a set of grpc dial options based on the provided dial params.
func OptsFromParams(ctx context.Context, params DialParams) ([]grpc.DialOption, AuthType, error) {
var authUsed AuthType

var opts []grpc.DialOption
Expand Down Expand Up @@ -661,27 +689,14 @@
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
}

grpcInt := createGRPCInterceptor(params)
opts = append(opts, grpc.WithDisableServiceConfig())
opts = append(opts, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, balancer.Name)))
opts = append(opts, grpc.WithUnaryInterceptor(grpcInt.GCPUnaryClientInterceptor))
opts = append(opts, grpc.WithStreamInterceptor(grpcInt.GCPStreamClientInterceptor))

conn, err := grpc.Dial(endpoint, opts...)
if err != nil {
return nil, authUsed, fmt.Errorf("couldn't dial gRPC %q: %v", endpoint, err)
}
return conn, authUsed, nil
}

// DialRaw dials a remote execution service and returns the grpc connection that is established.
// TODO(olaola): remove this overload when all clients use Dial.
func DialRaw(ctx context.Context, params DialParams) (*grpc.ClientConn, AuthType, error) {
if params.Service == "" {
return nil, UnknownAuth, fmt.Errorf("service needs to be specified")
}
log.Infof("Connecting to remote execution service %s", params.Service)
return Dial(ctx, params.Service, params)
return opts, authUsed, nil
}

// NewClient connects to a remote execution service and returns a client suitable for higher-level
Expand All @@ -695,30 +710,40 @@
}
log.Infof("Connecting to remote execution instance %s", instanceName)
log.Infof("Connecting to remote execution service %s", params.Service)
conn, authUsed, err := Dial(ctx, params.Service, params)
casConn := conn
if params.CASService != "" && params.CASService != params.Service {
log.Infof("Connecting to CAS service %s", params.CASService)
casConn, authUsed, err = Dial(ctx, params.CASService, params)
}
dialOpts, authUsed, err := OptsFromParams(ctx, params)
if err != nil {
return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed}
return nil, fmt.Errorf("failed to prepare gRPC dial options: %v", err)
}

var conn, casConn grpcClientConn
if params.RoundRobinBalancer {
dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, params.Service, dialOpts...)
}
conn, err = balancer.NewRRConnPool(ctx, params.RoundRobinPoolSize, dial)
} else {
conn, err = grpc.Dial(params.Service, dialOpts...)
}
client, err := NewClientFromConnection(ctx, instanceName, conn, casConn, opts...)
if err != nil {
return nil, &InitError{Err: err, AuthUsed: authUsed}
return nil, fmt.Errorf("couldn't dial gRPC %q: %v", params.Service, err)
}
return client, nil
}

// NewClientFromConnection creates a client from gRPC connections to a remote execution service and a cas service.
func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn *grpc.ClientConn, opts ...Opt) (*Client, error) {
if conn == nil {
return nil, fmt.Errorf("connection to remote execution service may not be nil")
casConn = conn
if params.CASService != "" && params.CASService != params.Service {
log.Infof("Connecting to CAS service %s", params.CASService)
if params.RoundRobinBalancer {
dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, params.CASService, dialOpts...)
}
casConn, err = balancer.NewRRConnPool(ctx, params.RoundRobinPoolSize, dial)
} else {
casConn, err = grpc.Dial(params.CASService, dialOpts...)
}
}
if casConn == nil {
return nil, fmt.Errorf("connection to CAS service may not be nil")
if err != nil {
return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed}
}

client := &Client{
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
Expand All @@ -727,8 +752,8 @@
execution: regrpc.NewExecutionClient(conn),
operations: opgrpc.NewOperationsClient(conn),
rpcTimeouts: DefaultRPCTimeouts,
Connection: conn,
CASConnection: casConn,
connection: conn,
casConnection: casConn,
CompressedBytestreamThreshold: DefaultCompressedBytestreamThreshold,
ChunkMaxSize: chunker.DefaultChunkSize,
MaxBatchDigests: DefaultMaxBatchDigests,
Expand Down Expand Up @@ -762,6 +787,7 @@
return nil, fmt.Errorf("CASConcurrency should be at least 1")
}
client.RunBackgroundTasks(ctx)

return client, nil
}

Expand Down
29 changes: 7 additions & 22 deletions go/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ package client
import (
"context"
"errors"
"net"
"os"
"path"
"testing"

repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
svpb "github.com/bazelbuild/remote-apis/build/bazel/semver"
"google.golang.org/grpc"
)

const (
Expand Down Expand Up @@ -204,33 +202,20 @@ func TestNewClient(t *testing.T) {
defer c.Close()
}

func TestNewClientFromConnection(t *testing.T) {
func TestNewClientRR(t *testing.T) {
t.Parallel()
ctx := context.Background()
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("Cannot establish gRPC connection: %v", err)
}

c, err := NewClientFromConnection(ctx, instance, conn, conn, StartupCapabilities(false))
c, err := NewClient(ctx, instance, DialParams{
Service: "server",
NoSecurity: true,
RoundRobinBalancer: true,
RoundRobinPoolSize: 25,
}, StartupCapabilities(false))
if err != nil {
t.Fatalf("Error creating client: %v", err)
}
defer c.Close()

_, err = NewClientFromConnection(ctx, instance, nil, conn, StartupCapabilities(false))
if err == nil {
t.Fatalf("Expected error got nil")
}

_, err = NewClientFromConnection(ctx, instance, conn, nil, StartupCapabilities(false))
if err == nil {
t.Fatalf("Expected error got nil")
}
}

func TestResourceName(t *testing.T) {
Expand Down
7 changes: 5 additions & 2 deletions go/pkg/fakes/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@
// NewClientConn returns a gRPC client connction to the server.
func (s *Server) NewClientConn(ctx context.Context) (*grpc.ClientConn, error) {
p := s.dialParams()
conn, _, err := client.Dial(ctx, p.Service, p)
return conn, err
opts, _, err := client.OptsFromParams(ctx, p)
if err != nil {
return nil, err

Check failure on line 90 in go/pkg/fakes/server.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func github.com/bazelbuild/remote-apis-sdks/go/pkg/client.OptsFromParams(ctx context.Context, params github.com/bazelbuild/remote-apis-sdks/go/pkg/client.DialParams) ([]google.golang.org/grpc.DialOption, github.com/bazelbuild/remote-apis-sdks/go/pkg/client.AuthType, error) (wrapcheck)
}
return grpc.Dial(p.Service, opts...)

Check failure on line 92 in go/pkg/fakes/server.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func google.golang.org/grpc.Dial(target string, opts ...google.golang.org/grpc.DialOption) (*google.golang.org/grpc.ClientConn, error) (wrapcheck)
}

func (s *Server) dialParams() rc.DialParams {
Expand Down
7 changes: 7 additions & 0 deletions go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ var (
KeepAliveTimeout = flag.Duration("grpc_keepalive_timeout", 20*time.Second, "After having pinged for keepalive check, the client waits for a duration of Timeout and if no activity is seen even after that the connection is closed. Default is 20s.")
// KeepAlivePermitWithoutStream specifies gRPCs keepalive permitWithoutStream parameter.
KeepAlivePermitWithoutStream = flag.Bool("grpc_keepalive_permit_without_stream", false, "If true, client sends keepalive pings even with no active RPCs; otherwise, doesn't send pings even if time and timeout are set. Default is false.")
// UseRoundRobinBalancer is a temporary feature flag to rollout a simplified load balancer.
// See http://go/remote-apis-sdks/issues/499
UseRoundRobinBalancer = flag.Bool("use_round_robin_balancer", false, "If true, a round-robin connection bool is used for gRPC. Otherwise, the existing load balancer is used.")
// RoundRobinBalancerPoolSize specifies the pool size for the round robin balancer.
RoundRobinBalancerPoolSize = flag.Int("round_robin_balancer_pool_size", client.DefaultMaxConcurrentRequests, "pool size for round robin grpc balacner")
)

func init() {
Expand Down Expand Up @@ -145,5 +150,7 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
TLSClientAuthKey: *TLSClientAuthKey,
MaxConcurrentRequests: uint32(*MaxConcurrentRequests),
MaxConcurrentStreams: uint32(*MaxConcurrentStreams),
RoundRobinBalancer: *UseRoundRobinBalancer,
RoundRobinPoolSize: *RoundRobinBalancerPoolSize,
}, opts...)
}
2 changes: 1 addition & 1 deletion go/pkg/tool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func (c *Client) UploadBlob(ctx context.Context, path string) error {

// UploadBlobV2 uploads a blob from the specified path into the remote cache using newer cas implementation.
func (c *Client) UploadBlobV2(ctx context.Context, path string) error {
casC, err := cas.NewClient(ctx, c.GrpcClient.Connection, c.GrpcClient.InstanceName)
casC, err := cas.NewClient(ctx, c.GrpcClient.Connection(), c.GrpcClient.InstanceName)
if err != nil {
return errors.WithStack(err)
}
Expand Down
Loading