Skip to content

Commit

Permalink
Add contexts to pubsub.Subscribe to allow early cancelation (#1756)
Browse files Browse the repository at this point in the history
This commit is related to dapr/dapr#4624. As noted there, we have an issue in the runtime where all components are shut down after the grace period, when the app is likely already stopped. Because of that, certain input components (the subscribe part of pubsub and the input part of bindings - the latter not in scope of this PR), can continue bringing new work when it's known to fail.

In order to fix the issue linked above properly, we need to implement a way for PubSub components to have the "publish" part closed before the "subscribe" one (and in the future that will need to be done for input bindings too).

This commit achieves precisely that by adding a context in the Subscribe method. When that context is canceled (which can be at any time), the subscription is removed.

PS: This API change was implemented so it can one day be used for dapr/dapr#814 too, as it allows canceling individual subscriptions by using a different context. Although that's not possible today because it requires more work on the runtime, it does implement everything that's needed in the pubsub components already.
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
  • Loading branch information
ItalyPaleAle committed Jun 2, 2022
1 parent 7c35a4e commit 704f4dd
Show file tree
Hide file tree
Showing 37 changed files with 948 additions and 1,789 deletions.
33 changes: 33 additions & 0 deletions .github/infrastructure/docker-compose-rocketmq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
version: '2'
services:
namesrv:
image: apache/rocketmq:4.9.3
container_name: rmqnamesrv
ports:
- 9876:9876
#volumes:
# - ./data/namesrv/logs:/home/rocketmq/logs
command: sh mqnamesrv
broker:
image: apache/rocketmq:4.9.3
container_name: rmqbroker
ports:
- 10909:10909
- 10911:10911
- 10912:10912
#volumes:
# - ./data/broker/logs:/home/rocketmq/logs
# - ./data/broker/store:/home/rocketmq/store
# - ./data/broker/conf/broker.conf:/home/rocketmq/rocketmq-4.9.3/conf/broker.conf
command: sh mqbroker -n namesrv:9876 -c ../conf/broker.conf
depends_on:
- namesrv
console:
image: styletang/rocketmq-console-ng:latest
container_name: console
links:
- namesrv
ports:
- 8080:8080
environment:
JAVA_OPTS: "-Drocketmq.namesrv.addr=namesrv:9876 -Dcom.rocketmq.sendMessageWithVIPChannel=false"
2 changes: 1 addition & 1 deletion bindings/rabbitmq/rabbitmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"strconv"
"time"

"github.com/streadway/amqp"
amqp "github.com/rabbitmq/amqp091-go"

"github.com/dapr/components-contrib/bindings"
contrib_metadata "github.com/dapr/components-contrib/metadata"
Expand Down
7 changes: 4 additions & 3 deletions bindings/rabbitmq/rabbitmq_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import (
"testing"
"time"

"github.com/google/uuid"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"

"github.com/dapr/components-contrib/bindings"
contrib_metadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
"github.com/google/uuid"
"github.com/streadway/amqp"
"github.com/stretchr/testify/assert"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ require (
github.com/sendgrid/rest v2.6.3+incompatible // indirect
github.com/sendgrid/sendgrid-go v3.5.0+incompatible
github.com/sergi/go-diff v1.2.0 // indirect
github.com/streadway/amqp v1.0.0
github.com/stretchr/testify v1.7.1
github.com/supplyon/gremcos v0.1.0
github.com/tidwall/gjson v1.8.1 // indirect
Expand Down Expand Up @@ -163,6 +162,7 @@ require (
github.com/huaweicloud/huaweicloud-sdk-go-v3 v0.0.87
github.com/labd/commercetools-go-sdk v0.3.2
github.com/nacos-group/nacos-sdk-go/v2 v2.0.1
github.com/rabbitmq/amqp091-go v1.3.4
go.uber.org/ratelimit v0.2.0
gopkg.in/couchbase/gocb.v1 v1.6.4
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,8 @@ github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/statsd_exporter v0.21.0 h1:hA05Q5RFeIjgwKIYEdFd59xu5Wwaznf33yKI+pyX6T8=
github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3MLCHmSHelCh9hSGYNLTQ=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rabbitmq/amqp091-go v1.3.4 h1:tXuIslN1nhDqs2t6Jrz3BAoqvt4qIZzxvdbdcxWtHYU=
github.com/rabbitmq/amqp091-go v1.3.4/go.mod h1:ogQDLSOACsLPsIq0NpbtiifNZi2YOz0VTJ0kHRghqbM=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA=
Expand Down Expand Up @@ -1434,8 +1436,6 @@ github.com/stathat/consistent v1.0.0/go.mod h1:uajTPbgSygZBJ+V+0mY7meZ8i0XAcZs7A
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v1.0.0 h1:kuuDrUJFZL1QYL9hUNuCxNObNzB0bV/ZG5jV3RWAQgo=
github.com/streadway/amqp v1.0.0/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
27 changes: 20 additions & 7 deletions internal/component/kafka/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import (
)

type consumer struct {
k *Kafka
ready chan bool
once sync.Once
k *Kafka
ready chan bool
running chan struct{}
once sync.Once
}

func (consumer *consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error {
Expand Down Expand Up @@ -120,6 +121,12 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Close resources and reset synchronization primitives
k.closeSubscriptionResources()

topics := k.subscribeTopics.TopicList()
if len(topics) == 0 {
// Nothing to subscribe to
return nil
}

cg, err := sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config)
if err != nil {
return err
Expand All @@ -132,12 +139,11 @@ func (k *Kafka) Subscribe(ctx context.Context) error {

ready := make(chan bool)
k.consumer = consumer{
k: k,
ready: ready,
k: k,
ready: ready,
running: make(chan struct{}),
}

topics := k.subscribeTopics.TopicList()

go func() {
k.logger.Debugf("Subscribed and listening to topics: %s", topics)

Expand All @@ -153,6 +159,9 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Consume the requested topics
bo := backoff.WithContext(backoff.NewConstantBackOff(k.consumeRetryInterval), ctx)
innerErr := retry.NotifyRecover(func() error {
if ctxErr := ctx.Err(); ctxErr != nil {
return backoff.Permanent(ctxErr)
}
return k.cg.Consume(ctx, topics, &(k.consumer))
}, bo, func(err error, t time.Duration) {
k.logger.Errorf("Error consuming %v. Retrying...: %v", topics, err)
Expand All @@ -169,6 +178,8 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
if err != nil {
k.logger.Errorf("Error closing consumer group: %v", err)
}

close(k.consumer.running)
}()

<-ready
Expand All @@ -186,6 +197,8 @@ func (k *Kafka) closeSubscriptionResources() {
}

k.consumer.once.Do(func() {
// Wait for shutdown to be complete
<-k.consumer.running
close(k.consumer.ready)
k.consumer.once = sync.Once{}
})
Expand Down
93 changes: 66 additions & 27 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ type snsSqs struct {
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
subscriptions sync.Map

snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
ctx context.Context
cancel context.CancelFunc
pollerCtx context.Context
pollerCancel context.CancelFunc
backOffConfig retry.Config
pollerSemaphore chan struct{}
snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
ctx context.Context
cancel context.CancelFunc
pollerCtx context.Context
pollerCancel context.CancelFunc
backOffConfig retry.Config
pollerRunning chan struct{}
}

type sqsQueueInfo struct {
Expand Down Expand Up @@ -101,10 +101,10 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
}

return &snsSqs{
logger: l,
id: id,
topicsLock: sync.RWMutex{},
pollerSemaphore: make(chan struct{}, 1),
logger: l,
id: id,
topicsLock: sync.RWMutex{},
pollerRunning: make(chan struct{}, 1),
}
}

Expand Down Expand Up @@ -406,6 +406,22 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t
return *subscribeOutput.SubscriptionArn, nil
}

func (s *snsSqs) removeSnsSqsSubscription(parentCtx context.Context, subscriptionArn string) error {
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
_, err := s.snsClient.UnsubscribeWithContext(ctx, &sns.UnsubscribeInput{
SubscriptionArn: aws.String(subscriptionArn),
})
cancel()
if err != nil {
wrappedErr := fmt.Errorf("error unsubscribing to arn: %s %w", subscriptionArn, err)
s.logger.Error(wrappedErr)

return wrappedErr
}

return nil
}

func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) {
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
Expand Down Expand Up @@ -644,7 +660,7 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}

// Signal that the poller stopped
<-s.pollerSemaphore
<-s.pollerRunning
}

func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
Expand Down Expand Up @@ -760,10 +776,10 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context,
return nil
}

func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
// subscribers declare a topic ARN and declare a SQS queue to use
// these should be idempotent - queues should not be created if they exist.
topicArn, sanitizedName, err := s.getOrCreateTopic(s.ctx, req.Topic)
topicArn, sanitizedName, err := s.getOrCreateTopic(subscribeCtx, req.Topic)
if err != nil {
wrappedErr := fmt.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
s.logger.Error(wrappedErr)
Expand All @@ -773,7 +789,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)

// this is the ID of the application, it is supplied via runtime as "consumerID".
var queueInfo *sqsQueueInfo
queueInfo, err = s.getOrCreateQueue(s.ctx, s.metadata.sqsQueueName)
queueInfo, err = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsQueueName)
if err != nil {
wrappedErr := fmt.Errorf("error retrieving SQS queue: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -783,7 +799,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)

// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
// to prevent anyone but SNS to publish message to SQS.
err = s.restrictQueuePublishPolicyToOnlySNS(s.ctx, queueInfo, topicArn)
err = s.restrictQueuePublishPolicyToOnlySNS(subscribeCtx, queueInfo, topicArn)
if err != nil {
wrappedErr := fmt.Errorf("error setting sns-sqs subscription policy: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -796,15 +812,15 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
var derr error

if len(s.metadata.sqsDeadLettersQueueName) > 0 {
deadLettersQueueInfo, derr = s.getOrCreateQueue(s.ctx, s.metadata.sqsDeadLettersQueueName)
deadLettersQueueInfo, derr = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsDeadLettersQueueName)
if derr != nil {
wrappedErr := fmt.Errorf("error retrieving SQS dead-letter queue: %w", err)
s.logger.Error(wrappedErr)

return wrappedErr
}

err = s.setDeadLettersQueueAttributes(s.ctx, queueInfo, deadLettersQueueInfo)
err = s.setDeadLettersQueueAttributes(subscribeCtx, queueInfo, deadLettersQueueInfo)
if err != nil {
wrappedErr := fmt.Errorf("error creating dead-letter queue: %w", err)
s.logger.Error(wrappedErr)
Expand All @@ -814,7 +830,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
}

// subscription creation is idempotent. Subscriptions are unique by topic/queue.
_, err = s.getOrCreateSnsSqsSubscription(s.ctx, queueInfo.arn, topicArn)
subscriptionArn, err := s.getOrCreateSnsSqsSubscription(subscribeCtx, queueInfo.arn, topicArn)
if err != nil {
wrappedErr := fmt.Errorf("error subscribing topic: %s, to queue: %s, with error: %w", topicArn, queueInfo.arn, err)
s.logger.Error(wrappedErr)
Expand All @@ -828,12 +844,12 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
s.topicHandlers[sanitizedName] = topicHandler{
topicName: req.Topic,
handler: handler,
ctx: s.ctx,
ctx: subscribeCtx,
}

// Start the poller for the queue if it's not running already
select {
case s.pollerSemaphore <- struct{}{}:
case s.pollerRunning <- struct{}{}:
// If inserting in the channel succeeds, then it's not running already
// Use a context that is tied to the background context
s.pollerCtx, s.pollerCancel = context.WithCancel(s.ctx)
Expand All @@ -842,6 +858,29 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
// Do nothing, it means the poller is already running
}

// Watch for subscription context cancelation to remove this subscription
go func() {
<-subscribeCtx.Done()

s.topicsLock.Lock()
defer s.topicsLock.Unlock()

// Remove the handler
delete(s.topicHandlers, sanitizedName)

// If we can perform management operations, remove the subscription entirely
if !s.metadata.disableEntityManagement {
// Use a background context because subscribeCtx is canceled already
// Error is logged already
_ = s.removeSnsSqsSubscription(s.ctx, subscriptionArn)
}

// If we don't have any topic left, close the poller
if len(s.topicHandlers) == 0 {
s.pollerCancel()
}
}()

return nil
}

Expand Down
Loading

0 comments on commit 704f4dd

Please sign in to comment.