Skip to content
Permalink
Browse files
fix(pubsublite): ensure timeout settings are respected (#4329)
Fixes for retryableStream and partitionCountWatcher to ensure PublisherSettings.Timeout and ReceiveSettings.Timeout are respected.
  • Loading branch information
tmdiep committed Jun 28, 2021
1 parent 5ac09d9 commit e75262cf5eba845271965eab3c28c0a23bec14c4
@@ -16,6 +16,9 @@ package wire
import (
"context"
"fmt"
"time"

"golang.org/x/xerrors"

vkit "cloud.google.com/go/pubsublite/apiv1"
gax "github.com/googleapis/gax-go/v2"
@@ -30,11 +33,13 @@ type partitionCountReceiver func(partitionCount int)
// topic and notifies a receiver if it changes.
type partitionCountWatcher struct {
// Immutable after creation.
ctx context.Context
adminClient *vkit.AdminClient
topicPath string
receiver partitionCountReceiver
callOption gax.CallOption
ctx context.Context
adminClient *vkit.AdminClient
topicPath string
receiver partitionCountReceiver
callOption gax.CallOption
initialTimeout time.Duration
pollPeriod time.Duration

// Fields below must be guarded with mu.
partitionCount int
@@ -47,11 +52,13 @@ func newPartitionCountWatcher(ctx context.Context, adminClient *vkit.AdminClient
settings PublishSettings, topicPath string, receiver partitionCountReceiver) *partitionCountWatcher {

p := &partitionCountWatcher{
ctx: ctx,
adminClient: adminClient,
topicPath: topicPath,
receiver: receiver,
callOption: resourceExhaustedRetryer(),
ctx: ctx,
adminClient: adminClient,
topicPath: topicPath,
receiver: receiver,
callOption: resourceExhaustedRetryer(),
initialTimeout: settings.Timeout,
pollPeriod: settings.ConfigPollPeriod,
}

// Polling the topic partition count can be disabled in settings if the period
@@ -88,8 +95,17 @@ func (p *partitionCountWatcher) updatePartitionCount() {
p.mu.Unlock()

newPartitionCount, err := func() (int, error) {
// Ensure the first update respects PublishSettings.Timeout.
timeout := p.initialTimeout
if prevPartitionCount > 0 {
timeout = p.pollPeriod
}
cctx, cancel := context.WithCancel(p.ctx)
rt := newRequestTimer(timeout, cancel, ErrBackendUnavailable)

req := &pb.GetTopicPartitionsRequest{Name: p.topicPath}
resp, err := p.adminClient.GetTopicPartitions(p.ctx, req, p.callOption)
resp, err := p.adminClient.GetTopicPartitions(cctx, req, p.callOption)
rt.Stop()

p.mu.Lock()
defer p.mu.Unlock()
@@ -105,7 +121,7 @@ func (p *partitionCountWatcher) updatePartitionCount() {
// TODO: Log the error.
return p.partitionCount, nil
}
err = fmt.Errorf("pubsublite: failed to update topic partition count: %v", err)
err = xerrors.Errorf("pubsublite: failed to update topic partition count: %w", rt.ResolveError(err))
p.unsafeInitiateShutdown(err)
return 0, err
}
@@ -16,6 +16,7 @@ package wire
import (
"context"
"testing"
"time"

"cloud.google.com/go/internal/testutil"
"cloud.google.com/go/pubsublite/internal/test"
@@ -54,7 +55,7 @@ func newTestPartitionCountWatcher(t *testing.T, topicPath string, settings Publi
tw := &testPartitionCountWatcher{
t: t,
}
tw.watcher = newPartitionCountWatcher(ctx, adminClient, testPublishSettings(), topicPath, tw.onCountChanged)
tw.watcher = newPartitionCountWatcher(ctx, adminClient, settings, topicPath, tw.onCountChanged)
tw.initAndStart(t, tw.watcher, "PartitionCountWatcher", adminClient)
return tw
}
@@ -95,6 +96,59 @@ func TestPartitionCountWatcherZeroPartitionCountFails(t *testing.T) {
watcher.VerifyCounts(nil)
}

func TestPartitionCountWatcherInitialRequestTimesOut(t *testing.T) {
const topic = "projects/123456/locations/us-central1-b/topics/my-topic"

verifiers := test.NewVerifiers(t)
barrier := verifiers.GlobalVerifier.PushWithBarrier(topicPartitionsReq(topic), topicPartitionsResp(1), nil)

mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()

settings := testPublishSettings()
settings.Timeout = 20 * time.Millisecond // Set low timeout for initial request
watcher := newTestPartitionCountWatcher(t, topic, settings)

if gotErr, wantErr := watcher.StartError(), ErrBackendUnavailable; !test.ErrorEqual(gotErr, wantErr) {
t.Errorf("Start() got err: (%v), want err: (%v)", gotErr, wantErr)
}
barrier.Release()
watcher.VerifyCounts(nil)
}

func TestPartitionCountWatcherUpdateLongerTimeouts(t *testing.T) {
const topic = "projects/123456/locations/us-central1-b/topics/my-topic"
wantPartitionCount1 := 1
wantPartitionCount2 := 2

verifiers := test.NewVerifiers(t)
verifiers.GlobalVerifier.Push(topicPartitionsReq(topic), topicPartitionsResp(wantPartitionCount1), nil)
// Barrier used to delay response.
barrier := verifiers.GlobalVerifier.PushWithBarrier(topicPartitionsReq(topic), topicPartitionsResp(wantPartitionCount2), nil)

mockServer.OnTestStart(verifiers)
defer mockServer.OnTestEnd()

watcher := newTestPartitionCountWatcher(t, topic, testPublishSettings())
if gotErr := watcher.StartError(); gotErr != nil {
t.Errorf("Start() got err: (%v)", gotErr)
}
watcher.VerifyCounts([]int{wantPartitionCount1})

// Override the initial timeout after the first request to verify that it
// isn't used. If set at creation, the first request will fail.
const timeout = time.Millisecond
watcher.watcher.initialTimeout = timeout
go func() {
barrier.ReleaseAfter(func() {
time.Sleep(5 * timeout)
})
}()
watcher.UpdatePartitionCount()
watcher.VerifyCounts([]int{wantPartitionCount1, wantPartitionCount2})
watcher.StopVerifyNoError()
}

func TestPartitionCountWatcherPartitionCountUnchanged(t *testing.T) {
const topic = "projects/123456/locations/us-central1-b/topics/my-topic"
wantPartitionCount1 := 4
@@ -34,8 +34,9 @@ func testPublishSettings() PublishSettings {
// Send messages with minimal delay to speed up tests.
settings.DelayThreshold = time.Millisecond
settings.Timeout = 5 * time.Second
// Disable topic partition count background polling.
settings.ConfigPollPeriod = 0
// Set long poll period to prevent background update, but still have non-zero
// request timeout.
settings.ConfigPollPeriod = 1 * time.Minute
return settings
}

@@ -0,0 +1,78 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and

package wire

import (
"sync"
"time"
)

type requestTimerStatus int

const (
requestTimerNew requestTimerStatus = iota
requestTimerStopped
requestTimerTriggered
)

// requestTimer bounds the duration of a request and executes `onTimeout` if
// the timer is triggered.
type requestTimer struct {
onTimeout func()
timeoutErr error
timer *time.Timer
mu sync.Mutex
status requestTimerStatus
}

func newRequestTimer(duration time.Duration, onTimeout func(), timeoutErr error) *requestTimer {
rt := &requestTimer{
onTimeout: onTimeout,
timeoutErr: timeoutErr,
status: requestTimerNew,
}
rt.timer = time.AfterFunc(duration, rt.onTriggered)
return rt
}

func (rt *requestTimer) onTriggered() {
rt.mu.Lock()
defer rt.mu.Unlock()
if rt.status == requestTimerNew {
rt.status = requestTimerTriggered
rt.onTimeout()
}
}

// Stop should be called upon a successful request to prevent the timer from
// expiring.
func (rt *requestTimer) Stop() {
rt.mu.Lock()
defer rt.mu.Unlock()
if rt.status == requestTimerNew {
rt.status = requestTimerStopped
rt.timer.Stop()
}
}

// ResolveError returns `timeoutErr` if the timer triggered, or otherwise
// `originalErr`.
func (rt *requestTimer) ResolveError(originalErr error) error {
rt.mu.Lock()
defer rt.mu.Unlock()
if rt.status == requestTimerTriggered {
return rt.timeoutErr
}
return originalErr
}
@@ -0,0 +1,61 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and

package wire

import (
"errors"
"testing"
"time"

"cloud.google.com/go/pubsublite/internal/test"
)

func TestRequestTimerStop(t *testing.T) {
const timeout = 5 * time.Millisecond
onTimeout := func() {
t.Error("onTimeout should not be called")
}

rt := newRequestTimer(timeout, onTimeout, errors.New("unused"))
rt.Stop()
time.Sleep(2 * timeout)

if err := rt.ResolveError(nil); err != nil {
t.Errorf("ResolveError() got gotErr: %v", err)
}
wantErr := errors.New("original error")
if gotErr := rt.ResolveError(wantErr); !test.ErrorEqual(gotErr, wantErr) {
t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, wantErr)
}
}

func TestRequestTimerExpires(t *testing.T) {
const timeout = 5 * time.Millisecond
timeoutErr := errors.New("on timeout")

expired := test.NewCondition("request timer expired")
onTimeout := func() {
expired.SetDone()
}

rt := newRequestTimer(timeout, onTimeout, timeoutErr)
expired.WaitUntilDone(t, serviceTestWaitTimeout)

if gotErr := rt.ResolveError(nil); !test.ErrorEqual(gotErr, timeoutErr) {
t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, timeoutErr)
}
if gotErr := rt.ResolveError(errors.New("ignored")); !test.ErrorEqual(gotErr, timeoutErr) {
t.Errorf("ResolveError() got err: %v, want err: %v", gotErr, timeoutErr)
}
}

0 comments on commit e75262c

Please sign in to comment.