diff --git a/writers/batchwriter/batchwriter.go b/writers/batchwriter/batchwriter.go index 1310b3956d..4a7247a214 100644 --- a/writers/batchwriter/batchwriter.go +++ b/writers/batchwriter/batchwriter.go @@ -122,7 +122,8 @@ func (w *BatchWriter) Close(context.Context) error { func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) { sizeBytes := int64(0) resources := make([]*message.WriteInsert, 0, w.batchSize) - tick := timer(w.batchTimeout) + tick, done := writers.NewTicker(w.batchTimeout) + defer done() for { select { case r, ok := <-ch: @@ -145,7 +146,6 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m w.flushTable(ctx, tableName, resources) resources, sizeBytes = resources[:0], 0 } - tick = timer(w.batchTimeout) case done := <-flush: if len(resources) > 0 { w.flushTable(ctx, tableName, resources) @@ -324,10 +324,3 @@ func (w *BatchWriter) startWorker(ctx context.Context, msg *message.WriteInsert) ch <- msg return nil } - -func timer(timeout time.Duration) <-chan time.Time { - if timeout == 0 { - return nil - } - return time.After(timeout) -} diff --git a/writers/mixedbatchwriter/mixedbatchwriter.go b/writers/mixedbatchwriter/mixedbatchwriter.go index 4f76a12fe6..d4c797debc 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter.go +++ b/writers/mixedbatchwriter/mixedbatchwriter.go @@ -17,15 +17,13 @@ type Client interface { DeleteStaleBatch(ctx context.Context, messages message.WriteDeleteStales) error } -type timerFn func(timeout time.Duration) <-chan time.Time - type MixedBatchWriter struct { client Client logger zerolog.Logger batchSize int batchSizeBytes int batchTimeout time.Duration - timerFn timerFn + tickerFn writers.TickerFunc } // Assert at compile-time that MixedBatchWriter implements the Writer interface @@ -57,9 +55,9 @@ func WithBatchTimeout(timeout time.Duration) Option { } } -func withTimerFn(timer timerFn) Option { +func withTickerFn(tickerFn writers.TickerFunc) Option { return func(p *MixedBatchWriter) { - p.timerFn = timer + p.tickerFn = tickerFn } } @@ -76,7 +74,7 @@ func New(client Client, opts ...Option) (*MixedBatchWriter, error) { batchSize: defaultBatchSize, batchSizeBytes: defaultBatchSizeBytes, batchTimeout: defaultBatchTimeout, - timerFn: timer, + tickerFn: writers.NewTicker, } for _, opt := range opts { opt(c) @@ -116,7 +114,8 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri } prevMsgType := writers.MsgTypeUnset var err error - tick := w.timerFn(w.batchTimeout) + tick, done := w.tickerFn(w.batchTimeout) + defer done() loop: for { select { @@ -149,7 +148,6 @@ loop: return err } prevMsgType = writers.MsgTypeUnset - tick = w.timerFn(w.batchTimeout) } } return flush(prevMsgType) @@ -215,10 +213,3 @@ func (m *insertBatchManager) flush(ctx context.Context) error { m.batch = m.batch[:0] return nil } - -func timer(timeout time.Duration) <-chan time.Time { - if timeout == 0 { - return nil - } - return time.After(timeout) -} diff --git a/writers/mixedbatchwriter/mixedbatchwriter_test.go b/writers/mixedbatchwriter/mixedbatchwriter_test.go index ade2d52558..a8536f8d4e 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter_test.go +++ b/writers/mixedbatchwriter/mixedbatchwriter_test.go @@ -243,13 +243,14 @@ func TestMixedBatchWriterTimeout(t *testing.T) { wr, err := New(client, WithBatchSize(1000), WithBatchSizeBytes(1000000), - withTimerFn(func(_ time.Duration) <-chan time.Time { + withTickerFn(func(_ time.Duration) (<-chan time.Time, func()) { c := make(chan time.Time) go func() { - <-triggerTimeout - c <- time.Now() + for range triggerTimeout { + c <- time.Now() + } }() - return c + return c, func() { close(c) } }), ) if err != nil { diff --git a/writers/streamingbatchwriter/mocktimer_test.go b/writers/streamingbatchwriter/mocktimer_test.go index a4166de499..044066dfbe 100644 --- a/writers/streamingbatchwriter/mocktimer_test.go +++ b/writers/streamingbatchwriter/mocktimer_test.go @@ -1,16 +1,24 @@ package streamingbatchwriter -import "time" +import ( + "time" + + "github.com/cloudquery/plugin-sdk/v4/writers" +) type mockTimer struct { expire chan time.Time } -func (t *mockTimer) timer(time.Duration) <-chan time.Time { - return t.expire +func (t *mockTimer) timer(time.Duration) (<-chan time.Time, func()) { + return t.expire, t.close +} + +func (t *mockTimer) close() { + close(t.expire) } -func newMockTimer() (timerFn, chan time.Time) { +func newMockTimer() (writers.TickerFunc, chan time.Time) { expire := make(chan time.Time) t := &mockTimer{ expire: expire, diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index 14f1ce3a78..01fb0d083b 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -61,11 +61,9 @@ type StreamingBatchWriter struct { batchSizeRows int64 batchSizeBytes int64 - timerFn timerFn + tickerFn writers.TickerFunc } -type timerFn func(timeout time.Duration) <-chan time.Time - // Assert at compile-time that StreamingBatchWriter implements the Writer interface var _ writers.Writer = (*StreamingBatchWriter)(nil) @@ -95,9 +93,9 @@ func WithBatchSizeBytes(size int64) Option { } } -func withTimerFn(timer timerFn) Option { +func withTickerFn(tickerFn writers.TickerFunc) Option { return func(p *StreamingBatchWriter) { - p.timerFn = timer + p.tickerFn = tickerFn } } @@ -115,7 +113,7 @@ func New(client Client, opts ...Option) (*StreamingBatchWriter, error) { batchTimeout: defaultBatchTimeoutSeconds * time.Second, batchSizeRows: defaultBatchSize, batchSizeBytes: defaultBatchSizeBytes, - timerFn: timer, + tickerFn: writers.NewTicker, } for _, opt := range opts { opt(c) @@ -225,7 +223,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err batchSizeRows: w.batchSizeRows, batchTimeout: w.batchTimeout, - timerFn: w.timerFn, + tickerFn: w.tickerFn, } w.workersWaitGroup.Add(1) @@ -277,7 +275,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err batchSizeRows: w.batchSizeRows, batchSizeBytes: w.batchSizeBytes, batchTimeout: w.batchTimeout, - timerFn: w.timerFn, + tickerFn: w.tickerFn, } w.workersLock.Lock() w.insertWorkers[tableName] = wr @@ -303,7 +301,7 @@ type streamingWorkerManager[T message.WriteMessage] struct { batchSizeRows int64 batchSizeBytes int64 batchTimeout time.Duration - timerFn timerFn + tickerFn writers.TickerFunc } func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) { @@ -345,7 +343,8 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, } defer closeFlush() - tick := s.timerFn(s.batchTimeout) + tick, done := s.tickerFn(s.batchTimeout) + defer done() for { select { case r, ok := <-s.ch: @@ -370,7 +369,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if sizeRows > 0 { closeFlush() } - tick = s.timerFn(s.batchTimeout) case done := <-s.flush: if sizeRows > 0 { closeFlush() @@ -379,10 +377,3 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, } } } - -func timer(timeout time.Duration) <-chan time.Time { - if timeout == 0 { - return nil - } - return time.After(timeout) -} diff --git a/writers/streamingbatchwriter/streamingbatchwriter_test.go b/writers/streamingbatchwriter/streamingbatchwriter_test.go index 12a5ffc043..9c5806747e 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter_test.go +++ b/writers/streamingbatchwriter/streamingbatchwriter_test.go @@ -229,7 +229,7 @@ func TestStreamingBatchTimeout(t *testing.T) { testClient := newClient() timerFn, timerExpire := newMockTimer() - wr, err := New(testClient, withTimerFn(timerFn)) + wr, err := New(testClient, withTickerFn(timerFn)) if err != nil { t.Fatal(err) } @@ -333,7 +333,7 @@ func TestStreamingBatchUpserts(t *testing.T) { testClient := newClient() timerFn, timerExpire := newMockTimer() - wr, err := New(testClient, WithBatchSizeRows(2), withTimerFn(timerFn)) + wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(timerFn)) if err != nil { t.Fatal(err) } diff --git a/writers/ticker.go b/writers/ticker.go new file mode 100644 index 0000000000..df657a2fbe --- /dev/null +++ b/writers/ticker.go @@ -0,0 +1,17 @@ +package writers + +import ( + "time" +) + +type TickerFunc func(interval time.Duration) (ch <-chan time.Time, done func()) + +func NewTicker(interval time.Duration) (<-chan time.Time, func()) { + if interval <= 0 { + return nil, nop + } + t := time.NewTicker(interval) + return t.C, t.Stop +} + +func nop() {}