From 9fa86f7da1542a3d8fea4435bb8c53086130e6b9 Mon Sep 17 00:00:00 2001 From: Nick Turner Date: Thu, 16 Nov 2023 16:53:11 -0800 Subject: [PATCH] Fix federated user ID parsing * refactor the parsing logic to make it easier --- pkg/arn/arn.go | 48 +++++++---- pkg/arn/arn_test.go | 2 +- pkg/mapper/crd/controller/controller.go | 2 +- pkg/mapper/dynamicfile/dynamicfile.go | 4 +- pkg/mapper/file/mapper.go | 7 +- pkg/token/token.go | 43 ++++++---- pkg/token/token_test.go | 106 +++++++++++++++++++++++- 7 files changed, 172 insertions(+), 40 deletions(-) diff --git a/pkg/arn/arn.go b/pkg/arn/arn.go index 55c936862..609b789ba 100644 --- a/pkg/arn/arn.go +++ b/pkg/arn/arn.go @@ -8,23 +8,35 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" ) +type PrincipalType int + +const ( + // Supported principals + NONE PrincipalType = iota + ROLE + USER + ROOT + FEDERATED_USER + ASSUMED_ROLE +) + // Canonicalize validates IAM resources are appropriate for the authenticator // and converts STS assumed roles into the IAM role resource. // // Supported IAM resources are: -// * AWS account: arn:aws:iam::123456789012:root -// * IAM user: arn:aws:iam::123456789012:user/Bob -// * IAM role: arn:aws:iam::123456789012:role/S3Access -// * IAM Assumed role: arn:aws:sts::123456789012:assumed-role/Accounting-Role/Mary (converted to IAM role) -// * Federated user: arn:aws:sts::123456789012:federated-user/Bob -func Canonicalize(arn string) (string, error) { +// - AWS root user: arn:aws:iam::123456789012:root +// - IAM user: arn:aws:iam::123456789012:user/Bob +// - IAM role: arn:aws:iam::123456789012:role/S3Access +// - IAM Assumed role: arn:aws:sts::123456789012:assumed-role/Accounting-Role/Mary (converted to IAM role) +// - Federated user: arn:aws:sts::123456789012:federated-user/Bob +func Canonicalize(arn string) (PrincipalType, string, error) { parsed, err := awsarn.Parse(arn) if err != nil { - return "", fmt.Errorf("arn '%s' is invalid: '%v'", arn, err) + return NONE, "", fmt.Errorf("arn '%s' is invalid: '%v'", arn, err) } if err := checkPartition(parsed.Partition); err != nil { - return "", fmt.Errorf("arn '%s' does not have a recognized partition", arn) + return NONE, "", fmt.Errorf("arn '%s' does not have a recognized partition", arn) } parts := strings.Split(parsed.Resource, "/") @@ -34,27 +46,31 @@ func Canonicalize(arn string) (string, error) { case "sts": switch resource { case "federated-user": - return arn, nil + return FEDERATED_USER, arn, nil case "assumed-role": if len(parts) < 3 { - return "", fmt.Errorf("assumed-role arn '%s' does not have a role", arn) + return NONE, "", fmt.Errorf("assumed-role arn '%s' does not have a role", arn) } // IAM ARNs can contain paths, part[0] is resource, parts[len(parts)] is the SessionName. role := strings.Join(parts[1:len(parts)-1], "/") - return fmt.Sprintf("arn:%s:iam::%s:role/%s", parsed.Partition, parsed.AccountID, role), nil + return ASSUMED_ROLE, fmt.Sprintf("arn:%s:iam::%s:role/%s", parsed.Partition, parsed.AccountID, role), nil default: - return "", fmt.Errorf("unrecognized resource %s for service sts", parsed.Resource) + return NONE, "", fmt.Errorf("unrecognized resource %s for service sts", parsed.Resource) } case "iam": switch resource { - case "role", "user", "root": - return arn, nil + case "role": + return ROLE, arn, nil + case "user": + return USER, arn, nil + case "root": + return ROOT, arn, nil default: - return "", fmt.Errorf("unrecognized resource %s for service iam", parsed.Resource) + return NONE, "", fmt.Errorf("unrecognized resource %s for service iam", parsed.Resource) } } - return "", fmt.Errorf("service %s in arn %s is not a valid service for identities", parsed.Service, arn) + return NONE, "", fmt.Errorf("service %s in arn %s is not a valid service for identities", parsed.Service, arn) } func checkPartition(partition string) error { diff --git a/pkg/arn/arn_test.go b/pkg/arn/arn_test.go index 7558ee9d0..708b3e943 100644 --- a/pkg/arn/arn_test.go +++ b/pkg/arn/arn_test.go @@ -23,7 +23,7 @@ var arnTests = []struct { func TestUserARN(t *testing.T) { for _, tc := range arnTests { - actual, err := Canonicalize(tc.arn) + _, actual, err := Canonicalize(tc.arn) if err != nil && tc.err == nil || err == nil && tc.err != nil { t.Errorf("Canoncialize(%s) expected err: %v, actual err: %v", tc.arn, tc.err, err) continue diff --git a/pkg/mapper/crd/controller/controller.go b/pkg/mapper/crd/controller/controller.go index 433d1ebd6..c35103025 100644 --- a/pkg/mapper/crd/controller/controller.go +++ b/pkg/mapper/crd/controller/controller.go @@ -207,7 +207,7 @@ func (c *Controller) syncHandler(key string) (err error) { if iamIdentityMapping.Spec.ARN != "" { iamIdentityMappingCopy := iamIdentityMapping.DeepCopy() - canonicalizedARN, err := arn.Canonicalize(strings.ToLower(iamIdentityMapping.Spec.ARN)) + _, canonicalizedARN, err := arn.Canonicalize(strings.ToLower(iamIdentityMapping.Spec.ARN)) if err != nil { return err } diff --git a/pkg/mapper/dynamicfile/dynamicfile.go b/pkg/mapper/dynamicfile/dynamicfile.go index b6081c9b5..c63dc9261 100644 --- a/pkg/mapper/dynamicfile/dynamicfile.go +++ b/pkg/mapper/dynamicfile/dynamicfile.go @@ -63,14 +63,14 @@ func (ms *DynamicFileMapStore) saveMap( ms.awsAccounts = make(map[string]interface{}) for _, user := range userMappings { - key, _ := arn.Canonicalize(strings.ToLower(user.UserARN)) + _, key, _ := arn.Canonicalize(strings.ToLower(user.UserARN)) if ms.userIDStrict { key = user.UserId } ms.users[key] = user } for _, role := range roleMappings { - key, _ := arn.Canonicalize(strings.ToLower(role.RoleARN)) + _, key, _ := arn.Canonicalize(strings.ToLower(role.RoleARN)) if ms.userIDStrict { key = role.UserId } diff --git a/pkg/mapper/file/mapper.go b/pkg/mapper/file/mapper.go index 5a181ff97..6fb5adb3c 100644 --- a/pkg/mapper/file/mapper.go +++ b/pkg/mapper/file/mapper.go @@ -2,9 +2,10 @@ package file import ( "fmt" - "sigs.k8s.io/aws-iam-authenticator/pkg/token" "strings" + "sigs.k8s.io/aws-iam-authenticator/pkg/token" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/mapper" @@ -32,7 +33,7 @@ func NewFileMapper(cfg config.Config) (*FileMapper, error) { return nil, err } if m.RoleARN != "" { - canonicalizedARN, err := arn.Canonicalize(m.RoleARN) + _, canonicalizedARN, err := arn.Canonicalize(m.RoleARN) if err != nil { return nil, err } @@ -47,7 +48,7 @@ func NewFileMapper(cfg config.Config) (*FileMapper, error) { } var key string if m.UserARN != "" { - canonicalizedARN, err := arn.Canonicalize(strings.ToLower(m.UserARN)) + _, canonicalizedARN, err := arn.Canonicalize(strings.ToLower(m.UserARN)) if err != nil { return nil, fmt.Errorf("error canonicalizing ARN: %v", err) } diff --git a/pkg/token/token.go b/pkg/token/token.go index 99b286074..fff02ec89 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -600,29 +600,42 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { return nil, NewSTSError(err.Error()) } - // parse the response into an Identity id := &Identity{ - ARN: callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.Arn, - AccountID: callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.Account, AccessKeyID: accessKeyID, } - id.CanonicalARN, err = arn.Canonicalize(id.ARN) + return getIdentityFromSTSResponse(id, callerIdentity) +} + +func getIdentityFromSTSResponse(id *Identity, wrapper getCallerIdentityWrapper) (*Identity, error) { + var err error + result := wrapper.GetCallerIdentityResponse.GetCallerIdentityResult + + id.ARN = result.Arn + id.AccountID = result.Account + + var principalType arn.PrincipalType + principalType, id.CanonicalARN, err = arn.Canonicalize(id.ARN) if err != nil { return nil, NewSTSError(err.Error()) } - // The user ID is either UserID:SessionName (for assumed roles) or just - // UserID (for IAM User principals). - userIDParts := strings.Split(callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID, ":") - if len(userIDParts) == 2 { - id.UserID = userIDParts[0] - id.SessionName = userIDParts[1] - } else if len(userIDParts) == 1 { - id.UserID = userIDParts[0] + // The user ID is one of: + // 1. UserID:SessionName (for assumed roles) + // 2. UserID (for IAM User principals). + // 3. AWSAccount:CallerSpecifiedName (for federated users) + // We want the entire UserID for federated users because otherwise, + // its just the account ID and is indistinguishable from the UserID + // of the root user. + if principalType == arn.FEDERATED_USER || principalType == arn.USER || principalType == arn.ROOT { + id.UserID = result.UserID } else { - return nil, STSError{fmt.Sprintf( - "malformed UserID %q", - callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)} + userIDParts := strings.Split(result.UserID, ":") + if len(userIDParts) == 2 { + id.UserID = userIDParts[0] + id.SessionName = userIDParts[1] + } else { + return nil, NewSTSError(fmt.Sprintf("malformed UserID %q", result.UserID)) + } } return id, nil diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index c85e7a157..11df461f7 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" @@ -278,7 +279,7 @@ func TestVerifyInvalidCanonicalARNError(t *testing.T) { } func TestVerifyInvalidUserIDError(t *testing.T) { - _, err := newVerifier("aws", 200, jsonResponse("arn:aws:iam::123456789012:user/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken) + _, err := newVerifier("aws", 200, jsonResponse("arn:aws:iam::123456789012:role/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken) errorContains(t, err, "malformed UserID") assertSTSError(t, err) } @@ -307,7 +308,7 @@ func TestVerifyNoSession(t *testing.T) { } func TestVerifySessionName(t *testing.T) { - arn := "arn:aws:iam::123456789012:user/Alice" + arn := "arn:aws:iam::123456789012:role/Alice" account := "123456789012" userID := "Alice" session := "session-name" @@ -413,3 +414,104 @@ func TestFormatJson(t *testing.T) { }) } } + +func TestGetIdentityFromSTSResponse(t *testing.T) { + var ( + accessKeyID = "AKIAVVVVVVVVVVVAGAVA" + defaultID = Identity{ + AccessKeyID: accessKeyID, + } + defaultAccount = "123456789012" + rootUserARN = "arn:aws:iam::123456789012:root" + userARN = "arn:aws:iam::123456789012:user/Alice" + userID = "AIDAIYCCCMMMMMMMMGGDA" + fedUserID = "123456789012:Alice" + fedUserARN = "arn:aws:sts::123456789012:federated-user/Alice" + roleARN = "arn:aws:iam::123456789012:role/Alice" + roleID = "AROAZZCCCNNNNNNNNFFFA" + ) + + cases := []struct { + name string + inputID Identity + inputResponse getCallerIdentityWrapper + expectedErr bool + want Identity + }{ + { + name: "Root User", + inputID: defaultID, + inputResponse: response(defaultAccount, defaultAccount, rootUserARN), + expectedErr: false, + want: Identity{ + ARN: rootUserARN, + CanonicalARN: rootUserARN, + AccountID: defaultAccount, + UserID: defaultAccount, + AccessKeyID: accessKeyID, + }, + }, + { + name: "User", + inputID: defaultID, + inputResponse: response(defaultAccount, userID, userARN), + expectedErr: false, + want: Identity{ + ARN: userARN, + CanonicalARN: userARN, + AccountID: defaultAccount, + UserID: userID, + AccessKeyID: accessKeyID, + }, + }, + { + name: "Role", + inputID: defaultID, + inputResponse: response(defaultAccount, roleID, roleARN), + expectedErr: false, + want: Identity{ + ARN: roleARN, + CanonicalARN: roleARN, + AccountID: defaultAccount, + UserID: roleID, + AccessKeyID: accessKeyID, + }, + }, + { + name: "Federated User", + inputID: defaultID, + inputResponse: response(defaultAccount, fedUserID, fedUserARN), + expectedErr: false, + want: Identity{ + ARN: fedUserARN, + CanonicalARN: fedUserARN, + AccountID: defaultAccount, + UserID: fedUserID, + AccessKeyID: accessKeyID, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + + if got, err := getIdentityFromSTSResponse(&c.inputID, c.inputResponse); err == nil { + if c.expectedErr { + t.Errorf("expected err to be nil but was %s", err) + } + + if diff := cmp.Diff(c.want, *got); diff != "" { + t.Errorf("getIdentityFromSTSResponse() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + +func response(account, userID, arn string) getCallerIdentityWrapper { + wrapper := getCallerIdentityWrapper{} + wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.Account = account + wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.Arn = arn + wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.UserID = userID + wrapper.GetCallerIdentityResponse.ResponseMetadata.RequestID = "id1234" + return wrapper +}