diff --git a/go.mod b/go.mod index e0af45a066..0254be4e18 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/CloudyKit/jet/v6 v6.2.0 // indirect github.com/Joker/jade v1.1.3 // indirect github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 // indirect + github.com/adrg/xdg v0.4.0 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v13 v13.0.0-20230731205701-112f94971882 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index 749f715153..f49f02f6ac 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 h1:KkH3I3sJuOLP3TjA/dfr4NAY8bghDwnXiU7cTKxQqo0= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06/go.mod h1:7erjKLwalezA0k99cWs5L11HWOAPNjdUZ6RxH1BXbbM= +github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= +github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= @@ -556,6 +558,7 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/premium/monitor.go b/premium/monitor.go new file mode 100644 index 0000000000..b9f3c150a9 --- /dev/null +++ b/premium/monitor.go @@ -0,0 +1,100 @@ +package premium + +import ( + "context" + "errors" + "time" +) + +var ErrNoQuota = errors.New("no remaining quota for the month, please increase your usage limit if you want to continue syncing this plugin") + +const DefaultQuotaCheckInterval = 30 * time.Second +const DefaultMaxQuotaFailures = 10 // 5 minutes + +type quotaChecker struct { + qm QuotaMonitor + duration time.Duration + maxConsecutiveFailures int +} + +type QuotaCheckOption func(*quotaChecker) + +// WithQuotaCheckPeriod controls the time interval between quota checks +func WithQuotaCheckPeriod(duration time.Duration) QuotaCheckOption { + return func(m *quotaChecker) { + m.duration = duration + } +} + +// WithQuotaMaxConsecutiveFailures controls the number of consecutive failed quota checks before the context is cancelled +func WithQuotaMaxConsecutiveFailures(n int) QuotaCheckOption { + return func(m *quotaChecker) { + m.maxConsecutiveFailures = n + } +} + +// WithCancelOnQuotaExceeded monitors the quota usage at intervals defined by duration and cancels the context if the quota is exceeded +func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...QuotaCheckOption) (context.Context, error) { + m := quotaChecker{ + qm: qm, + duration: DefaultQuotaCheckInterval, + maxConsecutiveFailures: DefaultMaxQuotaFailures, + } + for _, op := range ops { + op(&m) + } + + if err := m.checkInitialQuota(ctx); err != nil { + return ctx, err + } + + newCtx := m.startQuotaMonitor(ctx) + + return newCtx, nil +} + +func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { + hasQuota, err := qc.qm.HasQuota(ctx) + if err != nil { + return err + } + + if !hasQuota { + return ErrNoQuota + } + + return nil +} + +func (qc quotaChecker) startQuotaMonitor(ctx context.Context) context.Context { + newCtx, cancelWithCause := context.WithCancelCause(ctx) + go func() { + ticker := time.NewTicker(qc.duration) + consecutiveFailures := 0 + var hasQuotaErrors error + for { + select { + case <-newCtx.Done(): + return + case <-ticker.C: + hasQuota, err := qc.qm.HasQuota(newCtx) + if err != nil { + consecutiveFailures++ + hasQuotaErrors = errors.Join(hasQuotaErrors, err) + if consecutiveFailures >= qc.maxConsecutiveFailures { + cancelWithCause(hasQuotaErrors) + return + } + continue + } + consecutiveFailures = 0 + hasQuotaErrors = nil + if !hasQuota { + cancelWithCause(ErrNoQuota) + return + } + } + } + }() + return newCtx +} diff --git a/premium/monitor_test.go b/premium/monitor_test.go new file mode 100644 index 0000000000..dd6b5beb4b --- /dev/null +++ b/premium/monitor_test.go @@ -0,0 +1,77 @@ +package premium + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type quotaResponse struct { + hasQuota bool + err error +} + +func newFakeQuotaMonitor(hasQuota ...quotaResponse) *fakeQuotaMonitor { + return &fakeQuotaMonitor{responses: hasQuota} +} + +type fakeQuotaMonitor struct { + responses []quotaResponse + calls int +} + +func (f *fakeQuotaMonitor) HasQuota(_ context.Context) (bool, error) { + resp := f.responses[f.calls] + if f.calls < len(f.responses)-1 { + f.calls++ + } + return resp.hasQuota, resp.err +} + +func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { + ctx := context.Background() + + responses := []quotaResponse{ + {false, nil}, + } + _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...)) + + require.Error(t, err) +} + +func TestWithCancelOnQuotaExceeded_NoQuota(t *testing.T) { + ctx := context.Background() + + responses := []quotaResponse{ + {true, nil}, + {false, nil}, + } + ctx, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(responses...), WithQuotaCheckPeriod(1*time.Millisecond)) + require.NoError(t, err) + + <-ctx.Done() + cause := context.Cause(ctx) + require.Equal(t, ErrNoQuota, cause) +} + +func TestWithCancelOnQuotaCheckConsecutiveFailures(t *testing.T) { + ctx := context.Background() + + responses := []quotaResponse{ + {true, nil}, + {false, errors.New("test2")}, + {false, errors.New("test3")}, + } + ctx, err := WithCancelOnQuotaExceeded(ctx, + newFakeQuotaMonitor(responses...), + WithQuotaCheckPeriod(1*time.Millisecond), + WithQuotaMaxConsecutiveFailures(2), + ) + require.NoError(t, err) + <-ctx.Done() + cause := context.Cause(ctx) + require.Equal(t, "test2\ntest3", cause.Error()) +} diff --git a/premium/tables.go b/premium/tables.go new file mode 100644 index 0000000000..2f9a8217ba --- /dev/null +++ b/premium/tables.go @@ -0,0 +1,27 @@ +package premium + +import "github.com/cloudquery/plugin-sdk/v4/schema" + +// ContainsPaidTables returns true if any of the tables are paid +func ContainsPaidTables(tables schema.Tables) bool { + for _, t := range tables { + if t.IsPaid { + return true + } + } + return false +} + +// MakeAllTablesPaid sets all tables to paid +func MakeAllTablesPaid(tables schema.Tables) schema.Tables { + for _, table := range tables { + MakeTablePaid(table) + } + return tables +} + +// MakeTablePaid sets the table to paid +func MakeTablePaid(table *schema.Table) *schema.Table { + table.IsPaid = true + return table +} diff --git a/premium/tables_test.go b/premium/tables_test.go new file mode 100644 index 0000000000..251be02d0b --- /dev/null +++ b/premium/tables_test.go @@ -0,0 +1,39 @@ +package premium + +import ( + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestContainsPaidTables(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: true}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + assert.False(t, ContainsPaidTables(noPaidTables), "no paid tables") + assert.True(t, ContainsPaidTables(paidTables), "paid tables") +} + +func TestMakeAllTablesPaid(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := MakeAllTablesPaid(noPaidTables) + + assert.Equal(t, 3, len(paidTables)) + for _, table := range paidTables { + assert.True(t, table.IsPaid) + } +} diff --git a/premium/usage.go b/premium/usage.go index e8016949aa..7e2fdcedb6 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -3,16 +3,21 @@ package premium import ( "context" "fmt" - cqapi "github.com/cloudquery/cloudquery-api-go" - "github.com/google/uuid" - "github.com/rs/zerolog/log" "math/rand" "net/http" "sync/atomic" "time" + + cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/auth" + "github.com/cloudquery/cloudquery-api-go/config" + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) const ( + defaultAPIURL = "https://api.cloudquery.io" defaultBatchLimit = 1000 defaultMaxRetries = 5 defaultMaxWaitTime = 60 * time.Second @@ -20,60 +25,108 @@ const ( defaultMaxTimeBetweenFlushes = 30 * time.Second ) -type UsageClient interface { - // Increase updates the usage by the given number of rows - Increase(context.Context, uint32) +type QuotaMonitor interface { // HasQuota returns true if the quota has not been exceeded HasQuota(context.Context) (bool, error) +} + +type UsageClient interface { + QuotaMonitor + // Increase updates the usage by the given number of rows + Increase(uint32) error // Close flushes any remaining rows and closes the quota service Close() error } -type UpdaterOptions func(updater *BatchUpdater) +type UsageClientOptions func(updater *BatchUpdater) // WithBatchLimit sets the maximum number of rows to update in a single request -func WithBatchLimit(batchLimit uint32) UpdaterOptions { +func WithBatchLimit(batchLimit uint32) UsageClientOptions { return func(updater *BatchUpdater) { updater.batchLimit = batchLimit } } // WithMaxTimeBetweenFlushes sets the flush duration - the time at which an update will be triggered even if the batch limit is not reached -func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxTimeBetweenFlushes = maxTimeBetweenFlushes } } // WithMinTimeBetweenFlushes sets the minimum time between updates -func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.minTimeBetweenFlushes = minTimeBetweenFlushes } } // WithMaxRetries sets the maximum number of retries to update the usage in case of an API error -func WithMaxRetries(maxRetries int) UpdaterOptions { +func WithMaxRetries(maxRetries int) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxRetries = maxRetries } } // WithMaxWaitTime sets the maximum time to wait before retrying a failed update -func WithMaxWaitTime(maxWaitTime time.Duration) UpdaterOptions { +func WithMaxWaitTime(maxWaitTime time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxWaitTime = maxWaitTime } } +// WithLogger sets the logger to use - defaults to a no-op logger +func WithLogger(logger zerolog.Logger) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.logger = logger + } +} + +// WithURL sets the API URL to use - defaults to https://api.cloudquery.io +func WithURL(url string) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.url = url + } +} + +// withTeamName sets the team name to use - defaults to the team name from the configuration +func withTeamName(teamName cqapi.TeamName) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.teamName = teamName + } +} + +// WithAPIClient sets the API client to use - defaults to a client using a bearer token generated from the refresh token stored in the configuration +func WithAPIClient(apiClient *cqapi.ClientWithResponses) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.apiClient = apiClient + } +} + +func WithPluginTeam(pluginTeam string) cqapi.PluginTeam { + return pluginTeam +} + +func WithPluginKind(pluginKind string) cqapi.PluginKind { + return cqapi.PluginKind(pluginKind) +} + +func WithPluginName(pluginName string) cqapi.PluginName { + return pluginName +} + +var _ UsageClient = (*BatchUpdater)(nil) + type BatchUpdater struct { + logger zerolog.Logger + url string apiClient *cqapi.ClientWithResponses // Plugin details - teamName string - pluginTeam string - pluginKind string - pluginName string + teamName cqapi.TeamName + pluginTeam cqapi.PluginTeam + pluginKind cqapi.PluginKind + pluginName cqapi.PluginName // Configuration batchLimit uint32 @@ -91,11 +144,11 @@ type BatchUpdater struct { isClosed bool } -func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, teamName, pluginTeam, pluginKind, pluginName string, ops ...UpdaterOptions) *BatchUpdater { +func NewUsageClient(pluginTeam cqapi.PluginTeam, pluginKind cqapi.PluginKind, pluginName cqapi.PluginName, ops ...UsageClientOptions) (*BatchUpdater, error) { u := &BatchUpdater{ - apiClient: apiClient, + logger: zerolog.Nop(), + url: defaultAPIURL, - teamName: teamName, pluginTeam: pluginTeam, pluginKind: pluginKind, pluginName: pluginName, @@ -113,12 +166,38 @@ func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, t op(u) } - u.backgroundUpdater(ctx) + // Set team name from configuration if not provided + if u.teamName == "" { + teamName, err := config.GetValue("team") + if err != nil { + return nil, fmt.Errorf("failed to get team name from config: %w", err) + } + u.teamName = teamName + } + + // Create a default api client if none was provided + if u.apiClient == nil { + tokenClient := auth.NewTokenClient() + ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { + token, err := tokenClient.GetToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + })) + if err != nil { + return nil, fmt.Errorf("failed to create api client: %w", err) + } + u.apiClient = ac + } + + u.backgroundUpdater() - return u + return u, nil } -func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { +func (u *BatchUpdater) Increase(rows uint32) error { if rows <= 0 { return fmt.Errorf("rows must be greater than zero got %d", rows) } @@ -140,7 +219,8 @@ func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { } func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { - usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, cqapi.PluginKind(u.pluginKind), u.pluginName) + u.logger.Debug().Str("url", u.url).Str("team", u.teamName).Str("pluginTeam", u.pluginTeam).Str("pluginKind", string(u.pluginKind)).Str("pluginName", u.pluginName).Msg("checking quota") + usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, u.pluginKind, u.pluginName) if err != nil { return false, fmt.Errorf("failed to get usage: %w", err) } @@ -150,7 +230,7 @@ func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { return *usage.JSON200.RemainingRows > 0, nil } -func (u *BatchUpdater) Close(_ context.Context) error { +func (u *BatchUpdater) Close() error { u.isClosed = true close(u.done) @@ -158,7 +238,8 @@ func (u *BatchUpdater) Close(_ context.Context) error { return <-u.closeError } -func (u *BatchUpdater) backgroundUpdater(ctx context.Context) { +func (u *BatchUpdater) backgroundUpdater() { + ctx := context.Background() started := make(chan struct{}) flushDuration := time.NewTicker(u.maxTimeBetweenFlushes) @@ -216,12 +297,13 @@ func (u *BatchUpdater) backgroundUpdater(ctx context.Context) { func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numberToUpdate uint32) error { for retry := 0; retry < u.maxRetries; retry++ { + u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", numberToUpdate).Msg("updating usage") queryStartTime := time.Now() resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, cqapi.IncreaseTeamPluginUsageJSONRequestBody{ RequestId: uuid.New(), PluginTeam: u.pluginTeam, - PluginKind: cqapi.PluginKind(u.pluginKind), + PluginKind: u.pluginKind, PluginName: u.pluginName, Rows: int(numberToUpdate), }) @@ -229,6 +311,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numbe return fmt.Errorf("failed to update usage: %w", err) } if resp.StatusCode() >= 200 && resp.StatusCode() < 300 { + u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", numberToUpdate).Msg("usage updated") u.lastUpdateTime = time.Now().UTC() return nil } diff --git a/premium/usage_test.go b/premium/usage_test.go index 236b35c219..0493337715 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -4,16 +4,68 @@ import ( "context" "encoding/json" "fmt" - cqapi "github.com/cloudquery/cloudquery-api-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "math" "net/http" "net/http/httptest" "testing" "time" + + cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/config" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestUsageService_NewUsageClient_Defaults(t *testing.T) { + err := config.SetConfigHome(t.TempDir()) + require.NoError(t, err) + + err = config.SetValue("team", "config-team") + require.NoError(t, err) + + uc, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + ) + require.NoError(t, err) + + assert.NotNil(t, uc.apiClient) + assert.Equal(t, "config-team", uc.teamName) + assert.Equal(t, zerolog.Nop(), uc.logger) + assert.Equal(t, 5, uc.maxRetries) + assert.Equal(t, 60*time.Second, uc.maxWaitTime) + assert.Equal(t, 30*time.Second, uc.maxTimeBetweenFlushes) +} + +func TestUsageService_NewUsageClient_Override(t *testing.T) { + ac, err := cqapi.NewClientWithResponses("http://localhost") + require.NoError(t, err) + + logger := zerolog.New(zerolog.NewTestWriter(t)) + + uc, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + WithLogger(logger), + WithAPIClient(ac), + withTeamName("override-team-name"), + WithMaxRetries(10), + WithMaxWaitTime(120*time.Second), + WithMaxTimeBetweenFlushes(10*time.Second), + ) + require.NoError(t, err) + + assert.Equal(t, ac, uc.apiClient) + assert.Equal(t, "override-team-name", uc.teamName) + assert.Equal(t, logger, uc.logger) + assert.Equal(t, 10, uc.maxRetries) + assert.Equal(t, 120*time.Second, uc.maxWaitTime) + assert.Equal(t, 10*time.Second, uc.maxTimeBetweenFlushes) +} + func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { ctx := context.Background() @@ -23,7 +75,7 @@ func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -40,7 +92,7 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -49,29 +101,26 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { } func TestUsageService_ZeroBatchSize(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") } func TestUsageService_WithBatchSize(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -80,13 +129,13 @@ func TestUsageService_WithBatchSize(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize))) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize))) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -94,7 +143,6 @@ func TestUsageService_WithBatchSize(t *testing.T) { } func TestUsageService_WithFlushDuration(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -103,14 +151,14 @@ func TestUsageService_WithFlushDuration(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) for i := 0; i < 10; i++ { - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.NoError(t, err) time.Sleep(5 * time.Millisecond) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 100, s.sumOfUpdates(), "total should equal number of updated rows") @@ -118,21 +166,19 @@ func TestUsageService_WithFlushDuration(t *testing.T) { } func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) + usageClient := newClient(t, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -140,58 +186,52 @@ func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { } func TestUsageService_NoUpdates(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_UpdatesWithZeroRows(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Increase(ctx, 0) + err = usageClient.Increase(0) require.Error(t, err, "should not be able to update with zero rows") - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_ShouldNotUpdateClosedService(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) // Close the service first - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.Error(t, err, "should not be able to update closed service") assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") @@ -247,7 +287,7 @@ func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -294,7 +334,7 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -311,8 +351,15 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } } -func newClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, ops ...UpdaterOptions) *BatchUpdater { - return NewUsageClient(ctx, apiClient, "myteam", "mnorbury-team", "source", "vault", ops...) +func newClient(t *testing.T, apiClient *cqapi.ClientWithResponses, ops ...UsageClientOptions) *BatchUpdater { + client, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + append(ops, withTeamName("team-name"), WithAPIClient(apiClient))...) + require.NoError(t, err) + + return client } func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testStage { diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 87f33097d3..5e564c6492 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/apache/arrow/go/v14/arrow" "runtime/debug" "sync/atomic" "time" + "github.com/apache/arrow/go/v14/arrow" + "github.com/apache/arrow/go/v14/arrow/array" "github.com/apache/arrow/go/v14/arrow/memory" "github.com/cloudquery/plugin-sdk/v4/caser" @@ -186,10 +187,11 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s select { case res <- &message.SyncInsert{Record: resourceToRecord(resource)}: case <-ctx.Done(): - return ctx.Err() + s.logger.Debug().Msg("sync context cancelled") + return context.Cause(ctx) } } - return nil + return context.Cause(ctx) } func resourceToRecord(resource *schema.Resource) arrow.Record {