Skip to content

Commit

Permalink
fix(auth): default gRPC token type to Bearer if not set (#9800)
Browse files Browse the repository at this point in the history
As documented on auth.Token.Type, if the value of Type is "" it should be treated as a Bearer token. Added a similar helper method as we have in the httptransport package to default this.
  • Loading branch information
codyoss committed Apr 18, 2024
1 parent da245fa commit 5284066
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
2 changes: 1 addition & 1 deletion auth/grpctransport/dial_socketopt_test.go
Expand Up @@ -109,7 +109,7 @@ func TestDialWithDirectPathEnabled(t *testing.T) {

pool, err := Dial(ctx, true, &Options{
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("hey"),
TokenProvider: &staticTP{tok: &auth.Token{Value: "hey"}},
}),
GRPCDialOpts: []grpc.DialOption{userDialer},
Endpoint: "example.google.com:443",
Expand Down
15 changes: 12 additions & 3 deletions auth/grpctransport/grpctransport.go
Expand Up @@ -287,15 +287,24 @@ func (c *grpcCredentialsProvider) GetRequestMetadata(ctx context.Context, uri ..
return nil, fmt.Errorf("unable to transfer credentials PerRPCCredentials: %v", err)
}
}
metadata := map[string]string{
"authorization": token.Type + " " + token.Value,
}
metadata := make(map[string]string, len(c.metadata)+1)
setAuthMetadata(token, metadata)
for k, v := range c.metadata {
metadata[k] = v
}
return metadata, nil
}

// setAuthMetadata uses the provided token to set the Authorization metadata.
// If the token.Type is empty, the type is assumed to be Bearer.
func setAuthMetadata(token *auth.Token, m map[string]string) {
typ := token.Type
if typ == "" {
typ = internal.TokenTypeBearer
}
m["authorization"] = typ + " " + token.Value
}

func (c *grpcCredentialsProvider) RequireTransportSecurity() bool {
return c.secure
}
Expand Down
51 changes: 45 additions & 6 deletions auth/grpctransport/grpctransport_test.go
Expand Up @@ -17,6 +17,7 @@ package grpctransport
import (
"context"
"errors"
"log"
"net"
"testing"

Expand Down Expand Up @@ -83,7 +84,7 @@ func TestDial_FailsValidation(t *testing.T) {
opts: &Options{
DisableAuthentication: true,
Credentials: auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: staticTP("fakeToken"),
TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
}),
},
},
Expand Down Expand Up @@ -272,6 +273,44 @@ func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) {
}
}

func TestGrpcCredentialsProvider_TokenType(t *testing.T) {
tests := []struct {
name string
tok *auth.Token
want string
}{
{
name: "type set",
tok: &auth.Token{
Value: "token",
Type: "Basic",
},
want: "Basic token",
},
{
name: "type set",
tok: &auth.Token{
Value: "token",
},
want: "Bearer token",
},
}
for _, tc := range tests {
cp := grpcCredentialsProvider{
creds: &auth.Credentials{
TokenProvider: &staticTP{tok: tc.tok},
},
}
m, err := cp.GetRequestMetadata(context.Background(), "")
if err != nil {
log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err)
}
if got := m["authorization"]; got != tc.want {
t.Fatalf("got %q, want %q", got, tc.want)
}
}
}

func TestNewClient_DetectedServiceAccount(t *testing.T) {
testQuota := "testquota"
wantHeader := "bar"
Expand Down Expand Up @@ -329,12 +368,12 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) {
}
}

type staticTP string
type staticTP struct {
tok *auth.Token
}

func (tp staticTP) Token(context.Context) (*auth.Token, error) {
return &auth.Token{
Value: string(tp),
}, nil
func (tp *staticTP) Token(context.Context) (*auth.Token, error) {
return tp.tok, nil
}

type fakeEchoService struct {
Expand Down

0 comments on commit 5284066

Please sign in to comment.