From 3d5f45172f6c463cf5c090f07560b1841ddfb003 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Tue, 29 Oct 2024 18:34:22 +0000 Subject: [PATCH 01/11] fix: Clean up usage retry logic --- premium/usage.go | 73 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index a9c5d5a844..83fe5a7d34 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -12,6 +12,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/marketplacemetering" "github.com/aws/aws-sdk-go-v2/service/marketplacemetering/types" @@ -31,6 +32,9 @@ const ( defaultMaxWaitTime = 60 * time.Second defaultMinTimeBetweenFlushes = 10 * time.Second defaultMaxTimeBetweenFlushes = 30 * time.Second + + marketplaceDuplicateWaitTime = 1 * time.Second + marketplaceMinRetries = 20 ) const ( @@ -99,7 +103,9 @@ func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientO // WithMaxRetries sets the maximum number of retries to update the usage in case of an API error func WithMaxRetries(maxRetries int) UsageClientOptions { return func(updater *BatchUpdater) { - updater.maxRetries = maxRetries + if maxRetries > 0 { + updater.maxRetries = maxRetries + } } } @@ -278,6 +284,9 @@ func (u *BatchUpdater) setupAWSMarketplace() error { u.batchLimit = 1000000000 u.minTimeBetweenFlushes = 1 * time.Minute + if u.maxRetries < marketplaceMinRetries { + u.maxRetries = marketplaceMinRetries + } u.backgroundUpdater() return nil } @@ -543,20 +552,60 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin UsageQuantity: aws.Int32(int32(rows)), }) if err != nil { - return fmt.Errorf("failed to update usage with : %w", err) + return fmt.Errorf("failed to update usage with: %w", err) } return nil } +func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Context, rows uint32) error { + var lastErr error + for retry := 0; retry < u.maxRetries; retry++ { + u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") + queryStartTime := time.Now() + + lastErr = u.reportUsageToAWSMarketplace(ctx, rows) + if lastErr == nil { + u.logger.Debug().Int("try", retry).Uint32("rows", rows).Msg("usage updated") + return nil + } + + var de *types.DuplicateRequestException + if errors.As(lastErr, &de) { + jitter := time.Duration(rand.Intn(1000)) * time.Millisecond + time.Sleep(marketplaceDuplicateWaitTime + jitter) + continue + } + + var ( + statusCode = -1 + rerr *awshttp.ResponseError + ) + if errors.As(lastErr, &rerr) { + statusCode = rerr.HTTPStatusCode() + } + + retryDuration, err := u.calculateRetryDuration(statusCode, http.Header{}, queryStartTime, retry) + if err != nil { + return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) + } + if retryDuration > 0 { + time.Sleep(retryDuration) + } + } + return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr) +} + func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error { + // If the AWS Marketplace client is set, use it to track usage + if u.awsMarketplaceClient != nil { + return u.updateMarketplaceUsageWithRetryAndBackoff(ctx, rows) + } + + var lastErr error for retry := 0; retry < u.maxRetries; retry++ { u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") queryStartTime := time.Now() - // If the AWS Marketplace client is set, use it to track usage - if u.awsMarketplaceClient != nil { - return u.reportUsageToAWSMarketplace(ctx, rows) - } payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{ RequestId: uuid.New(), PluginTeam: u.pluginMeta.Team, @@ -570,10 +619,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows } resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) - if err != nil { - return fmt.Errorf("failed to update usage: %w", err) - } - if resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") u.lastUpdateTime = time.Now().UTC() if resp.HTTPResponse != nil { @@ -582,15 +628,16 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows return nil } + lastErr = fmt.Errorf("failed to update usage: %w", err) retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry) if err != nil { - return fmt.Errorf("failed to calculate retry duration: %w", err) + return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) } if retryDuration > 0 { time.Sleep(retryDuration) } } - return fmt.Errorf("failed to update usage: max retries exceeded") + return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr) } // updateConfigurationFromHeaders updates the configuration based on the headers returned by the API @@ -651,7 +698,7 @@ func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Heade } func retryableStatusCode(statusCode int) bool { - return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable + return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable || statusCode == -1 } func (u *BatchUpdater) getTeamNameByTokenType(tokenType auth.TokenType) (string, error) { From 2de8ca8804f6fe6d642dd0548aaecae6687fb2d0 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Tue, 29 Oct 2024 19:25:53 +0000 Subject: [PATCH 02/11] More clean up: Utilize built-in go-retryablehttp in apiClient for api requests --- premium/usage.go | 97 ++++++++++++++++++++----------------------- premium/usage_test.go | 83 ++++++++++++++++++++++++++++++++---- 2 files changed, 120 insertions(+), 60 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index 83fe5a7d34..b08f660b55 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -21,8 +21,8 @@ import ( "github.com/cloudquery/cloudquery-api-go/config" "github.com/cloudquery/plugin-sdk/v4/plugin" "github.com/google/uuid" + "github.com/hashicorp/go-retryablehttp" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) const ( @@ -242,18 +242,26 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e // Create a default api client if none was provided if u.apiClient == nil { - ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { - token, err := u.tokenClient.GetToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - return nil - })) + retryClient := retryablehttp.NewClient() + retryClient.Logger = nil + retryClient.RetryMax = u.maxRetries + retryClient.RetryWaitMax = u.maxWaitTime + + var err error + u.apiClient, err = cqapi.NewClientWithResponses(u.url, + cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { + token, err := u.tokenClient.GetToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + }), + cqapi.WithHTTPClient(retryClient.StandardClient()), + ) if err != nil { return nil, fmt.Errorf("failed to create api client: %w", err) } - u.apiClient = ac } // Set team name from configuration if not provided @@ -465,7 +473,7 @@ func (u *BatchUpdater) backgroundUpdater() { } if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil { - log.Warn().Err(err).Msg("failed to update usage") + u.logger.Warn().Err(err).Msg("failed to update usage") continue } u.subtractTableUsage(tables, totals) @@ -488,7 +496,7 @@ func (u *BatchUpdater) backgroundUpdater() { totals = roundDown(totals, 1000) } if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil { - log.Warn().Err(err).Msg("failed to update usage") + u.logger.Warn().Err(err).Msg("failed to update usage") continue } u.subtractTableUsage(tables, totals) @@ -552,7 +560,7 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin UsageQuantity: aws.Int32(int32(rows)), }) if err != nil { - return fmt.Errorf("failed to update usage with: %w", err) + return fmt.Errorf("failed to update usage: %w", err) } return nil } @@ -584,7 +592,7 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con statusCode = rerr.HTTPStatusCode() } - retryDuration, err := u.calculateRetryDuration(statusCode, http.Header{}, queryStartTime, retry) + retryDuration, err := u.calculateMarketplaceRetryDuration(statusCode, http.Header{}, queryStartTime, retry) if err != nil { return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) } @@ -601,43 +609,30 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows return u.updateMarketplaceUsageWithRetryAndBackoff(ctx, rows) } - var lastErr error - for retry := 0; retry < u.maxRetries; retry++ { - u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - queryStartTime := time.Now() - - payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{ - RequestId: uuid.New(), - PluginTeam: u.pluginMeta.Team, - PluginKind: u.pluginMeta.Kind, - PluginName: u.pluginMeta.Name, - Rows: int(rows), - } - - if len(tables) > 0 { - payload.Tables = &tables - } + u.logger.Debug().Str("url", u.url).Uint32("rows", rows).Msg("updating usage") + payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{ + RequestId: uuid.New(), + PluginTeam: u.pluginMeta.Team, + PluginKind: u.pluginMeta.Kind, + PluginName: u.pluginMeta.Name, + Rows: int(rows), + } - resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) - if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { - u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") - u.lastUpdateTime = time.Now().UTC() - if resp.HTTPResponse != nil { - u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) - } - return nil - } + if len(tables) > 0 { + payload.Tables = &tables + } - lastErr = fmt.Errorf("failed to update usage: %w", err) - retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry) - if err != nil { - return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) - } - if retryDuration > 0 { - time.Sleep(retryDuration) + resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) + if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") + u.lastUpdateTime = time.Now().UTC() + if resp.HTTPResponse != nil { + u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) } + return nil } - return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr) + + return fmt.Errorf("failed to update usage: %w", err) } // updateConfigurationFromHeaders updates the configuration based on the headers returned by the API @@ -674,9 +669,9 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) { } } -// calculateRetryDuration calculates the duration to sleep relative to the query start time before retrying an update -func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) { - if !retryableStatusCode(statusCode) { +// calculateMarketplaceRetryDuration calculates the duration to sleep relative to the query start time before retrying an update +func (u *BatchUpdater) calculateMarketplaceRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) { + if statusCode > -1 && !retryableStatusCode(statusCode) { return 0, fmt.Errorf("non-retryable status code: %d", statusCode) } @@ -698,7 +693,7 @@ func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Heade } func retryableStatusCode(statusCode int) bool { - return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable || statusCode == -1 + return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable } func (u *BatchUpdater) getTeamNameByTokenType(tokenType auth.TokenType) (string, error) { diff --git a/premium/usage_test.go b/premium/usage_test.go index 9a3a17b3a6..a75739f908 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -105,7 +105,7 @@ func TestUsageService_NewUsageClient_Override(t *testing.T) { func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { ctx := context.Background() - s := createTestServerWithRemainingRows(t, 0) + s := createTestServerWithRemainingRows(t, 0, nil) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) @@ -122,7 +122,7 @@ func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { ctx := context.Background() - s := createTestServerWithRemainingRows(t, 100) + s := createTestServerWithRemainingRows(t, 100, nil) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) @@ -362,13 +362,13 @@ func TestUsageService_IncreaseForTable_CorrectByTable(t *testing.T) { func TestUsageService_AWSMarketplaceDone(t *testing.T) { var err error ctrl := gomock.NewController(t) - m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) + t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") out := marketplacemetering.MeterUsageOutput{} in := meteringInput{ MeterUsageInput: marketplacemetering.MeterUsageInput{ - ProductCode: aws.String("2a8bdkarwqrp0tmo4errl65s7"), + ProductCode: aws.String(awsMarketplaceProductCode()), UsageDimension: aws.String("rows"), UsageQuantity: aws.Int32(20), UsageAllocations: []types.UsageAllocation{{ @@ -392,7 +392,6 @@ func TestUsageService_AWSMarketplaceDone(t *testing.T) { } assert.NoError(t, faker.FakeObject(&out)) m.EXPECT().MeterUsage(gomock.Any(), in).Return(&out, nil) - t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") usageClient := newClient(t, nil, WithBatchLimit(50), WithAWSMarketplaceClient(m)) // This will generate 19,998 rows @@ -541,6 +540,62 @@ func TestUsageService_ShouldNotUpdateClosedService(t *testing.T) { assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } +func TestUsageService_RetryOnRetryableError(t *testing.T) { + s := createTestServerWithRemainingRows(t, 0, []int{http.StatusServiceUnavailable, http.StatusTooManyRequests}) + defer s.server.Close() + + usageClient, err := NewUsageClient( + plugin.Meta{ + Team: "plugin-team", + Kind: cqapi.PluginKindSource, + Name: "vault", + }, + WithURL(s.server.URL), + WithMaxRetries(2), + WithMaxWaitTime(time.Millisecond), + WithBatchLimit(0), + WithLogger(zerolog.Nop()), + // WithLogger(zerolog.New(zerolog.NewTestWriter(t)).Level(zerolog.DebugLevel)), + withTeamName("team-name"), + withTokenClient(newMockTokenClient(auth.BearerToken)), + ) + require.NoError(t, err) + + err = usageClient.Increase(100) + require.NoError(t, err) + + err = usageClient.Close() + require.NoError(t, err) + + assert.Equal(t, 1, s.numberOfUpdates(), "total number of updates should be one") +} + +func TestUsageService_RetryOnRetryableErrorExhaustRetries(t *testing.T) { + s := createTestServerWithRemainingRows(t, 0, []int{http.StatusServiceUnavailable, http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusServiceUnavailable}) + defer s.server.Close() + + usageClient, err := NewUsageClient( + plugin.Meta{ + Team: "plugin-team", + Kind: cqapi.PluginKindSource, + Name: "vault", + }, + WithURL(s.server.URL), + WithMaxRetries(1), + WithMaxWaitTime(time.Millisecond), + WithBatchLimit(0), + withTeamName("team-name"), + withTokenClient(newMockTokenClient(auth.BearerToken)), + ) + require.NoError(t, err) + + err = usageClient.Increase(100) + require.NoError(t, err) + + err = usageClient.Close() + require.Error(t, err) +} + func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { tests := []struct { name string @@ -596,7 +651,7 @@ func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { tt.ops(usageClient) } t.Run(tt.name, func(t *testing.T) { - retryDuration, err := usageClient.calculateRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) + retryDuration, err := usageClient.calculateMarketplaceRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) require.NoError(t, err) assert.InDeltaf(t, tt.expectedSeconds, retryDuration.Seconds(), 1, "retry duration should be %d seconds", tt.expectedSeconds) @@ -643,7 +698,7 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { tt.ops(usageClient) } t.Run(tt.name, func(t *testing.T) { - retryDuration, err := usageClient.calculateRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) + retryDuration, err := usageClient.calculateMarketplaceRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) if tt.wantErr == nil { require.NoError(t, err) } else { @@ -668,7 +723,7 @@ func newClient(t *testing.T, apiClient *cqapi.ClientWithResponses, ops ...UsageC return client.(*BatchUpdater) } -func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testStage { +func createTestServerWithRemainingRows(t *testing.T, remainingRows int, responseCodes []int) *testStage { stage := testStage{ remainingRows: remainingRows, headers: make(map[string]string), @@ -689,6 +744,16 @@ func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testSta return } if r.Method == "POST" { + if len(responseCodes) > 0 { + code := responseCodes[0] + responseCodes = responseCodes[1:] + w.WriteHeader(code) + for k, v := range stage.headers { + w.Header().Set(k, v) + } + return + } + dec := json.NewDecoder(r.Body) var req cqapi.IncreaseTeamPluginUsageJSONRequestBody err := dec.Decode(&req) @@ -731,7 +796,7 @@ func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testSta } func createTestServer(t *testing.T) *testStage { - return createTestServerWithRemainingRows(t, 0) + return createTestServerWithRemainingRows(t, 0, nil) } type testStage struct { From d20a7f3f4f18c11c74ff9b9c8e286425d47cbe27 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Tue, 29 Oct 2024 19:54:25 +0000 Subject: [PATCH 03/11] Remove redundant `Retry-After` handling (as AWS probably doesn't respond with those) --- premium/usage.go | 14 ++------------ premium/usage_test.go | 23 ++++------------------- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index b08f660b55..588beb406b 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -592,7 +592,7 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con statusCode = rerr.HTTPStatusCode() } - retryDuration, err := u.calculateMarketplaceRetryDuration(statusCode, http.Header{}, queryStartTime, retry) + retryDuration, err := u.calculateMarketplaceRetryDuration(statusCode, queryStartTime, retry) if err != nil { return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) } @@ -670,21 +670,11 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) { } // calculateMarketplaceRetryDuration calculates the duration to sleep relative to the query start time before retrying an update -func (u *BatchUpdater) calculateMarketplaceRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) { +func (u *BatchUpdater) calculateMarketplaceRetryDuration(statusCode int, queryStartTime time.Time, retry int) (time.Duration, error) { if statusCode > -1 && !retryableStatusCode(statusCode) { return 0, fmt.Errorf("non-retryable status code: %d", statusCode) } - // Check if we have a retry-after header - retryAfter := headers.Get("Retry-After") - if retryAfter != "" { - retryDelay, err := time.ParseDuration(retryAfter + "s") - if err != nil { - return 0, fmt.Errorf("failed to parse retry-after header: %w", err) - } - return retryDelay, nil - } - // Calculate exponential backoff baseRetry := min(time.Duration(1< Date: Wed, 30 Oct 2024 10:44:00 +0000 Subject: [PATCH 04/11] CR: use `max`, doh --- premium/usage.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index 588beb406b..b83d041cce 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -292,9 +292,7 @@ func (u *BatchUpdater) setupAWSMarketplace() error { u.batchLimit = 1000000000 u.minTimeBetweenFlushes = 1 * time.Minute - if u.maxRetries < marketplaceMinRetries { - u.maxRetries = marketplaceMinRetries - } + u.maxRetries = max(u.maxRetries, marketplaceMinRetries) u.backgroundUpdater() return nil } From fd71b5e889f736c120f1c8db7b1c324ab1ea9180 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Wed, 30 Oct 2024 14:13:02 +0000 Subject: [PATCH 05/11] "over charge" -> "overcharge" --- premium/usage.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index b83d041cce..d08a68ac09 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -465,7 +465,7 @@ func (u *BatchUpdater) backgroundUpdater() { } // If we are using AWS Marketplace, we need to round down to the nearest 1000 // Only on the last update, will we round up to the nearest 1000 - // This will allow us to not over charge the customer by rounding on each batch + // This will allow us to not overcharge the customer by rounding on each batch if u.awsMarketplaceClient != nil { totals = roundDown(totals, 1000) } @@ -489,7 +489,7 @@ func (u *BatchUpdater) backgroundUpdater() { } // If we are using AWS Marketplace, we need to round down to the nearest 1000 // Only on the last update, will we round up to the nearest 1000 - // This will allow us to not over charge the customer by rounding on each batch + // This will allow us to not overcharge the customer by rounding on each batch if u.awsMarketplaceClient != nil { totals = roundDown(totals, 1000) } From c3b9b5613e2aee5bb70f34dcc7776e3affc7e4ee Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Wed, 30 Oct 2024 14:45:11 +0000 Subject: [PATCH 06/11] Add mock test for duplicate marketplace requests --- premium/usage.go | 6 +-- premium/usage_test.go | 120 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index d08a68ac09..56e3da3dd0 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -552,7 +552,7 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin // Each product is given a unique product code when it is listed in AWS Marketplace // in the future we can have multiple product codes for container or AMI based listings ProductCode: aws.String(awsMarketplaceProductCode()), - Timestamp: aws.Time(time.Now()), + Timestamp: aws.Time(timeFunc()), UsageDimension: aws.String("rows"), UsageAllocations: usage, UsageQuantity: aws.Int32(int32(rows)), @@ -567,7 +567,7 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con var lastErr error for retry := 0; retry < u.maxRetries; retry++ { u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - queryStartTime := time.Now() + queryStartTime := timeFunc() lastErr = u.reportUsageToAWSMarketplace(ctx, rows) if lastErr == nil { @@ -623,7 +623,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") - u.lastUpdateTime = time.Now().UTC() + u.lastUpdateTime = timeFunc().UTC() if resp.HTTPResponse != nil { u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) } diff --git a/premium/usage_test.go b/premium/usage_test.go index b846006edf..2472c1c435 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -360,7 +360,6 @@ func TestUsageService_IncreaseForTable_CorrectByTable(t *testing.T) { } func TestUsageService_AWSMarketplaceDone(t *testing.T) { - var err error ctrl := gomock.NewController(t) m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") @@ -394,6 +393,116 @@ func TestUsageService_AWSMarketplaceDone(t *testing.T) { m.EXPECT().MeterUsage(gomock.Any(), in).Return(&out, nil) usageClient := newClient(t, nil, WithBatchLimit(50), WithAWSMarketplaceClient(m)) + // This will generate 19,998 rows + // We expect that there will be 20 rows reported to AWS Marketplace + rows := 9999 + for i := 0; i < rows; i++ { + err := usageClient.IncreaseForTable("table", 2) + require.NoError(t, err) + } + + err := usageClient.Close() + require.NoError(t, err) +} + +func TestUsageService_AWSMarketplace_DuplicateRowsRetry(t *testing.T) { + ctrl := gomock.NewController(t) + m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) + t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") + + pmeta := plugin.Meta{ + Team: "plugin-team", + Kind: cqapi.PluginKindSource, + Name: "vault", + } + + tn := time.Now() + timeList := []time.Time{ + tn, + tn, + tn.Add(time.Second), + tn.Add(time.Second), + tn.Add(2 * time.Second), + tn.Add(2 * time.Second), + } + timeFunc = func() time.Time { + if len(timeList) == 0 { + panic("timeFunc called too many times") + } + t := timeList[0] + timeList = timeList[1:] + return t + } + + out := marketplacemetering.MeterUsageOutput{} + in := meteringInput{ + MeterUsageInput: marketplacemetering.MeterUsageInput{ + ProductCode: aws.String(awsMarketplaceProductCode()), + UsageDimension: aws.String("rows"), + UsageQuantity: aws.Int32(20), + Timestamp: aws.Time(tn.Round(time.Second)), + UsageAllocations: []types.UsageAllocation{{ + AllocatedUsageQuantity: aws.Int32(int32(20)), + Tags: []types.Tag{ + { + Key: aws.String("plugin_name"), + Value: aws.String(pmeta.Name), + }, + { + Key: aws.String("plugin_team"), + Value: aws.String(pmeta.Team), + }, + { + Key: aws.String("plugin_kind"), + Value: aws.String(string(pmeta.Kind)), + }, + }, + }}, + }, + } + assert.NoError(t, faker.FakeObject(&out)) + + // logger := zerolog.New(zerolog.NewTestWriter(t)).Level(zerolog.DebugLevel) + logger := zerolog.Nop() + + type meteringKey struct { + Dimension string + Quantity int32 + Timestamp time.Time + } + dupes := make(map[meteringKey]struct{}) + // Add two duplicates + dupes[meteringKey{Dimension: *in.UsageDimension, Quantity: *in.UsageQuantity, Timestamp: in.Timestamp.Round(time.Second)}] = struct{}{} + dupes[meteringKey{Dimension: *in.UsageDimension, Quantity: *in.UsageQuantity, Timestamp: in.Timestamp.Add(time.Second).Round(time.Second)}] = struct{}{} + existingRows := int32(0) + for v := range dupes { + existingRows += v.Quantity + } + + duplicateRequests, validRequests := 0, 0 + m.EXPECT().MeterUsage(gomock.Any(), in).DoAndReturn(func(_ context.Context, in *marketplacemetering.MeterUsageInput, _ ...any) (*marketplacemetering.MeterUsageOutput, error) { + k := meteringKey{Dimension: *in.UsageDimension, Quantity: *in.UsageQuantity, Timestamp: in.Timestamp.Round(time.Second)} + if _, ok := dupes[k]; ok { + logger.Debug().Any("key", k).Msg("got duplicate request") + duplicateRequests++ + return nil, &types.DuplicateRequestException{Message: aws.String("duplicate request")} + } + logger.Debug().Any("key", k).Msg("got valid request") + validRequests++ + dupes[k] = struct{}{} + return &out, nil + }).MinTimes(1) + + usageClient, err := NewUsageClient( + pmeta, + WithMaxWaitTime(time.Millisecond), + WithBatchLimit(50), + WithLogger(logger), + withTeamName("team-name"), + WithAWSMarketplaceClient(m), + ) + require.NoError(t, err) + // This will generate 19,998 rows // We expect that there will be 20 rows reported to AWS Marketplace rows := 9999 @@ -404,6 +513,15 @@ func TestUsageService_AWSMarketplaceDone(t *testing.T) { err = usageClient.Close() require.NoError(t, err) + + require.Equal(t, 2, duplicateRequests, "should have 2 duplicate requests") + require.Equal(t, 1, validRequests, "should have 1 valid request") + + totalRows := int32(0) + for v := range dupes { + totalRows += v.Quantity + } + assert.Equal(t, int32(20), totalRows-existingRows, "should have 20 rows reported to AWS Marketplace") } func TestUsageService_Increase_ErrorOnMixingMethods(t *testing.T) { From 877a1efb8c62892fbc563dc300c5bd39e8f6ec79 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Wed, 30 Oct 2024 15:05:44 +0000 Subject: [PATCH 07/11] Use separate timeFunc for BatchUpdater --- premium/usage.go | 10 +++++++--- premium/usage_test.go | 32 ++++++++++++++++---------------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index 56e3da3dd0..6a946741c6 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -194,6 +194,9 @@ type BatchUpdater struct { isClosed bool dataOnClose bool usageIncreaseMethod int + + // Testing + timeFunc func() time.Time } func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, error) { @@ -212,6 +215,7 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e triggerUpdate: make(chan struct{}), done: make(chan struct{}), closeError: make(chan error), + timeFunc: time.Now, tables: map[string]uint32{}, } @@ -552,7 +556,7 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin // Each product is given a unique product code when it is listed in AWS Marketplace // in the future we can have multiple product codes for container or AMI based listings ProductCode: aws.String(awsMarketplaceProductCode()), - Timestamp: aws.Time(timeFunc()), + Timestamp: aws.Time(u.timeFunc()), UsageDimension: aws.String("rows"), UsageAllocations: usage, UsageQuantity: aws.Int32(int32(rows)), @@ -567,7 +571,7 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con var lastErr error for retry := 0; retry < u.maxRetries; retry++ { u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - queryStartTime := timeFunc() + queryStartTime := u.timeFunc() lastErr = u.reportUsageToAWSMarketplace(ctx, rows) if lastErr == nil { @@ -623,7 +627,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload) if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 { u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated") - u.lastUpdateTime = timeFunc().UTC() + u.lastUpdateTime = u.timeFunc().UTC() if resp.HTTPResponse != nil { u.updateConfigurationFromHeaders(resp.HTTPResponse.Header) } diff --git a/premium/usage_test.go b/premium/usage_test.go index 2472c1c435..0bc28304b9 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -405,7 +405,7 @@ func TestUsageService_AWSMarketplaceDone(t *testing.T) { require.NoError(t, err) } -func TestUsageService_AWSMarketplace_DuplicateRowsRetry(t *testing.T) { +func TestUsageService_AWSMarketpgolanglace_DuplicateRowsRetry(t *testing.T) { ctrl := gomock.NewController(t) m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") @@ -425,9 +425,22 @@ func TestUsageService_AWSMarketplace_DuplicateRowsRetry(t *testing.T) { tn.Add(2 * time.Second), tn.Add(2 * time.Second), } - timeFunc = func() time.Time { + + // logger := zerolog.New(zerolog.NewTestWriter(t)).Level(zerolog.DebugLevel) + logger := zerolog.Nop() + + usageClient, err := NewUsageClient( + pmeta, + WithMaxWaitTime(time.Millisecond), + WithBatchLimit(50), + WithLogger(logger), + withTeamName("team-name"), + WithAWSMarketplaceClient(m), + ) + require.NoError(t, err) + usageClient.(*BatchUpdater).timeFunc = func() time.Time { if len(timeList) == 0 { - panic("timeFunc called too many times") + panic("BatchUpdater.timeFunc called too many times") } t := timeList[0] timeList = timeList[1:] @@ -462,9 +475,6 @@ func TestUsageService_AWSMarketplace_DuplicateRowsRetry(t *testing.T) { } assert.NoError(t, faker.FakeObject(&out)) - // logger := zerolog.New(zerolog.NewTestWriter(t)).Level(zerolog.DebugLevel) - logger := zerolog.Nop() - type meteringKey struct { Dimension string Quantity int32 @@ -493,16 +503,6 @@ func TestUsageService_AWSMarketplace_DuplicateRowsRetry(t *testing.T) { return &out, nil }).MinTimes(1) - usageClient, err := NewUsageClient( - pmeta, - WithMaxWaitTime(time.Millisecond), - WithBatchLimit(50), - WithLogger(logger), - withTeamName("team-name"), - WithAWSMarketplaceClient(m), - ) - require.NoError(t, err) - // This will generate 19,998 rows // We expect that there will be 20 rows reported to AWS Marketplace rows := 9999 From 47ca61df9d9cd2384dc4a93ff80747d1966a74f0 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Wed, 30 Oct 2024 19:30:23 +0000 Subject: [PATCH 08/11] Simplify more --- premium/usage.go | 46 +++---------------- premium/usage_test.go | 102 ------------------------------------------ 2 files changed, 6 insertions(+), 142 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index 6a946741c6..aaadf14d94 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -12,7 +12,6 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" awsConfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/marketplacemetering" "github.com/aws/aws-sdk-go-v2/service/marketplacemetering/types" @@ -567,11 +566,10 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin return nil } -func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Context, rows uint32) error { +func (u *BatchUpdater) updateMarketplaceUsage(ctx context.Context, rows uint32) error { var lastErr error for retry := 0; retry < u.maxRetries; retry++ { u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage") - queryStartTime := u.timeFunc() lastErr = u.reportUsageToAWSMarketplace(ctx, rows) if lastErr == nil { @@ -580,27 +578,12 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con } var de *types.DuplicateRequestException - if errors.As(lastErr, &de) { - jitter := time.Duration(rand.Intn(1000)) * time.Millisecond - time.Sleep(marketplaceDuplicateWaitTime + jitter) - continue - } - - var ( - statusCode = -1 - rerr *awshttp.ResponseError - ) - if errors.As(lastErr, &rerr) { - statusCode = rerr.HTTPStatusCode() + if !errors.As(lastErr, &de) { + return fmt.Errorf("failed to update usage: %w", lastErr) } - retryDuration, err := u.calculateMarketplaceRetryDuration(statusCode, queryStartTime, retry) - if err != nil { - return fmt.Errorf("failed to calculate retry duration: %v: %w", err.Error(), lastErr) - } - if retryDuration > 0 { - time.Sleep(retryDuration) - } + jitter := time.Duration(rand.Intn(1000)) * time.Millisecond + time.Sleep(marketplaceDuplicateWaitTime + jitter) } return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr) } @@ -608,7 +591,7 @@ func (u *BatchUpdater) updateMarketplaceUsageWithRetryAndBackoff(ctx context.Con func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error { // If the AWS Marketplace client is set, use it to track usage if u.awsMarketplaceClient != nil { - return u.updateMarketplaceUsageWithRetryAndBackoff(ctx, rows) + return u.updateMarketplaceUsage(ctx, rows) } u.logger.Debug().Str("url", u.url).Uint32("rows", rows).Msg("updating usage") @@ -671,23 +654,6 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) { } } -// calculateMarketplaceRetryDuration calculates the duration to sleep relative to the query start time before retrying an update -func (u *BatchUpdater) calculateMarketplaceRetryDuration(statusCode int, queryStartTime time.Time, retry int) (time.Duration, error) { - if statusCode > -1 && !retryableStatusCode(statusCode) { - return 0, fmt.Errorf("non-retryable status code: %d", statusCode) - } - - // Calculate exponential backoff - baseRetry := min(time.Duration(1< Date: Wed, 30 Oct 2024 21:03:38 +0000 Subject: [PATCH 09/11] CR: Log error --- premium/usage.go | 1 + 1 file changed, 1 insertion(+) diff --git a/premium/usage.go b/premium/usage.go index aaadf14d94..b8a90de3cc 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -581,6 +581,7 @@ func (u *BatchUpdater) updateMarketplaceUsage(ctx context.Context, rows uint32) if !errors.As(lastErr, &de) { return fmt.Errorf("failed to update usage: %w", lastErr) } + u.logger.Debug().Err(lastErr).Int("try", retry).Uint32("rows", rows).Msg("usage update failed due to duplicate request") jitter := time.Duration(rand.Intn(1000)) * time.Millisecond time.Sleep(marketplaceDuplicateWaitTime + jitter) From bb8b56e68336b169274a8af77b20e6671f2cc7e7 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 31 Oct 2024 13:34:39 +0000 Subject: [PATCH 10/11] Extract dry run mock to helper --- premium/usage_test.go | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/premium/usage_test.go b/premium/usage_test.go index cfae2f0166..5a12b22d8b 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -360,20 +360,28 @@ func TestUsageService_IncreaseForTable_CorrectByTable(t *testing.T) { } } -func TestUsageService_AWSMarketplaceDone(t *testing.T) { - ctrl := gomock.NewController(t) - m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) - t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") +func usageMarketplaceDryRunHelper(t *testing.T, m *mocks.MockAWSMarketplaceClientInterface) *gomock.Call { + t.Helper() - out := marketplacemetering.MeterUsageOutput{} inTest := meteringInput{ marketplacemetering.MeterUsageInput{ - ProductCode: aws.String("2a8bdkarwqrp0tmo4errl65s7"), + ProductCode: aws.String(awsMarketplaceProductCode()), UsageDimension: aws.String("rows"), UsageQuantity: aws.Int32(int32(0)), DryRun: aws.Bool(true)}, } errTest := smithy.GenericAPIError{Code: "DryRunOperation", Message: "No errors detected in dry run"} + out := marketplacemetering.MeterUsageOutput{} + + return m.EXPECT().MeterUsage(gomock.Any(), inTest).Return(&out, &errTest) +} + +func TestUsageService_AWSMarketplaceDone(t *testing.T) { + ctrl := gomock.NewController(t) + m := mocks.NewMockAWSMarketplaceClientInterface(ctrl) + t.Setenv("CQ_AWS_MARKETPLACE_CONTAINER", "true") + + out := marketplacemetering.MeterUsageOutput{} in := meteringInput{ MeterUsageInput: marketplacemetering.MeterUsageInput{ @@ -402,7 +410,7 @@ func TestUsageService_AWSMarketplaceDone(t *testing.T) { assert.NoError(t, faker.FakeObject(&out)) gomock.InOrder( - m.EXPECT().MeterUsage(gomock.Any(), inTest).Return(&out, &errTest), + usageMarketplaceDryRunHelper(t, m), m.EXPECT().MeterUsage(gomock.Any(), in).Return(&out, nil), ) @@ -441,6 +449,7 @@ func TestUsageService_AWSMarketpgolanglace_DuplicateRowsRetry(t *testing.T) { // logger := zerolog.New(zerolog.NewTestWriter(t)).Level(zerolog.DebugLevel) logger := zerolog.Nop() + usageMarketplaceDryRunHelper(t, m) usageClient, err := NewUsageClient( pmeta, WithMaxWaitTime(time.Millisecond), @@ -502,6 +511,7 @@ func TestUsageService_AWSMarketpgolanglace_DuplicateRowsRetry(t *testing.T) { } duplicateRequests, validRequests := 0, 0 + m.EXPECT().MeterUsage(gomock.Any(), in).DoAndReturn(func(_ context.Context, in *marketplacemetering.MeterUsageInput, _ ...any) (*marketplacemetering.MeterUsageOutput, error) { k := meteringKey{Dimension: *in.UsageDimension, Quantity: *in.UsageQuantity, Timestamp: in.Timestamp.Round(time.Second)} if _, ok := dupes[k]; ok { From fde2bd916751138175c65a46aa881c56560d49c8 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 31 Oct 2024 13:37:36 +0000 Subject: [PATCH 11/11] Use timeFunc instead of time.Now --- premium/usage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/premium/usage.go b/premium/usage.go index 8e9146f90a..facaa83624 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -311,7 +311,7 @@ func (u *BatchUpdater) setupAWSMarketplace() error { _, err = u.awsMarketplaceClient.MeterUsage(ctx, &marketplacemetering.MeterUsageInput{ ProductCode: aws.String(awsMarketplaceProductCode()), - Timestamp: aws.Time(time.Now()), + Timestamp: aws.Time(u.timeFunc()), UsageDimension: aws.String("rows"), UsageQuantity: aws.Int32(int32(0)), DryRun: aws.Bool(true),