diff --git a/js.go b/js.go index abd71c7ff..5f152eea6 100644 --- a/js.go +++ b/js.go @@ -150,6 +150,16 @@ func (nc *Conn) JetStream(opts ...JSOpt) (JetStreamContext, error) { return js, nil } +// JSMOpt configures a JetStream manager. +type JSMOpt interface { + configureJSManager(opts *jsmOpts) error +} + +type jsmOpts struct { + ctx context.Context + ttl time.Duration +} + // JSOpt configures a JetStream context. type JSOpt interface { configureJSContext(opts *js) error @@ -340,6 +350,11 @@ func (ttl MaxWait) configureJSContext(js *js) error { return nil } +func (ttl MaxWait) configureJSManager(opt *jsmOpts) error { + opt.ttl = time.Duration(ttl) + return nil +} + // AckWait sets the maximum amount of time we will wait for an ack. type AckWait time.Duration @@ -363,6 +378,11 @@ func (ctx ContextOpt) configurePublish(opts *pubOpts) error { return nil } +func (ctx ContextOpt) configureJSManager(opts *jsmOpts) error { + opts.ctx = ctx + return nil +} + // Context returns an option that can be used to configure a context. func Context(ctx context.Context) ContextOpt { return ContextOpt{ctx} diff --git a/jsm.go b/jsm.go index 1120e09f7..948a90c21 100644 --- a/jsm.go +++ b/jsm.go @@ -41,10 +41,10 @@ type JetStreamManager interface { PurgeStream(name string) error // StreamsInfo can be used to retrieve a list of StreamInfo objects. - StreamsInfo(ctx context.Context) <-chan *StreamInfo + StreamsInfo(opts ...JSMOpt) <-chan *StreamInfo // StreamNames is used to retrieve a list of Stream names. - StreamNames(ctx context.Context) <-chan string + StreamNames(opts ...JSMOpt) <-chan string // GetMsg retrieves a raw stream message stored in JetStream by sequence number. GetMsg(name string, seq uint64) (*RawStreamMsg, error) @@ -62,10 +62,10 @@ type JetStreamManager interface { ConsumerInfo(stream, name string) (*ConsumerInfo, error) // ConsumersInfo is used to retrieve a list of ConsumerInfo objects. - ConsumersInfo(ctx context.Context, stream string) <-chan *ConsumerInfo + ConsumersInfo(stream string, opts ...JSMOpt) <-chan *ConsumerInfo // ConsumerNames is used to retrieve a list of Consumer names. - ConsumerNames(ctx context.Context, stream string) <-chan string + ConsumerNames(stream string, opts ...JSMOpt) <-chan string // AccountInfo retrieves info about the JetStream usage from an account. AccountInfo() (*AccountInfo, error) @@ -286,10 +286,25 @@ type consumerLister struct { } // ConsumersInfo returns a receive only channel to iterate on the consumers info. -func (js *js) ConsumersInfo(ctx context.Context, stream string) <-chan *ConsumerInfo { +func (js *js) ConsumersInfo(stream string, opts ...JSMOpt) <-chan *ConsumerInfo { + var o jsmOpts + if len(opts) > 0 { + for _, opt := range opts { + if err := opt.configureJSManager(&o); err != nil { + return nil + } + } + } + // Check for option collisions. Right now just timeout and context. + if o.ctx != nil && o.ttl != 0 { + return nil + } + if o.ttl == 0 && o.ctx == nil { + o.ttl = js.wait + } var cancel context.CancelFunc - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), js.wait) + if o.ctx == nil && o.ttl > 0 { + o.ctx, cancel = context.WithTimeout(context.Background(), o.ttl) } ach := make(chan *ConsumerInfo) @@ -305,7 +320,7 @@ func (js *js) ConsumersInfo(ctx context.Context, stream string) <-chan *Consumer for _, info := range cl.Page() { select { case ach <- info: - case <-ctx.Done(): + case <-o.ctx.Done(): return } } @@ -442,10 +457,25 @@ func (c *consumerNamesLister) Err() error { } // ConsumerNames is used to retrieve a list of Consumer names. -func (js *js) ConsumerNames(ctx context.Context, stream string) <-chan string { +func (js *js) ConsumerNames(stream string, opts ...JSMOpt) <-chan string { + var o jsmOpts + if len(opts) > 0 { + for _, opt := range opts { + if err := opt.configureJSManager(&o); err != nil { + return nil + } + } + } + // Check for option collisions. Right now just timeout and context. + if o.ctx != nil && o.ttl != 0 { + return nil + } + if o.ttl == 0 && o.ctx == nil { + o.ttl = js.wait + } var cancel context.CancelFunc - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), js.wait) + if o.ctx == nil && o.ttl > 0 { + o.ctx, cancel = context.WithTimeout(context.Background(), o.ttl) } ch := make(chan string) @@ -457,11 +487,12 @@ func (js *js) ConsumerNames(ctx context.Context, stream string) <-chan string { } }() defer close(ch) + for l.Next() { for _, info := range l.Page() { select { case ch <- info: - case <-ctx.Done(): + case <-o.ctx.Done(): return } } @@ -763,10 +794,25 @@ type streamLister struct { } // StreamsInfo returns a receive only channel to iterate on the streams. -func (js *js) StreamsInfo(ctx context.Context) <-chan *StreamInfo { +func (js *js) StreamsInfo(opts ...JSMOpt) <-chan *StreamInfo { + var o jsmOpts + if len(opts) > 0 { + for _, opt := range opts { + if err := opt.configureJSManager(&o); err != nil { + return nil + } + } + } + // Check for option collisions. Right now just timeout and context. + if o.ctx != nil && o.ttl != 0 { + return nil + } + if o.ttl == 0 && o.ctx == nil { + o.ttl = js.wait + } var cancel context.CancelFunc - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), js.wait) + if o.ctx == nil && o.ttl > 0 { + o.ctx, cancel = context.WithTimeout(context.Background(), o.ttl) } ach := make(chan *StreamInfo) @@ -782,7 +828,7 @@ func (js *js) StreamsInfo(ctx context.Context) <-chan *StreamInfo { for _, info := range sl.Page() { select { case ach <- info: - case <-ctx.Done(): + case <-o.ctx.Done(): return } } @@ -906,10 +952,25 @@ func (l *streamNamesLister) Err() error { } // StreamNames is used to retrieve a list of Stream names. -func (js *js) StreamNames(ctx context.Context) <-chan string { +func (js *js) StreamNames(opts ...JSMOpt) <-chan string { + var o jsmOpts + if len(opts) > 0 { + for _, opt := range opts { + if err := opt.configureJSManager(&o); err != nil { + return nil + } + } + } + // Check for option collisions. Right now just timeout and context. + if o.ctx != nil && o.ttl != 0 { + return nil + } + if o.ttl == 0 && o.ctx == nil { + o.ttl = js.wait + } var cancel context.CancelFunc - if ctx == nil { - ctx, cancel = context.WithTimeout(context.Background(), js.wait) + if o.ctx == nil && o.ttl > 0 { + o.ctx, cancel = context.WithTimeout(context.Background(), o.ttl) } ch := make(chan string) @@ -925,7 +986,7 @@ func (js *js) StreamNames(ctx context.Context) <-chan string { for _, info := range l.Page() { select { case ch <- info: - case <-ctx.Done(): + case <-o.ctx.Done(): return } } diff --git a/test/js_test.go b/test/js_test.go index ba4bc5a7a..e895a38f1 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -254,9 +254,7 @@ func TestJetStreamSubscribe(t *testing.T) { expectConsumers := func(t *testing.T, expected int) { t.Helper() var count int - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - for range js.ConsumersInfo(ctx, "TEST") { + for range js.ConsumersInfo("TEST") { count++ } if count != expected { @@ -1100,9 +1098,7 @@ func TestJetStreamManagement(t *testing.T) { t.Run("list consumer names", func(t *testing.T) { var names []string - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - for name := range js.ConsumerNames(ctx, "foo") { + for name := range js.ConsumerNames("foo") { names = append(names, name) } if got, want := len(names), 1; got != want { @@ -1113,8 +1109,7 @@ func TestJetStreamManagement(t *testing.T) { t.Run("streams info", func(t *testing.T) { var i int expected := "foo" - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - for stream := range js.StreamsInfo(ctx) { + for stream := range js.StreamsInfo(nats.MaxWait(3 * time.Second)) { i++ got := stream.Config.Name @@ -1122,7 +1117,6 @@ func TestJetStreamManagement(t *testing.T) { t.Fatalf("Expected stream to be %v, got: %v", expected, got) } } - cancel() if i != 1 { t.Errorf("Expected single stream: %v", err) } @@ -1130,22 +1124,18 @@ func TestJetStreamManagement(t *testing.T) { t.Run("consumers info", func(t *testing.T) { var called bool - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - for range js.ConsumersInfo(ctx, "") { + for range js.ConsumersInfo("") { called = true } - cancel() if called { t.Error("Expected not not receive entries") } - ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) - for ci := range js.ConsumersInfo(ctx, "foo") { + for ci := range js.ConsumersInfo("foo") { if ci.Stream != "foo" || ci.Config.Durable != "dlc" { t.Fatalf("ConsumerInfo is not correct %+v", ci) } } - cancel() }) t.Run("delete consumers", func(t *testing.T) { @@ -1170,9 +1160,7 @@ func TestJetStreamManagement(t *testing.T) { t.Run("list stream names", func(t *testing.T) { var names []string - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - for name := range js.StreamNames(ctx) { + for name := range js.StreamNames() { names = append(names, name) } if got, want := len(names), 1; got != want { @@ -2562,9 +2550,7 @@ func TestJetStream_Unsubscribe(t *testing.T) { t.Helper() var i int - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - for range js.ConsumersInfo(ctx, "foo") { + for range js.ConsumersInfo("foo") { i++ } if i != expected { @@ -2693,10 +2679,10 @@ func TestJetStream_UnsubscribeCloseDrain(t *testing.T) { fetchConsumers := func(t *testing.T, expected int) { t.Helper() - var i int ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - for range jsm.ConsumersInfo(ctx, "foo") { + var i int + for range jsm.ConsumersInfo("foo", nats.Context(ctx)) { i++ } if i != expected {