Skip to content

Commit

Permalink
Fix race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrpio committed Feb 15, 2023
1 parent 51374f6 commit 1b84e26
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
15 changes: 8 additions & 7 deletions js.go
Expand Up @@ -2770,13 +2770,13 @@ type MessageBatch interface {
Error() error

// Done signals end of execution.
Done() bool
Done() <-chan struct{}
}

type messageBatch struct {
msgs chan *Msg
err error
done bool
done chan struct{}
}

func (mb *messageBatch) Messages() <-chan *Msg {
Expand All @@ -2787,7 +2787,7 @@ func (mb *messageBatch) Error() error {
return mb.err
}

func (mb *messageBatch) Done() bool {
func (mb *messageBatch) Done() <-chan struct{} {
return mb.done
}

Expand Down Expand Up @@ -2882,6 +2882,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e

result := &messageBatch{
msgs: make(chan *Msg, batch),
done: make(chan struct{}, 1),
}
var msg *Msg
for pmc && len(result.msgs) < batch {
Expand All @@ -2906,7 +2907,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
}
if len(result.msgs) == batch || result.err != nil {
close(result.msgs)
result.done = true
result.done <- struct{}{}
return result, nil
}

Expand All @@ -2928,7 +2929,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
reqJSON, err := json.Marshal(req)
if err != nil {
close(result.msgs)
result.done = true
result.done <- struct{}{}
result.err = err
return result, nil
}
Expand All @@ -2937,7 +2938,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
return nil, err
}
close(result.msgs)
result.done = true
result.done <- struct{}{}
result.err = err
return result, nil
}
Expand Down Expand Up @@ -2971,7 +2972,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
result.err = o.checkCtxErr(err)
}
close(result.msgs)
result.done = true
result.done <- struct{}{}
}()
return result, nil
}
Expand Down
27 changes: 18 additions & 9 deletions test/js_test.go
Expand Up @@ -1019,20 +1019,23 @@ func TestPullSubscribeFetchBatch(t *testing.T) {
started.Add(3)
errs := make(chan error, 3)
go func() {
var err error
r1, err = sub.FetchBatch(10)
if err != nil {
errs <- err
}
started.Done()
}()
go func() {
var err error
r2, err = sub.FetchBatch(10)
if err != nil {
errs <- err
}
started.Done()
}()
go func() {
var err error
r3, err = sub.FetchBatch(10)
if err != nil {
errs <- err
Expand Down Expand Up @@ -1062,6 +1065,21 @@ func TestPullSubscribeFetchBatch(t *testing.T) {
t.Fatalf("Timeout waiting for incoming messages")
}
}
select {
case <-r1.Done():
case <-time.After(1 * time.Second):
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r1")
}
select {
case <-r2.Done():
case <-time.After(1 * time.Second):
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r2")
}
select {
case <-r3.Done():
case <-time.After(1 * time.Second):
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r3")
}
if r1.Error() != nil {
t.Fatalf("Unexpected error: %s", r1.Error())
}
Expand All @@ -1071,15 +1089,6 @@ func TestPullSubscribeFetchBatch(t *testing.T) {
if r3.Error() != nil {
t.Fatalf("Unexpected error: %s", r3.Error())
}
if !r1.Done() {
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r1")
}
if !r2.Done() {
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r2")
}
if !r3.Done() {
t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r3")
}
if msgsReceived != 30 {
t.Fatalf("Expected %d messages; got: %d", 30, msgsReceived)
}
Expand Down

0 comments on commit 1b84e26

Please sign in to comment.