Skip to content

Commit

Permalink
Fix connectivity state transitions when dialing (grpc#1596)
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Oct 25, 2017
1 parent f164acd commit e0ce3fd
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 59 deletions.
38 changes: 19 additions & 19 deletions balancer/roundrobin/roundrobin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ func TestOneBackend(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
}
Expand All @@ -140,7 +140,7 @@ func TestBackendsRoundRobin(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

Expand All @@ -155,7 +155,7 @@ func TestBackendsRoundRobin(t *testing.T) {
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() == test.addresses[si] {
Expand All @@ -170,7 +170,7 @@ func TestBackendsRoundRobin(t *testing.T) {
}

for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
Expand Down Expand Up @@ -199,13 +199,13 @@ func TestAddressesRemoved(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

Expand Down Expand Up @@ -245,7 +245,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done()
// This RPC blocks until cc is closed.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) == codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded {
t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
}
cancel()
Expand Down Expand Up @@ -275,15 +275,15 @@ func TestNewAddressWhileBlocking(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
}

Expand All @@ -295,7 +295,7 @@ func TestNewAddressWhileBlocking(t *testing.T) {
go func() {
defer wg.Done()
// This RPC blocks until NewAddress is called.
testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false))
testc.EmptyCall(context.Background(), &testpb.Empty{})
}()
}
time.Sleep(50 * time.Millisecond)
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestOneServerDown(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

Expand All @@ -339,7 +339,7 @@ func TestOneServerDown(t *testing.T) {
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() == test.addresses[si] {
Expand All @@ -354,7 +354,7 @@ func TestOneServerDown(t *testing.T) {
}

for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
Expand All @@ -368,7 +368,7 @@ func TestOneServerDown(t *testing.T) {
// Loop until see server[backendCount-1] twice without seeing server[backendCount].
var targetSeen int
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
switch p.Addr.String() {
Expand All @@ -387,7 +387,7 @@ func TestOneServerDown(t *testing.T) {
t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
Expand Down Expand Up @@ -417,7 +417,7 @@ func TestAllServersDown(t *testing.T) {
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}

Expand All @@ -432,7 +432,7 @@ func TestAllServersDown(t *testing.T) {
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() == test.addresses[si] {
Expand All @@ -447,7 +447,7 @@ func TestAllServersDown(t *testing.T) {
}

for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
Expand Down
4 changes: 2 additions & 2 deletions balancer_conn_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,15 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac = ac
ac.acbw = acbw
if acState != connectivity.Idle {
ac.connect(false)
ac.connect()
}
}
}

func (acbw *acBalancerWrapper) Connect() {
acbw.mu.Lock()
defer acbw.mu.Unlock()
acbw.ac.connect(false)
acbw.ac.connect()
}

func (acbw *acBalancerWrapper) getAddrConn() *addrConn {
Expand Down
66 changes: 30 additions & 36 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
waitC <- err
return
}
if err := ac.connect(cc.dopts.block); err != nil {
if err := ac.connect(); err != nil {
waitC <- err
return
}
Expand Down Expand Up @@ -659,7 +659,7 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) {
// It does nothing if the ac is not IDLE.
// TODO(bar) Move this to the addrConn section.
// This was part of resetAddrConn, keep it here to make the diff look clean.
func (ac *addrConn) connect(block bool) error {
func (ac *addrConn) connect() error {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
Expand All @@ -677,32 +677,18 @@ func (ac *addrConn) connect(block bool) error {
}
ac.mu.Unlock()

if block {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := ac.resetTransport(); err != nil {
grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err
return
}
// Start to monitor the error status of transport.
go ac.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := ac.resetTransport(); err != nil {
grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
ac.transportMonitor()
}()
}
ac.transportMonitor()
}()
return nil
}

Expand Down Expand Up @@ -869,19 +855,14 @@ func (ac *addrConn) errorf(format string, a ...interface{}) {

// resetTransport recreates a transport to the address for ac. The old
// transport will close itself on error or when the clientconn is closed.
//
// TODO(bar) make sure all state transitions are valid.
func (ac *addrConn) resetTransport() error {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return errConnClosing
}
ac.state = connectivity.TransientFailure
if ac.cc.balancerWrapper != nil {
ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state)
} else {
ac.cc.csMgr.updateState(ac.state)
}
if ac.ready != nil {
close(ac.ready)
ac.ready = nil
Expand All @@ -905,12 +886,14 @@ func (ac *addrConn) resetTransport() error {
return errConnClosing
}
ac.printf("connecting")
ac.state = connectivity.Connecting
// TODO(bar) remove condition once we always have a balancer.
if ac.cc.balancerWrapper != nil {
ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state)
} else {
ac.cc.csMgr.updateState(ac.state)
if ac.state != connectivity.Connecting {
ac.state = connectivity.Connecting
// TODO(bar) remove condition once we always have a balancer.
if ac.cc.balancerWrapper != nil {
ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state)
} else {
ac.cc.csMgr.updateState(ac.state)
}
}
// copy ac.addrs in case of race
addrsIter := make([]resolver.Address, len(ac.addrs))
Expand Down Expand Up @@ -1013,6 +996,17 @@ func (ac *addrConn) transportMonitor() {
ac.adjustParams(t.GetGoAwayReason())
default:
}
ac.mu.Lock()
// Set connectivity state to TransientFailure before calling
// resetTransport. Transition READY->CONNECTING is not valid.
ac.state = connectivity.TransientFailure
if ac.cc.balancerWrapper != nil {
ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state)
} else {
ac.cc.csMgr.updateState(ac.state)
}
ac.curAddr = resolver.Address{}
ac.mu.Unlock()
if err := ac.resetTransport(); err != nil {
ac.mu.Lock()
ac.printf("transport exiting: %v", err)
Expand Down Expand Up @@ -1084,7 +1078,7 @@ func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) {
ac.mu.Unlock()
// Trigger idle ac to connect.
if idle {
ac.connect(false)
ac.connect()
}
return nil, false
}
Expand Down
4 changes: 2 additions & 2 deletions pickfirst_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestNewAddressWhileBlockingPickfirst(t *testing.T) {
go func() {
defer wg.Done()
// This RPC blocks until NewAddress is called.
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
}()
}
time.Sleep(50 * time.Millisecond)
Expand Down Expand Up @@ -165,7 +165,7 @@ func TestCloseWithPendingRPCPickfirst(t *testing.T) {
go func() {
defer wg.Done()
// This RPC blocks until NewAddress is called.
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
}()
}
time.Sleep(50 * time.Millisecond)
Expand Down

0 comments on commit e0ce3fd

Please sign in to comment.