Skip to content

Commit

Permalink
grpc: fix data race in balancer registration
Browse files Browse the repository at this point in the history
Registering gRPC balancers is thread-unsafe because they are stored in a
global map variable that is accessed without holding a lock. Therefore,
it's expected that balancers are registered _once_ at the beginning of
your program (e.g. in a package `init` function) and certainly not after
you've started dialing connections, etc.

> NOTE: this function must only be called during initialization time
> (i.e. in an init() function), and is not thread-safe.

While this is fine for us in production, it's challenging for tests that
spin up multiple agents in-memory. We currently register a balancer per-
agent which holds agent-specific state that cannot safely be shared.

This commit introduces our own registry that _is_ thread-safe, and
implements the Builder interface such that we can call gRPC's `Register`
method once, on start-up. It uses the same pattern as our resolver
registry where we use the dial target's host (aka "authority"), which is
unique per-agent, to determine which builder to use.
  • Loading branch information
boxofrad committed Feb 10, 2023
1 parent 78a4b5f commit 9108aac
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 103 deletions.
3 changes: 2 additions & 1 deletion agent/acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ func NewTestACLAgent(t *testing.T, name string, hcl string, resolveAuthz authzRe
}
return result, err
}
bd, err := NewBaseDeps(loader, logBuffer, logger)
bd, cleanup, err := NewBaseDeps(loader, logBuffer, logger)
t.Cleanup(cleanup)
require.NoError(t, err)

bd.MetricsConfig = &lib.MetricsConfig{
Expand Down
5 changes: 4 additions & 1 deletion agent/consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {

resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
resolver.Register(resolverBuilder)
t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority())
})

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

r := router.NewRouter(
logger,
Expand Down Expand Up @@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}),
LeaderForwarder: resolverBuilder,
NewRequestRecorderFunc: middleware.NewRequestRecorder,
Expand Down
1 change: 0 additions & 1 deletion agent/consul/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
Servers: resolverBuilder,
DialingFromServer: false,
DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder,
})

conn, err = pool.ClientConn("dc2")
Expand Down
5 changes: 1 addition & 4 deletions agent/consul/subscribe_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter(
Expand Down
40 changes: 20 additions & 20 deletions agent/grpc-internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,25 @@ import (
"google.golang.org/grpc/status"
)

// NewBuilder constructs a new Builder with the given name.
func NewBuilder(name string, logger hclog.Logger) *Builder {
// NewBuilder constructs a new Builder. Calling Register will add the Builder
// to our global registry under the given "authority" such that it will be used
// when dialing targets in the form "consul-internal://<authority>/...", this
// allows us to add and remove balancers for different in-memory agents during
// tests.
func NewBuilder(authority string, logger hclog.Logger) *Builder {
return &Builder{
name: name,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
authority: authority,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
}
}

// Builder implements gRPC's balancer.Builder interface to construct balancers.
type Builder struct {
name string
logger hclog.Logger
shuffler shuffler
authority string
logger hclog.Logger
shuffler shuffler

mu sync.Mutex
byTarget map[string]*list.List
Expand Down Expand Up @@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
}
}

// Name implements the gRPC Balancer interface by returning its given name.
func (b *Builder) Name() string { return b.name }

// gRPC's balancer.Register method is not thread-safe, so we guard our calls
// with a global lock (as it may be called from parallel tests).
var registerLock sync.Mutex

// Register the Builder in gRPC's global registry using its given name.
// Register the Builder in our global registry. Users should call Deregister
// when finished using the Builder to clean-up global state.
func (b *Builder) Register() {
registerLock.Lock()
defer registerLock.Unlock()
globalRegistry.register(b.authority, b)
}

gbalancer.Register(b)
// Deregister the Builder from our global registry to clean up state.
func (b *Builder) Deregister() {
globalRegistry.deregister(b.authority)
}

// Rebalance randomizes the priority order of servers for the given target to
Expand Down
31 changes: 20 additions & 11 deletions agent/grpc-internal/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"

"github.com/hashicorp/go-uuid"

"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
Expand All @@ -34,10 +36,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
Expand Down Expand Up @@ -78,10 +81,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
Expand Down Expand Up @@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
Expand Down Expand Up @@ -177,10 +182,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, res := stubResolver(t, server1, server2)
target, authority, res := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
Expand Down Expand Up @@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
})
}

func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper()

addresses := make([]resolver.Address, len(servers))
Expand All @@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })

return fmt.Sprintf("%s://", scheme), r
authority, err := uuid.GenerateUUID()
require.NoError(t, err)

return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}

func runServer(t *testing.T, name string) *server {
Expand Down Expand Up @@ -314,7 +323,7 @@ func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
),
)
t.Cleanup(func() {
Expand Down
69 changes: 69 additions & 0 deletions agent/grpc-internal/balancer/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package balancer

import (
"fmt"
"sync"

gbalancer "google.golang.org/grpc/balancer"
)

// BuilderName should be given in gRPC service configuration to enable our
// custom balancer. It refers to this package's global registry, rather than
// an instance of Builder to enable us to add and remove builders at runtime,
// specifically during tests.
const BuilderName = "consul-internal"

// gRPC's balancer.Register method is thread-unsafe because it mutates a global
// map without holding a lock. As such, it's expected that you register custom
// balancers once at the start of your program (e.g. a package init function).
//
// In production, this is fine. Agents register a single instance of our builder
// and use it for the duration. Tests are where this becomes problematic, as we
// spin up several agents in-memory and register/deregister a builder for each,
// with its own agent-specific state, logger, etc.
//
// To avoid data races, we call gRPC's Register method once, on-package init,
// with a global registry struct that implements the Builder interface but
// delegates the building to N instances of our Builder that are registered and
// deregistered at runtime. We the dial target's host (aka "authority") which
// is unique per-agent to pick the correct builder.
func init() {
gbalancer.Register(globalRegistry)
}

var globalRegistry = &registry{
byAuthority: make(map[string]*Builder),
}

type registry struct {
mu sync.RWMutex
byAuthority map[string]*Builder
}

func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
r.mu.RLock()
defer r.mu.RUnlock()

auth := opts.Target.URL.Host
builder, ok := r.byAuthority[auth]
if !ok {
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
}
return builder.Build(cc, opts)
}

func (r *registry) Name() string { return BuilderName }

func (r *registry) register(auth string, builder *Builder) {
r.mu.Lock()
defer r.mu.Unlock()

r.byAuthority[auth] = builder
}

func (r *registry) deregister(auth string) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.byAuthority, auth)
}
34 changes: 12 additions & 22 deletions agent/grpc-internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"time"

"google.golang.org/grpc"
gbalancer "google.golang.org/grpc/balancer"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"

"github.com/armon/go-metrics"

"github.com/hashicorp/consul/agent/grpc-internal/balancer"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
Expand All @@ -22,8 +22,8 @@ import (

// grpcServiceConfig is provided as the default service config.
//
// It configures our custom balancer (via the %s directive to interpolate its
// name) which will automatically switch servers on error.
// It configures our custom balancer which will automatically switch servers
// on error.
//
// It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED
// errors *only*, as this is the status code servers will return for an
Expand All @@ -41,7 +41,7 @@ import (
// but we're working on generating them automatically from the protobuf files
const grpcServiceConfig = `
{
"loadBalancingConfig": [{"%s":{}}],
"loadBalancingConfig": [{"` + balancer.BuilderName + `":{}}],
"methodConfig": [
{
"name": [{}],
Expand Down Expand Up @@ -131,12 +131,11 @@ const grpcServiceConfig = `

// ClientConnPool creates and stores a connection for each datacenter.
type ClientConnPool struct {
dialer dialer
servers ServerLocator
gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
balancerBuilder gbalancer.Builder
dialer dialer
servers ServerLocator
gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
}

type ServerLocator interface {
Expand Down Expand Up @@ -198,21 +197,14 @@ type ClientConnPoolConfig struct {
// DialingFromDatacenter is the datacenter of the consul agent using this
// pool.
DialingFromDatacenter string

// BalancerBuilder is a builder for the gRPC balancer that will be used.
BalancerBuilder gbalancer.Builder
}

// NewClientConnPool create new GRPC client pool to connect to servers using
// GRPC over RPC.
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
if cfg.BalancerBuilder == nil {
panic("missing required BalancerBuilder")
}
c := &ClientConnPool{
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn),
balancerBuilder: cfg.BalancerBuilder,
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn),
}
c.dialer = newDialer(cfg, &c.gwResolverDep)
return c
Expand Down Expand Up @@ -251,9 +243,7 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(c.dialer),
grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()),
),
grpc.WithDefaultServiceConfig(grpcServiceConfig),
// Keep alive parameters are based on the same default ones we used for
// Yamux. These are somewhat arbitrary but we did observe in scale testing
// that the gRPC defaults (servers send keepalives only every 2 hours,
Expand Down
Loading

0 comments on commit 9108aac

Please sign in to comment.