diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index 7b953a520e5..c8a30753142 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -138,7 +138,7 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions { // and server options (server options struct does not exist now. When // caller can provide endpoints, it should be created. -// altsHandshaker is used to complete a ALTS handshaking between client and +// altsHandshaker is used to complete an ALTS handshake between client and // server. This handshaker talks to the ALTS handshaker service in the metadata // server. type altsHandshaker struct { @@ -146,6 +146,8 @@ type altsHandshaker struct { stream altsgrpc.HandshakerService_DoHandshakeClient // the connection to the peer. conn net.Conn + // a virtual connection to the ALTS handshaker service. + clientConn *grpc.ClientConn // client handshake options. clientOpts *ClientHandshakerOptions // server handshake options. @@ -154,39 +156,33 @@ type altsHandshaker struct { side core.Side } -// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC -// stub created using the passed conn and used to talk to the ALTS Handshaker +// NewClientHandshaker creates a core.Handshaker that performs a client-side +// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker // service in the metadata server. func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) { - stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx) - if err != nil { - return nil, err - } return &altsHandshaker{ - stream: stream, + stream: nil, conn: c, + clientConn: conn, clientOpts: opts, side: core.ClientSide, }, nil } -// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC -// stub created using the passed conn and used to talk to the ALTS Handshaker +// NewServerHandshaker creates a core.Handshaker that performs a server-side +// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker // service in the metadata server. func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) { - stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx) - if err != nil { - return nil, err - } return &altsHandshaker{ - stream: stream, + stream: nil, conn: c, + clientConn: conn, serverOpts: opts, side: core.ServerSide, }, nil } -// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once +// ClientHandshake starts and completes a client ALTS handshake for GCP. Once // done, ClientHandshake returns a secure connection. func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { if !acquire() { @@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker") } + // TODO(matthewstevenson88): Change unit tests to use public APIs so + // that h.stream can unconditionally be set based on h.clientConn. + if h.stream == nil { + stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err) + } + h.stream = stream + } + // Create target identities from service account list. targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts)) for _, account := range h.clientOpts.TargetServiceAccounts { @@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent return conn, authInfo, nil } -// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once +// ServerHandshake starts and completes a server ALTS handshake for GCP. Once // done, ServerHandshake returns a secure connection. func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { if !acquire() { @@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker") } + // TODO(matthewstevenson88): Change unit tests to use public APIs so + // that h.stream can unconditionally be set based on h.clientConn. + if h.stream == nil { + stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err) + } + h.stream = stream + } + p := make([]byte, frameLimit) n, err := h.conn.Read(p) if err != nil { diff --git a/credentials/alts/internal/handshaker/handshaker_test.go b/credentials/alts/internal/handshaker/handshaker_test.go index 14a0721054f..53aee642315 100644 --- a/credentials/alts/internal/handshaker/handshaker_test.go +++ b/credentials/alts/internal/handshaker/handshaker_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" grpc "google.golang.org/grpc" core "google.golang.org/grpc/credentials/alts/internal" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" @@ -283,3 +285,65 @@ func (s) TestPeerNotResponding(t *testing.T) { t.Errorf("ClientHandshake() = %v, want %v", got, want) } } + +func (s) TestNewClientHandshaker(t *testing.T) { + conn := testutil.NewTestConn(nil, nil) + clientConn := &grpc.ClientConn{} + opts := &ClientHandshakerOptions{} + hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts) + if err != nil { + t.Errorf("NewClientHandshaker returned unexpected error: %v", err) + } + expectedHs := &altsHandshaker{ + stream: nil, + conn: conn, + clientConn: clientConn, + clientOpts: opts, + serverOpts: nil, + side: core.ClientSide, + } + cmpOpts := []cmp.Option{ + cmp.AllowUnexported(altsHandshaker{}), + cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), + } + if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { + t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) + } + if hs.(*altsHandshaker).stream != nil { + t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream") + } + if hs.(*altsHandshaker).clientConn != clientConn { + t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn") + } +} + +func (s) TestNewServerHandshaker(t *testing.T) { + conn := testutil.NewTestConn(nil, nil) + clientConn := &grpc.ClientConn{} + opts := &ServerHandshakerOptions{} + hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts) + if err != nil { + t.Errorf("NewServerHandshaker returned unexpected error: %v", err) + } + expectedHs := &altsHandshaker{ + stream: nil, + conn: conn, + clientConn: clientConn, + clientOpts: nil, + serverOpts: opts, + side: core.ServerSide, + } + cmpOpts := []cmp.Option{ + cmp.AllowUnexported(altsHandshaker{}), + cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), + } + if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { + t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) + } + if hs.(*altsHandshaker).stream != nil { + t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream") + } + if hs.(*altsHandshaker).clientConn != clientConn { + t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn") + } +}