From 1052ee65efd7b8d7b1ba0d8ad3b9c6d9699c225b Mon Sep 17 00:00:00 2001 From: Anas Sulaiman Date: Sat, 20 Apr 2024 04:44:29 +0000 Subject: [PATCH 1/2] add a round robin grpc balancer This replaces the current load balancer which was copied over from the gcp repo which uses experimental APIs from grpc-go. Depending on experimental APIs marks the SDK as experimental as well. It also complicates importing the SDK to google3 where everything must be compatbile at head. This major difference between this simple implementation and the existing one is the maximum number of streams allowed on a single connection. The existing balancer limits streams to 3 by default and allows configuring that limit. The new simpler implementation does not enforce any limit. I have tested the simple balancer by building chromium and android with and without it enabled. I observed no difference in build latency. --- go/pkg/balancer/BUILD.bazel | 1 + go/pkg/balancer/roundrobin.go | 56 +++++++++++++++++++++++++++++++++++ go/pkg/cas/client.go | 6 ++-- go/pkg/client/client.go | 43 ++++++++++++++++++++------- go/pkg/fakes/server.go | 2 +- go/pkg/flags/flags.go | 7 +++++ 6 files changed, 101 insertions(+), 14 deletions(-) create mode 100644 go/pkg/balancer/roundrobin.go diff --git a/go/pkg/balancer/BUILD.bazel b/go/pkg/balancer/BUILD.bazel index 93390ae8..ae7a4a01 100644 --- a/go/pkg/balancer/BUILD.bazel +++ b/go/pkg/balancer/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "gcp_balancer.go", "gcp_interceptor.go", "gcp_picker.go", + "roundrobin.go", ], importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer", visibility = ["//visibility:public"], diff --git a/go/pkg/balancer/roundrobin.go b/go/pkg/balancer/roundrobin.go new file mode 100644 index 00000000..b4dd8627 --- /dev/null +++ b/go/pkg/balancer/roundrobin.go @@ -0,0 +1,56 @@ +package balancer + +import ( + "context" + "errors" + "io" + "sync/atomic" + + "google.golang.org/grpc" +) + +type roundRobinConnPool struct { + grpc.ClientConnInterface + io.Closer + + conns []*grpc.ClientConn + idx uint32 // access via sync/atomic +} + +func (p *roundRobinConnPool) Conn() *grpc.ClientConn { + i := atomic.AddUint32(&p.idx, 1) + return p.conns[i%uint32(len(p.conns))] +} + +func (p *roundRobinConnPool) Close() error { + var errs error + for _, conn := range p.conns { + if err := conn.Close(); err != nil { + errs = errors.Join(errs, err) + } + } + return errs +} + +func (p *roundRobinConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + return p.Conn().Invoke(ctx, method, args, reply, opts...) +} + +func (p *roundRobinConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return p.Conn().NewStream(ctx, desc, method, opts...) +} + +type DialFunc func(ctx context.Context) (*grpc.ClientConn, error) + +func NewRoundRobinBalancer(ctx context.Context, poolSize int, dialFn DialFunc) (grpc.ClientConnInterface, error) { + pool := &roundRobinConnPool{} + for i := 0; i < poolSize; i++ { + conn, err := dialFn(ctx) + if err != nil { + defer pool.Close() + return nil, err + } + pool.conns = append(pool.conns, conn) + } + return pool, nil +} diff --git a/go/pkg/cas/client.go b/go/pkg/cas/client.go index 2bfc4ea5..487297e7 100644 --- a/go/pkg/cas/client.go +++ b/go/pkg/cas/client.go @@ -28,7 +28,7 @@ import ( // // All fields are considered immutable, and should not be changed. type Client struct { - conn *grpc.ClientConn + conn grpc.ClientConnInterface // InstanceName is the full name of the RBE instance. InstanceName string @@ -234,12 +234,12 @@ func (c *RPCConfig) validate() error { // NewClient creates a new client with the default configuration. // Use client.Dial to create a connection. -func NewClient(ctx context.Context, conn *grpc.ClientConn, instanceName string) (*Client, error) { +func NewClient(ctx context.Context, conn grpc.ClientConnInterface, instanceName string) (*Client, error) { return NewClientWithConfig(ctx, conn, instanceName, DefaultClientConfig()) } // NewClientWithConfig creates a new client and accepts a configuration. -func NewClientWithConfig(ctx context.Context, conn *grpc.ClientConn, instanceName string, config ClientConfig) (*Client, error) { +func NewClientWithConfig(ctx context.Context, conn grpc.ClientConnInterface, instanceName string, config ClientConfig) (*Client, error) { switch err := config.Validate(); { case err != nil: return nil, errors.Wrap(err, "invalid config") diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index cfe518a0..3d6401b4 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "net/http" "os" "os/user" @@ -128,8 +129,8 @@ type Client struct { // // These fields are logically "protected" and are intended for use by extensions of Client. Retrier *Retrier - Connection *grpc.ClientConn - CASConnection *grpc.ClientConn // Can be different from Connection a separate CAS endpoint is provided. + Connection grpc.ClientConnInterface + CASConnection grpc.ClientConnInterface // Can be different from Connection a separate CAS endpoint is provided. // StartupCapabilities denotes whether to load ServerCapabilities on startup. StartupCapabilities StartupCapabilities // LegacyExecRootRelativeOutputs denotes whether outputs are relative to the exec root. @@ -213,12 +214,16 @@ func (c *Client) Close() error { // Close the channels & stop background operations. UnifiedUploads(false).Apply(c) UnifiedDownloads(false).Apply(c) - err := c.Connection.Close() - if err != nil { - return err + if closer, ok := c.Connection.(io.Closer); ok { + err := closer.Close() + if err != nil { + return err + } } if c.CASConnection != c.Connection { - return c.CASConnection.Close() + if closer, ok := c.CASConnection.(io.Closer); ok { + return closer.Close() + } } return nil } @@ -537,6 +542,12 @@ type DialParams struct { // // If this is specified, TLSClientAuthCert must also be specified. TLSClientAuthKey string + + // RoundRobinBalancer enables the simplified gRPC balancer instead of the default one. + RoundRobinBalancer bool + + // RoundRobinPoolSize specifies the pool size for the round robin load balancer. + RoundRobinPoolSize int } func createGRPCInterceptor(p DialParams) *balancer.GCPInterceptor { @@ -593,7 +604,7 @@ func createTLSConfig(params DialParams) (*tls.Config, error) { } // Dial dials a given endpoint and returns the grpc connection that is established. -func Dial(ctx context.Context, endpoint string, params DialParams) (*grpc.ClientConn, AuthType, error) { +func Dial(ctx context.Context, endpoint string, params DialParams) (grpc.ClientConnInterface, AuthType, error) { var authUsed AuthType var opts []grpc.DialOption @@ -661,6 +672,18 @@ func Dial(ctx context.Context, endpoint string, params DialParams) (*grpc.Client } opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) } + + if params.RoundRobinBalancer { + dialFn := func(ctx context.Context) (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, endpoint, opts...) + } + conn, err := balancer.NewRoundRobinBalancer(ctx, params.RoundRobinPoolSize, dialFn) + if err != nil { + return nil, authUsed, fmt.Errorf("couldn't create round robin load balancer: %w", err) + } + return conn, authUsed, nil + } + grpcInt := createGRPCInterceptor(params) opts = append(opts, grpc.WithDisableServiceConfig()) opts = append(opts, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, balancer.Name))) @@ -676,7 +699,7 @@ func Dial(ctx context.Context, endpoint string, params DialParams) (*grpc.Client // DialRaw dials a remote execution service and returns the grpc connection that is established. // TODO(olaola): remove this overload when all clients use Dial. -func DialRaw(ctx context.Context, params DialParams) (*grpc.ClientConn, AuthType, error) { +func DialRaw(ctx context.Context, params DialParams) (grpc.ClientConnInterface, AuthType, error) { if params.Service == "" { return nil, UnknownAuth, fmt.Errorf("service needs to be specified") } @@ -712,7 +735,7 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts } // NewClientFromConnection creates a client from gRPC connections to a remote execution service and a cas service. -func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn *grpc.ClientConn, opts ...Opt) (*Client, error) { +func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn grpc.ClientConnInterface, opts ...Opt) (*Client, error) { if conn == nil { return nil, fmt.Errorf("connection to remote execution service may not be nil") } @@ -1022,7 +1045,7 @@ func (c *Client) WaitExecution(ctx context.Context, req *repb.WaitExecutionReque // GetBackendCapabilities returns the capabilities for a specific server connection // (either the main connection or the CAS connection). -func (c *Client) GetBackendCapabilities(ctx context.Context, conn *grpc.ClientConn, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) { +func (c *Client) GetBackendCapabilities(ctx context.Context, conn grpc.ClientConnInterface, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) { opts := c.RPCOpts() err = c.Retrier.Do(ctx, func() (e error) { return c.CallWithTimeout(ctx, "GetCapabilities", func(ctx context.Context) (e error) { diff --git a/go/pkg/fakes/server.go b/go/pkg/fakes/server.go index 080f05f9..e8a9a8d6 100644 --- a/go/pkg/fakes/server.go +++ b/go/pkg/fakes/server.go @@ -83,7 +83,7 @@ func (s *Server) NewTestClient(ctx context.Context) (*rc.Client, error) { } // NewClientConn returns a gRPC client connction to the server. -func (s *Server) NewClientConn(ctx context.Context) (*grpc.ClientConn, error) { +func (s *Server) NewClientConn(ctx context.Context) (grpc.ClientConnInterface, error) { p := s.dialParams() conn, _, err := client.Dial(ctx, p.Service, p) return conn, err diff --git a/go/pkg/flags/flags.go b/go/pkg/flags/flags.go index d7acc01b..f5038664 100644 --- a/go/pkg/flags/flags.go +++ b/go/pkg/flags/flags.go @@ -74,6 +74,11 @@ var ( KeepAliveTimeout = flag.Duration("grpc_keepalive_timeout", 20*time.Second, "After having pinged for keepalive check, the client waits for a duration of Timeout and if no activity is seen even after that the connection is closed. Default is 20s.") // KeepAlivePermitWithoutStream specifies gRPCs keepalive permitWithoutStream parameter. KeepAlivePermitWithoutStream = flag.Bool("grpc_keepalive_permit_without_stream", false, "If true, client sends keepalive pings even with no active RPCs; otherwise, doesn't send pings even if time and timeout are set. Default is false.") + // UseRoundRobinBalancer is a temporary feature flag to rollout a simplified load balancer. + // See http://go/remote-apis-sdks/issues/499 + UseRoundRobinBalancer = flag.Bool("use_simple_balancer", false, "If true, a simple round-robin connection bool is used for gRPC. Otherwise, the existing load balancer is used.") + // RoundRobinBalancerPoolSize specifies the pool size for the round robin balancer. + RoundRobinBalancerPoolSize = flag.Int("round_robin_balancer_pool_size", client.DefaultMaxConcurrentRequests, "pool size for round robin grpc balacner") ) func init() { @@ -145,5 +150,7 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client TLSClientAuthKey: *TLSClientAuthKey, MaxConcurrentRequests: uint32(*MaxConcurrentRequests), MaxConcurrentStreams: uint32(*MaxConcurrentStreams), + RoundRobinBalancer: *UseRoundRobinBalancer, + RoundRobinPoolSize: *RoundRobinBalancerPoolSize, }, opts...) } From ca0195544a93e86d7b1d9995ad13f2ccdaa63865 Mon Sep 17 00:00:00 2001 From: Anas Sulaiman Date: Mon, 29 Apr 2024 21:25:48 +0000 Subject: [PATCH 2/2] add a round robin grpc balancer This replaces the current load balancer which was copied over from the gcp repo which uses experimental APIs from grpc-go. Depending on experimental APIs marks the SDK as experimental as well. It also complicates importing the SDK to google3 where everything must be compatbile at head. This major difference between this simple implementation and the existing one is the maximum number of streams allowed on a single connection. The existing balancer limits streams to 3 by default and allows configuring that limit. The new simpler implementation does not enforce any limit. I have tested the simple balancer by building chromium and android with and without it enabled. I observed no difference in build latency. --- go/pkg/balancer/roundrobin.go | 22 ++++-- go/pkg/cas/client.go | 6 +- go/pkg/client/capabilities.go | 6 +- go/pkg/client/client.go | 123 +++++++++++++++++----------------- go/pkg/client/client_test.go | 29 ++------ go/pkg/fakes/server.go | 9 ++- go/pkg/flags/flags.go | 2 +- go/pkg/tool/tool.go | 2 +- 8 files changed, 99 insertions(+), 100 deletions(-) diff --git a/go/pkg/balancer/roundrobin.go b/go/pkg/balancer/roundrobin.go index b4dd8627..9c8d0d75 100644 --- a/go/pkg/balancer/roundrobin.go +++ b/go/pkg/balancer/roundrobin.go @@ -9,7 +9,8 @@ import ( "google.golang.org/grpc" ) -type roundRobinConnPool struct { +// RRConnPool is a pool of *grpc.ClientConn that are selected in a round-robin fashion. +type RRConnPool struct { grpc.ClientConnInterface io.Closer @@ -17,12 +18,14 @@ type roundRobinConnPool struct { idx uint32 // access via sync/atomic } -func (p *roundRobinConnPool) Conn() *grpc.ClientConn { +// Conn picks the next connection from the pool in a round-robin fasion. +func (p *RRConnPool) Conn() *grpc.ClientConn { i := atomic.AddUint32(&p.idx, 1) return p.conns[i%uint32(len(p.conns))] } -func (p *roundRobinConnPool) Close() error { +// Close closes all connections in the bool. +func (p *RRConnPool) Close() error { var errs error for _, conn := range p.conns { if err := conn.Close(); err != nil { @@ -32,18 +35,23 @@ func (p *roundRobinConnPool) Close() error { return errs } -func (p *roundRobinConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { +// Invoke picks up a connection from the pool and delegates the call to it. +func (p *RRConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { return p.Conn().Invoke(ctx, method, args, reply, opts...) } -func (p *roundRobinConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { +// NewStream picks up a connection from the pool and delegates the call to it. +func (p *RRConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { return p.Conn().NewStream(ctx, desc, method, opts...) } +// DialFunc defines the dial function used in creating the pool. type DialFunc func(ctx context.Context) (*grpc.ClientConn, error) -func NewRoundRobinBalancer(ctx context.Context, poolSize int, dialFn DialFunc) (grpc.ClientConnInterface, error) { - pool := &roundRobinConnPool{} +// NewRRConnPool makes a new instance of the round-robin connection pool and dials as many as poolSize connections +// using the provided dialFn. +func NewRRConnPool(ctx context.Context, poolSize int, dialFn DialFunc) (*RRConnPool, error) { + pool := &RRConnPool{} for i := 0; i < poolSize; i++ { conn, err := dialFn(ctx) if err != nil { diff --git a/go/pkg/cas/client.go b/go/pkg/cas/client.go index 487297e7..2bfc4ea5 100644 --- a/go/pkg/cas/client.go +++ b/go/pkg/cas/client.go @@ -28,7 +28,7 @@ import ( // // All fields are considered immutable, and should not be changed. type Client struct { - conn grpc.ClientConnInterface + conn *grpc.ClientConn // InstanceName is the full name of the RBE instance. InstanceName string @@ -234,12 +234,12 @@ func (c *RPCConfig) validate() error { // NewClient creates a new client with the default configuration. // Use client.Dial to create a connection. -func NewClient(ctx context.Context, conn grpc.ClientConnInterface, instanceName string) (*Client, error) { +func NewClient(ctx context.Context, conn *grpc.ClientConn, instanceName string) (*Client, error) { return NewClientWithConfig(ctx, conn, instanceName, DefaultClientConfig()) } // NewClientWithConfig creates a new client and accepts a configuration. -func NewClientWithConfig(ctx context.Context, conn grpc.ClientConnInterface, instanceName string, config ClientConfig) (*Client, error) { +func NewClientWithConfig(ctx context.Context, conn *grpc.ClientConn, instanceName string, config ClientConfig) (*Client, error) { switch err := config.Validate(); { case err != nil: return nil, errors.Wrap(err, "invalid config") diff --git a/go/pkg/client/capabilities.go b/go/pkg/client/capabilities.go index 68dcc558..48fa2005 100644 --- a/go/pkg/client/capabilities.go +++ b/go/pkg/client/capabilities.go @@ -65,12 +65,12 @@ func (c *Client) GetCapabilities(ctx context.Context) (res *repb.ServerCapabilit // be determined from that; ExecutionCapabilities will always come from the main URL. func (c *Client) GetCapabilitiesForInstance(ctx context.Context, instance string) (res *repb.ServerCapabilities, err error) { req := &repb.GetCapabilitiesRequest{InstanceName: instance} - caps, err := c.GetBackendCapabilities(ctx, c.Connection, req) + caps, err := c.GetBackendCapabilities(ctx, c.Connection(), req) if err != nil { return nil, err } - if c.CASConnection != c.Connection { - casCaps, err := c.GetBackendCapabilities(ctx, c.CASConnection, req) + if c.CASConnection() != c.Connection() { + casCaps, err := c.GetBackendCapabilities(ctx, c.CASConnection(), req) if err != nil { return nil, err } diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index 3d6401b4..00c76c7b 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -108,6 +108,13 @@ func (ce *InitError) Error() string { return fmt.Sprintf("%v, authentication type (identity) used=%q", ce.Err.Error(), ce.AuthUsed) } +// Temporary interface definition until the gcp balancer is removed in favour of +// the round-robin balancer. +type grpcClientConn interface { + grpc.ClientConnInterface + io.Closer +} + // Client is a client to several services, including remote execution and services used in // conjunction with remote execution. A Client must be constructed by calling Dial() or NewClient() // rather than attempting to assemble it directly. @@ -129,8 +136,8 @@ type Client struct { // // These fields are logically "protected" and are intended for use by extensions of Client. Retrier *Retrier - Connection grpc.ClientConnInterface - CASConnection grpc.ClientConnInterface // Can be different from Connection a separate CAS endpoint is provided. + connection grpcClientConn + casConnection grpcClientConn // StartupCapabilities denotes whether to load ServerCapabilities on startup. StartupCapabilities StartupCapabilities // LegacyExecRootRelativeOutputs denotes whether outputs are relative to the exec root. @@ -209,21 +216,31 @@ const ( DefaultRegularMode = 0644 ) +func (c *Client) Connection() *grpc.ClientConn { + if conn, ok := c.connection.(*grpc.ClientConn); ok { + return conn + } + return c.connection.(*balancer.RRConnPool).Conn() +} + +func (c *Client) CASConnection() *grpc.ClientConn { + if conn, ok := c.casConnection.(*grpc.ClientConn); ok { + return conn + } + return c.casConnection.(*balancer.RRConnPool).Conn() +} + // Close closes the underlying gRPC connection(s). func (c *Client) Close() error { // Close the channels & stop background operations. UnifiedUploads(false).Apply(c) UnifiedDownloads(false).Apply(c) - if closer, ok := c.Connection.(io.Closer); ok { - err := closer.Close() - if err != nil { - return err - } + err := c.connection.Close() + if err != nil { + return err } - if c.CASConnection != c.Connection { - if closer, ok := c.CASConnection.(io.Closer); ok { - return closer.Close() - } + if c.casConnection != c.connection { + return c.casConnection.Close() } return nil } @@ -546,7 +563,7 @@ type DialParams struct { // RoundRobinBalancer enables the simplified gRPC balancer instead of the default one. RoundRobinBalancer bool - // RoundRobinPoolSize specifies the pool size for the round robin load balancer. + // RoundRobinPoolSize specifies the pool size for the round-robin load balancer. RoundRobinPoolSize int } @@ -603,8 +620,8 @@ func createTLSConfig(params DialParams) (*tls.Config, error) { return c, nil } -// Dial dials a given endpoint and returns the grpc connection that is established. -func Dial(ctx context.Context, endpoint string, params DialParams) (grpc.ClientConnInterface, AuthType, error) { +// OptsFromParams prepares a set of grpc dial options based on the provided dial params. +func OptsFromParams(ctx context.Context, params DialParams) ([]grpc.DialOption, AuthType, error) { var authUsed AuthType var opts []grpc.DialOption @@ -673,38 +690,13 @@ func Dial(ctx context.Context, endpoint string, params DialParams) (grpc.ClientC opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) } - if params.RoundRobinBalancer { - dialFn := func(ctx context.Context) (*grpc.ClientConn, error) { - return grpc.DialContext(ctx, endpoint, opts...) - } - conn, err := balancer.NewRoundRobinBalancer(ctx, params.RoundRobinPoolSize, dialFn) - if err != nil { - return nil, authUsed, fmt.Errorf("couldn't create round robin load balancer: %w", err) - } - return conn, authUsed, nil - } - grpcInt := createGRPCInterceptor(params) opts = append(opts, grpc.WithDisableServiceConfig()) opts = append(opts, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, balancer.Name))) opts = append(opts, grpc.WithUnaryInterceptor(grpcInt.GCPUnaryClientInterceptor)) opts = append(opts, grpc.WithStreamInterceptor(grpcInt.GCPStreamClientInterceptor)) - conn, err := grpc.Dial(endpoint, opts...) - if err != nil { - return nil, authUsed, fmt.Errorf("couldn't dial gRPC %q: %v", endpoint, err) - } - return conn, authUsed, nil -} - -// DialRaw dials a remote execution service and returns the grpc connection that is established. -// TODO(olaola): remove this overload when all clients use Dial. -func DialRaw(ctx context.Context, params DialParams) (grpc.ClientConnInterface, AuthType, error) { - if params.Service == "" { - return nil, UnknownAuth, fmt.Errorf("service needs to be specified") - } - log.Infof("Connecting to remote execution service %s", params.Service) - return Dial(ctx, params.Service, params) + return opts, authUsed, nil } // NewClient connects to a remote execution service and returns a client suitable for higher-level @@ -718,30 +710,40 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts } log.Infof("Connecting to remote execution instance %s", instanceName) log.Infof("Connecting to remote execution service %s", params.Service) - conn, authUsed, err := Dial(ctx, params.Service, params) - casConn := conn - if params.CASService != "" && params.CASService != params.Service { - log.Infof("Connecting to CAS service %s", params.CASService) - casConn, authUsed, err = Dial(ctx, params.CASService, params) - } + dialOpts, authUsed, err := OptsFromParams(ctx, params) if err != nil { - return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed} + return nil, fmt.Errorf("failed to prepare gRPC dial options: %v", err) + } + + var conn, casConn grpcClientConn + if params.RoundRobinBalancer { + dial := func(ctx context.Context) (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, params.Service, dialOpts...) + } + conn, err = balancer.NewRRConnPool(ctx, params.RoundRobinPoolSize, dial) + } else { + conn, err = grpc.Dial(params.Service, dialOpts...) } - client, err := NewClientFromConnection(ctx, instanceName, conn, casConn, opts...) if err != nil { - return nil, &InitError{Err: err, AuthUsed: authUsed} + return nil, fmt.Errorf("couldn't dial gRPC %q: %v", params.Service, err) } - return client, nil -} -// NewClientFromConnection creates a client from gRPC connections to a remote execution service and a cas service. -func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn grpc.ClientConnInterface, opts ...Opt) (*Client, error) { - if conn == nil { - return nil, fmt.Errorf("connection to remote execution service may not be nil") + casConn = conn + if params.CASService != "" && params.CASService != params.Service { + log.Infof("Connecting to CAS service %s", params.CASService) + if params.RoundRobinBalancer { + dial := func(ctx context.Context) (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, params.CASService, dialOpts...) + } + casConn, err = balancer.NewRRConnPool(ctx, params.RoundRobinPoolSize, dial) + } else { + casConn, err = grpc.Dial(params.CASService, dialOpts...) + } } - if casConn == nil { - return nil, fmt.Errorf("connection to CAS service may not be nil") + if err != nil { + return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed} } + client := &Client{ InstanceName: instanceName, actionCache: regrpc.NewActionCacheClient(casConn), @@ -750,8 +752,8 @@ func NewClientFromConnection(ctx context.Context, instanceName string, conn, cas execution: regrpc.NewExecutionClient(conn), operations: opgrpc.NewOperationsClient(conn), rpcTimeouts: DefaultRPCTimeouts, - Connection: conn, - CASConnection: casConn, + connection: conn, + casConnection: casConn, CompressedBytestreamThreshold: DefaultCompressedBytestreamThreshold, ChunkMaxSize: chunker.DefaultChunkSize, MaxBatchDigests: DefaultMaxBatchDigests, @@ -785,6 +787,7 @@ func NewClientFromConnection(ctx context.Context, instanceName string, conn, cas return nil, fmt.Errorf("CASConcurrency should be at least 1") } client.RunBackgroundTasks(ctx) + return client, nil } @@ -1045,7 +1048,7 @@ func (c *Client) WaitExecution(ctx context.Context, req *repb.WaitExecutionReque // GetBackendCapabilities returns the capabilities for a specific server connection // (either the main connection or the CAS connection). -func (c *Client) GetBackendCapabilities(ctx context.Context, conn grpc.ClientConnInterface, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) { +func (c *Client) GetBackendCapabilities(ctx context.Context, conn *grpc.ClientConn, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) { opts := c.RPCOpts() err = c.Retrier.Do(ctx, func() (e error) { return c.CallWithTimeout(ctx, "GetCapabilities", func(ctx context.Context) (e error) { diff --git a/go/pkg/client/client_test.go b/go/pkg/client/client_test.go index 44f743ce..cb7308d8 100644 --- a/go/pkg/client/client_test.go +++ b/go/pkg/client/client_test.go @@ -3,14 +3,12 @@ package client import ( "context" "errors" - "net" "os" "path" "testing" repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" svpb "github.com/bazelbuild/remote-apis/build/bazel/semver" - "google.golang.org/grpc" ) const ( @@ -204,33 +202,20 @@ func TestNewClient(t *testing.T) { defer c.Close() } -func TestNewClientFromConnection(t *testing.T) { +func TestNewClientRR(t *testing.T) { t.Parallel() ctx := context.Background() - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("Cannot listen: %v", err) - } - conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure()) - if err != nil { - t.Fatalf("Cannot establish gRPC connection: %v", err) - } - c, err := NewClientFromConnection(ctx, instance, conn, conn, StartupCapabilities(false)) + c, err := NewClient(ctx, instance, DialParams{ + Service: "server", + NoSecurity: true, + RoundRobinBalancer: true, + RoundRobinPoolSize: 25, + }, StartupCapabilities(false)) if err != nil { t.Fatalf("Error creating client: %v", err) } defer c.Close() - - _, err = NewClientFromConnection(ctx, instance, nil, conn, StartupCapabilities(false)) - if err == nil { - t.Fatalf("Expected error got nil") - } - - _, err = NewClientFromConnection(ctx, instance, conn, nil, StartupCapabilities(false)) - if err == nil { - t.Fatalf("Expected error got nil") - } } func TestResourceName(t *testing.T) { diff --git a/go/pkg/fakes/server.go b/go/pkg/fakes/server.go index e8a9a8d6..989fcea6 100644 --- a/go/pkg/fakes/server.go +++ b/go/pkg/fakes/server.go @@ -83,10 +83,13 @@ func (s *Server) NewTestClient(ctx context.Context) (*rc.Client, error) { } // NewClientConn returns a gRPC client connction to the server. -func (s *Server) NewClientConn(ctx context.Context) (grpc.ClientConnInterface, error) { +func (s *Server) NewClientConn(ctx context.Context) (*grpc.ClientConn, error) { p := s.dialParams() - conn, _, err := client.Dial(ctx, p.Service, p) - return conn, err + opts, _, err := client.OptsFromParams(ctx, p) + if err != nil { + return nil, err + } + return grpc.Dial(p.Service, opts...) } func (s *Server) dialParams() rc.DialParams { diff --git a/go/pkg/flags/flags.go b/go/pkg/flags/flags.go index f5038664..75a64b7c 100644 --- a/go/pkg/flags/flags.go +++ b/go/pkg/flags/flags.go @@ -76,7 +76,7 @@ var ( KeepAlivePermitWithoutStream = flag.Bool("grpc_keepalive_permit_without_stream", false, "If true, client sends keepalive pings even with no active RPCs; otherwise, doesn't send pings even if time and timeout are set. Default is false.") // UseRoundRobinBalancer is a temporary feature flag to rollout a simplified load balancer. // See http://go/remote-apis-sdks/issues/499 - UseRoundRobinBalancer = flag.Bool("use_simple_balancer", false, "If true, a simple round-robin connection bool is used for gRPC. Otherwise, the existing load balancer is used.") + UseRoundRobinBalancer = flag.Bool("use_round_robin_balancer", false, "If true, a round-robin connection bool is used for gRPC. Otherwise, the existing load balancer is used.") // RoundRobinBalancerPoolSize specifies the pool size for the round robin balancer. RoundRobinBalancerPoolSize = flag.Int("round_robin_balancer_pool_size", client.DefaultMaxConcurrentRequests, "pool size for round robin grpc balacner") ) diff --git a/go/pkg/tool/tool.go b/go/pkg/tool/tool.go index 85f8f797..2e202f41 100644 --- a/go/pkg/tool/tool.go +++ b/go/pkg/tool/tool.go @@ -287,7 +287,7 @@ func (c *Client) UploadBlob(ctx context.Context, path string) error { // UploadBlobV2 uploads a blob from the specified path into the remote cache using newer cas implementation. func (c *Client) UploadBlobV2(ctx context.Context, path string) error { - casC, err := cas.NewClient(ctx, c.GrpcClient.Connection, c.GrpcClient.InstanceName) + casC, err := cas.NewClient(ctx, c.GrpcClient.Connection(), c.GrpcClient.InstanceName) if err != nil { return errors.WithStack(err) }