Skip to content
Merged
159 changes: 80 additions & 79 deletions premium/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
"github.com/cloudquery/cloudquery-api-go/config"
"github.com/cloudquery/plugin-sdk/v4/plugin"
"github.com/google/uuid"
"github.com/hashicorp/go-retryablehttp"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

const (
Expand All @@ -32,6 +32,9 @@ const (
defaultMaxWaitTime = 60 * time.Second
defaultMinTimeBetweenFlushes = 10 * time.Second
defaultMaxTimeBetweenFlushes = 30 * time.Second

marketplaceDuplicateWaitTime = 1 * time.Second
marketplaceMinRetries = 20
)

const (
Expand Down Expand Up @@ -109,7 +112,9 @@ func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientO
// WithMaxRetries sets the maximum number of retries to update the usage in case of an API error
func WithMaxRetries(maxRetries int) UsageClientOptions {
return func(updater *BatchUpdater) {
updater.maxRetries = maxRetries
if maxRetries > 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that retries can never be turned off?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxRetries = 1 would be "single request, no retries" but I didn't want to be pedantic about the loop and the descriptions.

updater.maxRetries = maxRetries
}
}
}

Expand Down Expand Up @@ -198,6 +203,9 @@ type BatchUpdater struct {
isClosed bool
dataOnClose bool
usageIncreaseMethod int

// Testing
timeFunc func() time.Time
}

func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, error) {
Expand All @@ -216,6 +224,7 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e
triggerUpdate: make(chan struct{}),
done: make(chan struct{}),
closeError: make(chan error),
timeFunc: time.Now,

tables: map[string]uint32{},
}
Expand Down Expand Up @@ -246,18 +255,26 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e

// Create a default api client if none was provided
if u.apiClient == nil {
ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
token, err := u.tokenClient.GetToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil
}))
retryClient := retryablehttp.NewClient()
retryClient.Logger = nil
retryClient.RetryMax = u.maxRetries
retryClient.RetryWaitMax = u.maxWaitTime

var err error
u.apiClient, err = cqapi.NewClientWithResponses(u.url,
cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
token, err := u.tokenClient.GetToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil
}),
cqapi.WithHTTPClient(retryClient.StandardClient()),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what the default cqapi.NewClientWithResponses does anyway, we just set the retry params beforehand.

)
if err != nil {
return nil, fmt.Errorf("failed to create api client: %w", err)
}
u.apiClient = ac
}

// Set team name from configuration if not provided
Expand Down Expand Up @@ -289,11 +306,12 @@ func (u *BatchUpdater) setupAWSMarketplace() error {
u.batchLimit = 1000000000

u.minTimeBetweenFlushes = 1 * time.Minute
u.maxRetries = max(u.maxRetries, marketplaceMinRetries)
u.backgroundUpdater()

_, err = u.awsMarketplaceClient.MeterUsage(ctx, &marketplacemetering.MeterUsageInput{
ProductCode: aws.String(awsMarketplaceProductCode()),
Timestamp: aws.Time(time.Now()),
Timestamp: aws.Time(u.timeFunc()),
UsageDimension: aws.String("rows"),
UsageQuantity: aws.Int32(int32(0)),
DryRun: aws.Bool(true),
Expand Down Expand Up @@ -486,13 +504,13 @@ func (u *BatchUpdater) backgroundUpdater() {
}
// If we are using AWS Marketplace, we need to round down to the nearest 1000
// Only on the last update, will we round up to the nearest 1000
// This will allow us to not over charge the customer by rounding on each batch
// This will allow us to not overcharge the customer by rounding on each batch
if u.awsMarketplaceClient != nil {
totals = roundDown(totals, 1000)
}

if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil {
log.Warn().Err(err).Msg("failed to update usage")
u.logger.Warn().Err(err).Msg("failed to update usage")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaning up some "untethered" logger usage.

continue
}
u.subtractTableUsage(tables, totals)
Expand All @@ -510,12 +528,12 @@ func (u *BatchUpdater) backgroundUpdater() {
}
// If we are using AWS Marketplace, we need to round down to the nearest 1000
// Only on the last update, will we round up to the nearest 1000
// This will allow us to not over charge the customer by rounding on each batch
// This will allow us to not overcharge the customer by rounding on each batch
if u.awsMarketplaceClient != nil {
totals = roundDown(totals, 1000)
}
if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil {
log.Warn().Err(err).Msg("failed to update usage")
u.logger.Warn().Err(err).Msg("failed to update usage")
continue
}
u.subtractTableUsage(tables, totals)
Expand Down Expand Up @@ -573,60 +591,70 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin
// Each product is given a unique product code when it is listed in AWS Marketplace
// in the future we can have multiple product codes for container or AMI based listings
ProductCode: aws.String(awsMarketplaceProductCode()),
Timestamp: aws.Time(time.Now()),
Timestamp: aws.Time(u.timeFunc()),
UsageDimension: aws.String("rows"),
UsageAllocations: usage,
UsageQuantity: aws.Int32(int32(rows)),
})
if err != nil {
return fmt.Errorf("failed to update usage with : %w", err)
return fmt.Errorf("failed to update usage: %w", err)
}
return nil
}

func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error {
func (u *BatchUpdater) updateMarketplaceUsage(ctx context.Context, rows uint32) error {
var lastErr error
for retry := 0; retry < u.maxRetries; retry++ {
u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage")
queryStartTime := time.Now()
u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage")

// If the AWS Marketplace client is set, use it to track usage
if u.awsMarketplaceClient != nil {
return u.reportUsageToAWSMarketplace(ctx, rows)
}
payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{
RequestId: uuid.New(),
PluginTeam: u.pluginMeta.Team,
PluginKind: u.pluginMeta.Kind,
PluginName: u.pluginMeta.Name,
Rows: int(rows),
lastErr = u.reportUsageToAWSMarketplace(ctx, rows)
if lastErr == nil {
u.logger.Debug().Int("try", retry).Uint32("rows", rows).Msg("usage updated")
return nil
}

if len(tables) > 0 {
payload.Tables = &tables
var de *types.DuplicateRequestException
if !errors.As(lastErr, &de) {
return fmt.Errorf("failed to update usage: %w", lastErr)
}
u.logger.Debug().Err(lastErr).Int("try", retry).Uint32("rows", rows).Msg("usage update failed due to duplicate request")

resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload)
if err != nil {
return fmt.Errorf("failed to update usage: %w", err)
}
if resp.StatusCode() >= 200 && resp.StatusCode() < 300 {
u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated")
u.lastUpdateTime = time.Now().UTC()
if resp.HTTPResponse != nil {
u.updateConfigurationFromHeaders(resp.HTTPResponse.Header)
}
return nil
}
jitter := time.Duration(rand.Intn(1000)) * time.Millisecond
time.Sleep(marketplaceDuplicateWaitTime + jitter)
}
return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr)
}

retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry)
if err != nil {
return fmt.Errorf("failed to calculate retry duration: %w", err)
}
if retryDuration > 0 {
time.Sleep(retryDuration)
func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error {
// If the AWS Marketplace client is set, use it to track usage
if u.awsMarketplaceClient != nil {
return u.updateMarketplaceUsage(ctx, rows)
}

u.logger.Debug().Str("url", u.url).Uint32("rows", rows).Msg("updating usage")
payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{
RequestId: uuid.New(),
PluginTeam: u.pluginMeta.Team,
PluginKind: u.pluginMeta.Kind,
PluginName: u.pluginMeta.Name,
Rows: int(rows),
}

if len(tables) > 0 {
payload.Tables = &tables
}

resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload)
if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 {
u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated")
u.lastUpdateTime = u.timeFunc().UTC()
if resp.HTTPResponse != nil {
u.updateConfigurationFromHeaders(resp.HTTPResponse.Header)
}
return nil
}
return fmt.Errorf("failed to update usage: max retries exceeded")

return fmt.Errorf("failed to update usage: %w", err)
}

// updateConfigurationFromHeaders updates the configuration based on the headers returned by the API
Expand Down Expand Up @@ -663,33 +691,6 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) {
}
}

// calculateRetryDuration calculates the duration to sleep relative to the query start time before retrying an update
func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) {
if !retryableStatusCode(statusCode) {
return 0, fmt.Errorf("non-retryable status code: %d", statusCode)
}

// Check if we have a retry-after header
retryAfter := headers.Get("Retry-After")
if retryAfter != "" {
retryDelay, err := time.ParseDuration(retryAfter + "s")
if err != nil {
return 0, fmt.Errorf("failed to parse retry-after header: %w", err)
}
return retryDelay, nil
}

// Calculate exponential backoff
baseRetry := min(time.Duration(1<<retry)*time.Second, u.maxWaitTime)
jitter := time.Duration(rand.Intn(1000)) * time.Millisecond
retryDelay := baseRetry + jitter
return retryDelay - time.Since(queryStartTime), nil
}

func retryableStatusCode(statusCode int) bool {
return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable
}

func (u *BatchUpdater) getTeamNameByTokenType(tokenType auth.TokenType) (string, error) {
switch tokenType {
case auth.BearerToken:
Expand Down
Loading
Loading