Skip to content

Commit

Permalink
Adding Rate limiting ec2:DescribeInstances API along with Batching fo…
Browse files Browse the repository at this point in the history
…r high TPS
  • Loading branch information
bhks committed Feb 21, 2020
1 parent aaeee87 commit 978fa3f
Show file tree
Hide file tree
Showing 13 changed files with 646 additions and 123 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,8 @@ cscope.*
*.pyc

# local dot files
.envrc
.envrc

# coverage.out file
coverage.out
coverage.html
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ endif
$(GORELEASER) --skip-publish --rm-dist --snapshot

test:
go test -v -cover -race $(GITHUB_REPO)/...
go test -v -coverprofile=coverage.out -race $(GITHUB_REPO)/...
go tool cover -html=coverage.out -o coverage.html

format:
test -z "$$(find . -path ./vendor -prune -type f -o -name '*.go' -exec gofmt -d {} + | tee /dev/stderr)" || \
Expand Down
2 changes: 2 additions & 0 deletions cmd/aws-iam-authenticator/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ func getConfig() (config.Config, error) {
Kubeconfig: viper.GetString("server.kubeconfig"),
Master: viper.GetString("server.master"),
BackendMode: viper.GetStringSlice("server.backendMode"),
EC2DescribeInstancesQps: viper.GetInt("server.ec2DescribeInstancesQps"),
EC2DescribeInstancesBurst: viper.GetInt("server.ec2DescribeInstancesBurst"),
}
if err := viper.UnmarshalKey("server.mapRoles", &cfg.RoleMappings); err != nil {
return cfg, fmt.Errorf("invalid server role mappings: %v", err)
Expand Down
21 changes: 19 additions & 2 deletions cmd/aws-iam-authenticator/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ import (
"github.com/spf13/viper"
)

// DefaultPort is the default localhost port (chosen randomly).
const DefaultPort = 21362
const (
// DefaultPort is the default localhost port (chosen randomly).
DefaultPort = 21362
// Default Ec2 TPS Variables
DefaultEC2DescribeInstancesQps = 15
DefaultEC2DescribeInstancesBurst = 5
)

// serverCmd represents the server command
var serverCmd = &cobra.Command{
Expand Down Expand Up @@ -102,6 +107,18 @@ func init() {
"Port to bind the server to listen to")
viper.BindPFlag("server.port", serverCmd.Flags().Lookup("port"))

serverCmd.Flags().Int(
"ec2-describeInstances-qps",
DefaultEC2DescribeInstancesQps,
"AWS EC2 rate limiting with qps")
viper.BindPFlag("server.ec2DescribeInstancesQps", serverCmd.Flags().Lookup("ec2-describeInstances-qps"))

serverCmd.Flags().Int(
"ec2-describeInstances-burst",
DefaultEC2DescribeInstancesBurst,
"AWS EC2 rate Limiting with burst")
viper.BindPFlag("server.ec2DescribeInstancesBurst", serverCmd.Flags().Lookup("ec2-describeInstances-burst"))

fs := flag.NewFlagSet("", flag.ContinueOnError)
_ = fs.Parse([]string{})
flag.CommandLine = fs
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module sigs.k8s.io/aws-iam-authenticator

go 1.12
go 1.13

require (
github.com/aws/aws-sdk-go v1.26.7
Expand All @@ -10,6 +10,7 @@ require (
github.com/spf13/cobra v0.0.5
github.com/spf13/viper v1.4.0
go.hein.dev/go-version v0.1.0
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
gopkg.in/yaml.v2 v2.2.2
k8s.io/api v0.0.0-20190425012535-181e1f9c52c1
k8s.io/apimachinery v0.0.0-20190612125636-6a5db36e93ad
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/aws/aws-sdk-go v1.23.11 h1:fTq1xdeDdCwUfBA64QHk1b5HJfWauac36LvtWlk0pEw=
github.com/aws/aws-sdk-go v1.23.11/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.26.7 h1:ObjEnmzvSdYy8KVd3me7v/UMyCn81inLy2SyoIPoBkg=
github.com/aws/aws-sdk-go v1.26.7/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
Expand Down
5 changes: 5 additions & 0 deletions pkg/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,9 @@ type Config struct {

// BackendMode is an ordered list of backends to get mappings from. Comma-delimited list of: File,ConfigMap,CRD
BackendMode []string

// Ec2 DescribeInstances rate limiting variables initially set to defaults until we completely
// understand we don't need to change
EC2DescribeInstancesQps int
EC2DescribeInstancesBurst int
}
275 changes: 275 additions & 0 deletions pkg/ec2provider/ec2provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package ec2provider

import (
"errors"
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/sirupsen/logrus"
"sigs.k8s.io/aws-iam-authenticator/pkg/httputil"
)

const (
// max limit of k8s nodes support
maxChannelSize = 8000
// max number of in flight non batched ec2:DescribeInstances request to flow
maxAllowedInflightRequest = 5
// default wait interval for the ec2 instance id request which is already in flight
defaultWaitInterval = 50 * time.Millisecond
// Making sure the single instance calls waits max till 5 seconds 100* (50 * time.Millisecond)
totalIterationForWaitInterval = 100
// Maximum number of instances with which ec2:DescribeInstances call will be made
maxInstancesBatchSize = 100
// Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has
// already become 100 then it will not respect this limit
maxWaitIntervalForBatch = 200
)

// Get a node name from instance ID
type EC2Provider interface {
GetPrivateDNSName(string) (string, error)
StartEc2DescribeBatchProcessing()
}

type ec2PrivateDNSCache struct {
cache map[string]string
lock sync.RWMutex
}

type ec2Requests struct {
set map[string]bool
lock sync.RWMutex
}

type ec2ProviderImpl struct {
ec2 ec2iface.EC2API
privateDNSCache ec2PrivateDNSCache
ec2Requests ec2Requests
instanceIdsChannel chan string
}

func New(roleARN string, qps int, burst int) EC2Provider {
dnsCache := ec2PrivateDNSCache{
cache: make(map[string]string),
lock: sync.RWMutex{},
}
ec2Requests := ec2Requests{
set: make(map[string]bool),
lock: sync.RWMutex{},
}
return &ec2ProviderImpl{
ec2: ec2.New(newSession(roleARN, qps, burst)),
privateDNSCache: dnsCache,
ec2Requests: ec2Requests,
instanceIdsChannel: make(chan string, maxChannelSize),
}
}

// Initial credentials loaded from SDK's default credential chain, such as
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role.

func newSession(roleARN string, qps int, burst int) *session.Session {
sess := session.Must(session.NewSession())
if aws.StringValue(sess.Config.Region) == "" {
ec2metadata := ec2metadata.New(sess)
regionFound, err := ec2metadata.Region()
if err != nil {
logrus.WithError(err).Fatal("Region not found in shared credentials, environment variable, or instance metadata.")
}
sess.Config.Region = aws.String(regionFound)
}

if roleARN != "" {
logrus.WithFields(logrus.Fields{
"roleARN": roleARN,
}).Infof("Using assumed role for EC2 API")

rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst)

if err != nil {
logrus.Errorf("Getting error = %s while creating rate limited client ", err)
}

ap := &stscreds.AssumeRoleProvider{
Client: sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)),
RoleARN: roleARN,
Duration: time.Duration(60) * time.Minute,
}

sess.Config.Credentials = credentials.NewCredentials(ap)
}
return sess
}

func (p *ec2ProviderImpl) setPrivateDNSNameCache(id string, privateDNSName string) {
p.privateDNSCache.lock.Lock()
defer p.privateDNSCache.lock.Unlock()
p.privateDNSCache.cache[id] = privateDNSName
}

func (p *ec2ProviderImpl) setRequestInFlightForInstanceId(id string) {
p.ec2Requests.lock.Lock()
defer p.ec2Requests.lock.Unlock()
p.ec2Requests.set[id] = true
}

func (p *ec2ProviderImpl) unsetRequestInFlightForInstanceId(id string) {
p.ec2Requests.lock.Lock()
defer p.ec2Requests.lock.Unlock()
delete(p.ec2Requests.set, id)
}

func (p *ec2ProviderImpl) getRequestInFlightForInstanceId(id string) bool {
p.ec2Requests.lock.RLock()
defer p.ec2Requests.lock.RUnlock()
_, ok := p.ec2Requests.set[id]
return ok
}

func (p *ec2ProviderImpl) getRequestInFlightSize() int {
p.ec2Requests.lock.RLock()
defer p.ec2Requests.lock.RUnlock()
length := len(p.ec2Requests.set)
return length
}

// GetPrivateDNS looks up the private DNS from the EC2 API
func (p *ec2ProviderImpl) getPrivateDNSNameCache(id string) (string, error) {
p.privateDNSCache.lock.RLock()
defer p.privateDNSCache.lock.RUnlock()
name, ok := p.privateDNSCache.cache[id]
if ok {
return name, nil
}
return "", errors.New("instance id not found")
}

// Only calls API if its not in the cache
func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) {
privateDNSName, err := p.getPrivateDNSNameCache(id)
if err == nil {
return privateDNSName, nil
}
logrus.Debugf("Missed the cache for the InstanceId = %s Verifying if its already in requestQueue ", id)
// check if the request for instanceId already in queue.
if p.getRequestInFlightForInstanceId(id) {
logrus.Debugf("Found the InstanceId:= %s request In Queue waiting in 5 seconds loop ", id)
for i := 0; i < totalIterationForWaitInterval; i++ {
time.Sleep(defaultWaitInterval)
privateDNSName, err := p.getPrivateDNSNameCache(id)
if err == nil {
return privateDNSName, nil
}
}
return "", fmt.Errorf("failed to find node %s in PrivateDNSNameCache returning from loop", id)
}
logrus.Debugf("Missed the requestQueue cache for the InstanceId = %s", id)
p.setRequestInFlightForInstanceId(id)
requestQueueLength := p.getRequestInFlightSize()
//The code verifies if the requestQuqueMap size is greater than max request in flight with rate
//limiting then writes to the channel where we are making batch ec2:DescribeInstances API call.
if requestQueueLength > maxAllowedInflightRequest {
logrus.Debugf("Writing to buffered channel for instance Id %s ", id)
p.instanceIdsChannel <- id
return p.GetPrivateDNSName(id)
}

logrus.Infof("Calling ec2:DescribeInstances for the InstanceId = %s ", id)
// Look up instance from EC2 API
output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{id}),
})
if err != nil {
p.unsetRequestInFlightForInstanceId(id)
return "", fmt.Errorf("failed querying private DNS from EC2 API for node %s: %s ", id, err.Error())
}
for _, reservation := range output.Reservations {
for _, instance := range reservation.Instances {
if aws.StringValue(instance.InstanceId) == id {
privateDNSName = aws.StringValue(instance.PrivateDnsName)
p.setPrivateDNSNameCache(id, privateDNSName)
p.unsetRequestInFlightForInstanceId(id)
}
}
}

if privateDNSName == "" {
return "", fmt.Errorf("failed to find node %s ", id)
}
return privateDNSName, nil
}

func (p *ec2ProviderImpl) StartEc2DescribeBatchProcessing() {
startTime := time.Now()
var instanceIdList []string
for {
var instanceId string
select {
case instanceId = <-p.instanceIdsChannel:
logrus.Debugf("Received the Instance Id := %s from buffered Channel for batch processing ", instanceId)
instanceIdList = append(instanceIdList, instanceId)
default:
// Waiting for more elements to get added to the buffered Channel
// And to support the for select loop.
time.Sleep(20 * time.Millisecond)
}
endTime := time.Now()
/*
The if statement checks for empty list and ignores to make any ec2:Describe API call
If elements are less than 100 and time of 200 millisecond has elapsed it will make the
ec2:DescribeInstances call with as many elements in the list.
It is also possible that if the system gets more than 99 elements in the list in less than
200 milliseconds time it will the ec2:DescribeInstances call and that's our whole point of
optimization here. Also for FYI we have client level rate limiting which is what this
ec2:DescribeInstances call will make so this call is also rate limited.
*/
if (len(instanceIdList) > 0 && (endTime.Sub(startTime).Milliseconds()) > maxWaitIntervalForBatch) || len(instanceIdList) > maxInstancesBatchSize {
startTime = time.Now()
dupInstanceList := make([]string, len(instanceIdList))
copy(dupInstanceList, instanceIdList)
go p.getPrivateDnsAndPublishToCache(dupInstanceList)
instanceIdList = nil
}
}
}

func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string) {
// Look up instance from EC2 API
logrus.Infof("Making Batch Query to DescribeInstances for %v instances ", len(instanceIdList))
output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice(instanceIdList),
})
if err != nil {
logrus.Errorf("Batch call failed querying private DNS from EC2 API for nodes [%s] : with error = []%s ", instanceIdList, err.Error())
} else {
if output.NextToken != nil {
logrus.Debugf("Successfully got the batch result , output.NextToken = %s ", *output.NextToken)
} else {
logrus.Debugf("Successfully got the batch result , output.NextToken is nil ")
}
// Adding the result to privateDNSChache as well as removing from the requestQueueMap.
for _, reservation := range output.Reservations {
for _, instance := range reservation.Instances {
id := aws.StringValue(instance.InstanceId)
privateDNSName := aws.StringValue(instance.PrivateDnsName)
p.setPrivateDNSNameCache(id, privateDNSName)
}
}
}

logrus.Debugf("Removing instances from request Queue after getting response from Ec2")
for _, id := range instanceIdList {
p.unsetRequestInFlightForInstanceId(id)
}
}
Loading

0 comments on commit 978fa3f

Please sign in to comment.