Skip to content

Commit

Permalink
feat: Use aws:username for IAM initiated federated console sessions. (#…
Browse files Browse the repository at this point in the history
…626)

* feat: Use aws:username for IAM initiated federated console sessions.

The changes refactor the way federation token ID is used for AWS IAM credentials.
Instead of relying on the userID which was previously parsed, the code now uses the
userName which is more easily attributable to the IAM user name in the Cloudtrail
events list view.

In the Cloudtrail console's event history view, the IAM user name will now display
in the `user name` column. Previously, the `user id` would display (e.g. AIDA.....).

Signed-off-by: Matthew Hembree <47449406+matthewhembree@users.noreply.github.com>

* add unit test around behaviour

---------

Signed-off-by: Matthew Hembree <47449406+matthewhembree@users.noreply.github.com>
Co-authored-by: chrnorm <17420369+chrnorm@users.noreply.github.com>
  • Loading branch information
matthewhembree and chrnorm committed Mar 29, 2024
1 parent 1c0922f commit 3bfb958
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
59 changes: 46 additions & 13 deletions pkg/cfaws/assumer_aws_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
Expand Down Expand Up @@ -244,22 +245,14 @@ func getFederationToken(ctx context.Context, c *Profile) (aws.Credentials, error
if err != nil {
return aws.Credentials{}, err
}
name := *caller.UserId

// for an iam credential, the userID might be something like abcd:test@example.com or in might just be an id
// the idea here is to use the name portion as the federation token id
parts := strings.SplitN(*caller.UserId, ":", 2)
if len(parts) > 1 {
name = parts[1]
}
tags, userName := getSessionTags(caller)

// name is truncated to ensure it meets the maximum length requirements for the AWS api
out, err := client.GetFederationToken(ctx, &sts.GetFederationTokenInput{Name: aws.String(truncateString(name, 32)), Policy: aws.String(allowAllPolicy),
out, err := client.GetFederationToken(ctx, &sts.GetFederationTokenInput{Name: aws.String(truncateString(userName, 32)), Policy: aws.String(allowAllPolicy),
// tags are added to the federation token
Tags: []types.Tag{
{Key: aws.String("userID"), Value: caller.UserId},
{Key: aws.String("account"), Value: caller.Account},
{Key: aws.String("principalArn"), Value: caller.Arn},
}})
Tags: tags,
})
if err != nil {
return aws.Credentials{}, err
}
Expand Down Expand Up @@ -311,3 +304,43 @@ func truncateString(s string, length int) string {
}
return s[:length]
}

func getSessionTags(caller *sts.GetCallerIdentityOutput) (tags []types.Tag, userName string) {
if caller == nil {
return
}

tags = []types.Tag{
{Key: aws.String("userID"), Value: caller.UserId},
{Key: aws.String("account"), Value: caller.Account},
{Key: aws.String("principalArn"), Value: caller.Arn},
}

if caller.UserId != nil {
userName = *caller.UserId
}

callerArn, err := arn.Parse(*caller.Arn)
if err != nil {
clio.Debugw("could not parse caller arn", "error", err)
return
}

// for an iam credential, the caller ARN.Resource will be user/<username>
// the idea here is to use the username portion as the federation token id
parts := strings.Split(callerArn.Resource, "/")

if len(parts) < 2 {
clio.Debugw("could not split caller resource", "resource", callerArn.Resource)
return
}

userName = parts[1]

tags = append(tags, types.Tag{
Key: aws.String("userName"),
Value: aws.String(userName),
})

return
}
66 changes: 66 additions & 0 deletions pkg/cfaws/assumer_aws_iam_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package cfaws

import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/assert"
)

func Test_getSessionTags(t *testing.T) {

tests := []struct {
name string

caller *sts.GetCallerIdentityOutput
wantTags []types.Tag
wantUserName string
}{
{
name: "ok",
caller: &sts.GetCallerIdentityOutput{
Account: aws.String("123456789012"),
Arn: aws.String("arn:aws:iam::123456789012:user/example-user"),
UserId: aws.String("XXXYYYZZZ"),
},
wantUserName: "example-user",
wantTags: []types.Tag{
{Key: aws.String("userID"), Value: aws.String("XXXYYYZZZ")},
{Key: aws.String("account"), Value: aws.String("123456789012")},
{Key: aws.String("principalArn"), Value: aws.String("arn:aws:iam::123456789012:user/example-user")},
{Key: aws.String("userName"), Value: aws.String("example-user")},
},
},
{
name: "falls_back_to_user_id_if_arn_invalid",
caller: &sts.GetCallerIdentityOutput{
Account: aws.String("123456789012"),
Arn: aws.String("invalid arn"),
UserId: aws.String("XXXYYYZZZ"),
},
wantUserName: "XXXYYYZZZ",
wantTags: []types.Tag{
{Key: aws.String("userID"), Value: aws.String("XXXYYYZZZ")},
{Key: aws.String("account"), Value: aws.String("123456789012")},
{Key: aws.String("principalArn"), Value: aws.String("invalid arn")},
},
},
{
name: "wont_panic",
caller: nil,
wantUserName: "",
wantTags: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotTags, gotUserName := getSessionTags(tt.caller)
assert.Equal(t, tt.wantTags, gotTags)
if gotUserName != tt.wantUserName {
t.Errorf("getSessionTags() gotUserName = %v, want %v", gotUserName, tt.wantUserName)
}
})
}
}

0 comments on commit 3bfb958

Please sign in to comment.