diff --git a/js.go b/js.go index d7cf8ea05..e82de7fe9 100644 --- a/js.go +++ b/js.go @@ -2770,13 +2770,13 @@ type MessageBatch interface { Error() error // Done signals end of execution. - Done() bool + Done() <-chan struct{} } type messageBatch struct { msgs chan *Msg err error - done bool + done chan struct{} } func (mb *messageBatch) Messages() <-chan *Msg { @@ -2787,7 +2787,7 @@ func (mb *messageBatch) Error() error { return mb.err } -func (mb *messageBatch) Done() bool { +func (mb *messageBatch) Done() <-chan struct{} { return mb.done } @@ -2882,6 +2882,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e result := &messageBatch{ msgs: make(chan *Msg, batch), + done: make(chan struct{}, 1), } var msg *Msg for pmc && len(result.msgs) < batch { @@ -2906,7 +2907,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } if len(result.msgs) == batch || result.err != nil { close(result.msgs) - result.done = true + result.done <- struct{}{} return result, nil } @@ -2928,7 +2929,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e reqJSON, err := json.Marshal(req) if err != nil { close(result.msgs) - result.done = true + result.done <- struct{}{} result.err = err return result, nil } @@ -2937,7 +2938,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e return nil, err } close(result.msgs) - result.done = true + result.done <- struct{}{} result.err = err return result, nil } @@ -2971,7 +2972,7 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e result.err = o.checkCtxErr(err) } close(result.msgs) - result.done = true + result.done <- struct{}{} }() return result, nil } diff --git a/test/js_test.go b/test/js_test.go index 826aa493c..879b8bc09 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1019,6 +1019,7 @@ func TestPullSubscribeFetchBatch(t *testing.T) { started.Add(3) errs := make(chan error, 3) go func() { + var err error r1, err = sub.FetchBatch(10) if err != nil { errs <- err @@ -1026,6 +1027,7 @@ func TestPullSubscribeFetchBatch(t *testing.T) { started.Done() }() go func() { + var err error r2, err = sub.FetchBatch(10) if err != nil { errs <- err @@ -1033,6 +1035,7 @@ func TestPullSubscribeFetchBatch(t *testing.T) { started.Done() }() go func() { + var err error r3, err = sub.FetchBatch(10) if err != nil { errs <- err @@ -1062,6 +1065,21 @@ func TestPullSubscribeFetchBatch(t *testing.T) { t.Fatalf("Timeout waiting for incoming messages") } } + select { + case <-r1.Done(): + case <-time.After(1 * time.Second): + t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r1") + } + select { + case <-r2.Done(): + case <-time.After(1 * time.Second): + t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r2") + } + select { + case <-r3.Done(): + case <-time.After(1 * time.Second): + t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r3") + } if r1.Error() != nil { t.Fatalf("Unexpected error: %s", r1.Error()) } @@ -1071,15 +1089,6 @@ func TestPullSubscribeFetchBatch(t *testing.T) { if r3.Error() != nil { t.Fatalf("Unexpected error: %s", r3.Error()) } - if !r1.Done() { - t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r1") - } - if !r2.Done() { - t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r2") - } - if !r3.Done() { - t.Fatalf("FetchBatch result channel should be closed after receiving all messages on r3") - } if msgsReceived != 30 { t.Fatalf("Expected %d messages; got: %d", 30, msgsReceived) }