Skip to content

Commit

Permalink
[FIXED] Race condition when getting pull subscriptions in ordered con…
Browse files Browse the repository at this point in the history
…sumer (#1497)

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Jan 9, 2024
1 parent dce16a5 commit c067746
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
48 changes: 35 additions & 13 deletions jetstream/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,20 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
}
meta, err := msg.Metadata()
if err != nil {
c.errHandler(serial)(c.currentConsumer.subscriptions[""], err)
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(serial)(sub, err)
return
}
dseq := meta.Sequence.Consumer
if dseq != c.cursor.deliverSeq+1 {
c.errHandler(serial)(c.currentConsumer.subscriptions[""], errOrderedSequenceMismatch)
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(serial)(sub, errOrderedSequenceMismatch)
return
}
c.cursor.deliverSeq = dseq
Expand Down Expand Up @@ -141,7 +149,11 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
}
}
if err := c.reset(); err != nil {
c.errHandler(c.serial)(c.currentConsumer.subscriptions[""], err)
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(c.serial)(sub, err)
}
if c.stopAfter > 0 {
opts = opts[:len(opts)-2]
Expand All @@ -155,7 +167,11 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
opts = append(opts, consumeStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft))
}
if _, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil {
c.errHandler(c.serial)(c.currentConsumer.subscriptions[""], err)
sub, ok := c.currentConsumer.getSubscription("")
if !ok {
return
}
c.errHandler(c.serial)(sub, err)
}
case <-sub.done:
return
Expand Down Expand Up @@ -234,7 +250,11 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
func (s *orderedSubscription) Next() (Msg, error) {
for {
currentConsumer := s.consumer.currentConsumer
msg, err := currentConsumer.subscriptions[""].Next()
sub, ok := currentConsumer.getSubscription("")
if !ok {
return nil, ErrMsgIteratorClosed
}
msg, err := sub.Next()
if err != nil {
if errors.Is(err, ErrMsgIteratorClosed) {
s.Stop()
Expand Down Expand Up @@ -262,13 +282,13 @@ func (s *orderedSubscription) Next() (Msg, error) {
}
meta, err := msg.Metadata()
if err != nil {
s.consumer.errHandler(s.consumer.serial)(currentConsumer.subscriptions[""], err)
s.consumer.errHandler(s.consumer.serial)(sub, err)
continue
}
serial := serialNumberFromConsumer(meta.Consumer)
dseq := meta.Sequence.Consumer
if dseq != s.consumer.cursor.deliverSeq+1 {
s.consumer.errHandler(serial)(currentConsumer.subscriptions[""], errOrderedSequenceMismatch)
s.consumer.errHandler(serial)(sub, errOrderedSequenceMismatch)
continue
}
s.consumer.cursor.deliverSeq = dseq
Expand All @@ -278,12 +298,13 @@ func (s *orderedSubscription) Next() (Msg, error) {
}

func (s *orderedSubscription) Stop() {
s.consumer.currentConsumer.Lock()
defer s.consumer.currentConsumer.Unlock()
if s.consumer.currentConsumer.subscriptions[""] == nil {
sub, ok := s.consumer.currentConsumer.getSubscription("")
if !ok {
return
}
s.consumer.currentConsumer.subscriptions[""].Stop()
s.consumer.currentConsumer.Lock()
defer s.consumer.currentConsumer.Unlock()
sub.Stop()
close(s.done)
}

Expand Down Expand Up @@ -390,9 +411,10 @@ func (c *orderedConsumer) reset() error {
defer c.Unlock()
defer atomic.StoreUint32(&c.resetInProgress, 0)
if c.currentConsumer != nil {
sub, ok := c.currentConsumer.getSubscription("")
c.currentConsumer.Lock()
if c.currentConsumer.subscriptions[""] != nil {
c.currentConsumer.subscriptions[""].Stop()
if ok {
sub.Stop()
}
consName := c.currentConsumer.CachedInfo().Name
c.currentConsumer.Unlock()
Expand Down
7 changes: 7 additions & 0 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,10 @@ func retryWithBackoff(f func(int) (bool, error), opts backoffOpts) error {
}
return err
}

func (c *pullConsumer) getSubscription(id string) (*pullSubscription, bool) {
c.Lock()
defer c.Unlock()
sub, ok := c.subscriptions[id]
return sub, ok
}

0 comments on commit c067746

Please sign in to comment.