From 3e47919825777d77b1cb56f89d0120cf04b916f3 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Wed, 10 Jan 2024 10:40:23 +0100 Subject: [PATCH 1/2] [ADDED] Drain for jetstream consume methods Signed-off-by: Piotr Piotrowski --- .golangci.yaml | 3 + go_test.mod | 14 ++-- go_test.sum | 29 ++++----- jetstream/ordered.go | 18 ++++++ jetstream/pull.go | 114 ++++++++++++++++++++++++++------- jetstream/test/ordered_test.go | 95 +++++++++++++++++++++++++++ jetstream/test/pull_test.go | 95 +++++++++++++++++++++++++++ nats.go | 6 ++ 8 files changed, 329 insertions(+), 45 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index be66189ed..fb548e50e 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -5,6 +5,9 @@ issues: - linters: - errcheck text: "Unsubscribe" + - linters: + - errcheck + text: "Drain" - linters: - errcheck text: "msg.Ack" diff --git a/go_test.mod b/go_test.mod index e6653b746..d28963c27 100644 --- a/go_test.mod +++ b/go_test.mod @@ -4,19 +4,19 @@ go 1.19 require ( github.com/golang/protobuf v1.4.2 - github.com/klauspost/compress v1.17.2 - github.com/nats-io/nats-server/v2 v2.10.0 + github.com/klauspost/compress v1.17.4 + github.com/nats-io/nats-server/v2 v2.10.7 github.com/nats-io/nkeys v0.4.6 github.com/nats-io/nuid v1.0.1 go.uber.org/goleak v1.2.1 - golang.org/x/text v0.13.0 + golang.org/x/text v0.14.0 google.golang.org/protobuf v1.23.0 ) require ( github.com/minio/highwayhash v1.0.2 // indirect - github.com/nats-io/jwt/v2 v2.5.2 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/sys v0.13.0 // indirect - golang.org/x/time v0.3.0 // indirect + github.com/nats-io/jwt/v2 v2.5.3 // indirect + golang.org/x/crypto v0.16.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/time v0.5.0 // indirect ) diff --git a/go_test.sum b/go_test.sum index 933432dc6..38fe6ef6f 100644 --- a/go_test.sum +++ b/go_test.sum @@ -10,14 +10,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= -github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= -github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= -github.com/nats-io/nats-server/v2 v2.10.0 h1:rcU++Hzo+wARxtJugrV3J5z5iGdHeVG8tT8Chb3bKDg= -github.com/nats-io/nats-server/v2 v2.10.0/go.mod h1:3PMvMSu2cuK0J9YInRLWdFpFsswKKGUS77zVSAudRto= +github.com/nats-io/jwt/v2 v2.5.3 h1:/9SWvzc6hTfamcgXJ3uYRpgj+QuY2aLNqRiqrKcrpEo= +github.com/nats-io/jwt/v2 v2.5.3/go.mod h1:iysuPemFcc7p4IoYots3IuELSI4EDe9Y0bQMe+I3Bf4= +github.com/nats-io/nats-server/v2 v2.10.7 h1:f5VDy+GMu7JyuFA0Fef+6TfulfCs5nBTgq7MMkFJx5Y= +github.com/nats-io/nats-server/v2 v2.10.7/go.mod h1:V2JHOvPiPdtfDXTuEUsthUnCvSDeFrK4Xn9hRo6du7c= github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -26,16 +26,15 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/jetstream/ordered.go b/jetstream/ordered.go index b2373fd57..e4e8bde1f 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -49,6 +49,7 @@ type ( consumer *orderedConsumer opts []PullMessagesOpt done chan struct{} + closed uint32 } cursor struct { @@ -298,6 +299,9 @@ func (s *orderedSubscription) Next() (Msg, error) { } func (s *orderedSubscription) Stop() { + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return + } sub, ok := s.consumer.currentConsumer.getSubscription("") if !ok { return @@ -308,6 +312,20 @@ func (s *orderedSubscription) Stop() { close(s.done) } +func (s *orderedSubscription) Drain() { + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return + } + sub, ok := s.consumer.currentConsumer.getSubscription("") + if !ok { + return + } + s.consumer.currentConsumer.Lock() + defer s.consumer.currentConsumer.Unlock() + sub.Drain() + close(s.done) +} + // Fetch is used to retrieve up to a provided number of messages from a stream. // This method will always send a single request and wait until either all messages are retrieved // or context reaches its deadline. diff --git a/jetstream/pull.go b/jetstream/pull.go index f7fb12bd7..9e45e94f0 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 The NATS Authors +// Copyright 2022-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -30,14 +30,29 @@ import ( type ( // MessagesContext supports iterating over a messages on a stream. MessagesContext interface { - // Next retreives next message on a stream. It will block until the next message is available. + // Next retreives next message on a stream. It will block until the next + // message is available. Next() (Msg, error) - // Stop closes the iterator and cancels subscription. + + // Stop unsubscribes from the stream and cancels subscription. Calling + // Next after calling Stop will return ErrMsgIteratorClosed error. Stop() + + // Drain unsubscribes from the stream and cancels subscription. All + // messages that are already in the buffer will be available on + // subsequent calls to Next. After the buffer is drained, Next will + // return ErrMsgIteratorClosed error. + Drain() } ConsumeContext interface { + // Stop unsubscribes from the stream and cancels subscription. + // No more messages will be received after calling this method. Stop() + + // Drain unsubscribes from the stream and cancels subscription. + // All messages that are already in the buffer will be processed in callback function. + Drain() } // MessageHandler is a handler function used as callback in [Consume] @@ -97,7 +112,9 @@ type ( hbMonitor *hbMonitor fetchInProgress uint32 closed uint32 + draining uint32 done chan struct{} + drained chan struct{} connStatusChanged chan nats.Status fetchNext chan *pullRequest consumeOpts *consumeOpts @@ -240,6 +257,13 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( if err != nil { return nil, err } + sub.subscription.SetClosedHandler(func(sid string) func(string) { + return func(subject string) { + p.Lock() + defer p.Unlock() + delete(sub.consumer.subscriptions, sid) + } + }(sub.id)) sub.Lock() // initial pull @@ -352,6 +376,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.resetPendingMsgs() } sub.Unlock() + case <-sub.done: + return } } }() @@ -438,6 +464,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error id: consumeID, consumer: p, done: make(chan struct{}, 1), + drained: make(chan struct{}, 1), msgs: msgs, errs: make(chan error, 1), fetchNext: make(chan *pullRequest, 1), @@ -450,11 +477,21 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error p.Unlock() return nil, err } + sub.subscription.SetClosedHandler(func(sid string) func(string) { + return func(subject string) { + p.Lock() + defer p.Unlock() + if atomic.LoadUint32(&sub.draining) != 1 { + // if we're not draining, subscription can be closed as soon + // as closed handler is called + // otherwise, we need to wait until all messages are drained + // in Next + delete(p.subscriptions, sid) + } + close(msgs) + } + }(sub.id)) - go func() { - <-sub.done - sub.cleanup() - }() p.subscriptions[sub.id] = sub p.Unlock() @@ -462,16 +499,20 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error go func() { for { - status, ok := <-sub.connStatusChanged - if !ok { + select { + case status, ok := <-sub.connStatusChanged: + if !ok { + return + } + if status == nats.CONNECTED { + sub.errs <- errConnected + } + if status == nats.RECONNECTING { + sub.errs <- errDisconnected + } + case <-sub.done: return } - if status == nats.CONNECTED { - sub.errs <- errConnected - } - if status == nats.RECONNECTING { - sub.errs <- errDisconnected - } } }() @@ -486,7 +527,7 @@ var ( func (s *pullSubscription) Next() (Msg, error) { s.Lock() defer s.Unlock() - if atomic.LoadUint32(&s.closed) == 1 { + if len(s.msgs) == 0 && (s.subscription == nil || !s.subscription.IsValid()) { return nil, ErrMsgIteratorClosed } hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat) @@ -506,8 +547,17 @@ func (s *pullSubscription) Next() (Msg, error) { s.checkPending() select { case <-s.done: + drainMode := atomic.LoadUint32(&s.draining) == 1 + if drainMode { + continue + } return nil, ErrMsgIteratorClosed - case msg := <-s.msgs: + case msg, ok := <-s.msgs: + if !ok { + // if msgs channel is closed, it means that subscription was either drained or stopped + delete(s.consumer.subscriptions, s.id) + return nil, ErrMsgIteratorClosed + } if hbMonitor != nil { hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) } @@ -650,6 +700,21 @@ func (s *pullSubscription) Stop() { } } +func (s *pullSubscription) Drain() { + if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return + } + atomic.StoreUint32(&s.draining, 1) + close(s.done) + if s.consumeOpts.stopAfterMsgsLeft != nil { + if s.delivered >= s.consumeOpts.StopAfter { + close(s.consumeOpts.stopAfterMsgsLeft) + } else { + s.consumeOpts.stopAfterMsgsLeft <- s.consumeOpts.StopAfter - s.delivered + } + } +} + // Fetch sends a single request to retrieve given number of messages. // It will wait up to provided expiry time if not all messages are available. func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { @@ -834,18 +899,21 @@ func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor } func (s *pullSubscription) cleanup() { - s.consumer.Lock() - defer s.consumer.Unlock() - if s.subscription == nil { + s.Lock() + defer s.Unlock() + if s.subscription == nil || !s.subscription.IsValid() { return } if s.hbMonitor != nil { s.hbMonitor.Stop() } - s.subscription.Unsubscribe() close(s.connStatusChanged) - s.subscription = nil - delete(s.consumer.subscriptions, s.id) + drainMode := atomic.LoadUint32(&s.draining) == 1 + if drainMode { + s.subscription.Drain() + } else { + s.subscription.Unsubscribe() + } atomic.StoreUint32(&s.closed, 1) } diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index f94f9c0e1..c8b529f16 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -408,6 +408,46 @@ func TestOrderedConsumerConsume(t *testing.T) { } time.Sleep(10 * time.Millisecond) }) + + t.Run("drain mode", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + wg := &sync.WaitGroup{} + wg.Add(5) + publishTestMsgs(t, nc) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + cc.Drain() + wg.Wait() + }) } func TestOrderedConsumerMessages(t *testing.T) { @@ -822,6 +862,61 @@ func TestOrderedConsumerMessages(t *testing.T) { t.Fatalf("Expected error: %v; got: %v", jetstream.ErrOrderedConsumerConcurrentRequests, err) } }) + + t.Run("drain mode", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]jetstream.Msg, 0) + it, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + go func() { + time.Sleep(100 * time.Millisecond) + it.Drain() + }() + for i := 0; i < len(testMsgs); i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + time.Sleep(50 * time.Millisecond) + msg.Ack() + msgs = append(msgs, msg) + } + _, err = it.Next() + if !errors.Is(err, jetstream.ErrMsgIteratorClosed) { + t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err) + } + + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs)) + } + }) } func TestOrderedConsumerFetch(t *testing.T) { diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 20354128c..6e92cfeea 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -1358,6 +1358,61 @@ func TestPullConsumerMessages(t *testing.T) { } } }) + + t.Run("drain mode", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]jetstream.Msg, 0) + it, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + go func() { + time.Sleep(100 * time.Millisecond) + it.Drain() + }() + for i := 0; i < len(testMsgs); i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + time.Sleep(50 * time.Millisecond) + msg.Ack() + msgs = append(msgs, msg) + } + _, err = it.Next() + if !errors.Is(err, jetstream.ErrMsgIteratorClosed) { + t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err) + } + + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs)) + } + }) } func TestPullConsumerConsume(t *testing.T) { @@ -2040,6 +2095,46 @@ func TestPullConsumerConsume(t *testing.T) { publishTestMsgs(t, nc) wg.Wait() }) + + t.Run("drain mode", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + wg := &sync.WaitGroup{} + wg.Add(5) + publishTestMsgs(t, nc) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + cc.Drain() + wg.Wait() + }) } func TestPullConsumerConsume_WithCluster(t *testing.T) { diff --git a/nats.go b/nats.go index da13692fd..77d9489c1 100644 --- a/nats.go +++ b/nats.go @@ -4298,6 +4298,12 @@ func (nc *Conn) removeSub(s *Subscription) { } } + if s.typ != AsyncSubscription { + done := s.pDone + if done != nil { + done(s.Subject) + } + } // Mark as invalid s.closed = true if s.pCond != nil { From 595538e14ec5e9eff5eccbda7ce17612cd9074ab Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Wed, 10 Jan 2024 12:15:58 +0100 Subject: [PATCH 2/2] Added tests for Stop() Signed-off-by: Piotr Piotrowski --- jetstream/pull.go | 11 +++- jetstream/test/jetstream_test.go | 14 +++-- jetstream/test/pull_test.go | 103 +++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 7 deletions(-) diff --git a/jetstream/pull.go b/jetstream/pull.go index 9e45e94f0..3ea7092b0 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -36,6 +36,7 @@ type ( // Stop unsubscribes from the stream and cancels subscription. Calling // Next after calling Stop will return ErrMsgIteratorClosed error. + // All messages that are already in the buffer are discarded. Stop() // Drain unsubscribes from the stream and cancels subscription. All @@ -48,6 +49,7 @@ type ( ConsumeContext interface { // Stop unsubscribes from the stream and cancels subscription. // No more messages will be received after calling this method. + // All messages that are already in the buffer are discarded. Stop() // Drain unsubscribes from the stream and cancels subscription. @@ -261,7 +263,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( return func(subject string) { p.Lock() defer p.Unlock() - delete(sub.consumer.subscriptions, sid) + delete(p.subscriptions, sid) + atomic.CompareAndSwapUint32(&sub.draining, 1, 0) } }(sub.id)) @@ -527,7 +530,9 @@ var ( func (s *pullSubscription) Next() (Msg, error) { s.Lock() defer s.Unlock() - if len(s.msgs) == 0 && (s.subscription == nil || !s.subscription.IsValid()) { + drainMode := atomic.LoadUint32(&s.draining) == 1 + closed := atomic.LoadUint32(&s.closed) == 1 + if closed && !drainMode { return nil, ErrMsgIteratorClosed } hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat) @@ -556,6 +561,7 @@ func (s *pullSubscription) Next() (Msg, error) { if !ok { // if msgs channel is closed, it means that subscription was either drained or stopped delete(s.consumer.subscriptions, s.id) + atomic.CompareAndSwapUint32(&s.draining, 1, 0) return nil, ErrMsgIteratorClosed } if hbMonitor != nil { @@ -907,7 +913,6 @@ func (s *pullSubscription) cleanup() { if s.hbMonitor != nil { s.hbMonitor.Stop() } - close(s.connStatusChanged) drainMode := atomic.LoadUint32(&s.draining) == 1 if drainMode { s.subscription.Drain() diff --git a/jetstream/test/jetstream_test.go b/jetstream/test/jetstream_test.go index 410c7f7bc..f5c9c8ee1 100644 --- a/jetstream/test/jetstream_test.go +++ b/jetstream/test/jetstream_test.go @@ -426,10 +426,16 @@ func TestCreateStreamMirrorCrossDomains(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - - if lStream.CachedInfo().State.Msgs != 3 { - t.Fatalf("Expected 3 msgs in stream; got: %d", lStream.CachedInfo().State.Msgs) - } + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + info, err := lStream.Info(ctx) + if err != nil { + return fmt.Errorf("Unexpected error when getting stream info: %v", err) + } + if info.State.Msgs != 3 { + return fmt.Errorf("Expected 3 msgs in stream; got: %d", lStream.CachedInfo().State.Msgs) + } + return nil + }) rjs, err := jetstream.NewWithDomain(lnc, "HUB") if err != nil { diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 6e92cfeea..4f4e85dbe 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -1359,6 +1359,61 @@ func TestPullConsumerMessages(t *testing.T) { } }) + t.Run("no messages received after stop", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]jetstream.Msg, 0) + it, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + go func() { + time.Sleep(100 * time.Millisecond) + it.Stop() + }() + for i := 0; i < 2; i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + time.Sleep(80 * time.Millisecond) + msg.Ack() + msgs = append(msgs, msg) + } + _, err = it.Next() + if !errors.Is(err, jetstream.ErrMsgIteratorClosed) { + t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err) + } + + if len(msgs) != 2 { + t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs)) + } + }) + t.Run("drain mode", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -2096,6 +2151,54 @@ func TestPullConsumerConsume(t *testing.T) { wg.Wait() }) + t.Run("no messages received after stop", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + wg := &sync.WaitGroup{} + wg.Add(2) + publishTestMsgs(t, nc) + msgs := make([]jetstream.Msg, 0) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(80 * time.Millisecond) + msg.Ack() + msgs = append(msgs, msg) + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + cc.Stop() + wg.Wait() + // wait for some time to make sure no new messages are received + time.Sleep(100 * time.Millisecond) + + if len(msgs) != 2 { + t.Fatalf("Unexpected received message count after stop; want 2; got %d", len(msgs)) + } + }) + t.Run("drain mode", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv)