Skip to content

Commit

Permalink
[ADDED] Drain for jetstream consume methods
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Jan 10, 2024
1 parent c067746 commit 704cde6
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .golangci.yaml
Expand Up @@ -5,6 +5,9 @@ issues:
- linters:
- errcheck
text: "Unsubscribe"
- linters:
- errcheck
text: "Drain"
- linters:
- errcheck
text: "msg.Ack"
Expand Down
18 changes: 18 additions & 0 deletions jetstream/ordered.go
Expand Up @@ -49,6 +49,7 @@ type (
consumer *orderedConsumer
opts []PullMessagesOpt
done chan struct{}
closed uint32
}

cursor struct {
Expand Down Expand Up @@ -298,6 +299,9 @@ func (s *orderedSubscription) Next() (Msg, error) {
}

func (s *orderedSubscription) Stop() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return
}
sub, ok := s.consumer.currentConsumer.getSubscription("")
if !ok {
return
Expand All @@ -308,6 +312,20 @@ func (s *orderedSubscription) Stop() {
close(s.done)
}

func (s *orderedSubscription) Drain() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return
}
sub, ok := s.consumer.currentConsumer.getSubscription("")
if !ok {
return
}
s.consumer.currentConsumer.Lock()
defer s.consumer.currentConsumer.Unlock()
sub.Drain()
close(s.done)
}

// Fetch is used to retrieve up to a provided number of messages from a stream.
// This method will always send a single request and wait until either all messages are retrieved
// or context reaches its deadline.
Expand Down
114 changes: 91 additions & 23 deletions jetstream/pull.go
@@ -1,4 +1,4 @@
// Copyright 2022-2023 The NATS Authors
// Copyright 2022-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -30,14 +30,29 @@ import (
type (
// MessagesContext supports iterating over a messages on a stream.
MessagesContext interface {
// Next retreives next message on a stream. It will block until the next message is available.
// Next retreives next message on a stream. It will block until the next
// message is available.
Next() (Msg, error)
// Stop closes the iterator and cancels subscription.

// Stop unsubscribes from the stream and cancels subscription. Calling
// Next after calling Stop will return ErrMsgIteratorClosed error.
Stop()

// Drain unsubscribes from the stream and cancels subscription. All
// messages that are already in the buffer will be available on
// subsequent calls to Next. After the buffer is drained, Next will
// return ErrMsgIteratorClosed error.
Drain()
}

ConsumeContext interface {
// Stop unsubscribes from the stream and cancels subscription.
// No more messages will be received after calling this method.
Stop()

// Drain unsubscribes from the stream and cancels subscription.
// All messages that are already in the buffer will be processed in callback function.
Drain()
}

// MessageHandler is a handler function used as callback in [Consume]
Expand Down Expand Up @@ -97,7 +112,9 @@ type (
hbMonitor *hbMonitor
fetchInProgress uint32
closed uint32
draining uint32
done chan struct{}
drained chan struct{}
connStatusChanged chan nats.Status
fetchNext chan *pullRequest
consumeOpts *consumeOpts
Expand Down Expand Up @@ -240,6 +257,13 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
if err != nil {
return nil, err
}
sub.subscription.SetClosedHandler(func(sid string) func(string) {
return func(subject string) {
p.Lock()
defer p.Unlock()
delete(sub.consumer.subscriptions, sid)
}
}(sub.id))

sub.Lock()
// initial pull
Expand Down Expand Up @@ -352,6 +376,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
sub.resetPendingMsgs()
}
sub.Unlock()
case <-sub.done:
return
}
}
}()
Expand Down Expand Up @@ -438,6 +464,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
id: consumeID,
consumer: p,
done: make(chan struct{}, 1),
drained: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
fetchNext: make(chan *pullRequest, 1),
Expand All @@ -450,28 +477,42 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
p.Unlock()
return nil, err
}
sub.subscription.SetClosedHandler(func(sid string) func(string) {
return func(subject string) {
p.Lock()
defer p.Unlock()
if atomic.LoadUint32(&sub.draining) != 1 {
// if we're not draining, subscription can be closed as soon
// as closed handler is called
// otherwise, we need to wait until all messages are drained
// in Next
delete(p.subscriptions, sid)
}
close(msgs)
}
}(sub.id))

go func() {
<-sub.done
sub.cleanup()
}()
p.subscriptions[sub.id] = sub
p.Unlock()

go sub.pullMessages(subject)

go func() {
for {
status, ok := <-sub.connStatusChanged
if !ok {
select {
case status, ok := <-sub.connStatusChanged:
if !ok {
return
}
if status == nats.CONNECTED {
sub.errs <- errConnected
}
if status == nats.RECONNECTING {
sub.errs <- errDisconnected
}
case <-sub.done:
return
}
if status == nats.CONNECTED {
sub.errs <- errConnected
}
if status == nats.RECONNECTING {
sub.errs <- errDisconnected
}
}
}()

Expand All @@ -486,7 +527,7 @@ var (
func (s *pullSubscription) Next() (Msg, error) {
s.Lock()
defer s.Unlock()
if atomic.LoadUint32(&s.closed) == 1 {
if len(s.msgs) == 0 && (s.subscription == nil || !s.subscription.IsValid()) {
return nil, ErrMsgIteratorClosed
}
hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat)
Expand All @@ -506,8 +547,17 @@ func (s *pullSubscription) Next() (Msg, error) {
s.checkPending()
select {
case <-s.done:
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
continue
}
return nil, ErrMsgIteratorClosed
case msg := <-s.msgs:
case msg, ok := <-s.msgs:
if !ok {
// if msgs channel is closed, it means that subscription was either drained or stopped
delete(s.consumer.subscriptions, s.id)
return nil, ErrMsgIteratorClosed
}
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
}
Expand Down Expand Up @@ -650,6 +700,21 @@ func (s *pullSubscription) Stop() {
}
}

func (s *pullSubscription) Drain() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return
}
atomic.StoreUint32(&s.draining, 1)
close(s.done)
if s.consumeOpts.stopAfterMsgsLeft != nil {
if s.delivered >= s.consumeOpts.StopAfter {
close(s.consumeOpts.stopAfterMsgsLeft)
} else {
s.consumeOpts.stopAfterMsgsLeft <- s.consumeOpts.StopAfter - s.delivered
}
}
}

// Fetch sends a single request to retrieve given number of messages.
// It will wait up to provided expiry time if not all messages are available.
func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
Expand Down Expand Up @@ -834,18 +899,21 @@ func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor
}

func (s *pullSubscription) cleanup() {
s.consumer.Lock()
defer s.consumer.Unlock()
if s.subscription == nil {
s.Lock()
defer s.Unlock()
if s.subscription == nil || !s.subscription.IsValid() {
return
}
if s.hbMonitor != nil {
s.hbMonitor.Stop()
}
s.subscription.Unsubscribe()
close(s.connStatusChanged)
s.subscription = nil
delete(s.consumer.subscriptions, s.id)
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
s.subscription.Drain()
} else {
s.subscription.Unsubscribe()
}
atomic.StoreUint32(&s.closed, 1)
}

Expand Down
95 changes: 95 additions & 0 deletions jetstream/test/ordered_test.go
Expand Up @@ -408,6 +408,46 @@ func TestOrderedConsumerConsume(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
})

t.Run("drain mode", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
wg := &sync.WaitGroup{}
wg.Add(5)
publishTestMsgs(t, nc)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
wg.Done()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
cc.Drain()
wg.Wait()
})
}

func TestOrderedConsumerMessages(t *testing.T) {
Expand Down Expand Up @@ -822,6 +862,61 @@ func TestOrderedConsumerMessages(t *testing.T) {
t.Fatalf("Expected error: %v; got: %v", jetstream.ErrOrderedConsumerConcurrentRequests, err)
}
})

t.Run("drain mode", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

msgs := make([]jetstream.Msg, 0)
it, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

publishTestMsgs(t, nc)
go func() {
time.Sleep(100 * time.Millisecond)
it.Drain()
}()
for i := 0; i < len(testMsgs); i++ {
msg, err := it.Next()
if err != nil {
t.Fatal(err)
}
time.Sleep(50 * time.Millisecond)
msg.Ack()
msgs = append(msgs, msg)
}
_, err = it.Next()
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
}

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs))
}
})
}

func TestOrderedConsumerFetch(t *testing.T) {
Expand Down

0 comments on commit 704cde6

Please sign in to comment.