diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go index c2ad4578bc..bd08a4e79f 100644 --- a/internal/servers/plugin/v3/plugin.go +++ b/internal/servers/plugin/v3/plugin.go @@ -210,6 +210,10 @@ func (s *Server) Sync(req *pb.Sync_Request, stream pb.Plugin_SyncServer) error { } } + if err := s.Plugin.OnSyncFinish(ctx); err != nil { + return status.Errorf(codes.Internal, "failed to finish sync: %v", err) + } + return syncErr } diff --git a/plugin/plugin.go b/plugin/plugin.go index 0d63b51ec9..0811ee4d5a 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -141,6 +141,19 @@ func (p *Plugin) OnBeforeSend(ctx context.Context, msg message.SyncMessage) (mes return msg, nil } +// OnSyncFinisher is an interface that can be implemented by a plugin client to be notified when a sync finishes. +type OnSyncFinisher interface { + OnSyncFinish(context.Context) error +} + +// OnSyncFinish gets called after a sync finishes. +func (p *Plugin) OnSyncFinish(ctx context.Context) error { + if v, ok := p.client.(OnSyncFinisher); ok { + return v.OnSyncFinish(ctx) + } + return nil +} + // IsStaticLinkingEnabled whether static linking is to be enabled func (p *Plugin) IsStaticLinkingEnabled() bool { return p.staticLinking diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 2332500e33..87f33097d3 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/apache/arrow/go/v14/arrow" "runtime/debug" "sync/atomic" "time" @@ -182,15 +183,23 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s } }() for resource := range resources { - vector := resource.GetValues() - bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) - scalar.AppendToRecordBuilder(bldr, vector) - rec := bldr.NewRecord() - res <- &message.SyncInsert{Record: rec} + select { + case res <- &message.SyncInsert{Record: resourceToRecord(resource)}: + case <-ctx.Done(): + return ctx.Err() + } } return nil } +func resourceToRecord(resource *schema.Resource) arrow.Record { + vector := resource.GetValues() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) + scalar.AppendToRecordBuilder(bldr, vector) + rec := bldr.NewRecord() + return rec +} + func (s *syncClient) logTablesMetrics(tables schema.Tables, client Client) { clientName := client.ID() for _, table := range tables { diff --git a/scheduler/scheduler_dfs.go b/scheduler/scheduler_dfs.go index e30cd8a62c..e353a443cc 100644 --- a/scheduler/scheduler_dfs.go +++ b/scheduler/scheduler_dfs.go @@ -176,7 +176,10 @@ func (s *syncClient) resolveResourcesDfs(ctx context.Context, table *schema.Tabl atomic.AddUint64(&tableMetrics.Errors, 1) return } - resourcesChan <- resolvedResource + select { + case resourcesChan <- resolvedResource: + case <-ctx.Done(): + } }() } wg.Wait() diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index 1e8ba41f5d..1fda03d545 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -2,6 +2,9 @@ package scheduler import ( "context" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "github.com/apache/arrow/go/v14/arrow" @@ -40,6 +43,22 @@ func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resourc panic("ColumnResolver") } +func testTableSuccessWithData(data []any) *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: func(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { + res <- data + return nil + }, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + func testTableSuccess() *schema.Table { return &schema.Table{ Name: "test_table_success", @@ -233,6 +252,72 @@ func TestScheduler(t *testing.T) { } } +func TestScheduler_Cancellation(t *testing.T) { + data := make([]any, 100) + + tests := []struct { + name string + data []any + cancel bool + messageCount int + }{ + { + name: "should consume all message", + data: data, + cancel: false, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + { + name: "should not consume all message on cancel", + data: data, + cancel: true, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + } + + for _, strategy := range AllStrategies { + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("%s_%s", tc.name, strategy.String()), func(t *testing.T) { + sc := NewScheduler(WithLogger(zerolog.New(zerolog.NewTestWriter(t))), WithStrategy(strategy)) + + messages := make(chan message.SyncMessage) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + err := sc.Sync( + ctx, + &testExecutionClient{}, + []*schema.Table{testTableSuccessWithData(tc.data)}, + messages, + ) + if tc.cancel { + assert.Equal(t, err, context.Canceled) + } else { + require.NoError(t, err) + } + close(messages) + }() + + messageConsumed := 0 + for range messages { + if tc.cancel { + cancel() + } + messageConsumed++ + } + + if tc.cancel { + assert.NotEqual(t, tc.messageCount, messageConsumed) + } else { + assert.Equal(t, tc.messageCount, messageConsumed) + } + }) + } + } +} + func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, deterministicCQID bool) { ctx := context.Background() tables := []*schema.Table{}