Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions pubsub/google/cloud/pubsub_v1/subscriber/_protocol/bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def __init__(self, start_rpc, initial_request=None):
self._request_queue = queue.Queue()
self._request_generator = None
self._is_active = False
self.call = None
self._callbacks = []
self.call = None

def add_done_callback(self, callback):
"""Adds a callback that will be called when the RPC terminates.
Expand Down Expand Up @@ -311,14 +311,25 @@ def __init__(self, start_rpc, should_recover, initial_request=None):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request)
self._should_recover = should_recover
self._operational_lock = threading.Lock()
self._finalized = False
self._finalize_lock = threading.Lock()

def _finalize(self, result):
with self._finalize_lock:
if self._finalized:
return

for callback in self._callbacks:
callback(result)

self._finalized = True

def _on_call_done(self, future):
# Unlike the base class, we only execute the callbacks on a terminal
# error, not for errors that we can recover from. Note that grpc's
# "future" here is also a grpc.RpcError.
if not self._should_recover(future):
for callback in self._callbacks:
callback(future)
self._finalize(future)

def _reopen(self):
with self._operational_lock:
Expand All @@ -330,7 +341,14 @@ def _reopen(self):
# Request generator should exit cleanly since the RPC its bound to
# has exited.
self.request_generator = None
self.open()

try:
self.open()
# If re-opening fails, consider this a terminal error and finalize
# the object.
except Exception as exc:
self._finalize(exc)
raise

def _recoverable(self, method, *args, **kwargs):
"""Wraps a method to recover the stream and retry on error.
Expand Down
50 changes: 46 additions & 4 deletions pubsub/tests/unit/pubsub_v1/subscriber/test_bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def test_send_recover(self):
grpc.StreamStreamMultiCallable,
instance=True,
side_effect=[call_1, call_2])
should_recover = mock.Mock(autospec=['__call__'], return_value=True)
should_recover = mock.Mock(spec=['__call__'], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

bidi_rpc.open()
Expand All @@ -331,7 +331,7 @@ def test_send_failure(self):
grpc.StreamStreamMultiCallable,
instance=True,
return_value=call)
should_recover = mock.Mock(autospec=['__call__'], return_value=False)
should_recover = mock.Mock(spec=['__call__'], return_value=False)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

bidi_rpc.open()
Expand All @@ -355,7 +355,7 @@ def test_recv_recover(self):
grpc.StreamStreamMultiCallable,
instance=True,
side_effect=[call_1, call_2])
should_recover = mock.Mock(autospec=['__call__'], return_value=True)
should_recover = mock.Mock(spec=['__call__'], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

bidi_rpc.open()
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_recv_failure(self):
grpc.StreamStreamMultiCallable,
instance=True,
return_value=call)
should_recover = mock.Mock(autospec=['__call__'], return_value=False)
should_recover = mock.Mock(spec=['__call__'], return_value=False)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

bidi_rpc.open()
Expand All @@ -426,6 +426,48 @@ def test_recv_failure(self):
assert bidi_rpc.is_active is False
assert call.cancelled is True

def test_reopen_failure_on_rpc_restart(self):
error1 = ValueError('1')
error2 = ValueError('2')
call = CallStub([error1])
# Invoking start RPC a second time will trigger an error.
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable,
instance=True,
side_effect=[call, error2])
should_recover = mock.Mock(spec=['__call__'], return_value=True)
callback = mock.Mock(spec=['__call__'])

bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
bidi_rpc.add_done_callback(callback)

bidi_rpc.open()

with pytest.raises(ValueError) as exc_info:
bidi_rpc.recv()

assert exc_info.value == error2
should_recover.assert_called_once_with(error1)
assert bidi_rpc.call is None
assert bidi_rpc.is_active is False
callback.assert_called_once_with(error2)

def test_finalize_idempotent(self):
error1 = ValueError('1')
error2 = ValueError('2')
callback = mock.Mock(spec=['__call__'])
should_recover = mock.Mock(spec=['__call__'], return_value=False)

bidi_rpc = bidi.ResumableBidiRpc(
mock.sentinel.start_rpc, should_recover)

bidi_rpc.add_done_callback(callback)

bidi_rpc._on_call_done(error1)
bidi_rpc._on_call_done(error2)

callback.assert_called_once_with(error1)


class TestBackgroundConsumer(object):
def test_consume_once_then_exit(self):
Expand Down