Skip to content

Commit

Permalink
Add rate limiting when calling STS assume role API
Browse files Browse the repository at this point in the history
  • Loading branch information
prateekgogia authored and nckturner committed Jul 23, 2022
1 parent d1cebb8 commit 4be6568
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 6 deletions.
10 changes: 4 additions & 6 deletions pkg/providers/v1/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -1229,15 +1229,13 @@ func init() {
var creds *credentials.Credentials
if cfg.Global.RoleARN != "" {
klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN)
provider := &stscreds.AssumeRoleProvider{
Client: sts.New(sess),
RoleARN: cfg.Global.RoleARN,
}

creds = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
provider,
assumeRoleProvider(&stscreds.AssumeRoleProvider{
Client: sts.New(sess),
RoleARN: cfg.Global.RoleARN,
}),
})
}

Expand Down
62 changes: 62 additions & 0 deletions pkg/providers/v1/aws_assumerole_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
)

const (
invalidateCredsAfter = 1 * time.Second
)

// assumeRoleProviderWithRateLimiting makes sure we call the underlying provider only
// once after `invalidateCredsAfter` period
type assumeRoleProviderWithRateLimiting struct {
provider credentials.Provider
invalidateCredsAfter time.Duration
sync.RWMutex
lastError error
lastValue credentials.Value
lastRetrieveTime time.Time
}

func assumeRoleProvider(provider credentials.Provider) credentials.Provider {
return &assumeRoleProviderWithRateLimiting{provider: provider,
invalidateCredsAfter: invalidateCredsAfter}
}

func (l *assumeRoleProviderWithRateLimiting) Retrieve() (credentials.Value, error) {
l.Lock()
defer l.Unlock()
if time.Since(l.lastRetrieveTime) < l.invalidateCredsAfter {
if l.lastError != nil {
return credentials.Value{}, l.lastError
}
return l.lastValue, nil
}
l.lastValue, l.lastError = l.provider.Retrieve()
l.lastRetrieveTime = time.Now()
return l.lastValue, l.lastError
}

func (l *assumeRoleProviderWithRateLimiting) IsExpired() bool {
return l.provider.IsExpired()
}
132 changes: 132 additions & 0 deletions pkg/providers/v1/aws_assumerole_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"fmt"
"reflect"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
)

func Test_assumeRoleProviderWithRateLimiting_Retrieve(t *testing.T) {
type fields struct {
provider credentials.Provider
invalidateCredsAfter time.Duration
RWMutex sync.RWMutex
lastError error
lastValue credentials.Value
lastRetrieveTime time.Time
}
tests := []struct {
name string
fields *fields
want credentials.Value
wantProviderCalled bool
sleepBeforeCallingProvider time.Duration
wantErr bool
wantErrString string
}{{
name: "Call assume role provider and verify access ID returned",
fields: &fields{provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}},
want: credentials.Value{AccessKeyID: "fakeID"},
wantProviderCalled: true,
}, {
name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last value",
fields: &fields{
provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"},
invalidateCredsAfter: 100 * time.Millisecond,
lastValue: credentials.Value{AccessKeyID: "fakeID1"},
lastRetrieveTime: time.Now(),
},
want: credentials.Value{AccessKeyID: "fakeID1"},
wantProviderCalled: false,
sleepBeforeCallingProvider: 10 * time.Millisecond,
}, {
name: "Assume role provider returns an error when trying to assume a role",
fields: &fields{
provider: &fakeAssumeRoleProvider{err: fmt.Errorf("can't assume fake role")},
invalidateCredsAfter: 10 * time.Millisecond,
lastRetrieveTime: time.Now(),
},
wantProviderCalled: true,
wantErr: true,
wantErrString: "can't assume fake role",
sleepBeforeCallingProvider: 15 * time.Millisecond,
}, {
name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last error value",
fields: &fields{
provider: &fakeAssumeRoleProvider{},
invalidateCredsAfter: 100 * time.Millisecond,
lastRetrieveTime: time.Now(),
},
want: credentials.Value{},
wantProviderCalled: false,
wantErr: true,
wantErrString: "can't assume fake role",
}, {
name: "Delayed call to assume role API, should call the underlying provider",
fields: &fields{
provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID2"},
invalidateCredsAfter: 20 * time.Millisecond,
lastRetrieveTime: time.Now(),
},
want: credentials.Value{AccessKeyID: "fakeID2"},
wantProviderCalled: true,
sleepBeforeCallingProvider: 25 * time.Millisecond,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &assumeRoleProviderWithRateLimiting{
provider: tt.fields.provider,
invalidateCredsAfter: tt.fields.invalidateCredsAfter,
lastError: tt.fields.lastError,
lastValue: tt.fields.lastValue,
lastRetrieveTime: tt.fields.lastRetrieveTime,
}
time.Sleep(tt.sleepBeforeCallingProvider)
got, err := l.Retrieve()
if (err != nil) != tt.wantErr && (tt.wantErr && reflect.DeepEqual(err, tt.wantErrString)) {
t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() got = %v, want %v", got, tt.want)
return
}
if tt.wantProviderCalled != tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled {
t.Errorf("provider called %v, want %v", tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled, tt.wantProviderCalled)
}
})
}
}

type fakeAssumeRoleProvider struct {
accesskeyID string
err error
providerCalled bool
}

func (f *fakeAssumeRoleProvider) Retrieve() (credentials.Value, error) {
f.providerCalled = true
return credentials.Value{AccessKeyID: f.accesskeyID}, f.err
}

func (f *fakeAssumeRoleProvider) IsExpired() bool { return true }

0 comments on commit 4be6568

Please sign in to comment.