diff --git a/test/goaway_test.go b/test/goaway_test.go index 61936008f0e5..61446c1f7583 100644 --- a/test/goaway_test.go +++ b/test/goaway_test.go @@ -770,36 +770,20 @@ func (s) TestClientSendsAGoAway(t *testing.T) { if err != nil { t.Fatalf("error listening: %v", err) } - ctCh := testutils.NewChannel() + defer lis.Close() + goAwayReceived := make(chan struct{}) + errCh := make(chan error) go func() { conn, err := lis.Accept() if err != nil { t.Errorf("error in lis.Accept(): %v", err) } ct := newClientTester(t, conn) - ctCh.Send(ct) - }() - defer lis.Close() - - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - t.Fatalf("error dialing: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - val, err := ctCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout waiting for client transport (should be given after http2 creation)") - } - ct := val.(*clientTester) - goAwayReceived := make(chan struct{}) - errCh := make(chan error) - go func() { + defer ct.conn.Close() for { f, err := ct.fr.ReadFrame() if err != nil { + errCh <- fmt.Errorf("error reading frame: %v", err) return } switch fr := f.(type) { @@ -808,6 +792,7 @@ func (s) TestClientSendsAGoAway(t *testing.T) { if fr.ErrCode == http2.ErrCodeNo { t.Logf("GoAway received from client") close(goAwayReceived) + return } default: t.Errorf("server tester received unexpected frame type %T", f) @@ -816,8 +801,18 @@ func (s) TestClientSendsAGoAway(t *testing.T) { } } }() + + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + cc.Connect() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + testutils.AwaitState(ctx, t, cc, connectivity.Ready) cc.Close() - defer ct.conn.Close() select { case <-goAwayReceived: case err := <-errCh: