Skip to content

Commit

Permalink
pubsub/rabbitpubsub: Add URL parameter that sets the qos prefetch co…
Browse files Browse the repository at this point in the history
…unt (#3431)
  • Loading branch information
peczenyj committed May 28, 2024
1 parent d8b9c94 commit e677ded
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 49 deletions.
6 changes: 6 additions & 0 deletions pubsub/rabbitpubsub/amqp.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type amqpChannel interface {
QueueDeclareAndBind(qname, ename string) error
ExchangeDelete(string) error
QueueDelete(qname string) error
Qos(prefetchCount, prefetchSize int, global bool) error
}

// connection adapts an *amqp.Connection to the amqpConnection interface.
Expand All @@ -79,6 +80,7 @@ func (c *connection) Channel() (amqpChannel, error) {
if err := ch.Confirm(wait); err != nil {
return nil, err
}

return &channel{ch}, nil
}

Expand Down Expand Up @@ -168,3 +170,7 @@ func (ch *channel) QueueDelete(qname string) error {
_, err := ch.ch.QueueDelete(qname, false, false, false)
return err
}

func (ch *channel) Qos(prefetchCount, prefetchSize int, global bool) error {
return ch.ch.Qos(prefetchCount, prefetchSize, global)
}
8 changes: 8 additions & 0 deletions pubsub/rabbitpubsub/fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ func (ch *fakeChannel) QueueDelete(name string) error {
return nil
}

func (ch *fakeChannel) Qos(_, _ int, _ bool) error {
if ch.isClosed() {
return amqp.ErrClosed
}

return nil
}

// Assumes nothing is ever written to the channel.
func chanIsClosed(ch chan struct{}) bool {
select {
Expand Down
62 changes: 55 additions & 7 deletions pubsub/rabbitpubsub/rabbit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/url"
"os"
"path"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -96,7 +97,9 @@ const Scheme = "rabbit"
//
// For subscriptions, the URL's host+path is used as the queue name.
//
// No query parameters are supported.
// An optional query string can be used to set the Qos consumer prefetch on subscriptions
// like "rabbit://myqueue?prefetch_count=1000" to set the consumer prefetch count to 1000
// see also https://www.rabbitmq.com/docs/consumer-prefetch
type URLOpener struct {
// Connection to use for communication with the server.
Connection *amqp.Connection
Expand All @@ -118,11 +121,27 @@ func (o *URLOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic

// OpenSubscriptionURL opens a pubsub.Subscription based on u.
func (o *URLOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) {
for param := range u.Query() {
return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param)
opts := o.SubscriptionOptions
for param, value := range u.Query() {
switch param {
case "prefetch_count":
if len(value) != 1 || len(value[0]) == 0 {
return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param)
}

prefetchCount, err := strconv.Atoi(value[0])
if err != nil {
return nil, fmt.Errorf("open subscription %v: invalid query parameter %q: %w", u, param, err)
}

opts.PrefetchCount = &prefetchCount
default:
return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param)
}
}

queueName := path.Join(u.Host, u.Path)
return OpenSubscription(o.Connection, queueName, &o.SubscriptionOptions), nil
return OpenSubscription(o.Connection, queueName, &opts), nil
}

type topic struct {
Expand All @@ -142,7 +161,10 @@ type TopicOptions struct{}

// SubscriptionOptions sets options for constructing a *pubsub.Subscription
// backed by RabbitMQ.
type SubscriptionOptions struct{}
type SubscriptionOptions struct {
// Qos property prefetch count. Optional.
PrefetchCount *int
}

// OpenTopic returns a *pubsub.Topic corresponding to the named exchange.
// See the package documentation for an example.
Expand Down Expand Up @@ -515,14 +537,16 @@ func (*topic) Close() error { return nil }
// The documentation of the amqp package recommends using separate connections for
// publishing and subscribing.
func OpenSubscription(conn *amqp.Connection, name string, opts *SubscriptionOptions) *pubsub.Subscription {
return pubsub.NewSubscription(newSubscription(&connection{conn}, name), nil, nil)
return pubsub.NewSubscription(newSubscription(&connection{conn}, name, opts), nil, nil)
}

type subscription struct {
conn amqpConnection
queue string // the AMQP queue name
consumer string // the client-generated name for this particular subscriber

opts *SubscriptionOptions

mu sync.Mutex
ch amqpChannel // AMQP channel used for all communication.
delc <-chan amqp.Delivery
Expand All @@ -533,11 +557,16 @@ type subscription struct {

var nextConsumer int64 // atomic

func newSubscription(conn amqpConnection, name string) *subscription {
func newSubscription(conn amqpConnection, name string, opts *SubscriptionOptions) *subscription {
if opts == nil {
opts = &SubscriptionOptions{}
}

return &subscription{
conn: conn,
queue: name,
consumer: fmt.Sprintf("c%d", atomic.AddInt64(&nextConsumer, 1)),
opts: opts,
receiveBatchHook: func() {},
}
}
Expand All @@ -564,15 +593,34 @@ func (s *subscription) establishChannel(ctx context.Context) error {
if err != nil {
return err
}
// Apply subscription options to channel.
err = applyOptionsToChannel(s.opts, ch)
if err != nil {
return err
}
// Subscribe to messages from the queue.
s.delc, err = ch.Consume(s.queue, s.consumer)
return err
})
if err != nil {
return err
}

s.ch = ch
s.closec = ch.NotifyClose(make(chan *amqp.Error, 1)) // closec will get at most one element

return nil
}

func applyOptionsToChannel(opts *SubscriptionOptions, ch amqpChannel) error {
if opts.PrefetchCount == nil {
return nil
}

if err := ch.Qos(*opts.PrefetchCount, 0, false); err != nil {
return fmt.Errorf("unable to set channel Qos: %w", err)
}

return nil
}

Expand Down
131 changes: 89 additions & 42 deletions pubsub/rabbitpubsub/rabbit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ const rabbitURL = "amqp://guest:guest@localhost:5672/"
var logOnce sync.Once

func mustDialRabbit(t testing.TB) amqpConnection {
t.Helper()

if !setup.HasDockerTestEnvironment() {
logOnce.Do(func() {
t.Log("using the fake because the RabbitMQ server is not available")
Expand All @@ -61,6 +63,8 @@ func mustDialRabbit(t testing.TB) amqpConnection {

func TestConformance(t *testing.T) {
harnessMaker := func(_ context.Context, t *testing.T) (drivertest.Harness, error) {
t.Helper()

return &harness{conn: mustDialRabbit(t)}, nil
}
_, isFake := mustDialRabbit(t).(*fakeConnection)
Expand All @@ -73,6 +77,8 @@ func TestConformance(t *testing.T) {
}
t.Logf("now running tests with the fake")
harnessMaker = func(_ context.Context, t *testing.T) (drivertest.Harness, error) {
t.Helper()

return &harness{conn: newFakeConnection()}, nil
}
asTests = []drivertest.AsTest{rabbitAsTest{true}}
Expand Down Expand Up @@ -138,12 +144,12 @@ func (h *harness) CreateSubscription(_ context.Context, dt driver.Topic, testNam
}
ch.QueueDelete(queue)
}
ds = newSubscription(h.conn, queue)
ds = newSubscription(h.conn, queue, nil)
return ds, cleanup, nil
}

func (h *harness) MakeNonexistentSubscription(_ context.Context) (driver.Subscription, func(), error) {
return newSubscription(h.conn, "nonexistent-subscription"), func() {}, nil
return newSubscription(h.conn, "nonexistent-subscription", nil), func() {}, nil
}

func (h *harness) Close() {
Expand Down Expand Up @@ -379,62 +385,103 @@ func (rabbitAsTest) AfterSend(as func(interface{}) bool) error {
return nil
}

func fakeConnectionStringInEnv() func() {
oldEnvVal := os.Getenv("RABBIT_SERVER_URL")
os.Setenv("RABBIT_SERVER_URL", "amqp://localhost:10000/vhost")
return func() {
os.Setenv("RABBIT_SERVER_URL", oldEnvVal)
}
}

func TestOpenTopicFromURL(t *testing.T) {
cleanup := fakeConnectionStringInEnv()
defer cleanup()
t.Setenv("RABBIT_SERVER_URL", rabbitURL)

tests := []struct {
URL string
WantErr bool
label string
URLTemplate string
WantErr bool
}{
// OK, but still error because Dial fails.
{"rabbit://myexchange", true},
// Invalid parameter.
{"rabbit://myexchange?param=value", true},
{"valid url", "rabbit://%s", false},
{"invalid url with parameters", "rabbit://%s?param=value", true},
}

ctx := context.Background()
for _, test := range tests {
topic, err := pubsub.OpenTopic(ctx, test.URL)
if (err != nil) != test.WantErr {
t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
}
if topic != nil {
topic.Shutdown(ctx)
}
t.Run(test.label, func(t *testing.T) {
conn := mustDialRabbit(t)
_, isFake := conn.(*fakeConnection)
if isFake {
t.Skip("test requires real rabbitmq")
}

h := &harness{conn: conn}

ctx := context.Background()

dt, cleanupTopic, err := h.CreateTopic(ctx, t.Name())
if err != nil {
t.Fatalf("unable to create topic: %v", err)
}

t.Cleanup(cleanupTopic)

exchange := dt.(*topic).exchange
url := fmt.Sprintf(test.URLTemplate, exchange)

topic, err := pubsub.OpenTopic(ctx, url)
if (err != nil) != test.WantErr {
t.Errorf("%s: got error %v, want error %v", test.URLTemplate, err, test.WantErr)
}
if topic != nil {
topic.Shutdown(ctx)
}
})
}
}

func TestOpenSubscriptionFromURL(t *testing.T) {
cleanup := fakeConnectionStringInEnv()
defer cleanup()
t.Setenv("RABBIT_SERVER_URL", rabbitURL)

tests := []struct {
URL string
WantErr bool
label string
URLTemplate string
WantErr bool
}{
// OK, but error because Dial fails.
{"rabbit://myqueue", true},
// Invalid parameter.
{"rabbit://myqueue?param=value", true},

{"url with no QoS prefetch count", "rabbit://%s", false},
{"invalid parameters", "rabbit://%s?param=value", true},
{"valid url with QoS prefetch count", "rabbit://%s?prefetch_count=1024", false},
{"invalid url with QoS prefetch count", "rabbit://%s?prefetch_count=value", true},
}

ctx := context.Background()
for _, test := range tests {
sub, err := pubsub.OpenSubscription(ctx, test.URL)
if (err != nil) != test.WantErr {
t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
}
if sub != nil {
sub.Shutdown(ctx)
}
t.Run(test.label, func(t *testing.T) {
conn := mustDialRabbit(t)
_, isFake := conn.(*fakeConnection)
if isFake {
t.Skip("test requires real rabbitmq")
}

h := &harness{conn: conn}

ctx := context.Background()

dt, cleanupTopic, err := h.CreateTopic(ctx, t.Name())
if err != nil {
t.Fatalf("unable to create topic: %v", err)
}

t.Cleanup(cleanupTopic)

ds, cleanupSubscription, err := h.CreateSubscription(ctx, dt, t.Name())
if err != nil {
t.Fatalf("unable to create subscription: %v", err)
}

t.Cleanup(cleanupSubscription)

queue := ds.(*subscription).queue
url := fmt.Sprintf(test.URLTemplate, queue)

sub, err := pubsub.OpenSubscription(ctx, url)
if (err != nil) != test.WantErr {
t.Errorf("%s: got error %v, want error %v", test.URLTemplate, err, test.WantErr)
}

if sub != nil {
sub.Shutdown(ctx)
}
})
}
}

0 comments on commit e677ded

Please sign in to comment.