Skip to content

Commit

Permalink
[ADDED] FetchBatch method to utilize non-blocking pull subscription r…
Browse files Browse the repository at this point in the history
…equests (#1211)

Co-authored-by: Waldemar Quevedo <wally@nats.io>
  • Loading branch information
piotrpio and wallyqs committed Feb 15, 2023
1 parent 2805753 commit edb105c
Show file tree
Hide file tree
Showing 3 changed files with 613 additions and 16 deletions.
240 changes: 227 additions & 13 deletions js.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2022 The NATS Authors
// Copyright 2020-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -2673,16 +2673,6 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
return nil, err
}

// Use the deadline of the context to base the expire times.
deadline, _ := ctx.Deadline()
ttl = time.Until(deadline)
checkCtxErr := func(err error) error {
if o.ctx == nil && err == context.DeadlineExceeded {
return ErrTimeout
}
return err
}

var (
msgs = make([]*Msg, 0, batch)
msg *Msg
Expand Down Expand Up @@ -2716,7 +2706,7 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
sendReq := func() error {
// The current deadline for the context will be used
// to set the expires TTL for a fetch request.
deadline, _ = ctx.Deadline()
deadline, _ := ctx.Deadline()
ttl = time.Until(deadline)

// Check if context has already been canceled or expired.
Expand Down Expand Up @@ -2766,11 +2756,235 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
}
// If there is at least a message added to msgs, then need to return OK and no error
if err != nil && len(msgs) == 0 {
return nil, checkCtxErr(err)
return nil, o.checkCtxErr(err)
}
return msgs, nil
}

// MessageBatch provides methods to retrieve messages consumed using [Subscribe.FetchBatch].
type MessageBatch interface {
// Messages returns a channel on which messages will be published.
Messages() <-chan *Msg

// Error returns an error encountered when fetching messages.
Error() error

// Done signals end of execution.
Done() <-chan struct{}
}

type messageBatch struct {
msgs chan *Msg
err error
done chan struct{}
}

func (mb *messageBatch) Messages() <-chan *Msg {
return mb.msgs
}

func (mb *messageBatch) Error() error {
return mb.err
}

func (mb *messageBatch) Done() <-chan struct{} {
return mb.done
}

// FetchBatch pulls a batch of messages from a stream for a pull consumer.
// Unlike [Subscription.Fetch], it is non blocking and returns [MessageBatch],
// allowing to retrieve incoming messages from a channel.
// The returned channel is always closed after all messages for a batch have been
// delivered by the server - it is safe to iterate over it using range.
//
// To avoid using default JetStream timeout as fetch expiry time, use [nats.MaxWait]
// or [nats.Context] (with deadline set).
//
// This method will not return error in case of pull request expiry (even if there are no messages).
// Any other error encountered when receiving messages will cause FetchBatch to stop receiving new messages.
func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, error) {
if sub == nil {
return nil, ErrBadSubscription
}
if batch < 1 {
return nil, ErrInvalidArg
}

var o pullOpts
for _, opt := range opts {
if err := opt.configurePull(&o); err != nil {
return nil, err
}
}
if o.ctx != nil && o.ttl != 0 {
return nil, ErrContextAndTimeout
}
sub.mu.Lock()
jsi := sub.jsi
// Reject if this is not a pull subscription. Note that sub.typ is SyncSubscription,
// so check for jsi.pull boolean instead.
if jsi == nil || !jsi.pull {
sub.mu.Unlock()
return nil, ErrTypeSubscription
}

nc := sub.conn
nms := sub.jsi.nms
rply := sub.jsi.deliver
js := sub.jsi.js
pmc := len(sub.mch) > 0

// All fetch requests have an expiration, in case of no explicit expiration
// then the default timeout of the JetStream context is used.
ttl := o.ttl
if ttl == 0 {
ttl = js.opts.wait
}
sub.mu.Unlock()

// Use the given context or setup a default one for the span
// of the pull batch request.
var (
ctx = o.ctx
cancel context.CancelFunc
cancelContext = true
)
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), ttl)
} else if _, hasDeadline := ctx.Deadline(); !hasDeadline {
// Prevent from passing the background context which will just block
// and cannot be canceled either.
if octx, ok := ctx.(ContextOpt); ok && octx.Context == context.Background() {
return nil, ErrNoDeadlineContext
}

// If the context did not have a deadline, then create a new child context
// that will use the default timeout from the JS context.
ctx, cancel = context.WithTimeout(ctx, ttl)
}
defer func() {
// only cancel the context here if we are sure the fetching goroutine has not been started yet
if cancel != nil && cancelContext {
cancel()
}
}()

// Check if context not done already before making the request.
select {
case <-ctx.Done():
if o.ctx != nil { // Timeout or Cancel triggered by context object option
return nil, ctx.Err()
} else { // Timeout triggered by timeout option
return nil, ErrTimeout
}
default:
}

result := &messageBatch{
msgs: make(chan *Msg, batch),
done: make(chan struct{}, 1),
}
var msg *Msg
for pmc && len(result.msgs) < batch {
// Check next msg with booleans that say that this is an internal call
// for a pull subscribe (so don't reject it) and don't wait if there
// are no messages.
msg, err := sub.nextMsgWithContext(ctx, true, false)
if err != nil {
if err == errNoMessages {
err = nil
}
result.err = err
break
}
// Check msg but just to determine if this is a user message
// or status message, however, we don't care about values of status
// messages at this point in the Fetch() call, so checkMsg can't
// return an error.
if usrMsg, _ := checkMsg(msg, false, false); usrMsg {
result.msgs <- msg
}
}
if len(result.msgs) == batch || result.err != nil {
close(result.msgs)
result.done <- struct{}{}
return result, nil
}

deadline, _ := ctx.Deadline()
ttl = time.Until(deadline)

// Make our request expiration a bit shorter than the current timeout.
expires := ttl
if ttl >= 20*time.Millisecond {
expires = ttl - 10*time.Millisecond
}

requestBatch := batch - len(result.msgs)
req := nextRequest{
Expires: expires,
Batch: requestBatch,
MaxBytes: o.maxBytes,
}
reqJSON, err := json.Marshal(req)
if err != nil {
close(result.msgs)
result.done <- struct{}{}
result.err = err
return result, nil
}
if err := nc.PublishRequest(nms, rply, reqJSON); err != nil {
if len(result.msgs) == 0 {
return nil, err
}
close(result.msgs)
result.done <- struct{}{}
result.err = err
return result, nil
}
cancelContext = false
go func() {
if cancel != nil {
defer cancel()
}
var requestMsgs int
for requestMsgs < requestBatch {
// Ask for next message and wait if there are no messages
msg, err = sub.nextMsgWithContext(ctx, true, true)
if err != nil {
break
}
var usrMsg bool

usrMsg, err = checkMsg(msg, true, false)
if err != nil {
if err == ErrTimeout {
err = nil
}
break
}
if usrMsg {
result.msgs <- msg
requestMsgs++
}
}
if err != nil {
result.err = o.checkCtxErr(err)
}
close(result.msgs)
result.done <- struct{}{}
}()
return result, nil
}

// checkCtxErr is used to determine whether ErrTimeout should be returned in case of context timeout
func (o *pullOpts) checkCtxErr(err error) error {
if o.ctx == nil && err == context.DeadlineExceeded {
return ErrTimeout
}
return err
}

func (js *js) getConsumerInfo(stream, consumer string) (*ConsumerInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), js.opts.wait)
defer cancel()
Expand Down
3 changes: 1 addition & 2 deletions js_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2012-2022 The NATS Authors
// Copyright 2012-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -116,7 +116,6 @@ func TestJetStreamOrderedConsumer(t *testing.T) {

// Create a sample asset.
msg := make([]byte, 1024*1024)
//lint:ignore SA1019 crypto/rand.Read is recommended after Go 1.20 but fine for this test.
rand.Read(msg)
msg = []byte(base64.StdEncoding.EncodeToString(msg))
mlen, sum := len(msg), sha256.Sum256(msg)
Expand Down
Loading

0 comments on commit edb105c

Please sign in to comment.