From 8ca7384510072dc7d257bcb26c330d5222cf64c5 Mon Sep 17 00:00:00 2001 From: rafaeleyng Date: Fri, 4 Sep 2020 17:52:21 -0300 Subject: [PATCH] fix a race condition that could cause the dataCh receive data before the closeCh received the signal to end the consume loop --- Makefile | 2 +- buffer.go | 23 +++++++++++++---------- buffer_test.go | 27 +++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 06d5fcb..b2e4f39 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,2 @@ test: - @go run github.com/onsi/ginkgo/ginkgo -keepGoing -progress -timeout 1m -race + @go run github.com/onsi/ginkgo/ginkgo -keepGoing -progress -timeout 1m -race --randomizeAllSpecs --randomizeSuites diff --git a/buffer.go b/buffer.go index 614339c..c565d87 100644 --- a/buffer.go +++ b/buffer.go @@ -45,22 +45,25 @@ func (buffer *Buffer) Flush() error { } } -// Close flushes the buffer and prevents it from being further used. The buffer -// cannot be used after it has been closed as all further operations will panic. +// Close flushes the buffer and prevents it from being further used. If it succeeds, +// the buffer cannot be used after it has been closed as all further operations will panic. func (buffer *Buffer) Close() error { - close(buffer.closeCh) + select { + case buffer.closeCh <- struct{}{}: + // noop + case <-time.After(buffer.options.CloseTimeout): + return ErrTimeout + } - var err error select { case <-buffer.doneCh: - err = nil + close(buffer.dataCh) + close(buffer.flushCh) + close(buffer.closeCh) + return nil case <-time.After(buffer.options.CloseTimeout): - err = ErrTimeout + return ErrTimeout } - - close(buffer.dataCh) - close(buffer.flushCh) - return err } func (buffer *Buffer) consume() { diff --git a/buffer_test.go b/buffer_test.go index ca9ebcf..f74055a 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -243,6 +243,33 @@ var _ = Describe("Buffer", func() { // assert Expect(err).To(MatchError(buffer.ErrTimeout)) }) + + It("allow Close to be called again if it fails", func() { + // arrange + flusher.Func = func() { time.Sleep(2 * time.Second) } + + sut := buffer.New( + buffer.WithSize(1), + buffer.WithFlusher(flusher), + buffer.WithCloseTimeout(time.Second), + ) + _ = sut.Push(1) + + // act + err := sut.Close() + + // assert + Expect(err).To(MatchError(buffer.ErrTimeout)) + + // arrange + time.Sleep(time.Second) + + // act + err = sut.Close() + + // assert + Expect(err).To(BeNil()) + }) }) })