diff --git a/premium/usage.go b/premium/usage.go index 0be20665b6..facaa83624 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -21,8 +21,8 @@ import ( "github.com/cloudquery/cloudquery-api-go/config" "github.com/cloudquery/plugin-sdk/v4/plugin" "github.com/google/uuid" + "github.com/hashicorp/go-retryablehttp" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) const ( @@ -32,6 +32,9 @@ const ( defaultMaxWaitTime = 60 * time.Second defaultMinTimeBetweenFlushes = 10 * time.Second defaultMaxTimeBetweenFlushes = 30 * time.Second + + marketplaceDuplicateWaitTime = 1 * time.Second + marketplaceMinRetries = 20 ) const ( @@ -109,7 +112,9 @@ func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientO // WithMaxRetries sets the maximum number of retries to update the usage in case of an API error func WithMaxRetries(maxRetries int) UsageClientOptions { return func(updater *BatchUpdater) { - updater.maxRetries = maxRetries + if maxRetries > 0 { + updater.maxRetries = maxRetries + } } } @@ -198,6 +203,9 @@ type BatchUpdater struct { isClosed bool dataOnClose bool usageIncreaseMethod int + + // Testing + timeFunc func() time.Time } func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, error) { @@ -216,6 +224,7 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e triggerUpdate: make(chan struct{}), done: make(chan struct{}), closeError: make(chan error), + timeFunc: time.Now, tables: map[string]uint32{}, } @@ -246,18 +255,26 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e // Create a default api client if none was provided if u.apiClient == nil { - ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { - token, err := u.tokenClient.GetToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - return nil - })) + retryClient := retryablehttp.NewClient() + retryClient.Logger = nil + retryClient.RetryMax = u.maxRetries + retryClient.RetryWaitMax = u.maxWaitTime + + var err error + u.apiClient, err = cqapi.NewClientWithResponses(u.url, + cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { + token, err := u.tokenClient.GetToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + }), + cqapi.WithHTTPClient(retryClient.StandardClient()), + ) if err != nil { return nil, fmt.Errorf("failed to create api client: %w", err) } - u.apiClient = ac } // Set team name from configuration if not provided @@ -289,11 +306,12 @@ func (u *BatchUpdater) setupAWSMarketplace() error { u.batchLimit = 1000000000 u.minTimeBetweenFlushes = 1 * time.Minute + u.maxRetries = max(u.maxRetries, marketplaceMinRetries) u.backgroundUpdater() _, err = u.awsMarketplaceClient.MeterUsage(ctx, &marketplacemetering.MeterUsageInput{ ProductCode: aws.String(awsMarketplaceProductCode()), - Timestamp: aws.Time(time.Now()), + Timestamp: aws.Time(u.timeFunc()), UsageDimension: aws.String("rows"), UsageQuantity: aws.Int32(int32(0)), DryRun: aws.Bool(true), @@ -486,13 +504,13 @@ func (u *BatchUpdater) backgroundUpdater() { } // If we are using AWS Marketplace, we need to round down to the nearest 1000 // Only on the last update, will we round up to the nearest 1000 - // This will allow us to not over charge the customer by rounding on each batch + // This will allow us to not overcharge the customer by rounding on each batch if u.awsMarketplaceClient != nil { totals = roundDown(totals, 1000) } if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil { - log.Warn().Err(err).Msg("failed to update usage") + u.logger.Warn().Err(err).Msg("failed to update usage") continue } u.subtractTableUsage(tables, totals) @@ -510,12 +528,12 @@ func (u *BatchUpdater) backgroundUpdater() { } // If we are using AWS Marketplace, we need to round down to the nearest 1000 // Only on the last update, will we round up to the nearest 1000 - // This will allow us to not over charge the customer by rounding on each batch + // This will allow us to not overcharge the customer by rounding on each batch if u.awsMarketplaceClient != nil { totals = roundDown(totals, 1000) } if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil { - log.Warn().Err(err).Msg("failed to update usage") + u.logger.Warn().Err(err).Msg("failed to update usage") continue } u.subtractTableUsage(tables, totals) @@ -573,60 +591,70 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin // Each product is given a unique product code when it is listed in AWS Marketplace // in the future we can have multiple product codes for container or AMI based listings ProductCode: aws.String(awsMarketplaceProductCode()), - Timestamp: aws.Time(time.Now()), + Timestamp: aws.Time(u.timeFunc()), UsageDimension: aws.String("rows"), UsageAllocations: usage, UsageQuantity: aws.Int32(int32(rows)), }) if err != nil { - return fmt.Errorf("failed to update usage with : %w", err) + return fmt.Errorf("failed to update usage: %w", err) } return nil } -func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error { +func (u *BatchUpdater) updateMarketplaceUsage(ctx context.Context, rows uint32) error { + var lastErr error for retry := 0; retry < u.maxRetries; retry++ { - u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - queryStartTime := time.Now() + u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - // If the AWS Marketplace client is set, use it to track usage - if u.awsMarketplaceClient != nil { - return u.reportUsageToAWSMarketplace(ctx, rows) - } - payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{ - RequestId: uuid.New(), - PluginTeam: u.pluginMeta.Team, - PluginKind: u.pluginMeta.Kind, - PluginName: u.pluginMeta.Name, - Rows: int(rows), + lastErr = u.reportUsageToAWSMarketplace(ctx, rows) + if lastErr == nil { + u.logger.Debug().Int("try", retry).Uint32("rows", rows).Msg("usage updated") + return nil } - if len(tables) > 0 { - payload.Tables = &tables + var de *types.DuplicateRequestException + if !errors.As(lastErr, &de) { + return fmt.Errorf("failed to update usage: %w", lastErr) } + u.logger.Debug().Err(lastErr).Int("try", retry).Uint32("rows", rows).Msg("usage update failed due to duplicate request") - resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) - if err != nil { - return fmt.Errorf("failed to update usage: %w", err) - } - if resp.StatusCode() >= 200 && resp.StatusCode() < 300 { - u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") - u.lastUpdateTime = time.Now().UTC() - if resp.HTTPResponse != nil { - u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) - } - return nil - } + jitter := time.Duration(rand.Intn(1000)) * time.Millisecond + time.Sleep(marketplaceDuplicateWaitTime + jitter) + } + return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr) +} - retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry) - if err != nil { - return fmt.Errorf("failed to calculate retry duration: %w", err) - } - if retryDuration > 0 { - time.Sleep(retryDuration) +func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error { + // If the AWS Marketplace client is set, use it to track usage + if u.awsMarketplaceClient != nil { + return u.updateMarketplaceUsage(ctx, rows) + } + + u.logger.Debug().Str("url", u.url).Uint32("rows", rows).Msg("updating usage") + payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{ + RequestId: uuid.New(), + PluginTeam: u.pluginMeta.Team, + PluginKind: u.pluginMeta.Kind, + PluginName: u.pluginMeta.Name, + Rows: int(rows), + } + + if len(tables) > 0 { + payload.Tables = &tables + } + + resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) + if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") + u.lastUpdateTime = u.timeFunc().UTC() + if resp.HTTPResponse != nil { + u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) } + return nil } - return fmt.Errorf("failed to update usage: max retries exceeded") + + return fmt.Errorf("failed to update usage: %w", err) } // updateConfigurationFromHeaders updates the configuration based on the headers returned by the API @@ -663,33 +691,6 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) { } } -// calculateRetryDuration calculates the duration to sleep relative to the query start time before retrying an update -func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) { - if !retryableStatusCode(statusCode) { - return 0, fmt.Errorf("non-retryable status code: %d", statusCode) - } - - // Check if we have a retry-after header - retryAfter := headers.Get("Retry-After") - if retryAfter != "" { - retryDelay, err := time.ParseDuration(retryAfter + "s") - if err != nil { - return 0, fmt.Errorf("failed to parse retry-after header: %w", err) - } - return retryDelay, nil - } - - // Calculate exponential backoff - baseRetry := min(time.Duration(1< 0 { + code := responseCodes[0] + responseCodes = responseCodes[1:] + w.WriteHeader(code) + for k, v := range stage.headers { + w.Header().Set(k, v) + } + return + } + dec := json.NewDecoder(r.Body) var req cqapi.IncreaseTeamPluginUsageJSONRequestBody err := dec.Decode(&req) @@ -746,7 +822,7 @@ func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testSta } func createTestServer(t *testing.T) *testStage { - return createTestServerWithRemainingRows(t, 0) + return createTestServerWithRemainingRows(t, 0, nil) } type testStage struct {