Skip to content
Permalink
Browse files
fix(pubsublite): close clients after publisher and subscriber have te…
…rminated (#3512)

Close gapic clients in SubscriberClient, PublisherClient and tests in order to close client connections. Create new client connections to the mock server in unit tests.
  • Loading branch information
tmdiep committed Jan 19, 2021
1 parent 9d8fd2b commit 72d2affb957cea7b6a223b108d0fe67c5635b25c
@@ -27,7 +27,7 @@ import (
)

func newTestAdminClient(t *testing.T) *AdminClient {
admin, err := NewAdminClient(context.Background(), "us-central1", testClientOpts...)
admin, err := NewAdminClient(context.Background(), "us-central1", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
@@ -85,6 +85,7 @@ func TestAdminTopicCRUD(t *testing.T) {
defer mockServer.OnTestEnd()

admin := newTestAdminClient(t)
defer admin.Close()

if gotConfig, err := admin.CreateTopic(ctx, topicConfig); err != nil {
t.Errorf("CreateTopic() got err: %v", err)
@@ -172,6 +173,7 @@ func TestAdminListTopics(t *testing.T) {
defer mockServer.OnTestEnd()

admin := newTestAdminClient(t)
defer admin.Close()

var gotTopicConfigs []*TopicConfig
topicIt := admin.Topics(ctx, locationPath)
@@ -227,6 +229,7 @@ func TestAdminListTopicSubscriptions(t *testing.T) {
defer mockServer.OnTestEnd()

admin := newTestAdminClient(t)
defer admin.Close()

var gotSubscriptions []string
subsPathIt := admin.TopicSubscriptions(ctx, topicPath)
@@ -290,6 +293,7 @@ func TestAdminSubscriptionCRUD(t *testing.T) {
defer mockServer.OnTestEnd()

admin := newTestAdminClient(t)
defer admin.Close()

if gotConfig, err := admin.CreateSubscription(ctx, subscriptionConfig); err != nil {
t.Errorf("CreateSubscription() got err: %v", err)
@@ -362,6 +366,7 @@ func TestAdminListSubscriptions(t *testing.T) {
defer mockServer.OnTestEnd()

admin := newTestAdminClient(t)
defer admin.Close()

var gotSubscriptionConfigs []*SubscriptionConfig
subscriptionIt := admin.Subscriptions(ctx, locationPath)
@@ -66,23 +66,13 @@ func NewServer() (*Server, error) {
return &Server{LiteServer: liteServer, gRPCServer: srv}, nil
}

// NewServerWithConn creates a new mock Pub/Sub Lite server along with client
// options to connect to it.
func NewServerWithConn() (*Server, []option.ClientOption) {
testServer, err := NewServer()
// ClientConn creates a client connection to the gRPC test server.
func (s *Server) ClientConn() option.ClientOption {
conn, err := grpc.Dial(s.gRPCServer.Addr, grpc.WithInsecure())
if err != nil {
log.Fatal(err)
}
conn, err := grpc.Dial(testServer.Addr(), grpc.WithInsecure())
if err != nil {
log.Fatal(err)
}
return testServer, []option.ClientOption{option.WithGRPCConn(conn)}
}

// Addr returns the address that the server is listening on.
func (s *Server) Addr() string {
return s.gRPCServer.Addr
return option.WithGRPCConn(conn)
}

// Close shuts down the server and releases all resources.
@@ -75,7 +75,7 @@ type testAssigner struct {

func newTestAssigner(t *testing.T, subscription string) *testAssigner {
ctx := context.Background()
assignmentClient, err := newPartitionAssignmentClient(ctx, "ignored", testClientOpts...)
assignmentClient, err := newPartitionAssignmentClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
@@ -89,7 +89,7 @@ func newTestAssigner(t *testing.T, subscription string) *testAssigner {
t.Fatal(err)
}
ta.asn = asn
ta.initAndStart(t, ta.asn, "Assigner")
ta.initAndStart(t, ta.asn, "Assigner", assignmentClient)
return ta
}

@@ -30,15 +30,15 @@ type testCommitter struct {

func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
ctx := context.Background()
cursorClient, err := newCursorClient(ctx, "ignored", testClientOpts...)
cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}

tc := &testCommitter{
cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
}
tc.initAndStart(t, tc.cmt, "Committer")
tc.initAndStart(t, tc.cmt, "Committer", cursorClient)
return tc
}

@@ -15,25 +15,27 @@ package wire

import (
"flag"
"log"
"os"
"testing"

"cloud.google.com/go/pubsublite/internal/test"
"google.golang.org/api/option"
)

var (
// Initialized in TestMain.
mockServer test.MockServer
testClientOpts []option.ClientOption
testServer *test.Server
mockServer test.MockServer
)

func TestMain(m *testing.M) {
flag.Parse()

testServer, clientOpts := test.NewServerWithConn()
var err error
if testServer, err = test.NewServer(); err != nil {
log.Fatal(err)
}
mockServer = testServer.LiteServer
testClientOpts = clientOpts

exit := m.Run()
testServer.Close()
@@ -47,15 +47,15 @@ func (tw *testPartitionCountWatcher) UpdatePartitionCount() {

func newTestPartitionCountWatcher(t *testing.T, topicPath string, settings PublishSettings) *testPartitionCountWatcher {
ctx := context.Background()
adminClient, err := NewAdminClient(ctx, "ignored", testClientOpts...)
adminClient, err := NewAdminClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
tw := &testPartitionCountWatcher{
t: t,
}
tw.watcher = newPartitionCountWatcher(ctx, adminClient, testPublishSettings(), topicPath, tw.onCountChanged)
tw.initAndStart(t, tw.watcher, "PartitionCountWatcher")
tw.initAndStart(t, tw.watcher, "PartitionCountWatcher", adminClient)
return tw
}

@@ -279,6 +279,7 @@ func (pp *singlePartitionPublisher) unsafeCheckDone() {
// count, but not decreasing.
type routingPublisher struct {
// Immutable after creation.
clients apiClients
msgRouterFactory *messageRouterFactory
pubFactory *singlePartitionPublisherFactory
partitionWatcher *partitionCountWatcher
@@ -290,8 +291,9 @@ type routingPublisher struct {
compositeService
}

func newRoutingPublisher(adminClient *vkit.AdminClient, msgRouterFactory *messageRouterFactory, pubFactory *singlePartitionPublisherFactory) *routingPublisher {
func newRoutingPublisher(allClients apiClients, adminClient *vkit.AdminClient, msgRouterFactory *messageRouterFactory, pubFactory *singlePartitionPublisherFactory) *routingPublisher {
pub := &routingPublisher{
clients: allClients,
msgRouterFactory: msgRouterFactory,
pubFactory: pubFactory,
}
@@ -357,6 +359,12 @@ func (rp *routingPublisher) routeToPublisher(msg *pb.PubSubMessage) (*singlePart
return rp.publishers[partition], nil
}

func (rp *routingPublisher) WaitStopped() error {
err := rp.compositeService.WaitStopped()
rp.clients.Close()
return err
}

// Publisher is the client interface exported from this package for publishing
// messages.
type Publisher interface {
@@ -385,6 +393,7 @@ func NewPublisher(ctx context.Context, settings PublishSettings, region, topicPa
if err != nil {
return nil, err
}
allClients := apiClients{pubClient, adminClient}

msgRouterFactory := newMessageRouterFactory(rand.New(rand.NewSource(time.Now().UnixNano())))
pubFactory := &singlePartitionPublisherFactory{
@@ -393,5 +402,5 @@ func NewPublisher(ctx context.Context, settings PublishSettings, region, topicPa
settings: settings,
topicPath: topicPath,
}
return newRoutingPublisher(adminClient, msgRouterFactory, pubFactory), nil
return newRoutingPublisher(allClients, adminClient, msgRouterFactory, pubFactory), nil
}
@@ -47,7 +47,7 @@ type testPartitionPublisher struct {

func newTestSinglePartitionPublisher(t *testing.T, topic topicPartition, settings PublishSettings) *testPartitionPublisher {
ctx := context.Background()
pubClient, err := newPublisherClient(ctx, "ignored", testClientOpts...)
pubClient, err := newPublisherClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
@@ -61,7 +61,7 @@ func newTestSinglePartitionPublisher(t *testing.T, topic topicPartition, setting
tp := &testPartitionPublisher{
pub: pubFactory.New(topic.Partition),
}
tp.initAndStart(t, tp.pub, "Publisher")
tp.initAndStart(t, tp.pub, "Publisher", pubClient)
return tp
}

@@ -506,14 +506,15 @@ type testRoutingPublisher struct {

func newTestRoutingPublisher(t *testing.T, topicPath string, settings PublishSettings, fakeSourceVal int64) *testRoutingPublisher {
ctx := context.Background()
pubClient, err := newPublisherClient(ctx, "ignored", testClientOpts...)
pubClient, err := newPublisherClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
adminClient, err := NewAdminClient(ctx, "ignored", testClientOpts...)
adminClient, err := NewAdminClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
allClients := apiClients{pubClient, adminClient}

source := &test.FakeSource{Ret: fakeSourceVal}
msgRouterFactory := newMessageRouterFactory(rand.New(source))
@@ -523,7 +524,7 @@ func newTestRoutingPublisher(t *testing.T, topicPath string, settings PublishSet
settings: settings,
topicPath: topicPath,
}
pub := newRoutingPublisher(adminClient, msgRouterFactory, pubFactory)
pub := newRoutingPublisher(allClients, adminClient, msgRouterFactory, pubFactory)
pub.Start()
return &testRoutingPublisher{t: t, pub: pub}
}
@@ -138,6 +138,21 @@ func defaultClientOptions(region string) []option.ClientOption {
}
}

type apiClient interface {
Close() error
}

type apiClients []apiClient

func (ac apiClients) Close() (retErr error) {
for _, c := range ac {
if err := c.Close(); retErr == nil {
retErr = err
}
}
return
}

// NewAdminClient creates a new gapic AdminClient for a region.
func NewAdminClient(ctx context.Context, region string, opts ...option.ClientOption) (*vkit.AdminClient, error) {
options := append(defaultClientOptions(region), opts...)
@@ -33,14 +33,16 @@ type serviceTestProxy struct {
t *testing.T
service service
name string
clients apiClients
started chan struct{}
terminated chan struct{}
}

func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string) {
func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string, clients ...apiClient) {
sp.t = t
sp.service = s
sp.name = name
sp.clients = clients
sp.started = make(chan struct{})
sp.terminated = make(chan struct{})
s.AddStatusChangeReceiver(nil, sp.onStatusChange)
@@ -65,6 +67,7 @@ func (sp *serviceTestProxy) StartError() error {
case <-time.After(serviceTestWaitTimeout):
return fmt.Errorf("%s did not start within %v", sp.name, serviceTestWaitTimeout)
case <-sp.terminated:
sp.clients.Close()
return sp.service.Error()
case <-sp.started:
return sp.service.Error()
@@ -77,6 +80,7 @@ func (sp *serviceTestProxy) FinalError() error {
case <-time.After(serviceTestWaitTimeout):
return fmt.Errorf("%s did not terminate within %v", sp.name, serviceTestWaitTimeout)
case <-sp.terminated:
sp.clients.Close()
return sp.service.Error()
}
}
@@ -49,7 +49,7 @@ type testStreamHandler struct {

func newTestStreamHandler(t *testing.T, timeout time.Duration) *testStreamHandler {
ctx := context.Background()
pubClient, err := newPublisherClient(ctx, "ignored", testClientOpts...)
pubClient, err := newPublisherClient(ctx, "ignored", testServer.ClientConn())
if err != nil {
t.Fatal(err)
}
@@ -105,6 +105,11 @@ func (sh *testStreamHandler) initialRequest() (interface{}, initialResponseRequi

func (sh *testStreamHandler) onStreamStatusChange(status streamStatus) {
sh.statuses <- status

// Close connections.
if status == streamTerminated {
sh.pubClient.Close()
}
}

func (sh *testStreamHandler) onResponse(response interface{}) {

0 comments on commit 72d2aff

Please sign in to comment.