Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions writers/batchwriter/batchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ func (w *BatchWriter) Close(context.Context) error {
func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) {
sizeBytes := int64(0)
resources := make([]*message.WriteInsert, 0, w.batchSize)
tick := timer(w.batchTimeout)
tick, done := writers.NewTicker(w.batchTimeout)
defer done()
for {
select {
case r, ok := <-ch:
Expand All @@ -145,7 +146,6 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m
w.flushTable(ctx, tableName, resources)
resources, sizeBytes = resources[:0], 0
}
tick = timer(w.batchTimeout)
case done := <-flush:
if len(resources) > 0 {
w.flushTable(ctx, tableName, resources)
Expand Down Expand Up @@ -324,10 +324,3 @@ func (w *BatchWriter) startWorker(ctx context.Context, msg *message.WriteInsert)
ch <- msg
return nil
}

func timer(timeout time.Duration) <-chan time.Time {
if timeout == 0 {
return nil
}
return time.After(timeout)
}
21 changes: 6 additions & 15 deletions writers/mixedbatchwriter/mixedbatchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ type Client interface {
DeleteStaleBatch(ctx context.Context, messages message.WriteDeleteStales) error
}

type timerFn func(timeout time.Duration) <-chan time.Time

type MixedBatchWriter struct {
client Client
logger zerolog.Logger
batchSize int
batchSizeBytes int
batchTimeout time.Duration
timerFn timerFn
tickerFn writers.TickerFunc
}

// Assert at compile-time that MixedBatchWriter implements the Writer interface
Expand Down Expand Up @@ -57,9 +55,9 @@ func WithBatchTimeout(timeout time.Duration) Option {
}
}

func withTimerFn(timer timerFn) Option {
func withTickerFn(tickerFn writers.TickerFunc) Option {
return func(p *MixedBatchWriter) {
p.timerFn = timer
p.tickerFn = tickerFn
}
}

Expand All @@ -76,7 +74,7 @@ func New(client Client, opts ...Option) (*MixedBatchWriter, error) {
batchSize: defaultBatchSize,
batchSizeBytes: defaultBatchSizeBytes,
batchTimeout: defaultBatchTimeout,
timerFn: timer,
tickerFn: writers.NewTicker,
}
for _, opt := range opts {
opt(c)
Expand Down Expand Up @@ -116,7 +114,8 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri
}
prevMsgType := writers.MsgTypeUnset
var err error
tick := w.timerFn(w.batchTimeout)
tick, done := w.tickerFn(w.batchTimeout)
defer done()
loop:
for {
select {
Expand Down Expand Up @@ -149,7 +148,6 @@ loop:
return err
}
prevMsgType = writers.MsgTypeUnset
tick = w.timerFn(w.batchTimeout)
}
}
return flush(prevMsgType)
Expand Down Expand Up @@ -215,10 +213,3 @@ func (m *insertBatchManager) flush(ctx context.Context) error {
m.batch = m.batch[:0]
return nil
}

func timer(timeout time.Duration) <-chan time.Time {
if timeout == 0 {
return nil
}
return time.After(timeout)
}
9 changes: 5 additions & 4 deletions writers/mixedbatchwriter/mixedbatchwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,14 @@ func TestMixedBatchWriterTimeout(t *testing.T) {
wr, err := New(client,
WithBatchSize(1000),
WithBatchSizeBytes(1000000),
withTimerFn(func(_ time.Duration) <-chan time.Time {
withTickerFn(func(_ time.Duration) (<-chan time.Time, func()) {
c := make(chan time.Time)
go func() {
<-triggerTimeout
c <- time.Now()
for range triggerTimeout {
c <- time.Now()
}
}()
return c
return c, func() { close(c) }
}),
)
if err != nil {
Expand Down
16 changes: 12 additions & 4 deletions writers/streamingbatchwriter/mocktimer_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package streamingbatchwriter

import "time"
import (
"time"

"github.com/cloudquery/plugin-sdk/v4/writers"
)

type mockTimer struct {
expire chan time.Time
}

func (t *mockTimer) timer(time.Duration) <-chan time.Time {
return t.expire
func (t *mockTimer) timer(time.Duration) (<-chan time.Time, func()) {
return t.expire, t.close
}

func (t *mockTimer) close() {
close(t.expire)
}

func newMockTimer() (timerFn, chan time.Time) {
func newMockTimer() (writers.TickerFunc, chan time.Time) {
expire := make(chan time.Time)
t := &mockTimer{
expire: expire,
Expand Down
27 changes: 9 additions & 18 deletions writers/streamingbatchwriter/streamingbatchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ type StreamingBatchWriter struct {
batchSizeRows int64
batchSizeBytes int64

timerFn timerFn
tickerFn writers.TickerFunc
}

type timerFn func(timeout time.Duration) <-chan time.Time

// Assert at compile-time that StreamingBatchWriter implements the Writer interface
var _ writers.Writer = (*StreamingBatchWriter)(nil)

Expand Down Expand Up @@ -95,9 +93,9 @@ func WithBatchSizeBytes(size int64) Option {
}
}

func withTimerFn(timer timerFn) Option {
func withTickerFn(tickerFn writers.TickerFunc) Option {
return func(p *StreamingBatchWriter) {
p.timerFn = timer
p.tickerFn = tickerFn
}
}

Expand All @@ -115,7 +113,7 @@ func New(client Client, opts ...Option) (*StreamingBatchWriter, error) {
batchTimeout: defaultBatchTimeoutSeconds * time.Second,
batchSizeRows: defaultBatchSize,
batchSizeBytes: defaultBatchSizeBytes,
timerFn: timer,
tickerFn: writers.NewTicker,
}
for _, opt := range opts {
opt(c)
Expand Down Expand Up @@ -225,7 +223,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err

batchSizeRows: w.batchSizeRows,
batchTimeout: w.batchTimeout,
timerFn: w.timerFn,
tickerFn: w.tickerFn,
}

w.workersWaitGroup.Add(1)
Expand Down Expand Up @@ -277,7 +275,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
batchSizeRows: w.batchSizeRows,
batchSizeBytes: w.batchSizeBytes,
batchTimeout: w.batchTimeout,
timerFn: w.timerFn,
tickerFn: w.tickerFn,
}
w.workersLock.Lock()
w.insertWorkers[tableName] = wr
Expand All @@ -303,7 +301,7 @@ type streamingWorkerManager[T message.WriteMessage] struct {
batchSizeRows int64
batchSizeBytes int64
batchTimeout time.Duration
timerFn timerFn
tickerFn writers.TickerFunc
}

func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
Expand Down Expand Up @@ -345,7 +343,8 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
}
defer closeFlush()

tick := s.timerFn(s.batchTimeout)
tick, done := s.tickerFn(s.batchTimeout)
defer done()
for {
select {
case r, ok := <-s.ch:
Expand All @@ -370,7 +369,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if sizeRows > 0 {
closeFlush()
}
tick = s.timerFn(s.batchTimeout)
case done := <-s.flush:
if sizeRows > 0 {
closeFlush()
Expand All @@ -379,10 +377,3 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
}
}
}

func timer(timeout time.Duration) <-chan time.Time {
if timeout == 0 {
return nil
}
return time.After(timeout)
}
4 changes: 2 additions & 2 deletions writers/streamingbatchwriter/streamingbatchwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestStreamingBatchTimeout(t *testing.T) {
testClient := newClient()
timerFn, timerExpire := newMockTimer()

wr, err := New(testClient, withTimerFn(timerFn))
wr, err := New(testClient, withTickerFn(timerFn))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -333,7 +333,7 @@ func TestStreamingBatchUpserts(t *testing.T) {

testClient := newClient()
timerFn, timerExpire := newMockTimer()
wr, err := New(testClient, WithBatchSizeRows(2), withTimerFn(timerFn))
wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(timerFn))
if err != nil {
t.Fatal(err)
}
Expand Down
17 changes: 17 additions & 0 deletions writers/ticker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package writers

import (
"time"
)

type TickerFunc func(interval time.Duration) (ch <-chan time.Time, done func())

func NewTicker(interval time.Duration) (<-chan time.Time, func()) {
if interval <= 0 {
return nil, nop
}
t := time.NewTicker(interval)
return t.C, t.Stop
}

func nop() {}