diff --git a/balancer_wrapper.go b/balancer_wrapper.go index 2c760e623f63..4c247f2ab773 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -282,6 +282,10 @@ type acBalancerWrapper struct { // dropped or updated. This is required as closures can't be compared for // equality. healthData *healthData + + shutdownMu sync.Mutex + shutdownCh chan struct{} + activeGofuncs sync.WaitGroup } // healthData holds data related to health state reporting. @@ -347,16 +351,45 @@ func (acbw *acBalancerWrapper) String() string { } func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { - acbw.ac.updateAddrs(addrs) + acbw.goFunc(func(shutdown <-chan struct{}) { + acbw.ac.updateAddrs(shutdown, addrs) + }) } func (acbw *acBalancerWrapper) Connect() { - go acbw.ac.connect() + acbw.goFunc(acbw.ac.connect) +} + +func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) { + acbw.shutdownMu.Lock() + defer acbw.shutdownMu.Unlock() + + shutdown := acbw.shutdownCh + if shutdown == nil { + shutdown = make(chan struct{}) + acbw.shutdownCh = shutdown + } + + acbw.activeGofuncs.Add(1) + go func() { + defer acbw.activeGofuncs.Done() + fn(shutdown) + }() } func (acbw *acBalancerWrapper) Shutdown() { acbw.closeProducers() acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain) + + acbw.shutdownMu.Lock() + defer acbw.shutdownMu.Unlock() + + shutdown := acbw.shutdownCh + acbw.shutdownCh = nil + if shutdown != nil { + close(shutdown) + acbw.activeGofuncs.Wait() + } } // NewStream begins a streaming RPC on the addrConn. If the addrConn is not diff --git a/clientconn.go b/clientconn.go index c0c2c9a76abf..2dfd6cd3aad9 100644 --- a/clientconn.go +++ b/clientconn.go @@ -925,25 +925,24 @@ func (cc *ClientConn) incrCallsFailed() { // connect starts creating a transport. // It does nothing if the ac is not IDLE. // TODO(bar) Move this to the addrConn section. -func (ac *addrConn) connect() error { +func (ac *addrConn) connect(abort <-chan struct{}) { ac.mu.Lock() if ac.state == connectivity.Shutdown { if logger.V(2) { logger.Infof("connect called on shutdown addrConn; ignoring.") } ac.mu.Unlock() - return errConnClosing + return } if ac.state != connectivity.Idle { if logger.V(2) { logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state) } ac.mu.Unlock() - return nil + return } - ac.resetTransportAndUnlock() - return nil + ac.resetTransportAndUnlock(abort) } // equalAddressIgnoringBalAttributes returns true is a and b are considered equal. @@ -962,7 +961,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool { // updateAddrs updates ac.addrs with the new addresses list and handles active // connections or connection attempts. -func (ac *addrConn) updateAddrs(addrs []resolver.Address) { +func (ac *addrConn) updateAddrs(abort <-chan struct{}, addrs []resolver.Address) { addrs = copyAddresses(addrs) limit := len(addrs) if limit > 5 { @@ -1018,7 +1017,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { // Since we were connecting/connected, we should start a new connection // attempt. - go ac.resetTransportAndUnlock() + ac.resetTransportAndUnlock(abort) } // getServerName determines the serverName to be used in the connection @@ -1249,9 +1248,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { // resetTransportAndUnlock unconditionally connects the addrConn. // // ac.mu must be held by the caller, and this function will guarantee it is released. -func (ac *addrConn) resetTransportAndUnlock() { - acCtx := ac.ctx - if acCtx.Err() != nil { +func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) { + ctx, cancel := context.WithCancel(ac.ctx) + go func() { + select { + case <-abort: + cancel() + case <-ctx.Done(): + } + }() + + if ctx.Err() != nil { ac.mu.Unlock() return } @@ -1279,12 +1286,12 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() - if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { + if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil { // TODO: #7534 - Move re-resolution requests into the pick_first LB policy // to ensure one resolution request per pass instead of per subconn failure. ac.cc.resolveNow(resolver.ResolveNowOptions{}) ac.mu.Lock() - if acCtx.Err() != nil { + if ctx.Err() != nil { // addrConn was torn down. ac.mu.Unlock() return @@ -1305,13 +1312,13 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.mu.Unlock() case <-b: timer.Stop() - case <-acCtx.Done(): + case <-ctx.Done(): timer.Stop() return } ac.mu.Lock() - if acCtx.Err() == nil { + if ctx.Err() == nil { ac.updateConnectivityState(connectivity.Idle, err) } ac.mu.Unlock() @@ -1366,6 +1373,9 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c // new transport. func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { addr.ServerName = ac.cc.getServerName(addr) + + var healthCheckStarted atomic.Bool + healthCheckDone := make(chan struct{}) hctx, hcancel := context.WithCancel(ctx) onClose := func(r transport.GoAwayReason) { @@ -1394,6 +1404,9 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, // Always go idle and wait for the LB policy to initiate a new // connection attempt. ac.updateConnectivityState(connectivity.Idle, nil) + if healthCheckStarted.Load() { + <-healthCheckDone + } } connectCtx, cancel := context.WithDeadline(ctx, connectDeadline) @@ -1406,29 +1419,35 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, logger.Infof("Creating new client transport to %q: %v", addr, err) } // newTr is either nil, or closed. - hcancel() channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err) return err } - ac.mu.Lock() - defer ac.mu.Unlock() + acMu := &ac.mu + acMu.Lock() + defer func() { + if acMu != nil { + acMu.Unlock() + } + }() if ctx.Err() != nil { // This can happen if the subConn was removed while in `Connecting` // state. tearDown() would have set the state to `Shutdown`, but // would not have closed the transport since ac.transport would not // have been set at that point. - // - // We run this in a goroutine because newTr.Close() calls onClose() + + // We unlock ac.mu because newTr.Close() calls onClose() // inline, which requires locking ac.mu. - // + acMu.Unlock() + acMu = nil + // The error we pass to Close() is immaterial since there are no open // streams at this point, so no trailers with error details will be sent // out. We just need to pass a non-nil error. // // This can also happen when updateAddrs is called during a connection // attempt. - go newTr.Close(transport.ErrConnClosing) + newTr.Close(transport.ErrConnClosing) return nil } if hctx.Err() != nil { @@ -1440,7 +1459,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, } ac.curAddr = addr ac.transport = newTr - ac.startHealthCheck(hctx) // Will set state to READY if appropriate. + healthCheckStarted.Store(ac.startHealthCheck(hctx, healthCheckDone)) // Will set state to READY if appropriate. return nil } @@ -1456,7 +1475,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, // It sets addrConn to READY if the health checking stream is not started. // // Caller must hold ac.mu. -func (ac *addrConn) startHealthCheck(ctx context.Context) { +func (ac *addrConn) startHealthCheck(ctx context.Context, done chan<- struct{}) bool { var healthcheckManagingState bool defer func() { if !healthcheckManagingState { @@ -1465,14 +1484,14 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { }() if ac.cc.dopts.disableHealthCheck { - return + return false } healthCheckConfig := ac.cc.healthCheckConfig() if healthCheckConfig == nil { - return + return false } if !ac.scopts.HealthCheckEnabled { - return + return false } healthCheckFunc := internal.HealthCheckFunc if healthCheckFunc == nil { @@ -1480,7 +1499,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { // // TODO: add a link to the health check doc in the error message. channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.") - return + return false } healthcheckManagingState = true @@ -1506,6 +1525,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { } // Start the health checking stream. go func() { + defer close(done) err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName) if err != nil { if status.Code(err) == codes.Unimplemented { @@ -1515,6 +1535,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { } } }() + return true } func (ac *addrConn) resetConnectBackoff() { diff --git a/internal/balancer/gracefulswitch/gracefulswitch.go b/internal/balancer/gracefulswitch/gracefulswitch.go index ba25b8988718..45dea2472f36 100644 --- a/internal/balancer/gracefulswitch/gracefulswitch.go +++ b/internal/balancer/gracefulswitch/gracefulswitch.go @@ -67,6 +67,8 @@ type Balancer struct { // balancerCurrent before the UpdateSubConnState is called on the // balancerCurrent. currentMu sync.Mutex + + pendingSwaps sync.WaitGroup } // swap swaps out the current lb with the pending lb and updates the ClientConn. @@ -76,7 +78,9 @@ func (gsb *Balancer) swap() { cur := gsb.balancerCurrent gsb.balancerCurrent = gsb.balancerPending gsb.balancerPending = nil + gsb.pendingSwaps.Add(1) go func() { + defer gsb.pendingSwaps.Done() gsb.currentMu.Lock() defer gsb.currentMu.Unlock() cur.Close() @@ -274,6 +278,7 @@ func (gsb *Balancer) Close() { currentBalancerToClose.Close() pendingBalancerToClose.Close() + gsb.pendingSwaps.Wait() } // balancerWrapper wraps a balancer.Balancer, and overrides some Balancer diff --git a/internal/testutils/pipe_listener.go b/internal/testutils/pipe_listener.go index 6bd3bc0bea12..70ebeafe4ac1 100644 --- a/internal/testutils/pipe_listener.go +++ b/internal/testutils/pipe_listener.go @@ -20,6 +20,7 @@ package testutils import ( + "context" "errors" "net" "time" @@ -81,11 +82,20 @@ func (p *PipeListener) Addr() net.Addr { // Dialer dials a connection. func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) { return func(string, time.Duration) (net.Conn, error) { + return p.ContextDialer()(context.Background(), "") + } +} + +// ContextDialer dials a using a context. +func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, _ string) (net.Conn, error) { connChan := make(chan net.Conn) select { case p.c <- connChan: case <-p.done: return nil, errClosed + case <-ctx.Done(): + return nil, context.Cause(ctx) } conn, ok := <-connChan if !ok { diff --git a/test/clientconn_state_transition_test.go b/test/clientconn_state_transition_test.go index 1706a81a257d..321062a5626d 100644 --- a/test/clientconn_state_transition_test.go +++ b/test/clientconn_state_transition_test.go @@ -166,7 +166,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s client, err := grpc.NewClient("passthrough:///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), - grpc.WithDialer(pl.Dialer()), + grpc.WithContextDialer(pl.ContextDialer()), grpc.WithConnectParams(grpc.ConnectParams{ Backoff: backoff.Config{}, MinConnectTimeout: 100 * time.Millisecond, diff --git a/test/end2end_test.go b/test/end2end_test.go index 9157c525c094..935c73c97e18 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -408,7 +408,7 @@ type env struct { security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS balancer string // One of "round_robin", "pick_first", or "". - customDialer func(string, string, time.Duration) (net.Conn, error) + customDialer func(context.Context, string, string) (net.Conn, error) } func (e env) runnable() bool { @@ -418,11 +418,12 @@ func (e env) runnable() bool { return true } -func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { +func (e env) dialer(ctx context.Context, addr string) (net.Conn, error) { if e.customDialer != nil { - return e.customDialer(e.network, addr, timeout) + return e.customDialer(ctx, e.network, addr) } - return net.DialTimeout(e.network, addr, timeout) + d := net.Dialer{} + return d.DialContext(ctx, e.network, addr) } var ( @@ -759,7 +760,7 @@ func (d *nopDecompressor) Type() string { } func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) { - opts = append(opts, grpc.WithDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent)) + opts = append(opts, grpc.WithContextDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent)) if te.clientCompression { opts = append(opts, @@ -868,7 +869,9 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - c, err := te.e.dialer(te.srvAddr, 10*time.Second) + ctx, cancel := context.WithTimeout(te.ctx, 10*time.Second) + defer cancel() + c, err := te.e.dialer(ctx, te.srvAddr) if err != nil { te.t.Fatal(err) } @@ -925,8 +928,9 @@ func (l *lazyConn) Write(b []byte) (int, error) { func (s) TestContextDeadlineNotIgnored(t *testing.T) { e := noBalancerEnv var lc *lazyConn - e.customDialer = func(network, addr string, timeout time.Duration) (net.Conn, error) { - conn, err := net.DialTimeout(network, addr, timeout) + e.customDialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, network, addr) if err != nil { return nil, err }