Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
Fix context, timeouts issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yevgenypats committed Apr 1, 2021
1 parent 475638c commit 7e3c756
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type Config struct {
MaxRetries int `yaml:"max_retries" default:"5"`
MaxBackoff int `yaml:"max_backoff" default:"30"`
// context timeout in seconds
Timeout int `yaml:"timeout" default:"30"`
Timeout int `yaml:"timeout" default:"-1"`
Resources []struct {
Name string
Other map[string]interface{} `yaml:",inline"`
Expand Down Expand Up @@ -212,7 +212,7 @@ func (p *Provider) validateFetchConfig() error {
return nil
}

func (p *Provider) fetchAccount(accountID string, awsCfg aws.Config, svc *sts.Client) error {
func (p *Provider) fetchAccount(ctx context.Context, accountID string, awsCfg aws.Config, svc *sts.Client) error {
var ae smithy.APIError
resourceClients := map[string]resource.ClientInterface{}

Expand All @@ -224,12 +224,10 @@ func (p *Provider) fetchAccount(accountID string, awsCfg aws.Config, svc *sts.Cl
}
globalServicesFetched := map[string]bool{}
for _, region := range p.regions {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(p.config.Timeout)*time.Second)
//Find a better way in AWS SDK V2 to decide if a region is disabled.
_, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) {
o.Region = region
})
cancel()
if err != nil {
if errors.As(err, &ae) && (ae.ErrorCode() == "InvalidClientTokenId" || ae.ErrorCode() == "OptInRequired") {
p.Logger.Info("region disabled. skipping...", "region", region, "account_id", accountID)
Expand All @@ -245,7 +243,6 @@ func (p *Provider) fetchAccount(accountID string, awsCfg aws.Config, svc *sts.Cl
p.db, innerLog, accountID, region)
}

ctx, cancel = context.WithTimeout(context.Background(), time.Duration(p.config.Timeout)*time.Second + time.Second*10)
g, _ := errgroup.WithContext(ctx)
for _, r := range p.config.Resources {
resourcePath := strings.Split(r.Name, ".")
Expand All @@ -259,8 +256,6 @@ func (p *Provider) fetchAccount(accountID string, awsCfg aws.Config, svc *sts.Cl
globalServicesFetched[r.Name] = true
}
g.Go(func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(p.config.Timeout)*time.Second)
defer cancel()
err := resourceClients[serviceName].CollectResource(ctx, resourceName, resourceConfig)
if err != nil {
var ae smithy.APIError
Expand All @@ -281,16 +276,15 @@ func (p *Provider) fetchAccount(accountID string, awsCfg aws.Config, svc *sts.Cl
})
}
if err := g.Wait(); err != nil {
cancel()
return err
}
cancel()
}
return nil
}

func (p *Provider) Fetch(data []byte) error {

var cancel context.CancelFunc
ctx := context.Background()
defaults.MustSet(&p.config)
if err := yaml.Unmarshal(data, &p.config); err != nil {
return err
Expand All @@ -317,11 +311,13 @@ func (p *Provider) Fetch(data []byte) error {
})
}
p.Logger.Info("Configuring SDK retryer", "retry_attempts", p.config.MaxRetries, "max_backoff", p.config.MaxBackoff)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(p.config.Timeout*len(p.config.Regions))*time.Second + time.Second*10)
defer cancel()
g, _ := errgroup.WithContext(ctx)
if p.config.Timeout != -1 {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(p.config.Timeout*len(p.config.Regions))*time.Second + time.Second*10)
defer cancel()
}

g, ctx := errgroup.WithContext(ctx)
for _, account := range p.config.Accounts {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(p.config.Timeout)*time.Second)
var err error
var awsCfg aws.Config
// This is a try to solve https://aws.amazon.com/premiumsupport/knowledge-center/iam-validate-access-credentials/
Expand All @@ -331,18 +327,15 @@ func (p *Provider) Fetch(data []byte) error {
if account.ID != "default" && account.RoleARN != "" {
// assume role if specified (SDK takes it from default or env var: AWS_PROFILE)
awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion))
cancel()
if err != nil {
_ = g.Wait()
return err
}
awsCfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(awsCfg), account.RoleARN)
} else if account.ID != "default" {
awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion), config.WithSharedConfigProfile(account.ID))
cancel()
} else {
awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion))
cancel()
}
if err != nil {
_ = g.Wait()
Expand All @@ -353,18 +346,16 @@ func (p *Provider) Fetch(data []byte) error {
}
awsCfg.Retryer = p.NewRetryer()
svc := sts.NewFromConfig(awsCfg)
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(p.config.Timeout)*time.Second)
output, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) {
o.Region = "us-east-1"
})
cancel()
if err != nil {
_ = g.Wait()
return err
}
accountID := *output.Account
g.Go(func() error {
return p.fetchAccount(accountID, awsCfg, svc)
return p.fetchAccount(ctx, accountID, awsCfg, svc)
})
}
if err := g.Wait(); err != nil {
Expand Down

0 comments on commit 7e3c756

Please sign in to comment.