diff --git a/go.mod b/go.mod index 94cf67f755..af10deba82 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/CloudyKit/jet/v6 v6.2.0 // indirect github.com/Joker/jade v1.1.3 // indirect github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 // indirect + github.com/adrg/xdg v0.4.0 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v13 v13.0.0-20230731205701-112f94971882 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index d7fe03b424..55ca10e7c2 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 h1:KkH3I3sJuOLP3TjA/dfr4NAY8bghDwnXiU7cTKxQqo0= github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06/go.mod h1:7erjKLwalezA0k99cWs5L11HWOAPNjdUZ6RxH1BXbbM= +github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= +github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= @@ -556,6 +558,7 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go index c2ad4578bc..bd08a4e79f 100644 --- a/internal/servers/plugin/v3/plugin.go +++ b/internal/servers/plugin/v3/plugin.go @@ -210,6 +210,10 @@ func (s *Server) Sync(req *pb.Sync_Request, stream pb.Plugin_SyncServer) error { } } + if err := s.Plugin.OnSyncFinish(ctx); err != nil { + return status.Errorf(codes.Internal, "failed to finish sync: %v", err) + } + return syncErr } diff --git a/plugin/plugin.go b/plugin/plugin.go index 0d63b51ec9..0811ee4d5a 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -141,6 +141,19 @@ func (p *Plugin) OnBeforeSend(ctx context.Context, msg message.SyncMessage) (mes return msg, nil } +// OnSyncFinisher is an interface that can be implemented by a plugin client to be notified when a sync finishes. +type OnSyncFinisher interface { + OnSyncFinish(context.Context) error +} + +// OnSyncFinish gets called after a sync finishes. +func (p *Plugin) OnSyncFinish(ctx context.Context) error { + if v, ok := p.client.(OnSyncFinisher); ok { + return v.OnSyncFinish(ctx) + } + return nil +} + // IsStaticLinkingEnabled whether static linking is to be enabled func (p *Plugin) IsStaticLinkingEnabled() bool { return p.staticLinking diff --git a/premium/monitor.go b/premium/monitor.go new file mode 100644 index 0000000000..f5c707ce93 --- /dev/null +++ b/premium/monitor.go @@ -0,0 +1,80 @@ +package premium + +import ( + "context" + "errors" + "time" +) + +var ErrNoQuota = errors.New("no remaining quota for the month, please increase your usage limit if you want to continue syncing this plugin") + +const DefaultQuotaCheckInterval = 30 * time.Second + +type quotaChecker struct { + qm QuotaMonitor + duration time.Duration +} + +type QuotaCheckOption func(*quotaChecker) + +// WithQuotaCheckPeriod the time interval between quota checks +func WithQuotaCheckPeriod(duration time.Duration) QuotaCheckOption { + return func(m *quotaChecker) { + m.duration = duration + } +} + +// WithCancelOnQuotaExceeded monitors the quota usage at intervals defined by duration and cancels the context if the quota is exceeded +func WithCancelOnQuotaExceeded(ctx context.Context, qm QuotaMonitor, ops ...QuotaCheckOption) (context.Context, func(), error) { + m := quotaChecker{ + qm: qm, + duration: DefaultQuotaCheckInterval, + } + for _, op := range ops { + op(&m) + } + + if err := m.checkInitialQuota(ctx); err != nil { + return ctx, nil, err + } + + ctx, cancel := m.startQuotaMonitor(ctx) + + return ctx, cancel, nil +} + +func (qc quotaChecker) checkInitialQuota(ctx context.Context) error { + hasQuota, err := qc.qm.HasQuota(ctx) + if err != nil { + return err + } + + if !hasQuota { + return ErrNoQuota + } + + return nil +} + +func (qc quotaChecker) startQuotaMonitor(ctx context.Context) (context.Context, func()) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + ticker := time.NewTicker(qc.duration) + for { + select { + case <-newCtx.Done(): + return + case <-ticker.C: + hasQuota, err := qc.qm.HasQuota(newCtx) + if err != nil { + continue + } + if !hasQuota { + return + } + } + } + }() + return newCtx, cancel +} diff --git a/premium/monitor_test.go b/premium/monitor_test.go new file mode 100644 index 0000000000..68b1a68ce5 --- /dev/null +++ b/premium/monitor_test.go @@ -0,0 +1,52 @@ +package premium + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func newFakeQuotaMonitor(hasQuota ...bool) *fakeQuotaMonitor { + return &fakeQuotaMonitor{hasQuota: hasQuota} +} + +type fakeQuotaMonitor struct { + hasQuota []bool + calls int +} + +func (f *fakeQuotaMonitor) HasQuota(_ context.Context) (bool, error) { + hasQuota := f.hasQuota[f.calls] + if f.calls < len(f.hasQuota)-1 { + f.calls++ + } + return hasQuota, nil +} + +func TestWithCancelOnQuotaExceeded_NoInitialQuota(t *testing.T) { + ctx := context.Background() + + _, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(false)) + + require.Error(t, err) +} + +func TestWithCancelOnQuotaExceeded_NoQuota(t *testing.T) { + ctx := context.Background() + + ctx, _, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, false), WithQuotaCheckPeriod(1*time.Millisecond)) + require.NoError(t, err) + + <-ctx.Done() +} + +func TestWithCancelOnQuotaExceeded_HasQuotaCanceled(t *testing.T) { + ctx := context.Background() + + ctx, cancel, err := WithCancelOnQuotaExceeded(ctx, newFakeQuotaMonitor(true, true, true), WithQuotaCheckPeriod(1*time.Millisecond)) + require.NoError(t, err) + cancel() + + <-ctx.Done() +} diff --git a/premium/tables.go b/premium/tables.go new file mode 100644 index 0000000000..2f9a8217ba --- /dev/null +++ b/premium/tables.go @@ -0,0 +1,27 @@ +package premium + +import "github.com/cloudquery/plugin-sdk/v4/schema" + +// ContainsPaidTables returns true if any of the tables are paid +func ContainsPaidTables(tables schema.Tables) bool { + for _, t := range tables { + if t.IsPaid { + return true + } + } + return false +} + +// MakeAllTablesPaid sets all tables to paid +func MakeAllTablesPaid(tables schema.Tables) schema.Tables { + for _, table := range tables { + MakeTablePaid(table) + } + return tables +} + +// MakeTablePaid sets the table to paid +func MakeTablePaid(table *schema.Table) *schema.Table { + table.IsPaid = true + return table +} diff --git a/premium/tables_test.go b/premium/tables_test.go new file mode 100644 index 0000000000..251be02d0b --- /dev/null +++ b/premium/tables_test.go @@ -0,0 +1,39 @@ +package premium + +import ( + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestContainsPaidTables(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: true}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + assert.False(t, ContainsPaidTables(noPaidTables), "no paid tables") + assert.True(t, ContainsPaidTables(paidTables), "paid tables") +} + +func TestMakeAllTablesPaid(t *testing.T) { + noPaidTables := schema.Tables{ + &schema.Table{Name: "table1", IsPaid: false}, + &schema.Table{Name: "table2", IsPaid: false}, + &schema.Table{Name: "table3", IsPaid: false}, + } + + paidTables := MakeAllTablesPaid(noPaidTables) + + assert.Equal(t, 3, len(paidTables)) + for _, table := range paidTables { + assert.True(t, table.IsPaid) + } +} diff --git a/premium/usage.go b/premium/usage.go index e8016949aa..f41e7be318 100644 --- a/premium/usage.go +++ b/premium/usage.go @@ -4,7 +4,10 @@ import ( "context" "fmt" cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/auth" + "github.com/cloudquery/cloudquery-api-go/config" "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "math/rand" "net/http" @@ -13,6 +16,7 @@ import ( ) const ( + defaultAPIURL = "https://api.cloudquery.io" defaultBatchLimit = 1000 defaultMaxRetries = 5 defaultMaxWaitTime = 60 * time.Second @@ -20,60 +24,108 @@ const ( defaultMaxTimeBetweenFlushes = 30 * time.Second ) -type UsageClient interface { - // Increase updates the usage by the given number of rows - Increase(context.Context, uint32) +type QuotaMonitor interface { // HasQuota returns true if the quota has not been exceeded HasQuota(context.Context) (bool, error) +} + +type UsageClient interface { + QuotaMonitor + // Increase updates the usage by the given number of rows + Increase(uint32) error // Close flushes any remaining rows and closes the quota service Close() error } -type UpdaterOptions func(updater *BatchUpdater) +type UsageClientOptions func(updater *BatchUpdater) // WithBatchLimit sets the maximum number of rows to update in a single request -func WithBatchLimit(batchLimit uint32) UpdaterOptions { +func WithBatchLimit(batchLimit uint32) UsageClientOptions { return func(updater *BatchUpdater) { updater.batchLimit = batchLimit } } // WithMaxTimeBetweenFlushes sets the flush duration - the time at which an update will be triggered even if the batch limit is not reached -func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMaxTimeBetweenFlushes(maxTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxTimeBetweenFlushes = maxTimeBetweenFlushes } } // WithMinTimeBetweenFlushes sets the minimum time between updates -func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UpdaterOptions { +func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.minTimeBetweenFlushes = minTimeBetweenFlushes } } // WithMaxRetries sets the maximum number of retries to update the usage in case of an API error -func WithMaxRetries(maxRetries int) UpdaterOptions { +func WithMaxRetries(maxRetries int) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxRetries = maxRetries } } // WithMaxWaitTime sets the maximum time to wait before retrying a failed update -func WithMaxWaitTime(maxWaitTime time.Duration) UpdaterOptions { +func WithMaxWaitTime(maxWaitTime time.Duration) UsageClientOptions { return func(updater *BatchUpdater) { updater.maxWaitTime = maxWaitTime } } +// WithLogger sets the logger to use - defaults to a no-op logger +func WithLogger(logger zerolog.Logger) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.logger = logger + } +} + +// WithURL sets the API URL to use - defaults to https://api.cloudquery.io +func WithURL(url string) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.url = url + } +} + +// WithTeamName sets the team name to use - defaults to the team name from the configuration +func WithTeamName(teamName cqapi.TeamName) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.teamName = teamName + } +} + +// WithAPIClient sets the API client to use - defaults to a client using a bearer token generated from the refresh token stored in the configuration +func WithAPIClient(apiClient *cqapi.ClientWithResponses) UsageClientOptions { + return func(updater *BatchUpdater) { + updater.apiClient = apiClient + } +} + +func WithPluginTeam(pluginTeam string) cqapi.PluginTeam { + return pluginTeam +} + +func WithPluginKind(pluginKind string) cqapi.PluginKind { + return cqapi.PluginKind(pluginKind) +} + +func WithPluginName(pluginName string) cqapi.PluginName { + return pluginName +} + +var _ UsageClient = (*BatchUpdater)(nil) + type BatchUpdater struct { + logger zerolog.Logger + url string apiClient *cqapi.ClientWithResponses // Plugin details - teamName string - pluginTeam string - pluginKind string - pluginName string + teamName cqapi.TeamName + pluginTeam cqapi.PluginTeam + pluginKind cqapi.PluginKind + pluginName cqapi.PluginName // Configuration batchLimit uint32 @@ -91,11 +143,11 @@ type BatchUpdater struct { isClosed bool } -func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, teamName, pluginTeam, pluginKind, pluginName string, ops ...UpdaterOptions) *BatchUpdater { +func NewUsageClient(pluginTeam cqapi.PluginTeam, pluginKind cqapi.PluginKind, pluginName cqapi.PluginName, ops ...UsageClientOptions) (*BatchUpdater, error) { u := &BatchUpdater{ - apiClient: apiClient, + logger: zerolog.Nop(), + url: defaultAPIURL, - teamName: teamName, pluginTeam: pluginTeam, pluginKind: pluginKind, pluginName: pluginName, @@ -113,12 +165,38 @@ func NewUsageClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, t op(u) } - u.backgroundUpdater(ctx) + // Set team name from configuration if not provided + if u.teamName == "" { + teamName, err := config.GetValue("team") + if err != nil { + return nil, fmt.Errorf("failed to get team name from config: %w", err) + } + u.teamName = teamName + } + + // Create a default api client if none was provided + if u.apiClient == nil { + tokenClient := auth.NewTokenClient() + ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { + token, err := tokenClient.GetToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + })) + if err != nil { + return nil, fmt.Errorf("failed to create api client: %w", err) + } + u.apiClient = ac + } + + u.backgroundUpdater() - return u + return u, nil } -func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { +func (u *BatchUpdater) Increase(rows uint32) error { if rows <= 0 { return fmt.Errorf("rows must be greater than zero got %d", rows) } @@ -140,7 +218,7 @@ func (u *BatchUpdater) Increase(_ context.Context, rows uint32) error { } func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { - usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, cqapi.PluginKind(u.pluginKind), u.pluginName) + usage, err := u.apiClient.GetTeamPluginUsageWithResponse(ctx, u.teamName, u.pluginTeam, u.pluginKind, u.pluginName) if err != nil { return false, fmt.Errorf("failed to get usage: %w", err) } @@ -150,7 +228,7 @@ func (u *BatchUpdater) HasQuota(ctx context.Context) (bool, error) { return *usage.JSON200.RemainingRows > 0, nil } -func (u *BatchUpdater) Close(_ context.Context) error { +func (u *BatchUpdater) Close() error { u.isClosed = true close(u.done) @@ -158,7 +236,8 @@ func (u *BatchUpdater) Close(_ context.Context) error { return <-u.closeError } -func (u *BatchUpdater) backgroundUpdater(ctx context.Context) { +func (u *BatchUpdater) backgroundUpdater() { + ctx := context.Background() started := make(chan struct{}) flushDuration := time.NewTicker(u.maxTimeBetweenFlushes) @@ -221,7 +300,7 @@ func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, numbe resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, cqapi.IncreaseTeamPluginUsageJSONRequestBody{ RequestId: uuid.New(), PluginTeam: u.pluginTeam, - PluginKind: cqapi.PluginKind(u.pluginKind), + PluginKind: u.pluginKind, PluginName: u.pluginName, Rows: int(numberToUpdate), }) diff --git a/premium/usage_test.go b/premium/usage_test.go index 236b35c219..033bd55791 100644 --- a/premium/usage_test.go +++ b/premium/usage_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" cqapi "github.com/cloudquery/cloudquery-api-go" + "github.com/cloudquery/cloudquery-api-go/config" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "math" @@ -14,6 +16,45 @@ import ( "time" ) +func TestUsageService_NewUsageClient_Defaults(t *testing.T) { + err := config.SetConfigHome(t.TempDir()) + require.NoError(t, err) + + err = config.SetValue("team", "config-team") + require.NoError(t, err) + + uc, err := NewUsageClient( + WithPluginTeam("plugin-team"), + WithPluginKind("source"), + WithPluginName("vault"), + ) + require.NoError(t, err) + + assert.NotNil(t, uc.apiClient) + assert.Equal(t, "config-team", uc.teamName) + assert.Equal(t, zerolog.Nop(), uc.logger) + assert.Equal(t, 5, uc.maxRetries) + assert.Equal(t, 60*time.Second, uc.maxWaitTime) + assert.Equal(t, 30*time.Second, uc.maxTimeBetweenFlushes) +} + +func TestUsageService_NewUsageClient_Override(t *testing.T) { + ac, err := cqapi.NewClientWithResponses("http://localhost") + require.NoError(t, err) + + logger := zerolog.New(zerolog.NewTestWriter(t)) + + uc, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), WithLogger(logger), WithAPIClient(ac), WithTeamName("override-team-name"), WithMaxRetries(10), WithMaxWaitTime(120*time.Second), WithMaxTimeBetweenFlushes(10*time.Second)) + require.NoError(t, err) + + assert.Equal(t, ac, uc.apiClient) + assert.Equal(t, "override-team-name", uc.teamName) + assert.Equal(t, logger, uc.logger) + assert.Equal(t, 10, uc.maxRetries) + assert.Equal(t, 120*time.Second, uc.maxWaitTime) + assert.Equal(t, 10*time.Second, uc.maxTimeBetweenFlushes) +} + func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { ctx := context.Background() @@ -23,7 +64,7 @@ func TestUsageService_HasQuota_NoRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -40,7 +81,7 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) hasQuota, err := usageClient.HasQuota(ctx) require.NoError(t, err) @@ -49,29 +90,26 @@ func TestUsageService_HasQuota_WithRowsRemaining(t *testing.T) { } func TestUsageService_ZeroBatchSize(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") } func TestUsageService_WithBatchSize(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -80,13 +118,13 @@ func TestUsageService_WithBatchSize(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize))) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize))) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -94,7 +132,6 @@ func TestUsageService_WithBatchSize(t *testing.T) { } func TestUsageService_WithFlushDuration(t *testing.T) { - ctx := context.Background() batchSize := 2000 s := createTestServer(t) @@ -103,14 +140,14 @@ func TestUsageService_WithFlushDuration(t *testing.T) { apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) + usageClient := newClient(t, apiClient, WithBatchLimit(uint32(batchSize)), WithMaxTimeBetweenFlushes(1*time.Millisecond), WithMinTimeBetweenFlushes(0*time.Millisecond)) for i := 0; i < 10; i++ { - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.NoError(t, err) time.Sleep(5 * time.Millisecond) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 100, s.sumOfUpdates(), "total should equal number of updated rows") @@ -118,21 +155,19 @@ func TestUsageService_WithFlushDuration(t *testing.T) { } func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) + usageClient := newClient(t, apiClient, WithBatchLimit(0), WithMinTimeBetweenFlushes(30*time.Second)) for i := 0; i < 10000; i++ { - err = usageClient.Increase(ctx, 1) + err = usageClient.Increase(1) require.NoError(t, err) } - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 10000, s.sumOfUpdates(), "total should equal number of updated rows") @@ -140,58 +175,52 @@ func TestUsageService_WithMinimumUpdateDuration(t *testing.T) { } func TestUsageService_NoUpdates(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_UpdatesWithZeroRows(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) - err = usageClient.Increase(ctx, 0) + err = usageClient.Increase(0) require.Error(t, err, "should not be able to update with zero rows") - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") } func TestUsageService_ShouldNotUpdateClosedService(t *testing.T) { - ctx := context.Background() - s := createTestServer(t) defer s.server.Close() apiClient, err := cqapi.NewClientWithResponses(s.server.URL) require.NoError(t, err) - usageClient := newClient(ctx, apiClient, WithBatchLimit(0)) + usageClient := newClient(t, apiClient, WithBatchLimit(0)) // Close the service first - err = usageClient.Close(ctx) + err = usageClient.Close() require.NoError(t, err) - err = usageClient.Increase(ctx, 10) + err = usageClient.Increase(10) require.Error(t, err, "should not be able to update closed service") assert.Equal(t, 0, s.numberOfUpdates(), "total number of updates should be zero") @@ -247,7 +276,7 @@ func TestUsageService_CalculateRetryDuration_Exp(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -294,7 +323,7 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } for _, tt := range tests { - usageClient := newClient(context.Background(), nil) + usageClient := newClient(t, nil) if tt.ops != nil { tt.ops(usageClient) } @@ -311,8 +340,11 @@ func TestUsageService_CalculateRetryDuration_ServerBackPressure(t *testing.T) { } } -func newClient(ctx context.Context, apiClient *cqapi.ClientWithResponses, ops ...UpdaterOptions) *BatchUpdater { - return NewUsageClient(ctx, apiClient, "myteam", "mnorbury-team", "source", "vault", ops...) +func newClient(t *testing.T, apiClient *cqapi.ClientWithResponses, ops ...UsageClientOptions) *BatchUpdater { + client, err := NewUsageClient(WithPluginTeam("plugin-team"), WithPluginKind("source"), WithPluginName("vault"), append(ops, WithTeamName("team-name"), WithAPIClient(apiClient))...) + require.NoError(t, err) + + return client } func createTestServerWithRemainingRows(t *testing.T, remainingRows int) *testStage { diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 2332500e33..9243c192e5 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/apache/arrow/go/v14/arrow" "runtime/debug" "sync/atomic" "time" @@ -182,15 +183,23 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s } }() for resource := range resources { - vector := resource.GetValues() - bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) - scalar.AppendToRecordBuilder(bldr, vector) - rec := bldr.NewRecord() - res <- &message.SyncInsert{Record: rec} + select { + case res <- &message.SyncInsert{Record: resourceToRecord(resource)}: + case <-ctx.Done(): + return nil + } } return nil } +func resourceToRecord(resource *schema.Resource) arrow.Record { + vector := resource.GetValues() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) + scalar.AppendToRecordBuilder(bldr, vector) + rec := bldr.NewRecord() + return rec +} + func (s *syncClient) logTablesMetrics(tables schema.Tables, client Client) { clientName := client.ID() for _, table := range tables { diff --git a/scheduler/scheduler_dfs.go b/scheduler/scheduler_dfs.go index e30cd8a62c..e353a443cc 100644 --- a/scheduler/scheduler_dfs.go +++ b/scheduler/scheduler_dfs.go @@ -176,7 +176,10 @@ func (s *syncClient) resolveResourcesDfs(ctx context.Context, table *schema.Tabl atomic.AddUint64(&tableMetrics.Errors, 1) return } - resourcesChan <- resolvedResource + select { + case resourcesChan <- resolvedResource: + case <-ctx.Done(): + } }() } wg.Wait() diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index 1e8ba41f5d..15e7777f08 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -2,6 +2,9 @@ package scheduler import ( "context" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "github.com/apache/arrow/go/v14/arrow" @@ -40,6 +43,22 @@ func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resourc panic("ColumnResolver") } +func testTableSuccessWithData(data []any) *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: func(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { + res <- data + return nil + }, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + func testTableSuccess() *schema.Table { return &schema.Table{ Name: "test_table_success", @@ -233,6 +252,67 @@ func TestScheduler(t *testing.T) { } } +func TestScheduler_Cancellation(t *testing.T) { + data := make([]any, 100) + + tests := []struct { + name string + data []any + cancel bool + messageCount int + }{ + { + name: "should consume all message", + data: data, + cancel: false, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + { + name: "should not consume all message on cancel", + data: data, + cancel: true, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + } + + for _, strategy := range AllStrategies { + for _, tc := range tests { + t.Run(fmt.Sprintf("%s_%s", tc.name, strategy.String()), func(t *testing.T) { + sc := NewScheduler(WithLogger(zerolog.New(zerolog.NewTestWriter(t))), WithStrategy(strategy)) + + messages := make(chan message.SyncMessage) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + err := sc.Sync( + ctx, + &testExecutionClient{}, + []*schema.Table{testTableSuccessWithData(tc.data)}, + messages, + ) + require.NoError(t, err) + close(messages) + }() + + messageConsumed := 0 + for range messages { + if tc.cancel { + cancel() + } + messageConsumed++ + } + + if tc.cancel { + assert.NotEqual(t, tc.messageCount, messageConsumed) + } else { + assert.Equal(t, tc.messageCount, messageConsumed) + } + }) + } + } +} + func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, deterministicCQID bool) { ctx := context.Background() tables := []*schema.Table{}