diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go index cdd103ef7dc..02bf19c7651 100644 --- a/xds/internal/resolver/xds_resolver.go +++ b/xds/internal/resolver/xds_resolver.go @@ -20,12 +20,12 @@ package resolver import ( - "context" "fmt" "google.golang.org/grpc" "google.golang.org/grpc/attributes" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/resolver" xdsinternal "google.golang.org/grpc/xds/internal" @@ -62,6 +62,7 @@ func (b *xdsResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, rb r := &xdsResolver{ target: t, cc: cc, + closed: grpcsync.NewEvent(), updateCh: make(chan suWithError, 1), } r.logger = prefixLogger((r)) @@ -86,7 +87,8 @@ func (b *xdsResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, rb return nil, fmt.Errorf("xds: failed to create xds-client: %v", err) } r.client = client - r.ctx, r.cancelCtx = context.WithCancel(context.Background()) + + // Register a watch on the xdsClient for the user's dial target. cancelWatch := r.client.WatchService(r.target.Endpoint, r.handleServiceUpdate) r.logger.Infof("Watch started on resource name %v with xds-client %p", r.target.Endpoint, r.client) r.cancelWatch = func() { @@ -145,10 +147,9 @@ type suWithError struct { // (which performs LDS/RDS queries for the same), and passes the received // updates to the ClientConn. type xdsResolver struct { - ctx context.Context - cancelCtx context.CancelFunc - target resolver.Target - cc resolver.ClientConn + target resolver.Target + cc resolver.ClientConn + closed *grpcsync.Event logger *grpclog.PrefixLogger @@ -176,7 +177,8 @@ type xdsResolver struct { func (r *xdsResolver) run() { for { select { - case <-r.ctx.Done(): + case <-r.closed.Done(): + return case update := <-r.updateCh: if update.err != nil { r.logger.Warningf("Watch error on resource %v from xds-client %p, %v", r.target.Endpoint, r.client, update.err) @@ -214,7 +216,7 @@ func (r *xdsResolver) run() { // the received update to the update channel, which is picked by the run // goroutine. func (r *xdsResolver) handleServiceUpdate(su xdsclient.ServiceUpdate, err error) { - if r.ctx.Err() != nil { + if r.closed.HasFired() { // Do not pass updates to the ClientConn once the resolver is closed. return } @@ -228,6 +230,6 @@ func (*xdsResolver) ResolveNow(o resolver.ResolveNowOptions) {} func (r *xdsResolver) Close() { r.cancelWatch() r.client.Close() - r.cancelCtx() + r.closed.Fire() r.logger.Infof("Shutdown") } diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index 98c6d6a2e5b..9d04e3c9546 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" @@ -58,7 +59,15 @@ var ( target = resolver.Target{Endpoint: targetStr} ) -func TestRegister(t *testing.T) { +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestRegister(t *testing.T) { b := resolver.Get(xdsScheme) if b == nil { t.Errorf("scheme %v is not registered", xdsScheme) @@ -119,7 +128,7 @@ func errorDialer(_ context.Context, _ string) (net.Conn, error) { // TestResolverBuilder tests the xdsResolverBuilder's Build method with // different parameters. -func TestResolverBuilder(t *testing.T) { +func (s) TestResolverBuilder(t *testing.T) { tests := []struct { name string rbo resolver.BuildOptions @@ -262,7 +271,7 @@ func waitForWatchService(t *testing.T, xdsC *fakeclient.Client, wantTarget strin // TestXDSResolverWatchCallbackAfterClose tests the case where a service update // from the underlying xdsClient is received after the resolver is closed. -func TestXDSResolverWatchCallbackAfterClose(t *testing.T) { +func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ config: &validConfig, @@ -286,7 +295,7 @@ func TestXDSResolverWatchCallbackAfterClose(t *testing.T) { // TestXDSResolverBadServiceUpdate tests the case the xdsClient returns a bad // service update. -func TestXDSResolverBadServiceUpdate(t *testing.T) { +func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ config: &validConfig, @@ -313,7 +322,7 @@ func TestXDSResolverBadServiceUpdate(t *testing.T) { // TestXDSResolverGoodServiceUpdate tests the happy case where the resolver // gets a good service update from the xdsClient. -func TestXDSResolverGoodServiceUpdate(t *testing.T) { +func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ config: &validConfig, @@ -372,7 +381,7 @@ func TestXDSResolverGoodServiceUpdate(t *testing.T) { // TestXDSResolverUpdates tests the cases where the resolver gets a good update // after an error, and an error after the good update. -func TestXDSResolverGoodUpdateAfterError(t *testing.T) { +func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ config: &validConfig, @@ -423,7 +432,7 @@ func TestXDSResolverGoodUpdateAfterError(t *testing.T) { // TestXDSResolverResourceNotFoundError tests the cases where the resolver gets // a ResourceNotFoundError. It should generate a service config picking // weighted_target, but no child balancers. -func TestXDSResolverResourceNotFoundError(t *testing.T) { +func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ config: &validConfig,