diff --git a/go.mod b/go.mod index c4d1c0b6d4..020c7383ca 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/CloudyKit/jet/v6 v6.2.0 // indirect github.com/Joker/jade v1.1.3 // indirect github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 // indirect + github.com/adrg/xdg v0.4.0 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v13 v13.0.0-20230731205701-112f94971882 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index 427a27a662..e57d68e4e8 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 h1:KkH3I3sJuOLP3TjA/dfr4NAY8bghDwnXiU7cTKxQqo0= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06/go.mod h1:7erjKLwalezA0k99cWs5L11HWOAPNjdUZ6RxH1BXbbM= +github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= +github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= @@ -556,6 +558,7 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/premium/usage.go b/premium/usage.go new file mode 100644 index 0000000000..e8016949aa --- /dev/null +++ b/premium/usage.go @@ -0,0 +1,272 @@ +package premium + +import ( + "context" + "fmt" + cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "math/rand" + "net/http" + "sync/atomic" + "time" +) + +const ( + defaultBatchLimit = 1000 + defaultMaxRetries = 5 + defaultMaxWaitTime = 60 * time.Second + defaultMinTimeBetweenFlushes = 10 * time.Second + defaultMaxTimeBetweenFlushes = 30 * time.Second +) + +type UsageClient interface { + // Increase updates the usage by the given number of rows + Increase(context.Context, uint32) + // HasQuota returns true if the quota has not been exceeded + HasQuota(context.Context) (bool, error) + // Close flushes any remaining rows and closes the quota service + Close() error +} + +type UpdaterOptions func(updater *BatchUpdater) + +// WithBatchLimit sets the maximum number of rows to update in a single request +func WithBatchLimit(batchLimit uint32) UpdaterOptions { + return func(updater *BatchUpdater) { + updater.batchLimit = batchLimit + } +} + +// WithMaxTimeBetweenFlushes sets the flush duration - the time at which an update will be triggered even if the batch limit is not reached +func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UpdaterOptions { + return func(updater *BatchUpdater) { + updater.maxTimeBetweenFlushes = maxTimeBetweenFlushes + } +} + +// WithMinTimeBetweenFlushes sets the minimum time between updates +func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UpdaterOptions { + return func(updater *BatchUpdater) { + updater.minTimeBetweenFlushes = minTimeBetweenFlushes + } +} + +// WithMaxRetries sets the maximum number of retries to update the usage in case of an API error +func WithMaxRetries(maxRetries int) UpdaterOptions { + return func(updater *BatchUpdater) { + updater.maxRetries = maxRetries + } +} + +// WithMaxWaitTime sets the maximum time to wait before retrying a failed update +func WithMaxWaitTime(maxWaitTime time.Duration) UpdaterOptions { + return func(updater *BatchUpdater) { + updater.maxWaitTime = maxWaitTime + } +} + +type BatchUpdater struct { + apiClient *cqapi.ClientWithResponses + + // Plugin details + teamName string + pluginTeam string + pluginKind string + pluginName string + + // Configuration + batchLimit uint32 + maxRetries int + maxWaitTime time.Duration + minTimeBetweenFlushes time.Duration + maxTimeBetweenFlushes time.Duration + + // State + lastUpdateTime time.Time + rowsToUpdate atomic.Uint32 + triggerUpdate chan struct{} + done chan struct{} + closeError chan error + isClosed bool +} + +func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, teamName, pluginTeam, pluginKind, pluginName string, ops ...UpdaterOptions) *BatchUpdater { + u := &BatchUpdater{ + apiClient: apiClient, + + teamName: teamName, + pluginTeam: pluginTeam, + pluginKind: pluginKind, + pluginName: pluginName, + + batchLimit: defaultBatchLimit, + minTimeBetweenFlushes: defaultMinTimeBetweenFlushes, + maxTimeBetweenFlushes: defaultMaxTimeBetweenFlushes, + maxRetries: defaultMaxRetries, + maxWaitTime: defaultMaxWaitTime, + triggerUpdate: make(chan struct{}), + done: make(chan struct{}), + closeError: make(chan error), + } + for _, op := range ops { + op(u) + } + + u.backgroundUpdater(ctx) + + return u +} + +func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { + if rows <= 0 { + return fmt.Errorf("rows must be greater than zero got %d", rows) + } + + if u.isClosed { + return fmt.Errorf("usage updater is closed") + } + + u.rowsToUpdate.Add(rows) + + // Trigger an update unless an update is already in process + select { + case u.triggerUpdate <- struct{}{}: + default: + return nil + } + + return nil +} + +func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { + usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, cqapi.PluginKind(u.pluginKind), u.pluginName) + if err != nil { + return false, fmt.Errorf("failed to get usage: %w", err) + } + if usage.StatusCode() != http.StatusOK { + return false, fmt.Errorf("failed to get usage: %s", usage.Status()) + } + return *usage.JSON200.RemainingRows > 0, nil +} + +func (u *BatchUpdater) Close(_ context.Context) error { + u.isClosed = true + + close(u.done) + + return <-u.closeError +} + +func (u *BatchUpdater) backgroundUpdater(ctx context.Context) { + started := make(chan struct{}) + + flushDuration := time.NewTicker(u.maxTimeBetweenFlushes) + + go func() { + started <- struct{}{} + for { + select { + case <-u.triggerUpdate: + if time.Since(u.lastUpdateTime) < u.minTimeBetweenFlushes { + // Not enough time since last update + continue + } + + rowsToUpdate := u.rowsToUpdate.Load() + if rowsToUpdate < u.batchLimit { + // Not enough rows to update + continue + } + if err := u.updateUsageWithRetryAndBackoff(ctx, rowsToUpdate); err != nil { + log.Warn().Err(err).Msg("failed to update usage") + continue + } + u.rowsToUpdate.Add(-rowsToUpdate) + case <-flushDuration.C: + if time.Since(u.lastUpdateTime) < u.minTimeBetweenFlushes { + // Not enough time since last update + continue + } + rowsToUpdate := u.rowsToUpdate.Load() + if rowsToUpdate == 0 { + continue + } + if err := u.updateUsageWithRetryAndBackoff(ctx, rowsToUpdate); err != nil { + log.Warn().Err(err).Msg("failed to update usage") + continue + } + u.rowsToUpdate.Add(-rowsToUpdate) + case <-u.done: + remainingRows := u.rowsToUpdate.Load() + if remainingRows != 0 { + if err := u.updateUsageWithRetryAndBackoff(ctx, remainingRows); err != nil { + u.closeError <- err + return + } + u.rowsToUpdate.Add(-remainingRows) + } + u.closeError <- nil + return + } + } + }() + <-started +} + +func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numberToUpdate uint32) error { + for retry := 0; retry < u.maxRetries; retry++ { + queryStartTime := time.Now() + + resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, cqapi.IncreaseTeamPluginUsageJSONRequestBody{ + RequestId: uuid.New(), + PluginTeam: u.pluginTeam, + PluginKind: cqapi.PluginKind(u.pluginKind), + PluginName: u.pluginName, + Rows: int(numberToUpdate), + }) + if err != nil { + return fmt.Errorf("failed to update usage: %w", err) + } + if resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + u.lastUpdateTime = time.Now().UTC() + return nil + } + + retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry) + if err != nil { + return fmt.Errorf("failed to calculate retry duration: %w", err) + } + if retryDuration > 0 { + time.Sleep(retryDuration) + } + } + return fmt.Errorf("failed to update usage: max retries exceeded") +} + +// calculateRetryDuration calculates the duration to sleep relative to the query start time before retrying an update +func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) { + if !retryableStatusCode(statusCode) { + return 0, fmt.Errorf("non-retryable status code: %d", statusCode) + } + + // Check if we have a retry-after header + retryAfter := headers.Get("Retry-After") + if retryAfter != "" { + retryDelay, err := time.ParseDuration(retryAfter + "s") + if err != nil { + return 0, fmt.Errorf("failed to parse retry-after header: %w", err) + } + return retryDelay, nil + } + + // Calculate exponential backoff + baseRetry := min(time.Duration(1< batchSize, "minimum should be greater than batch size") +} + +func TestUsageService_WithFlushDuration(t *testing.T) { + ctx := context.Background() + batchSize := 2000 + + s := createTestServer(t) + defer s.server.Close() + + apiClient, err := cqapi.NewClientWithResponses(s.server.URL) + require.NoError(t, err) + + usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) + + for i := 0; i < 10; i++ { + err = usageClient.Increase(ctx, 10) + require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + } + err = usageClient.Close(ctx) + require.NoError(t, err) + + assert.Equal(t, 100, s.sumOfUpdates(), "total should equal number of updated rows") + assert.True(t, s.minExcludingClose() < batchSize, "we should see updates less than batchsize if ticker is firing") +} + +func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { + ctx := context.Background() + + s := createTestServer(t) + defer s.server.Close() + + apiClient, err := cqapi.NewClientWithResponses(s.server.URL) + require.NoError(t, err) + + usageClient := newClient(ctx, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) + + for i := 0; i < 10000; i++ { + err = usageClient.Increase(ctx, 1) + require.NoError(t, err) + } + err = usageClient.Close(ctx) + require.NoError(t, err) + + assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") + assert.Equal(t, 2, s.numberOfUpdates(), "should only update first time and on close if minimum update duration is set") +} + +func TestUsageService_NoUpdates(t *testing.T) { + ctx := context.Background() + + s := createTestServer(t) + defer s.server.Close() + + apiClient, err := cqapi.NewClientWithResponses(s.server.URL) + require.NoError(t, err) + + usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + + err = usageClient.Close(ctx) + require.NoError(t, err) + + assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") +} + +func TestUsageService_UpdatesWithZeroRows(t *testing.T) { + ctx := context.Background() + + s := createTestServer(t) + defer s.server.Close() + + apiClient, err := cqapi.NewClientWithResponses(s.server.URL) + require.NoError(t, err) + + usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + + err = usageClient.Increase(ctx, 0) + require.Error(t, err, "should not be able to update with zero rows") + + err = usageClient.Close(ctx) + require.NoError(t, err) + + assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") +} + +func TestUsageService_ShouldNotUpdateClosedService(t *testing.T) { + ctx := context.Background() + + s := createTestServer(t) + defer s.server.Close() + + apiClient, err := cqapi.NewClientWithResponses(s.server.URL) + require.NoError(t, err) + + usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + + // Close the service first + err = usageClient.Close(ctx) + require.NoError(t, err) + + err = usageClient.Increase(ctx, 10) + require.Error(t, err, "should not be able to update closed service") + + assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") +} + +func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { + tests := []struct { + name string + statusCode int + headers http.Header + retry int + expectedSeconds int + ops func(client *BatchUpdater) + }{ + { + name: "first retry", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 0, + expectedSeconds: 1, + }, + { + name: "second retry", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 1, + expectedSeconds: 2, + }, + { + name: "third retry", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 2, + expectedSeconds: 4, + }, + { + name: "fourth retry", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 3, + expectedSeconds: 8, + }, + { + name: "should max out at max wait time", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 10, + expectedSeconds: 30, + ops: func(client *BatchUpdater) { + client.maxWaitTime = 30 * time.Second + }, + }, + } + + for _, tt := range tests { + usageClient := newClient(context.Background(), nil) + if tt.ops != nil { + tt.ops(usageClient) + } + t.Run(tt.name, func(t *testing.T) { + retryDuration, err := usageClient.calculateRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) + require.NoError(t, err) + + assert.InDeltaf(t, tt.expectedSeconds, retryDuration.Seconds(), 1, "retry duration should be %d seconds", tt.expectedSeconds) + }) + } +} + +func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { + tests := []struct { + name string + statusCode int + headers http.Header + retry int + expectedSeconds int + ops func(client *BatchUpdater) + wantErr error + }{ + { + name: "should use exponential backoff on 503 and no header", + statusCode: http.StatusServiceUnavailable, + headers: http.Header{}, + retry: 0, + expectedSeconds: 1, + }, + { + name: "should use exponential backoff on 429 if no retry-after header", + statusCode: http.StatusTooManyRequests, + headers: http.Header{}, + retry: 1, + expectedSeconds: 2, + }, + { + name: "should use retry-after header if present on 429", + statusCode: http.StatusTooManyRequests, + headers: http.Header{"Retry-After": []string{"5"}}, + retry: 0, + expectedSeconds: 5, + }, + } + + for _, tt := range tests { + usageClient := newClient(context.Background(), nil) + if tt.ops != nil { + tt.ops(usageClient) + } + t.Run(tt.name, func(t *testing.T) { + retryDuration, err := usageClient.calculateRetryDuration(tt.statusCode, tt.headers, time.Now(), tt.retry) + if tt.wantErr == nil { + require.NoError(t, err) + } else { + assert.Contains(t, err.Error(), tt.wantErr.Error()) + } + + assert.InDeltaf(t, tt.expectedSeconds, retryDuration.Seconds(), 1, "retry duration should be %d seconds", tt.expectedSeconds) + }) + } +} + +func newClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, ops ...UpdaterOptions) *BatchUpdater { + return NewUsageClient(ctx, apiClient, "myteam", "mnorbury-team", "source", "vault", ops...) +} + +func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testStage { + stage := testStage{ + remainingRows: remainingRows, + update: make([]int, 0), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + w.Header().Set("Content-Type", "application/json") + if _, err := fmt.Fprintf(w, `{"remaining_rows": %d}`, stage.remainingRows); err != nil { + t.Fatal(err) + } + w.WriteHeader(http.StatusOK) + return + } + if r.Method == "POST" { + dec := json.NewDecoder(r.Body) + var req cqapi.IncreaseTeamPluginUsageJSONRequestBody + err := dec.Decode(&req) + require.NoError(t, err) + + stage.update = append(stage.update, req.Rows) + + w.WriteHeader(http.StatusOK) + return + } + }) + + stage.server = httptest.NewServer(handler) + + return &stage +} + +func createTestServer(t *testing.T) *testStage { + return createTestServerWithRemainingRows(t, 0) +} + +type testStage struct { + server *httptest.Server + + remainingRows int + update []int +} + +func (s *testStage) numberOfUpdates() int { + return len(s.update) +} + +func (s *testStage) sumOfUpdates() int { + sum := 0 + for _, val := range s.update { + sum += val + } + return sum +} + +func (s *testStage) minExcludingClose() int { + m := math.MaxInt + for i := 0; i < len(s.update); i++ { + if s.update[i] < m { + m = s.update[i] + } + } + return m +}