diff --git a/writers/batchwriter/batchwriter.go b/writers/batchwriter/batchwriter.go index 1310b3956d..1ff75ca8f7 100644 --- a/writers/batchwriter/batchwriter.go +++ b/writers/batchwriter/batchwriter.go @@ -122,7 +122,10 @@ func (w *BatchWriter) Close(context.Context) error { func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) { sizeBytes := int64(0) resources := make([]*message.WriteInsert, 0, w.batchSize) - tick := timer(w.batchTimeout) + tick, tickClose := timer(w.batchTimeout) + defer func() { + tickClose() + }() for { select { case r, ok := <-ch: @@ -145,7 +148,8 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m w.flushTable(ctx, tableName, resources) resources, sizeBytes = resources[:0], 0 } - tick = timer(w.batchTimeout) + tickClose() + tick, tickClose = timer(w.batchTimeout) case done := <-flush: if len(resources) > 0 { w.flushTable(ctx, tableName, resources) @@ -325,9 +329,14 @@ func (w *BatchWriter) startWorker(ctx context.Context, msg *message.WriteInsert) return nil } -func timer(timeout time.Duration) <-chan time.Time { +func timer(timeout time.Duration) (<-chan time.Time, func()) { if timeout == 0 { - return nil + return nil, func() {} + } + t := time.NewTimer(timeout) + return t.C, func() { + if !t.Stop() { + <-t.C + } } - return time.After(timeout) } diff --git a/writers/mixedbatchwriter/mixedbatchwriter.go b/writers/mixedbatchwriter/mixedbatchwriter.go index 4f76a12fe6..719f8eebbc 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter.go +++ b/writers/mixedbatchwriter/mixedbatchwriter.go @@ -17,7 +17,7 @@ type Client interface { DeleteStaleBatch(ctx context.Context, messages message.WriteDeleteStales) error } -type timerFn func(timeout time.Duration) <-chan time.Time +type timerFn func(timeout time.Duration) (<-chan time.Time, func()) type MixedBatchWriter struct { client Client @@ -116,7 +116,10 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri } prevMsgType := writers.MsgTypeUnset var err error - tick := w.timerFn(w.batchTimeout) + tick, tickClose := w.timerFn(w.batchTimeout) + defer func() { + tickClose() + }() loop: for { select { @@ -149,7 +152,8 @@ loop: return err } prevMsgType = writers.MsgTypeUnset - tick = w.timerFn(w.batchTimeout) + tickClose() + tick, tickClose = w.timerFn(w.batchTimeout) } } return flush(prevMsgType) @@ -216,9 +220,14 @@ func (m *insertBatchManager) flush(ctx context.Context) error { return nil } -func timer(timeout time.Duration) <-chan time.Time { +func timer(timeout time.Duration) (<-chan time.Time, func()) { if timeout == 0 { - return nil + return nil, func() {} + } + t := time.NewTimer(timeout) + return t.C, func() { + if !t.Stop() { + <-t.C + } } - return time.After(timeout) } diff --git a/writers/mixedbatchwriter/mixedbatchwriter_test.go b/writers/mixedbatchwriter/mixedbatchwriter_test.go index ade2d52558..78132c6c67 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter_test.go +++ b/writers/mixedbatchwriter/mixedbatchwriter_test.go @@ -243,13 +243,13 @@ func TestMixedBatchWriterTimeout(t *testing.T) { wr, err := New(client, WithBatchSize(1000), WithBatchSizeBytes(1000000), - withTimerFn(func(_ time.Duration) <-chan time.Time { + withTimerFn(func(_ time.Duration) (<-chan time.Time, func()) { c := make(chan time.Time) go func() { <-triggerTimeout c <- time.Now() }() - return c + return c, func() {} }), ) if err != nil { diff --git a/writers/streamingbatchwriter/mocktimer_test.go b/writers/streamingbatchwriter/mocktimer_test.go index a4166de499..442861b3ca 100644 --- a/writers/streamingbatchwriter/mocktimer_test.go +++ b/writers/streamingbatchwriter/mocktimer_test.go @@ -6,8 +6,8 @@ type mockTimer struct { expire chan time.Time } -func (t *mockTimer) timer(time.Duration) <-chan time.Time { - return t.expire +func (t *mockTimer) timer(time.Duration) (<-chan time.Time, func()) { + return t.expire, func() {} } func newMockTimer() (timerFn, chan time.Time) { diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index 14f1ce3a78..565626a579 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -64,7 +64,7 @@ type StreamingBatchWriter struct { timerFn timerFn } -type timerFn func(timeout time.Duration) <-chan time.Time +type timerFn func(timeout time.Duration) (<-chan time.Time, func()) // Assert at compile-time that StreamingBatchWriter implements the Writer interface var _ writers.Writer = (*StreamingBatchWriter)(nil) @@ -345,7 +345,10 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, } defer closeFlush() - tick := s.timerFn(s.batchTimeout) + tick, tickClose := s.timerFn(s.batchTimeout) + defer func() { + tickClose() + }() for { select { case r, ok := <-s.ch: @@ -370,7 +373,8 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if sizeRows > 0 { closeFlush() } - tick = s.timerFn(s.batchTimeout) + tickClose() + tick, tickClose = s.timerFn(s.batchTimeout) case done := <-s.flush: if sizeRows > 0 { closeFlush() @@ -380,9 +384,14 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, } } -func timer(timeout time.Duration) <-chan time.Time { +func timer(timeout time.Duration) (<-chan time.Time, func()) { if timeout == 0 { - return nil + return nil, func() {} + } + t := time.NewTimer(timeout) + return t.C, func() { + if !t.Stop() { + <-t.C + } } - return time.After(timeout) }