Skip to content
Permalink
Browse files
fix(pubsublite): wire user context to api clients (#4318)
  • Loading branch information
tmdiep committed Jun 25, 2021
1 parent b34783a commit ae34396b1a2a970a0d871cd5496527294f3310d4
Showing with 66 additions and 26 deletions.
  1. +50 −10 pubsublite/pscompat/integration_test.go
  2. +1 −4 pubsublite/pscompat/publisher.go
  3. +13 −10 pubsublite/pscompat/subscriber.go
  4. +2 −2 pubsublite/pscompat/subscriber_test.go
@@ -30,7 +30,9 @@ import (
"cloud.google.com/go/pubsublite/internal/wire"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"

vkit "cloud.google.com/go/pubsublite/apiv1"
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
@@ -167,7 +169,7 @@ func partitionNumbers(partitionCount int) []int {

func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgs ...*pubsub.Message) {
ctx := context.Background()
publisher := publisherClient(ctx, t, settings, topic)
publisher := publisherClient(context.Background(), t, settings, topic)
defer publisher.Stop()

var pubResults []*pubsub.PublishResult
@@ -179,7 +181,7 @@ func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPat

func publishPrefixedMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgPrefix string, msgCount, msgSize int) []string {
ctx := context.Background()
publisher := publisherClient(ctx, t, settings, topic)
publisher := publisherClient(context.Background(), t, settings, topic)
defer publisher.Stop()

orderingSender := test.NewOrderingSender()
@@ -271,7 +273,7 @@ func receiveAllMessages(t *testing.T, msgTracker *test.MsgTracker, settings Rece
}
}

subscriber := subscriberClient(cctx, t, settings, subscription)
subscriber := subscriberClient(context.Background(), t, settings, subscription)
if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
}
@@ -298,7 +300,7 @@ func receiveAndVerifyMessage(t *testing.T, want *pubsub.Message, settings Receiv
}
}

subscriber := subscriberClient(cctx, t, settings, subscription)
subscriber := subscriberClient(context.Background(), t, settings, subscription)
if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
}
@@ -383,7 +385,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
}
got.Nack()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)
if gotErr := subscriber.Receive(cctx, messageReceiver1); !test.ErrorEqual(gotErr, errNackCalled) {
t.Errorf("Receive() got err: (%v), want err: (%v)", gotErr, errNackCalled)
}
@@ -400,7 +402,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
}
return fmt.Errorf("Received unexpected message: %q", truncateMsg(string(msg.Data)))
}
subscriber = subscriberClient(cctx, t, customSettings, subscriptionPath)
subscriber = subscriberClient(context.Background(), t, customSettings, subscriptionPath)

messageReceiver2 := func(ctx context.Context, got *pubsub.Message) {
got.Nack()
@@ -434,7 +436,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
got.Ack()
stopSubscriber()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)

// The message receiver stops the subscriber after receiving the first
// message. However, the subscriber isn't guaranteed to immediately stop, so
@@ -485,7 +487,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
// next test, which would receive an incorrect message.
got.Ack()
}
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath)

if err := subscriber.Receive(cctx, messageReceiver); err != nil {
t.Errorf("Receive() got err: %v", err)
@@ -539,6 +541,44 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
receiveAllMessages(t, msgTracker, recvSettings, subscriptionPath)
})

// Verifies that cancelling the context passed to NewPublisherClient can shut
// down the publisher.
t.Run("CancelPublisherContext", func(t *testing.T) {
cctx, cancel := context.WithCancel(context.Background())
publisher := publisherClient(cctx, t, DefaultPublishSettings, topicPath)

cancel()

wantCode := codes.Canceled
result := publisher.Publish(ctx, &pubsub.Message{Data: []byte("cancel_publisher_context")})
if _, err := result.Get(ctx); !test.ErrorHasCode(err, wantCode) {
t.Errorf("Publish() got err: %v, want code: %v", err, wantCode)
}
if err := xerrors.Unwrap(publisher.Error()); !test.ErrorHasCode(err, wantCode) {
t.Errorf("Error() got err: %v, want code: %v", err, wantCode)
}
publisher.Stop()
})

// Verifies that cancelling the context passed to NewSubscriberClient can shut
// down the subscriber.
t.Run("CancelSubscriberContext", func(t *testing.T) {
msg := &pubsub.Message{Data: []byte("cancel_subscriber_context")}
publishMessages(t, DefaultPublishSettings, topicPath, msg)

cctx, cancel := context.WithCancel(context.Background())
subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath)

subsErr := subscriber.Receive(context.Background(), func(ctx context.Context, got *pubsub.Message) {
got.Ack()
cancel()
})

if err, wantCode := xerrors.Unwrap(subsErr), codes.Canceled; !test.ErrorHasCode(err, wantCode) {
t.Errorf("Receive() got err: %v, want code: %v", err, wantCode)
}
})

// NOTE: This should be the last test case.
// Verifies that increasing the number of topic partitions is handled
// correctly by publishers.
@@ -547,7 +587,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) {
const pollPeriod = 5 * time.Second
pubSettings := DefaultPublishSettings
pubSettings.configPollPeriod = pollPeriod // Poll updates more frequently
publisher := publisherClient(ctx, t, pubSettings, topicPath)
publisher := publisherClient(context.Background(), t, pubSettings, topicPath)
defer publisher.Stop()

// Update the number of partitions.
@@ -661,7 +701,7 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) {
for i := 0; i < subscriberCount; i++ {
// Subscribers must be started in a goroutine as Receive() blocks.
g.Go(func() error {
subscriber := subscriberClient(cctx, t, DefaultReceiveSettings, subscriptionPath)
subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath)
err := subscriber.Receive(cctx, messageReceiver)
if err != nil {
t.Errorf("Receive() got err: %v", err)
@@ -82,10 +82,7 @@ func NewPublisherClientWithSettings(ctx context.Context, topic string, settings
return nil, err
}

// Note: ctx is not used to create the wire publisher, because if it is
// cancelled, the publisher will not be able to perform graceful shutdown
// (e.g. flush pending messages).
wirePub, err := wire.NewPublisher(context.Background(), settings.toWireSettings(), region, topic, opts...)
wirePub, err := wire.NewPublisher(ctx, settings.toWireSettings(), region, topic, opts...)
if err != nil {
return nil, err
}
@@ -72,7 +72,7 @@ func (ah *pslAckHandler) OnNack() {
// wireSubscriberFactory is a factory for creating wire subscribers, which can
// be overridden with a mock in unit tests.
type wireSubscriberFactory interface {
New(wire.MessageReceiverFunc) (wire.Subscriber, error)
New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error)
}

type wireSubscriberFactoryImpl struct {
@@ -82,8 +82,8 @@ type wireSubscriberFactoryImpl struct {
options []option.ClientOption
}

func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...)
func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...)
}

type messageReceiverFunc = func(context.Context, *pubsub.Message)
@@ -103,19 +103,20 @@ type subscriberInstance struct {
err error
}

func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) {
recvCtx, recvCancel := context.WithCancel(ctx)
func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) {
recvCtx, recvCancel := context.WithCancel(recvCtx)
subInstance := &subscriberInstance{
settings: settings,
recvCtx: recvCtx,
recvCancel: recvCancel,
receiver: receiver,
}

// Note: ctx is not used to create the wire subscriber, because if it is
// cancelled, the subscriber will not be able to perform graceful shutdown
// (e.g. process acks and commit the final cursor offset).
wireSub, err := factory.New(subInstance.onMessage)
// Note: The context from Receive (recvCtx) should not be used, as when it is
// cancelled, the gRPC streams will be disconnected and the subscriber will
// not be able to process acks and commit the final cursor offset. Use the
// context from NewSubscriberClient (clientCtx) instead.
wireSub, err := factory.New(clientCtx, subInstance.onMessage)
if err != nil {
return nil, err
}
@@ -229,6 +230,7 @@ func (si *subscriberInstance) Wait(ctx context.Context) error {
// See https://cloud.google.com/pubsub/lite/docs/subscribing for more
// information about receiving messages.
type SubscriberClient struct {
clientCtx context.Context
settings ReceiveSettings
wireSubFactory wireSubscriberFactory

@@ -265,6 +267,7 @@ func NewSubscriberClientWithSettings(ctx context.Context, subscription string, s
options: opts,
}
subClient := &SubscriberClient{
clientCtx: ctx,
settings: settings,
wireSubFactory: factory,
}
@@ -303,7 +306,7 @@ func (s *SubscriberClient) Receive(ctx context.Context, f func(context.Context,
defer s.setReceiveActive(false)

// Initialize a subscriber instance.
subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f)
subInstance, err := newSubscriberInstance(ctx, s.clientCtx, s.wireSubFactory, s.settings, f)
if err != nil {
return err
}
@@ -113,7 +113,7 @@ func (ms *mockWireSubscriber) WaitStopped() error {

type mockWireSubscriberFactory struct{}

func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) {
return &mockWireSubscriber{
receiver: receiver,
msgsC: make(chan *wire.ReceivedMessage, 10),
@@ -122,7 +122,7 @@ func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire
}

func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance {
sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver)
sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver)
return sub
}

0 comments on commit ae34396

Please sign in to comment.