diff --git a/broker/protocol/dispatcher.go b/broker/protocol/dispatcher.go index f38270a9..3e7cad3d 100644 --- a/broker/protocol/dispatcher.go +++ b/broker/protocol/dispatcher.go @@ -110,7 +110,7 @@ func (d *dispatcher) UpdateClientConnState(_ balancer.ClientConnState) error { // implements its own resolution and selection of an appropriate A record. func (d *dispatcher) ResolverError(_ error) {} -func (d *dispatcher) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { +func (d *dispatcher) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { d.mu.Lock() var id, ok = d.connID[sc] if !ok { @@ -152,6 +152,10 @@ func (d *dispatcher) UpdateSubConnState(sc balancer.SubConn, state balancer.SubC }) } +func (d *dispatcher) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + d.updateSubConnState(sc, state) +} + // markedSubConn tracks the last mark associated with a SubConn. // SubConns not used for a complete sweep interval are closed. type markedSubConn struct { @@ -187,7 +191,11 @@ func (d *dispatcher) Pick(info balancer.PickInfo) (balancer.PickResult, error) { []resolver.Address{{ Addr: d.idToAddr(dr.route, dispatchID), }}, - balancer.NewSubConnOptions{}, + balancer.NewSubConnOptions{ + StateListener: func(state balancer.SubConnState) { + d.updateSubConnState(msc.subConn, state) + }, + }, ); err != nil { return balancer.PickResult{}, err } diff --git a/broker/protocol/dispatcher_test.go b/broker/protocol/dispatcher_test.go index 29e8280e..4c96baed 100644 --- a/broker/protocol/dispatcher_test.go +++ b/broker/protocol/dispatcher_test.go @@ -45,6 +45,7 @@ func (s *DispatcherSuite) TestContextAdapters(c *gc.C) { func (s *DispatcherSuite) TestDispatchCases(c *gc.C) { var cc mockClientConn var disp = dispatcherBuilder{zone: "local"}.Build(&cc, balancer.BuildOptions{}).(*dispatcher) + cc.disp = disp close(disp.sweepDoneCh) // Disable async sweeping. // Case: Called without a dispatchRoute. Expect it panics. @@ -58,16 +59,16 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) { // SubConn to the default service address is started. var _, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"default.addr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name:"default.addr", disp: disp}}) cc.created = nil // Case: Default connection transitions to Ready. Expect it's now returned. - disp.UpdateSubConnState(mockSubConn("default.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name: "default.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) result, err := disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.IsNil) c.Check(result.Done, gc.IsNil) - c.Check(result.SubConn, gc.Equals, mockSubConn("default.addr")) + c.Check(result.SubConn, gc.Equals, mockSubConn{Name: "default.addr", disp: disp}) // Case: Specific remote peer is dispatched to. ctx = WithDispatchRoute(context.Background(), @@ -75,58 +76,58 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) { result, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"remote.addr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "remote.addr", disp: disp}}) cc.created = nil - disp.UpdateSubConnState(mockSubConn("remote.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name:"remote.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) result, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.IsNil) c.Check(result.Done, gc.IsNil) - c.Check(result.SubConn, gc.Equals, mockSubConn("remote.addr")) + c.Check(result.SubConn, gc.Equals, mockSubConn{Name: "remote.addr", disp: disp }) // Case: Route allows for multiple members. A local one is now dialed. ctx = WithDispatchRoute(context.Background(), buildRouteFixture(), ProcessSpec_ID{}) _, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"local.addr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name:"local.addr", disp: disp}}) cc.created = nil - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name:"local.addr",disp:disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) result, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.IsNil) c.Check(result.Done, gc.IsNil) - c.Check(result.SubConn, gc.Equals, mockSubConn("local.addr")) + c.Check(result.SubConn, gc.Equals, mockSubConn{Name:"local.addr", disp: disp}) // Case: One local addr is marked as failed. Another is dialed. - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + mockSubConn{Name: "local.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) _, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"local.otherAddr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "local.otherAddr", disp: disp}}) cc.created = nil - disp.UpdateSubConnState(mockSubConn("local.otherAddr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name: "local.otherAddr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) result, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.IsNil) c.Check(result.Done, gc.IsNil) - c.Check(result.SubConn, gc.Equals, mockSubConn("local.otherAddr")) + c.Check(result.SubConn, gc.Equals, mockSubConn{Name:"local.otherAddr", disp: disp}) // Case: otherAddr is also failed. Expect that an error is returned, // rather than dispatch to remote addr. (Eg we prefer to wait for a // local replica to recover or the route to change, vs using a remote // endpoint which incurs more networking cost). - disp.UpdateSubConnState(mockSubConn("local.otherAddr"), balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + mockSubConn{Name: "local.otherAddr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) _, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.Equals, balancer.ErrTransientFailure) // Case: local.addr is Ready again. However, primary is required and has failed. - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) - disp.UpdateSubConnState(mockSubConn("remote.addr"), balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + mockSubConn{Name: "local.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name: "remote.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) ctx = WithDispatchRoute(context.Background(), buildRouteFixture(), ProcessSpec_ID{Zone: "remote", Suffix: "primary"}) @@ -150,7 +151,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) { result, err = disp.Pick(balancer.PickInfo{Ctx: ctx}) c.Check(err, gc.IsNil) c.Check(result.Done, gc.NotNil) - c.Check(result.SubConn, gc.Equals, mockSubConn("local.addr")) + c.Check(result.SubConn, gc.Equals, mockSubConn{Name:"local.addr", disp: disp}) // Closure callback with an Unavailable error (only) will trigger an invalidation. result.Done(balancer.DoneInfo{Err: nil}) @@ -164,6 +165,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) { func (s *DispatcherSuite) TestDispatchMarkAndSweep(c *gc.C) { var cc mockClientConn var disp = dispatcherBuilder{zone: "local"}.Build(&cc, balancer.BuildOptions{}).(*dispatcher) + cc.disp = disp defer disp.Close() var err error @@ -177,11 +179,11 @@ func (s *DispatcherSuite) TestDispatchMarkAndSweep(c *gc.C) { _, err = disp.Pick(balancer.PickInfo{Ctx: localCtx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"remote.addr", "local.addr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "remote.addr", disp: disp}, mockSubConn{Name: "local.addr", disp: disp}}) cc.created = nil - disp.UpdateSubConnState(mockSubConn("remote.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + mockSubConn{Name: "remote.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name: "local.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) disp.sweep() c.Check(cc.removed, gc.IsNil) @@ -205,14 +207,14 @@ func (s *DispatcherSuite) TestDispatchMarkAndSweep(c *gc.C) { // This time, expect that local.addr is swept. disp.sweep() - c.Check(cc.removed, gc.DeepEquals, []mockSubConn{"local.addr"}) + c.Check(cc.removed, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "local.addr", disp: disp }}) cc.removed = nil - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + mockSubConn{Name: "local.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) disp.sweep() // Now remote.addr is swept. - c.Check(cc.removed, gc.DeepEquals, []mockSubConn{"remote.addr"}) + c.Check(cc.removed, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "remote.addr", disp: disp}}) cc.removed = nil - disp.UpdateSubConnState(mockSubConn("remote.addr"), balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + mockSubConn{Name: "remote.addr", disp: disp }.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) // No connections remain. c.Check(disp.idConn, gc.HasLen, 0) @@ -223,10 +225,10 @@ func (s *DispatcherSuite) TestDispatchMarkAndSweep(c *gc.C) { _, err = disp.Pick(balancer.PickInfo{Ctx: localCtx}) c.Check(err, gc.Equals, balancer.ErrNoSubConnAvailable) - c.Check(cc.created, gc.DeepEquals, []mockSubConn{"local.addr"}) + c.Check(cc.created, gc.DeepEquals, []mockSubConn{mockSubConn{Name: "local.addr", disp: disp}}) cc.created = nil - disp.UpdateSubConnState(mockSubConn("local.addr"), balancer.SubConnState{ConnectivityState: connectivity.Ready}) + mockSubConn{Name: "local.addr", disp: disp}.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) _, err = disp.Pick(balancer.PickInfo{Ctx: localCtx}) c.Check(err, gc.IsNil) } @@ -235,31 +237,42 @@ type mockClientConn struct { err error created []mockSubConn removed []mockSubConn + disp *dispatcher } -type mockSubConn string +type mockSubConn struct { + Name string + disp *dispatcher +} + +func (s1 mockSubConn) Equal(s2 mockSubConn) bool { + return s1.Name == s2.Name +} -func (s mockSubConn) UpdateAddresses([]resolver.Address) {} +func (s mockSubConn) UpdateAddresses([]resolver.Address) { panic("deprecated") } +func (s mockSubConn) UpdateState(state balancer.SubConnState) { s.disp.updateSubConnState(s, state) } func (s mockSubConn) Connect() {} func (s mockSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) { return nil, func() {} } -func (s mockSubConn) Shutdown() {} +func (s mockSubConn) Shutdown() { + var c = s.disp.cc.(*mockClientConn) + c.removed = append(c.removed, s) +} func (c *mockClientConn) NewSubConn(a []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { - var sc = mockSubConn(a[0].Addr) + var sc = mockSubConn{Name: a[0].Addr, disp: c.disp} c.created = append(c.created, sc) return sc, c.err } -func (c *mockClientConn) RemoveSubConn(sc balancer.SubConn) { - c.removed = append(c.removed, sc.(mockSubConn)) -} - func (c *mockClientConn) UpdateAddresses(balancer.SubConn, []resolver.Address) {} func (c *mockClientConn) UpdateState(balancer.State) {} func (c *mockClientConn) ResolveNow(resolver.ResolveNowOptions) {} func (c *mockClientConn) Target() string { return "default.addr" } +func (c *mockClientConn) RemoveSubConn(balancer.SubConn) { + panic("deprecated") +} type mockRouter struct{ invalidated string }