Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add contexts to pubsub.Subscribe to allow early cancelation #1756

Merged
merged 1 commit into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/infrastructure/docker-compose-rocketmq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
version: '2'
services:
namesrv:
image: apache/rocketmq:4.9.3
container_name: rmqnamesrv
ports:
- 9876:9876
#volumes:
# - ./data/namesrv/logs:/home/rocketmq/logs
command: sh mqnamesrv
broker:
image: apache/rocketmq:4.9.3
container_name: rmqbroker
ports:
- 10909:10909
- 10911:10911
- 10912:10912
#volumes:
# - ./data/broker/logs:/home/rocketmq/logs
# - ./data/broker/store:/home/rocketmq/store
# - ./data/broker/conf/broker.conf:/home/rocketmq/rocketmq-4.9.3/conf/broker.conf
command: sh mqbroker -n namesrv:9876 -c ../conf/broker.conf
depends_on:
- namesrv
console:
image: styletang/rocketmq-console-ng:latest
container_name: console
links:
- namesrv
ports:
- 8080:8080
environment:
JAVA_OPTS: "-Drocketmq.namesrv.addr=namesrv:9876 -Dcom.rocketmq.sendMessageWithVIPChannel=false"
2 changes: 1 addition & 1 deletion bindings/rabbitmq/rabbitmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"strconv"
"time"

"github.com/streadway/amqp"
amqp "github.com/rabbitmq/amqp091-go"

"github.com/dapr/components-contrib/bindings"
contrib_metadata "github.com/dapr/components-contrib/metadata"
Expand Down
7 changes: 4 additions & 3 deletions bindings/rabbitmq/rabbitmq_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import (
"testing"
"time"

"github.com/google/uuid"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"

"github.com/dapr/components-contrib/bindings"
contrib_metadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
"github.com/google/uuid"
"github.com/streadway/amqp"
"github.com/stretchr/testify/assert"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ require (
github.com/sendgrid/rest v2.6.3+incompatible // indirect
github.com/sendgrid/sendgrid-go v3.5.0+incompatible
github.com/sergi/go-diff v1.2.0 // indirect
github.com/streadway/amqp v1.0.0
github.com/stretchr/testify v1.7.1
github.com/supplyon/gremcos v0.1.0
github.com/tidwall/gjson v1.8.1 // indirect
Expand Down Expand Up @@ -163,6 +162,7 @@ require (
github.com/huaweicloud/huaweicloud-sdk-go-v3 v0.0.87
github.com/labd/commercetools-go-sdk v0.3.2
github.com/nacos-group/nacos-sdk-go/v2 v2.0.1
github.com/rabbitmq/amqp091-go v1.3.4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github.com/rabbitmq/amqp091-go is the new version of github.com/streadway/amqp and it's maintained by RabbitMQ now

go.uber.org/ratelimit v0.2.0
gopkg.in/couchbase/gocb.v1 v1.6.4
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,8 @@ github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/statsd_exporter v0.21.0 h1:hA05Q5RFeIjgwKIYEdFd59xu5Wwaznf33yKI+pyX6T8=
github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3MLCHmSHelCh9hSGYNLTQ=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rabbitmq/amqp091-go v1.3.4 h1:tXuIslN1nhDqs2t6Jrz3BAoqvt4qIZzxvdbdcxWtHYU=
github.com/rabbitmq/amqp091-go v1.3.4/go.mod h1:ogQDLSOACsLPsIq0NpbtiifNZi2YOz0VTJ0kHRghqbM=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA=
Expand Down Expand Up @@ -1434,8 +1436,6 @@ github.com/stathat/consistent v1.0.0/go.mod h1:uajTPbgSygZBJ+V+0mY7meZ8i0XAcZs7A
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v1.0.0 h1:kuuDrUJFZL1QYL9hUNuCxNObNzB0bV/ZG5jV3RWAQgo=
github.com/streadway/amqp v1.0.0/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
27 changes: 20 additions & 7 deletions internal/component/kafka/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import (
)

type consumer struct {
k *Kafka
ready chan bool
once sync.Once
k *Kafka
ready chan bool
running chan struct{}
once sync.Once
}

func (consumer *consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error {
Expand Down Expand Up @@ -120,6 +121,12 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Close resources and reset synchronization primitives
k.closeSubscriptionResources()

topics := k.subscribeTopics.TopicList()
if len(topics) == 0 {
// Nothing to subscribe to
return nil
}

cg, err := sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config)
if err != nil {
return err
Expand All @@ -132,12 +139,11 @@ func (k *Kafka) Subscribe(ctx context.Context) error {

ready := make(chan bool)
k.consumer = consumer{
k: k,
ready: ready,
k: k,
ready: ready,
running: make(chan struct{}),
}

topics := k.subscribeTopics.TopicList()

go func() {
k.logger.Debugf("Subscribed and listening to topics: %s", topics)

Expand All @@ -153,6 +159,9 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Consume the requested topics
bo := backoff.WithContext(backoff.NewConstantBackOff(k.consumeRetryInterval), ctx)
innerErr := retry.NotifyRecover(func() error {
if ctxErr := ctx.Err(); ctxErr != nil {
return backoff.Permanent(ctxErr)
}
return k.cg.Consume(ctx, topics, &(k.consumer))
}, bo, func(err error, t time.Duration) {
k.logger.Errorf("Error consuming %v. Retrying...: %v", topics, err)
Expand All @@ -169,6 +178,8 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
if err != nil {
k.logger.Errorf("Error closing consumer group: %v", err)
}

close(k.consumer.running)
}()

<-ready
Expand All @@ -186,6 +197,8 @@ func (k *Kafka) closeSubscriptionResources() {
}

k.consumer.once.Do(func() {
// Wait for shutdown to be complete
<-k.consumer.running
close(k.consumer.ready)
k.consumer.once = sync.Once{}
})
Expand Down
93 changes: 66 additions & 27 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ type snsSqs struct {
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
subscriptions sync.Map

snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
ctx context.Context
cancel context.CancelFunc
pollerCtx context.Context
pollerCancel context.CancelFunc
backOffConfig retry.Config
pollerSemaphore chan struct{}
snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
ctx context.Context
cancel context.CancelFunc
pollerCtx context.Context
pollerCancel context.CancelFunc
backOffConfig retry.Config
pollerRunning chan struct{}
}

type sqsQueueInfo struct {
Expand Down Expand Up @@ -101,10 +101,10 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
}

return &snsSqs{
logger: l,
id: id,
topicsLock: sync.RWMutex{},
pollerSemaphore: make(chan struct{}, 1),
logger: l,
id: id,
topicsLock: sync.RWMutex{},
pollerRunning: make(chan struct{}, 1),
}
}

Expand Down Expand Up @@ -406,6 +406,22 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t
return *subscribeOutput.SubscriptionArn, nil
}

func (s *snsSqs) removeSnsSqsSubscription(parentCtx context.Context, subscriptionArn string) error {
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
_, err := s.snsClient.UnsubscribeWithContext(ctx, &sns.UnsubscribeInput{
SubscriptionArn: aws.String(subscriptionArn),
})
cancel()
if err != nil {
wrappedErr := fmt.Errorf("error unsubscribing to arn: %s %w", subscriptionArn, err)
s.logger.Error(wrappedErr)

return wrappedErr
}

return nil
}

func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) {
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
Expand Down Expand Up @@ -644,7 +660,7 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}

// Signal that the poller stopped
<-s.pollerSemaphore
<-s.pollerRunning
}

func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
Expand Down Expand Up @@ -760,10 +776,10 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context,
return nil
}

func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
// subscribers declare a topic ARN and declare a SQS queue to use
// these should be idempotent - queues should not be created if they exist.
topicArn, sanitizedName, err := s.getOrCreateTopic(s.ctx, req.Topic)
topicArn, sanitizedName, err := s.getOrCreateTopic(subscribeCtx, req.Topic)
if err != nil {
wrappedErr := fmt.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
s.logger.Error(wrappedErr)
Expand All @@ -773,7 +789,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)

// this is the ID of the application, it is supplied via runtime as "consumerID".
var queueInfo *sqsQueueInfo
queueInfo, err = s.getOrCreateQueue(s.ctx, s.metadata.sqsQueueName)
queueInfo, err = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsQueueName)
if err != nil {
wrappedErr := fmt.Errorf("error retrieving SQS queue: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -783,7 +799,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)

// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
// to prevent anyone but SNS to publish message to SQS.
err = s.restrictQueuePublishPolicyToOnlySNS(s.ctx, queueInfo, topicArn)
err = s.restrictQueuePublishPolicyToOnlySNS(subscribeCtx, queueInfo, topicArn)
if err != nil {
wrappedErr := fmt.Errorf("error setting sns-sqs subscription policy: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -796,15 +812,15 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
var derr error

if len(s.metadata.sqsDeadLettersQueueName) > 0 {
deadLettersQueueInfo, derr = s.getOrCreateQueue(s.ctx, s.metadata.sqsDeadLettersQueueName)
deadLettersQueueInfo, derr = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsDeadLettersQueueName)
if derr != nil {
wrappedErr := fmt.Errorf("error retrieving SQS dead-letter queue: %w", err)
s.logger.Error(wrappedErr)

return wrappedErr
}

err = s.setDeadLettersQueueAttributes(s.ctx, queueInfo, deadLettersQueueInfo)
err = s.setDeadLettersQueueAttributes(subscribeCtx, queueInfo, deadLettersQueueInfo)
if err != nil {
wrappedErr := fmt.Errorf("error creating dead-letter queue: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -814,7 +830,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
}

// subscription creation is idempotent. Subscriptions are unique by topic/queue.
_, err = s.getOrCreateSnsSqsSubscription(s.ctx, queueInfo.arn, topicArn)
subscriptionArn, err := s.getOrCreateSnsSqsSubscription(subscribeCtx, queueInfo.arn, topicArn)
if err != nil {
wrappedErr := fmt.Errorf("error subscribing topic: %s, to queue: %s, with error: %w", topicArn, queueInfo.arn, err)
s.logger.Error(wrappedErr)
Expand All @@ -828,12 +844,12 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
s.topicHandlers[sanitizedName] = topicHandler{
topicName: req.Topic,
handler: handler,
ctx: s.ctx,
ctx: subscribeCtx,
}

// Start the poller for the queue if it's not running already
select {
case s.pollerSemaphore <- struct{}{}:
case s.pollerRunning <- struct{}{}:
// If inserting in the channel succeeds, then it's not running already
// Use a context that is tied to the background context
s.pollerCtx, s.pollerCancel = context.WithCancel(s.ctx)
Expand All @@ -842,6 +858,29 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
// Do nothing, it means the poller is already running
}

// Watch for subscription context cancelation to remove this subscription
go func() {
<-subscribeCtx.Done()

s.topicsLock.Lock()
defer s.topicsLock.Unlock()

// Remove the handler
delete(s.topicHandlers, sanitizedName)

// If we can perform management operations, remove the subscription entirely
if !s.metadata.disableEntityManagement {
// Use a background context because subscribeCtx is canceled already
// Error is logged already
_ = s.removeSnsSqsSubscription(s.ctx, subscriptionArn)
}

// If we don't have any topic left, close the poller
if len(s.topicHandlers) == 0 {
s.pollerCancel()
}
}()

return nil
}

Expand Down
Loading