From 5284066670b6fe65d79089cfe0199c9660f87fc7 Mon Sep 17 00:00:00 2001 From: Cody Oss <6331106+codyoss@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:34:18 -0500 Subject: [PATCH] fix(auth): default gRPC token type to Bearer if not set (#9800) As documented on auth.Token.Type, if the value of Type is "" it should be treated as a Bearer token. Added a similar helper method as we have in the httptransport package to default this. --- auth/grpctransport/dial_socketopt_test.go | 2 +- auth/grpctransport/grpctransport.go | 15 +++++-- auth/grpctransport/grpctransport_test.go | 51 ++++++++++++++++++++--- 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/auth/grpctransport/dial_socketopt_test.go b/auth/grpctransport/dial_socketopt_test.go index 9ee7053dd59..3ad3118c879 100644 --- a/auth/grpctransport/dial_socketopt_test.go +++ b/auth/grpctransport/dial_socketopt_test.go @@ -109,7 +109,7 @@ func TestDialWithDirectPathEnabled(t *testing.T) { pool, err := Dial(ctx, true, &Options{ Credentials: auth.NewCredentials(&auth.CredentialsOptions{ - TokenProvider: staticTP("hey"), + TokenProvider: &staticTP{tok: &auth.Token{Value: "hey"}}, }), GRPCDialOpts: []grpc.DialOption{userDialer}, Endpoint: "example.google.com:443", diff --git a/auth/grpctransport/grpctransport.go b/auth/grpctransport/grpctransport.go index 5cfa0a1fe03..06db948c42e 100644 --- a/auth/grpctransport/grpctransport.go +++ b/auth/grpctransport/grpctransport.go @@ -287,15 +287,24 @@ func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri .. return nil, fmt.Errorf("unable to transfer credentials PerRPCCredentials: %v", err) } } - metadata := map[string]string{ - "authorization": token.Type + " " + token.Value, - } + metadata := make(map[string]string, len(c.metadata)+1) + setAuthMetadata(token, metadata) for k, v := range c.metadata { metadata[k] = v } return metadata, nil } +// setAuthMetadata uses the provided token to set the Authorization metadata. +// If the token.Type is empty, the type is assumed to be Bearer. +func setAuthMetadata(token *auth.Token, m map[string]string) { + typ := token.Type + if typ == "" { + typ = internal.TokenTypeBearer + } + m["authorization"] = typ + " " + token.Value +} + func (c *grpcCredentialsProvider) RequireTransportSecurity() bool { return c.secure } diff --git a/auth/grpctransport/grpctransport_test.go b/auth/grpctransport/grpctransport_test.go index d60c3f6854f..46265f9deb9 100644 --- a/auth/grpctransport/grpctransport_test.go +++ b/auth/grpctransport/grpctransport_test.go @@ -17,6 +17,7 @@ package grpctransport import ( "context" "errors" + "log" "net" "testing" @@ -83,7 +84,7 @@ func TestDial_FailsValidation(t *testing.T) { opts: &Options{ DisableAuthentication: true, Credentials: auth.NewCredentials(&auth.CredentialsOptions{ - TokenProvider: staticTP("fakeToken"), + TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}}, }), }, }, @@ -272,6 +273,44 @@ func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) { } } +func TestGrpcCredentialsProvider_TokenType(t *testing.T) { + tests := []struct { + name string + tok *auth.Token + want string + }{ + { + name: "type set", + tok: &auth.Token{ + Value: "token", + Type: "Basic", + }, + want: "Basic token", + }, + { + name: "type set", + tok: &auth.Token{ + Value: "token", + }, + want: "Bearer token", + }, + } + for _, tc := range tests { + cp := grpcCredentialsProvider{ + creds: &auth.Credentials{ + TokenProvider: &staticTP{tok: tc.tok}, + }, + } + m, err := cp.GetRequestMetadata(context.Background(), "") + if err != nil { + log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err) + } + if got := m["authorization"]; got != tc.want { + t.Fatalf("got %q, want %q", got, tc.want) + } + } +} + func TestNewClient_DetectedServiceAccount(t *testing.T) { testQuota := "testquota" wantHeader := "bar" @@ -329,12 +368,12 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) { } } -type staticTP string +type staticTP struct { + tok *auth.Token +} -func (tp staticTP) Token(context.Context) (*auth.Token, error) { - return &auth.Token{ - Value: string(tp), - }, nil +func (tp *staticTP) Token(context.Context) (*auth.Token, error) { + return tp.tok, nil } type fakeEchoService struct {