From a20f5fbc5d06ee6a46f802316c192a2ee8f9eb53 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Fri, 4 Aug 2023 13:27:54 +0100 Subject: [PATCH 1/4] fix(writers): StreamingBatchWriter should close when Write is done --- .../streamingbatchwriter/streamingbatchwriter.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index ce84105b92..c87fe83a0a 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -156,13 +156,16 @@ func (w *StreamingBatchWriter) Close(context.Context) error { } w.workersWaitGroup.Wait() - w.insertWorkers = nil + w.insertWorkers = make(map[string]*streamingWorkerManager[*message.WriteInsert]) + w.migrateWorker = nil + w.deleteWorker = nil return nil } func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessage) error { errCh := make(chan error) + defer close(errCh) go func() { for err := range errCh { @@ -172,7 +175,7 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr for msg := range msgs { msgType := writers.MsgID(msg) - if w.lastMsgType != msgType { + if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType { if err := w.Flush(ctx); err != nil { return err } @@ -183,12 +186,7 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr } } - if err := w.Flush(ctx); err != nil { - return err - } - - close(errCh) - return nil + return w.Close(ctx) } func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error { From eb0f99e2399facd77ed26b50f6cce6dbc3d59472 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Fri, 4 Aug 2023 13:36:53 +0100 Subject: [PATCH 2/4] Fix the test --- writers/streamingbatchwriter/mocktimer_test.go | 16 ++++++++++++---- .../streamingbatchwriter_test.go | 8 ++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/writers/streamingbatchwriter/mocktimer_test.go b/writers/streamingbatchwriter/mocktimer_test.go index 1e3f1412b2..2b1e44757c 100644 --- a/writers/streamingbatchwriter/mocktimer_test.go +++ b/writers/streamingbatchwriter/mocktimer_test.go @@ -1,17 +1,25 @@ package streamingbatchwriter import ( + "sync" "time" "github.com/cloudquery/plugin-sdk/v4/writers" ) type mockTicker struct { - expire chan time.Time + expire chan time.Time + stopped sync.Once } func (t *mockTicker) Stop() { - close(t.expire) + t.stopped.Do(func() { + close(t.expire) + }) +} + +func (t *mockTicker) Tick() { + t.expire <- time.Now() } func (*mockTicker) Reset(time.Duration) {} @@ -20,12 +28,12 @@ func (t *mockTicker) Chan() <-chan time.Time { return t.expire } -func newMockTicker() (writers.TickerFunc, chan<- time.Time) { +func newMockTicker() (writers.TickerFunc, func()) { expire := make(chan time.Time) t := &mockTicker{ expire: expire, } return func(time.Duration) writers.Ticker { return t - }, expire + }, t.Tick } diff --git a/writers/streamingbatchwriter/streamingbatchwriter_test.go b/writers/streamingbatchwriter/streamingbatchwriter_test.go index e66779d3d9..ae1004ee96 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter_test.go +++ b/writers/streamingbatchwriter/streamingbatchwriter_test.go @@ -227,7 +227,7 @@ func TestStreamingBatchTimeout(t *testing.T) { ch := make(chan message.WriteMessage) testClient := newClient() - tickerFn, expire := newMockTicker() + tickerFn, tickFn := newMockTicker() wr, err := New(testClient, withTickerFn(tickerFn)) if err != nil { @@ -258,7 +258,7 @@ func TestStreamingBatchTimeout(t *testing.T) { } // flush - close(expire) + tickFn() waitForLength(t, testClient.MessageLen, messageTypeInsert, 1) close(ch) @@ -332,7 +332,7 @@ func TestStreamingBatchUpserts(t *testing.T) { ch := make(chan message.WriteMessage) testClient := newClient() - tickerFn, expire := newMockTicker() + tickerFn, tickFn := newMockTicker() wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(tickerFn)) if err != nil { t.Fatal(err) @@ -363,7 +363,7 @@ func TestStreamingBatchUpserts(t *testing.T) { time.Sleep(50 * time.Millisecond) // flush the batch - close(expire) + tickFn() waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) close(ch) From aafcf9dab3cba65f196a78694fce227b7e1d270b Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Fri, 4 Aug 2023 13:59:37 +0100 Subject: [PATCH 3/4] Reset lastMsgType as well --- writers/streamingbatchwriter/streamingbatchwriter.go | 1 + 1 file changed, 1 insertion(+) diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index c87fe83a0a..50c1779070 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -159,6 +159,7 @@ func (w *StreamingBatchWriter) Close(context.Context) error { w.insertWorkers = make(map[string]*streamingWorkerManager[*message.WriteInsert]) w.migrateWorker = nil w.deleteWorker = nil + w.lastMsgType = writers.MsgTypeUnset return nil } From e453015dd40d096930062af9c0895a6c156032fe Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Mon, 7 Aug 2023 18:45:40 +0100 Subject: [PATCH 4/4] UnimplementedDeleteStale should empty out the channel --- writers/streamingbatchwriter/unimplemented.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/writers/streamingbatchwriter/unimplemented.go b/writers/streamingbatchwriter/unimplemented.go index 715feb6f1a..3d560f585e 100644 --- a/writers/streamingbatchwriter/unimplemented.go +++ b/writers/streamingbatchwriter/unimplemented.go @@ -18,9 +18,12 @@ func (IgnoreMigrateTable) MigrateTable(_ context.Context, ch <-chan *message.Wri return nil } -// UnimplementedDeleteStale is a dummy handler to error on DeleteStale messages +// UnimplementedDeleteStale is a dummy handler to consume and error on DeleteStale messages type UnimplementedDeleteStale struct{} -func (UnimplementedDeleteStale) DeleteStale(_ context.Context, _ <-chan *message.WriteDeleteStale) error { +func (UnimplementedDeleteStale) DeleteStale(_ context.Context, ch <-chan *message.WriteDeleteStale) error { + // nolint:revive + for range ch { + } return fmt.Errorf("DeleteStale: %w", plugin.ErrNotImplemented) }