From 3b4a248d4dc2cd3f60a85762b075671f5c982fa6 Mon Sep 17 00:00:00 2001 From: Martin Norbury Date: Fri, 27 Oct 2023 11:04:07 +0100 Subject: [PATCH 1/6] feat: Adding quota monitoring for premium plugins This adds a background monitoring process that will periodically check the remaining quota for a premium plugin. If the quota is exceeded then a context cancellation is triggered, forcing the sync process to stop. fixes: https://github.com/cloudquery/cloudquery-issues/issues/749 --- go.mod | 1 + go.sum | 3 + premium/monitor.go | 80 +++++++++++++++++++++++++ premium/monitor_test.go | 52 +++++++++++++++++ premium/tables.go | 27 +++++++++ premium/tables_test.go | 39 +++++++++++++ premium/usage.go | 125 ++++++++++++++++++++++++++++++++-------- premium/usage_test.go | 108 ++++++++++++++++++++++------------ 8 files changed, 374 insertions(+), 61 deletions(-) create mode 100644 premium/monitor.go create mode 100644 premium/monitor_test.go create mode 100644 premium/tables.go create mode 100644 premium/tables_test.go diff --git a/go.mod b/go.mod index 94cf67f755..af10deba82 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/CloudyKit/jet/v6 v6.2.0 // indirect github.com/Joker/jade v1.1.3 // indirect github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 // indirect + github.com/adrg/xdg v0.4.0 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v13 v13.0.0-20230731205701-112f94971882 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index d7fe03b424..55ca10e7c2 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 h1:KkH3I3sJuOLP3TjA/dfr4NAY8bghDwnXiU7cTKxQqo0= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06/go.mod h1:7erjKLwalezA0k99cWs5L11HWOAPNjdUZ6RxH1BXbbM= +github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= +github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= @@ -556,6 +558,7 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/premium/monitor.go b/premium/monitor.go new file mode 100644 index 0000000000..f5c707ce93 --- /dev/null +++ b/premium/monitor.go @@ -0,0 +1,80 @@ +package premium + +import ( + "context" + "errors" + "time" +) + +var ErrNoQuota = errors.New("no remaining quota for the month, please increase your usage limit if you want to continue syncing this plugin") + +const DefaultQuotaCheckInterval = 30 * time.Second + +type quotaChecker struct { + qm QuotaMonitor + duration time.Duration +} + +type QuotaCheckOption func(*quotaChecker) + +// WithQuotaCheckPeriod the time interval between quota checks +func WithQuotaCheckPeriod(duration time.Duration) QuotaCheckOption { + return func(m *quotaChecker) { + m.duration = duration + } +} + +// WithCancelOnQuotaExceeded monitors the quota usage at intervals defined by duration and cancels the context if the quota is exceeded +func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...QuotaCheckOption) (context.Context, func(), error) { + m := quotaChecker{ + qm: qm, + duration: DefaultQuotaCheckInterval, + } + for _, op := range ops { + op(&m) + } + + if err := m.checkInitialQuota(ctx); err != nil { + return ctx, nil, err + } + + ctx, cancel := m.startQuotaMonitor(ctx) + + return ctx, cancel, nil +} + +func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { + hasQuota, err := qc.qm.HasQuota(ctx) + if err != nil { + return err + } + + if !hasQuota { + return ErrNoQuota + } + + return nil +} + +func (qc quotaChecker) startQuotaMonitor(ctx context.Context) (context.Context, func()) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + ticker := time.NewTicker(qc.duration) + for { + select { + case <-newCtx.Done(): + return + case <-ticker.C: + hasQuota, err := qc.qm.HasQuota(newCtx) + if err != nil { + continue + } + if !hasQuota { + return + } + } + } + }() + return newCtx, cancel +} diff --git a/premium/monitor_test.go b/premium/monitor_test.go new file mode 100644 index 0000000000..68b1a68ce5 --- /dev/null +++ b/premium/monitor_test.go @@ -0,0 +1,52 @@ +package premium + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func newFakeQuotaMonitor(hasQuota ...bool) *fakeQuotaMonitor { + return &fakeQuotaMonitor{hasQuota: hasQuota} +} + +type fakeQuotaMonitor struct { + hasQuota []bool + calls int +} + +func (f *fakeQuotaMonitor) HasQuota(_ context.Context) (bool, error) { + hasQuota := f.hasQuota[f.calls] + if f.calls < len(f.hasQuota)-1 { + f.calls++ + } + return hasQuota, nil +} + +func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { + ctx := context.Background() + + _, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(false)) + + require.Error(t, err) +} + +func TestWithCancelOnQuotaExceeded_NoQuota(t *testing.T) { + ctx := context.Background() + + ctx, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, false), WithQuotaCheckPeriod(1*time.Millisecond)) + require.NoError(t, err) + + <-ctx.Done() +} + +func TestWithCancelOnQuotaExceeded_HasQuotaCanceled(t *testing.T) { + ctx := context.Background() + + ctx, cancel, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, true, true), WithQuotaCheckPeriod(1*time.Millisecond)) + require.NoError(t, err) + cancel() + + <-ctx.Done() +} diff --git a/premium/tables.go b/premium/tables.go new file mode 100644 index 0000000000..2f9a8217ba --- /dev/null +++ b/premium/tables.go @@ -0,0 +1,27 @@ +package premium + +import "github.com/cloudquery/plugin-sdk/v4/schema" + +// ContainsPaidTables returns true if any of the tables are paid +func ContainsPaidTables(tables schema.Tables) bool { + for _, t := range tables { + if t.IsPaid { + return true + } + } + return false +} + +// MakeAllTablesPaid sets all tables to paid +func MakeAllTablesPaid(tables schema.Tables) schema.Tables { + for _, table := range tables { + MakeTablePaid(table) + } + return tables +} + +// MakeTablePaid sets the table to paid +func MakeTablePaid(table *schema.Table) *schema.Table { + table.IsPaid = true + return table +} diff --git a/premium/tables_test.go b/premium/tables_test.go new file mode 100644 index 0000000000..251be02d0b --- /dev/null +++ b/premium/tables_test.go @@ -0,0 +1,39 @@ +package premium + +import ( + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestContainsPaidTables(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: true}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + assert.False(t, ContainsPaidTables(noPaidTables), "no paid tables") + assert.True(t, ContainsPaidTables(paidTables), "paid tables") +} + +func TestMakeAllTablesPaid(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := MakeAllTablesPaid(noPaidTables) + + assert.Equal(t, 3, len(paidTables)) + for _, table := range paidTables { + assert.True(t, table.IsPaid) + } +} diff --git a/premium/usage.go b/premium/usage.go index e8016949aa..f41e7be318 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -4,7 +4,10 @@ import ( "context" "fmt" cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/auth" + "github.com/cloudquery/cloudquery-api-go/config" "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "math/rand" "net/http" @@ -13,6 +16,7 @@ import ( ) const ( + defaultAPIURL = "https://api.cloudquery.io" defaultBatchLimit = 1000 defaultMaxRetries = 5 defaultMaxWaitTime = 60 * time.Second @@ -20,60 +24,108 @@ const ( defaultMaxTimeBetweenFlushes = 30 * time.Second ) -type UsageClient interface { - // Increase updates the usage by the given number of rows - Increase(context.Context, uint32) +type QuotaMonitor interface { // HasQuota returns true if the quota has not been exceeded HasQuota(context.Context) (bool, error) +} + +type UsageClient interface { + QuotaMonitor + // Increase updates the usage by the given number of rows + Increase(uint32) error // Close flushes any remaining rows and closes the quota service Close() error } -type UpdaterOptions func(updater *BatchUpdater) +type UsageClientOptions func(updater *BatchUpdater) // WithBatchLimit sets the maximum number of rows to update in a single request -func WithBatchLimit(batchLimit uint32) UpdaterOptions { +func WithBatchLimit(batchLimit uint32) UsageClientOptions { return func(updater *BatchUpdater) { updater.batchLimit = batchLimit } } // WithMaxTimeBetweenFlushes sets the flush duration - the time at which an update will be triggered even if the batch limit is not reached -func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxTimeBetweenFlushes = maxTimeBetweenFlushes } } // WithMinTimeBetweenFlushes sets the minimum time between updates -func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.minTimeBetweenFlushes = minTimeBetweenFlushes } } // WithMaxRetries sets the maximum number of retries to update the usage in case of an API error -func WithMaxRetries(maxRetries int) UpdaterOptions { +func WithMaxRetries(maxRetries int) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxRetries = maxRetries } } // WithMaxWaitTime sets the maximum time to wait before retrying a failed update -func WithMaxWaitTime(maxWaitTime time.Duration) UpdaterOptions { +func WithMaxWaitTime(maxWaitTime time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxWaitTime = maxWaitTime } } +// WithLogger sets the logger to use - defaults to a no-op logger +func WithLogger(logger zerolog.Logger) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.logger = logger + } +} + +// WithURL sets the API URL to use - defaults to https://api.cloudquery.io +func WithURL(url string) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.url = url + } +} + +// WithTeamName sets the team name to use - defaults to the team name from the configuration +func WithTeamName(teamName cqapi.TeamName) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.teamName = teamName + } +} + +// WithAPIClient sets the API client to use - defaults to a client using a bearer token generated from the refresh token stored in the configuration +func WithAPIClient(apiClient *cqapi.ClientWithResponses) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.apiClient = apiClient + } +} + +func WithPluginTeam(pluginTeam string) cqapi.PluginTeam { + return pluginTeam +} + +func WithPluginKind(pluginKind string) cqapi.PluginKind { + return cqapi.PluginKind(pluginKind) +} + +func WithPluginName(pluginName string) cqapi.PluginName { + return pluginName +} + +var _ UsageClient = (*BatchUpdater)(nil) + type BatchUpdater struct { + logger zerolog.Logger + url string apiClient *cqapi.ClientWithResponses // Plugin details - teamName string - pluginTeam string - pluginKind string - pluginName string + teamName cqapi.TeamName + pluginTeam cqapi.PluginTeam + pluginKind cqapi.PluginKind + pluginName cqapi.PluginName // Configuration batchLimit uint32 @@ -91,11 +143,11 @@ type BatchUpdater struct { isClosed bool } -func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, teamName, pluginTeam, pluginKind, pluginName string, ops ...UpdaterOptions) *BatchUpdater { +func NewUsageClient(pluginTeam cqapi.PluginTeam, pluginKind cqapi.PluginKind, pluginName cqapi.PluginName, ops ...UsageClientOptions) (*BatchUpdater, error) { u := &BatchUpdater{ - apiClient: apiClient, + logger: zerolog.Nop(), + url: defaultAPIURL, - teamName: teamName, pluginTeam: pluginTeam, pluginKind: pluginKind, pluginName: pluginName, @@ -113,12 +165,38 @@ func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, t op(u) } - u.backgroundUpdater(ctx) + // Set team name from configuration if not provided + if u.teamName == "" { + teamName, err := config.GetValue("team") + if err != nil { + return nil, fmt.Errorf("failed to get team name from config: %w", err) + } + u.teamName = teamName + } + + // Create a default api client if none was provided + if u.apiClient == nil { + tokenClient := auth.NewTokenClient() + ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { + token, err := tokenClient.GetToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + })) + if err != nil { + return nil, fmt.Errorf("failed to create api client: %w", err) + } + u.apiClient = ac + } + + u.backgroundUpdater() - return u + return u, nil } -func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { +func (u *BatchUpdater) Increase(rows uint32) error { if rows <= 0 { return fmt.Errorf("rows must be greater than zero got %d", rows) } @@ -140,7 +218,7 @@ func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { } func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { - usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, cqapi.PluginKind(u.pluginKind), u.pluginName) + usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, u.pluginKind, u.pluginName) if err != nil { return false, fmt.Errorf("failed to get usage: %w", err) } @@ -150,7 +228,7 @@ func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { return *usage.JSON200.RemainingRows > 0, nil } -func (u *BatchUpdater) Close(_ context.Context) error { +func (u *BatchUpdater) Close() error { u.isClosed = true close(u.done) @@ -158,7 +236,8 @@ func (u *BatchUpdater) Close(_ context.Context) error { return <-u.closeError } -func (u *BatchUpdater) backgroundUpdater(ctx context.Context) { +func (u *BatchUpdater) backgroundUpdater() { + ctx := context.Background() started := make(chan struct{}) flushDuration := time.NewTicker(u.maxTimeBetweenFlushes) @@ -221,7 +300,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numbe resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, cqapi.IncreaseTeamPluginUsageJSONRequestBody{ RequestId: uuid.New(), PluginTeam: u.pluginTeam, - PluginKind: cqapi.PluginKind(u.pluginKind), + PluginKind: u.pluginKind, PluginName: u.pluginName, Rows: int(numberToUpdate), }) diff --git a/premium/usage_test.go b/premium/usage_test.go index 236b35c219..033bd55791 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/config" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "math" @@ -14,6 +16,45 @@ import ( "time" ) +func TestUsageService_NewUsageClient_Defaults(t *testing.T) { + err := config.SetConfigHome(t.TempDir()) + require.NoError(t, err) + + err = config.SetValue("team", "config-team") + require.NoError(t, err) + + uc, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + ) + require.NoError(t, err) + + assert.NotNil(t, uc.apiClient) + assert.Equal(t, "config-team", uc.teamName) + assert.Equal(t, zerolog.Nop(), uc.logger) + assert.Equal(t, 5, uc.maxRetries) + assert.Equal(t, 60*time.Second, uc.maxWaitTime) + assert.Equal(t, 30*time.Second, uc.maxTimeBetweenFlushes) +} + +func TestUsageService_NewUsageClient_Override(t *testing.T) { + ac, err := cqapi.NewClientWithResponses("http://localhost") + require.NoError(t, err) + + logger := zerolog.New(zerolog.NewTestWriter(t)) + + uc, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), WithLogger(logger), WithAPIClient(ac), WithTeamName("override-team-name"), WithMaxRetries(10), WithMaxWaitTime(120*time.Second), WithMaxTimeBetweenFlushes(10*time.Second)) + require.NoError(t, err) + + assert.Equal(t, ac, uc.apiClient) + assert.Equal(t, "override-team-name", uc.teamName) + assert.Equal(t, logger, uc.logger) + assert.Equal(t, 10, uc.maxRetries) + assert.Equal(t, 120*time.Second, uc.maxWaitTime) + assert.Equal(t, 10*time.Second, uc.maxTimeBetweenFlushes) +} + func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { ctx := context.Background() @@ -23,7 +64,7 @@ func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -40,7 +81,7 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -49,29 +90,26 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { } func TestUsageService_ZeroBatchSize(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") } func TestUsageService_WithBatchSize(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -80,13 +118,13 @@ func TestUsageService_WithBatchSize(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize))) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize))) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -94,7 +132,6 @@ func TestUsageService_WithBatchSize(t *testing.T) { } func TestUsageService_WithFlushDuration(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -103,14 +140,14 @@ func TestUsageService_WithFlushDuration(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) for i := 0; i < 10; i++ { - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.NoError(t, err) time.Sleep(5 * time.Millisecond) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 100, s.sumOfUpdates(), "total should equal number of updated rows") @@ -118,21 +155,19 @@ func TestUsageService_WithFlushDuration(t *testing.T) { } func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) + usageClient := newClient(t, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -140,58 +175,52 @@ func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { } func TestUsageService_NoUpdates(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_UpdatesWithZeroRows(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Increase(ctx, 0) + err = usageClient.Increase(0) require.Error(t, err, "should not be able to update with zero rows") - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_ShouldNotUpdateClosedService(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) // Close the service first - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.Error(t, err, "should not be able to update closed service") assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") @@ -247,7 +276,7 @@ func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -294,7 +323,7 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -311,8 +340,11 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } } -func newClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, ops ...UpdaterOptions) *BatchUpdater { - return NewUsageClient(ctx, apiClient, "myteam", "mnorbury-team", "source", "vault", ops...) +func newClient(t *testing.T, apiClient *cqapi.ClientWithResponses, ops ...UsageClientOptions) *BatchUpdater { + client, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), append(ops, WithTeamName("team-name"), WithAPIClient(apiClient))...) + require.NoError(t, err) + + return client } func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testStage { From 46490bb1fef6d136ea40ca060451f22550051c30 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Mon, 30 Oct 2023 11:45:37 +0000 Subject: [PATCH 2/6] Add handling for errors during checks and cancellation cause --- premium/monitor.go | 46 +++++++++++++++++++++++++---------- premium/monitor_test.go | 53 ++++++++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/premium/monitor.go b/premium/monitor.go index f5c707ce93..b9f3c150a9 100644 --- a/premium/monitor.go +++ b/premium/monitor.go @@ -9,38 +9,48 @@ import ( var ErrNoQuota = errors.New("no remaining quota for the month, please increase your usage limit if you want to continue syncing this plugin") const DefaultQuotaCheckInterval = 30 * time.Second +const DefaultMaxQuotaFailures = 10 // 5 minutes type quotaChecker struct { - qm QuotaMonitor - duration time.Duration + qm QuotaMonitor + duration time.Duration + maxConsecutiveFailures int } type QuotaCheckOption func(*quotaChecker) -// WithQuotaCheckPeriod the time interval between quota checks +// WithQuotaCheckPeriod controls the time interval between quota checks func WithQuotaCheckPeriod(duration time.Duration) QuotaCheckOption { return func(m *quotaChecker) { m.duration = duration } } +// WithQuotaMaxConsecutiveFailures controls the number of consecutive failed quota checks before the context is cancelled +func WithQuotaMaxConsecutiveFailures(n int) QuotaCheckOption { + return func(m *quotaChecker) { + m.maxConsecutiveFailures = n + } +} + // WithCancelOnQuotaExceeded monitors the quota usage at intervals defined by duration and cancels the context if the quota is exceeded -func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...QuotaCheckOption) (context.Context, func(), error) { +func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...QuotaCheckOption) (context.Context, error) { m := quotaChecker{ - qm: qm, - duration: DefaultQuotaCheckInterval, + qm: qm, + duration: DefaultQuotaCheckInterval, + maxConsecutiveFailures: DefaultMaxQuotaFailures, } for _, op := range ops { op(&m) } if err := m.checkInitialQuota(ctx); err != nil { - return ctx, nil, err + return ctx, err } - ctx, cancel := m.startQuotaMonitor(ctx) + newCtx := m.startQuotaMonitor(ctx) - return ctx, cancel, nil + return newCtx, nil } func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { @@ -56,11 +66,12 @@ func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { return nil } -func (qc quotaChecker) startQuotaMonitor(ctx context.Context) (context.Context, func()) { - newCtx, cancel := context.WithCancel(ctx) +func (qc quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { + newCtx, cancelWithCause := context.WithCancelCause(ctx) go func() { - defer cancel() ticker := time.NewTicker(qc.duration) + consecutiveFailures := 0 + var hasQuotaErrors error for { select { case <-newCtx.Done(): @@ -68,13 +79,22 @@ func (qc quotaChecker) startQuotaMonitor(ctx context.Context) (context.Context, case <-ticker.C: hasQuota, err := qc.qm.HasQuota(newCtx) if err != nil { + consecutiveFailures++ + hasQuotaErrors = errors.Join(hasQuotaErrors, err) + if consecutiveFailures >= qc.maxConsecutiveFailures { + cancelWithCause(hasQuotaErrors) + return + } continue } + consecutiveFailures = 0 + hasQuotaErrors = nil if !hasQuota { + cancelWithCause(ErrNoQuota) return } } } }() - return newCtx, cancel + return newCtx } diff --git a/premium/monitor_test.go b/premium/monitor_test.go index 68b1a68ce5..dd6b5beb4b 100644 --- a/premium/monitor_test.go +++ b/premium/monitor_test.go @@ -2,32 +2,42 @@ package premium import ( "context" - "github.com/stretchr/testify/require" + "errors" "testing" "time" + + "github.com/stretchr/testify/require" ) -func newFakeQuotaMonitor(hasQuota ...bool) *fakeQuotaMonitor { - return &fakeQuotaMonitor{hasQuota: hasQuota} +type quotaResponse struct { + hasQuota bool + err error +} + +func newFakeQuotaMonitor(hasQuota ...quotaResponse) *fakeQuotaMonitor { + return &fakeQuotaMonitor{responses: hasQuota} } type fakeQuotaMonitor struct { - hasQuota []bool - calls int + responses []quotaResponse + calls int } func (f *fakeQuotaMonitor) HasQuota(_ context.Context) (bool, error) { - hasQuota := f.hasQuota[f.calls] - if f.calls < len(f.hasQuota)-1 { + resp := f.responses[f.calls] + if f.calls < len(f.responses)-1 { f.calls++ } - return hasQuota, nil + return resp.hasQuota, resp.err } func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { ctx := context.Background() - _, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(false)) + responses := []quotaResponse{ + {false, nil}, + } + _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...)) require.Error(t, err) } @@ -35,18 +45,33 @@ func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { func TestWithCancelOnQuotaExceeded_NoQuota(t *testing.T) { ctx := context.Background() - ctx, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, false), WithQuotaCheckPeriod(1*time.Millisecond)) + responses := []quotaResponse{ + {true, nil}, + {false, nil}, + } + ctx, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...), WithQuotaCheckPeriod(1*time.Millisecond)) require.NoError(t, err) <-ctx.Done() + cause := context.Cause(ctx) + require.Equal(t, ErrNoQuota, cause) } -func TestWithCancelOnQuotaExceeded_HasQuotaCanceled(t *testing.T) { +func TestWithCancelOnQuotaCheckConsecutiveFailures(t *testing.T) { ctx := context.Background() - ctx, cancel, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, true, true), WithQuotaCheckPeriod(1*time.Millisecond)) + responses := []quotaResponse{ + {true, nil}, + {false, errors.New("test2")}, + {false, errors.New("test3")}, + } + ctx, err := WithCancelOnQuotaExceeded(ctx, + newFakeQuotaMonitor(responses...), + WithQuotaCheckPeriod(1*time.Millisecond), + WithQuotaMaxConsecutiveFailures(2), + ) require.NoError(t, err) - cancel() - <-ctx.Done() + cause := context.Cause(ctx) + require.Equal(t, "test2\ntest3", cause.Error()) } From 983bf127dce26a942fa0ab97c4a6aecb060b5431 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Mon, 30 Oct 2023 15:02:39 +0000 Subject: [PATCH 3/6] Add some debug logs --- premium/usage.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index f41e7be318..9a0150f1ef 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -3,16 +3,17 @@ package premium import ( "context" "fmt" + "math/rand" + "net/http" + "sync/atomic" + "time" + cqapi "github.com/cloudquery/cloudquery-api-go" "github.com/cloudquery/cloudquery-api-go/auth" "github.com/cloudquery/cloudquery-api-go/config" "github.com/google/uuid" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "math/rand" - "net/http" - "sync/atomic" - "time" ) const ( @@ -218,6 +219,7 @@ func (u *BatchUpdater) Increase(rows uint32) error { } func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { + u.logger.Debug().Str("url", u.url).Str("team", u.teamName).Str("pluginTeam", u.pluginTeam).Str("pluginKind", string(u.pluginKind)).Str("pluginName", string(u.pluginName)).Msg("checking quota") usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, u.pluginKind, u.pluginName) if err != nil { return false, fmt.Errorf("failed to get usage: %w", err) @@ -295,6 +297,7 @@ func (u *BatchUpdater) backgroundUpdater() { func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numberToUpdate uint32) error { for retry := 0; retry < u.maxRetries; retry++ { + u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", numberToUpdate).Msg("updating usage") queryStartTime := time.Now() resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, cqapi.IncreaseTeamPluginUsageJSONRequestBody{ @@ -308,6 +311,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numbe return fmt.Errorf("failed to update usage: %w", err) } if resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", numberToUpdate).Msg("usage updated") u.lastUpdateTime = time.Now().UTC() return nil } From dad7135893e412f41004a7a7f2116ddf8da2731f Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Mon, 30 Oct 2023 15:12:14 +0000 Subject: [PATCH 4/6] Fix linting --- premium/usage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/premium/usage.go b/premium/usage.go index 9a0150f1ef..122602e07e 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -219,7 +219,7 @@ func (u *BatchUpdater) Increase(rows uint32) error { } func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { - u.logger.Debug().Str("url", u.url).Str("team", u.teamName).Str("pluginTeam", u.pluginTeam).Str("pluginKind", string(u.pluginKind)).Str("pluginName", string(u.pluginName)).Msg("checking quota") + u.logger.Debug().Str("url", u.url).Str("team", u.teamName).Str("pluginTeam", u.pluginTeam).Str("pluginKind", string(u.pluginKind)).Str("pluginName", u.pluginName).Msg("checking quota") usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, u.pluginKind, u.pluginName) if err != nil { return false, fmt.Errorf("failed to get usage: %w", err) From 2a87d6eabe94c00c03f498b6aeeef94dc6fed4f8 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Mon, 30 Oct 2023 15:40:16 +0000 Subject: [PATCH 5/6] make team name option private --- premium/usage.go | 4 ++-- premium/usage_test.go | 29 ++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/premium/usage.go b/premium/usage.go index 122602e07e..7e2fdcedb6 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -89,8 +89,8 @@ func WithURL(url string) UsageClientOptions { } } -// WithTeamName sets the team name to use - defaults to the team name from the configuration -func WithTeamName(teamName cqapi.TeamName) UsageClientOptions { +// withTeamName sets the team name to use - defaults to the team name from the configuration +func withTeamName(teamName cqapi.TeamName) UsageClientOptions { return func(updater *BatchUpdater) { updater.teamName = teamName } diff --git a/premium/usage_test.go b/premium/usage_test.go index 033bd55791..0493337715 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -4,16 +4,17 @@ import ( "context" "encoding/json" "fmt" - cqapi "github.com/cloudquery/cloudquery-api-go" - "github.com/cloudquery/cloudquery-api-go/config" - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "math" "net/http" "net/http/httptest" "testing" "time" + + cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/config" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUsageService_NewUsageClient_Defaults(t *testing.T) { @@ -44,7 +45,17 @@ func TestUsageService_NewUsageClient_Override(t *testing.T) { logger := zerolog.New(zerolog.NewTestWriter(t)) - uc, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), WithLogger(logger), WithAPIClient(ac), WithTeamName("override-team-name"), WithMaxRetries(10), WithMaxWaitTime(120*time.Second), WithMaxTimeBetweenFlushes(10*time.Second)) + uc, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + WithLogger(logger), + WithAPIClient(ac), + withTeamName("override-team-name"), + WithMaxRetries(10), + WithMaxWaitTime(120*time.Second), + WithMaxTimeBetweenFlushes(10*time.Second), + ) require.NoError(t, err) assert.Equal(t, ac, uc.apiClient) @@ -341,7 +352,11 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } func newClient(t *testing.T, apiClient *cqapi.ClientWithResponses, ops ...UsageClientOptions) *BatchUpdater { - client, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), append(ops, WithTeamName("team-name"), WithAPIClient(apiClient))...) + client, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + append(ops, withTeamName("team-name"), WithAPIClient(apiClient))...) require.NoError(t, err) return client From 18b0c1a623b9b86094052520b97169b37a713322 Mon Sep 17 00:00:00 2001 From: Herman Schaaf Date: Mon, 30 Oct 2023 16:10:20 +0000 Subject: [PATCH 6/6] Make sure context error is returned, if any --- scheduler/scheduler.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 87f33097d3..5e564c6492 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/apache/arrow/go/v14/arrow" "runtime/debug" "sync/atomic" "time" + "github.com/apache/arrow/go/v14/arrow" + "github.com/apache/arrow/go/v14/arrow/array" "github.com/apache/arrow/go/v14/arrow/memory" "github.com/cloudquery/plugin-sdk/v4/caser" @@ -186,10 +187,11 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s select { case res <- &message.SyncInsert{Record: resourceToRecord(resource)}: case <-ctx.Done(): - return ctx.Err() + s.logger.Debug().Msg("sync context cancelled") + return context.Cause(ctx) } } - return nil + return context.Cause(ctx) } func resourceToRecord(resource *schema.Resource) arrow.Record {