From 5b34e155f9b542569651996ec87ac8847ad40edc Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Fri, 4 Oct 2019 11:03:08 -0700 Subject: [PATCH] clientconn: override authority with address's ServerName, if set --- clientconn.go | 8 +++++++- test/end2end_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/clientconn.go b/clientconn.go index d4fadfd68bbf..43cc37ad751d 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1135,10 +1135,16 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne onCloseCalled := make(chan struct{}) reconnect := grpcsync.NewEvent() + authority := ac.cc.authority + // addr.ServerName takes precedent over ClientConn authority, if present. + if addr.ServerName != "" { + authority = addr.ServerName + } + target := transport.TargetInfo{ Addr: addr.Addr, Metadata: addr.Metadata, - Authority: ac.cc.authority, + Authority: authority, } once := sync.Once{} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3f..74a141fe621e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4968,6 +4968,49 @@ func (s) TestCredsHandshakeAuthority(t *testing.T) { } } +// This test makes sure that the authority client handshake gets is the endpoint +// of the ServerName of the address when it is set. +func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { + const testAuthority = "test.auth.ori.ty" + const testServerName = "test.server.name" + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + cred := &authorityCheckCreds{} + s := grpc.NewServer() + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + // ctx got timeout or canceled. + t.Fatalf("ClientConn is not ready after 100 ms") + } + } + + if cred.got != testServerName { + t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) + } +} + type clientFailCreds struct{} func (c *clientFailCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {