diff --git a/packages/google-cloud-bigquery-storage/google/cloud/bigquery_storage_v1/writer.py b/packages/google-cloud-bigquery-storage/google/cloud/bigquery_storage_v1/writer.py index 4f0da5e1d933..b7fe68c26f9a 100644 --- a/packages/google-cloud-bigquery-storage/google/cloud/bigquery_storage_v1/writer.py +++ b/packages/google-cloud-bigquery-storage/google/cloud/bigquery_storage_v1/writer.py @@ -43,6 +43,17 @@ # but still work for all expected requests _DEFAULT_TIMEOUT = 600 +# Errors that indicate a transient connectivity failure. The stream can be +# transparently reconnected and in-flight requests replayed on these errors. +_STREAM_RESUMPTION_EXCEPTIONS = ( + exceptions.ServiceUnavailable, + exceptions.Unknown, +) + + +def _is_retryable_error(reason: Optional[BaseException]) -> bool: + return isinstance(reason, _STREAM_RESUMPTION_EXCEPTIONS) + def _wrap_as_exception(maybe_exception) -> BaseException: """Wrap an object as a Python exception, if needed. @@ -191,14 +202,13 @@ def close(self, reason: Optional[Exception] = None) -> None: def _renew_connection(self, reason: Optional[Exception] = None) -> None: """Helper function that is called when the RPC connection is closed without recovery. It first creates a new Connection instance in an - atomic manner, and then cleans up the failed connection. Note that a - new RPC connection is not established by instantiating _Connection, - but only when `send()` is called for the first time. + atomic manner, and then cleans up the failed connection. + + On transient errors (:data:`_STREAM_RESUMPTION_EXCEPTIONS`) any + in-flight requests are replayed on the new connection so that callers + do not need to handle reconnection themselves. On non-transient errors + the pending futures are failed immediately as before. """ - # Creates a new Connection instance, but doesn't establish a new RPC - # connection. New connection is only started when `send()` is called - # again, in order to save resource if the stream is idle. This action - # is atomic. with self._thread_lock: _closed_connection = self._connection self._connection = _Connection( @@ -206,10 +216,22 @@ def _renew_connection(self, reason: Optional[Exception] = None) -> None: writer=self, metadata=self._metadata, ) - - # Cleanup, and marks futures as failed. To minimize the length of the - # critical section, this step is not guaranteed to be atomic. - _closed_connection._shutdown(reason=reason) + # Copy the stream name so the new connection can build routing + # metadata even before the first send(). + self._connection._stream_name = self._stream_name + + # Shutdown the old connection. On transient errors this returns the + # in-flight (request, future) pairs so we can replay them; on + # non-transient errors it returns an empty list after failing futures. + pending = _closed_connection._shutdown(reason=reason) + + if pending: + _LOGGER.debug( + "Replaying %d in-flight request(s) after transient error: %s", + len(pending), + reason, + ) + self._connection._reopen_with_pending(pending) def _on_rpc_done(self, reason: Optional[BaseException] = None) -> None: """Callback passecd to _Connection. It's called when the RPC connection @@ -257,7 +279,9 @@ def __init__( self._rpc: Union[bidi.BidiRpc | None] = None self._consumer: Union[bidi.BackgroundConsumer | None] = None self._stream_name: str = "" - self._queue: queue.Queue[AppendRowsFuture] = queue.Queue() + self._queue: queue.Queue[ + Tuple[gapic_types.AppendRowsRequest, AppendRowsFuture] + ] = queue.Queue() # statuses self._closed = False @@ -314,7 +338,7 @@ def _open( request = self._make_initial_request(initial_request) future = AppendRowsFuture(self._writer) - self._queue.put(future) + self._queue.put((initial_request, future)) self._rpc = bidi.BidiRpc( self._client.append_rows, @@ -428,22 +452,32 @@ def send(self, request: gapic_types.AppendRowsRequest) -> AppendRowsFuture: # future to the queue so that when the response comes, the callback can # pull it off and notify completion. future = AppendRowsFuture(self._writer) - self._queue.put(future) + self._queue.put((request, future)) if self._rpc is not None: self._rpc.send(request) return future - def _shutdown(self, reason: Optional[Exception] = None) -> None: + def _shutdown( + self, reason: Optional[Exception] = None + ) -> List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]]: """Run the actual shutdown sequence (stop the stream and all helper threads). Args: reason: The reason to close the stream. If ``None``, this is considered an "intentional" shutdown. + + Returns: + A list of ``(request, future)`` pairs for requests that were + in-flight when the connection closed. On transient errors these + are returned instead of being failed so the caller can replay + them on a new connection. On non-transient errors the list is + always empty (futures are failed immediately). """ + pending: List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]] = [] with self._thread_lock: if self._closed: - return + return pending # Stop consuming messages. if self.is_active: @@ -459,19 +493,25 @@ def _shutdown(self, reason: Optional[Exception] = None) -> None: # We know that no new items will be added to the queue because # we've marked the stream as closed. + retryable = _is_retryable_error(reason) while not self._queue.empty(): - # Mark each future as failed. Since the consumer thread has - # stopped (or at least is attempting to stop), we won't get - # response callbacks to populate the remaining futures. - future = self._queue.get_nowait() - exc: Union[Exception, bqstorage_exceptions.StreamClosedError] - if reason is None: - exc = bqstorage_exceptions.StreamClosedError( - "Stream closed before receiving a response." - ) + # On transient errors, collect in-flight requests so they can + # be replayed on a fresh connection instead of surfacing an + # error to the caller. + request, future = self._queue.get_nowait() + if retryable: + pending.append((request, future)) else: - exc = reason - future.set_exception(exc) + exc: Union[Exception, bqstorage_exceptions.StreamClosedError] + if reason is None: + exc = bqstorage_exceptions.StreamClosedError( + "Stream closed before receiving a response." + ) + else: + exc = reason + future.set_exception(exc) + + return pending def close(self, reason: Optional[Exception] = None) -> None: """Stop consuming messages and shutdown all helper threads. @@ -496,7 +536,7 @@ def _on_response(self, response: gapic_types.AppendRowsResponse) -> None: # Since we have 1 response per request, if we get here from a response # callback, the queue should never be empty. - future: AppendRowsFuture = self._queue.get_nowait() + _, future = self._queue.get_nowait() if response.error.code: exc = exceptions.from_grpc_status( response.error.code, response.error.message, response=response @@ -505,6 +545,86 @@ def _on_response(self, response: gapic_types.AppendRowsResponse) -> None: else: future.set_result(response) + def _reopen_with_pending( + self, + pending: List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]], + timeout: float = _DEFAULT_TIMEOUT, + ) -> None: + """Open a fresh RPC connection and replay ``pending`` in-flight requests. + + The existing :class:`AppendRowsFuture` objects are reused so callers + that already hold references transparently receive their results once + the server acknowledges the replayed requests. + + Args: + pending: + ``(request, future)`` pairs collected from the failed + connection's queue. The first entry is used as the stream's + initial request (merged with the writer template); subsequent + entries are sent in order once the connection is active. + timeout: + How long (in seconds) to wait for the stream to be ready. + """ + if not pending: + return + + initial_user_request, initial_future = pending[0] + + with self._thread_lock: + # Inject the existing future so _on_response resolves it. + self._queue.put((initial_user_request, initial_future)) + + merged = self._make_initial_request(initial_user_request) + metadata = tuple(self._metadata) + ( + ( + "x-goog-request-params", + f"write_stream={self._stream_name}", + ), + ) + rpc = bidi.BidiRpc( + self._client.append_rows, + initial_request=merged, + metadata=metadata, + ) + rpc.add_done_callback(self._on_rpc_done) + + consumer = bidi.BackgroundConsumer(rpc, self._on_response) + consumer.start() + + self._rpc = rpc + self._consumer = consumer + + start_time = time.monotonic() + while not rpc.is_active and consumer.is_active: + time.sleep(_WRITE_OPEN_INTERVAL) + if timeout is not None and time.monotonic() - start_time > timeout: + break + + if not consumer.is_active: + # Connection failed — drain the queue and fail futures directly + # rather than going through close() to avoid triggering another + # reconnect attempt (which would cause an infinite retry loop). + exc = exceptions.Unknown( + "There was a problem reopening the stream after a transient error. " + "Try turning on DEBUG level logs to see the error." + ) + with self._thread_lock: + self._closed = True + while not self._queue.empty(): + _, future = self._queue.get_nowait() + if not future.done(): + future.set_exception(exc) + for _, future in pending: + if not future.done(): + future.set_exception(exc) + return + + # Send remaining pending requests over the live connection. + for request, future in pending[1:]: + self._queue.put((request, future)) + if self._rpc is not None: + self._rpc.send(request) + def _on_rpc_done(self, future: AppendRowsFuture) -> None: """Triggered when the underlying RPC terminates without recovery. diff --git a/packages/google-cloud-bigquery-storage/tests/unit/test_writer_v1.py b/packages/google-cloud-bigquery-storage/tests/unit/test_writer_v1.py index d27c5b14123a..1d7d8e27b3f6 100644 --- a/packages/google-cloud-bigquery-storage/tests/unit/test_writer_v1.py +++ b/packages/google-cloud-bigquery-storage/tests/unit/test_writer_v1.py @@ -158,6 +158,8 @@ def test__renew_connection(self): stream._thread_lock = mock_lock reason = Exception("test exception") old_connection = stream._connection = self._make_mock_connection() + # Non-transient error: _shutdown returns no pending pairs. + old_connection._shutdown.return_value = [] stream._renew_connection(reason=reason) @@ -169,6 +171,33 @@ def test__renew_connection(self): mock_lock.__enter__.assert_called_once() mock_lock.__exit__.assert_called_once() + @mock.patch( + "google.cloud.bigquery_storage_v1.writer._Connection._reopen_with_pending", + autospec=True, + ) + def test__renew_connection_replays_on_transient_error(self, reopen_mock): + """On transient errors, pending in-flight requests are replayed.""" + from google.cloud.bigquery_storage_v1 import writer + + mock_client = self._make_mock_client() + stream = self._make_one(mock_client, REQUEST_TEMPLATE) + old_connection = stream._connection = self._make_mock_connection() + + pending_request = gapic_types.AppendRowsRequest( + write_stream="projects/p/datasets/d/tables/t/streams/s" + ) + pending_future = writer.AppendRowsFuture(stream) + old_connection._shutdown.return_value = [(pending_request, pending_future)] + + reason = exceptions.ServiceUnavailable("server unavailable") + stream._renew_connection(reason=reason) + + assert stream._connection is not old_connection + # New connection should have been asked to replay the pending request. + reopen_mock.assert_called_once_with( + stream._connection, [(pending_request, pending_future)] + ) + @mock.patch("threading.Thread", autospec=True) def test__on_rpc_done(self, thread): from google.cloud.bigquery_storage_v1.writer import _RPC_ERROR_THREAD_NAME @@ -382,8 +411,9 @@ def test_close(self, background_consumer, bidi_rpc): connection._rpc = bidi_rpc futures = [writer.AppendRowsFuture(connection._writer) for _ in range(3)] + fake_request = gapic_types.AppendRowsRequest() for f in futures: - connection._queue.put(f) + connection._queue.put((fake_request, f)) close_exception = Exception("test exception") assert connection._closed is False @@ -424,7 +454,8 @@ def test__on_response_exception(self): connection = self._make_one(mock_client, mock_stream) connection._queue = mock.Mock() future = AppendRowsFuture(mock_stream) - connection._queue.get_nowait.return_value = future + fake_request = gapic_types.AppendRowsRequest() + connection._queue.get_nowait.return_value = (fake_request, future) response = gapic_types.AppendRowsResponse( { "error": { @@ -448,7 +479,8 @@ def test__on_response_result(self): connection = self._make_one(mock_client, mock_stream) connection._queue = mock.Mock() future = AppendRowsFuture(mock_stream) - connection._queue.get_nowait.return_value = future + fake_request = gapic_types.AppendRowsRequest() + connection._queue.get_nowait.return_value = (fake_request, future) response = gapic_types.AppendRowsResponse() connection._on_response(response) @@ -472,6 +504,73 @@ def test__on_rpc_done(self): assert isinstance(reason, Exception) assert reason.args[0] is future + @mock.patch("google.api_core.bidi.BidiRpc", autospec=True) + @mock.patch("google.api_core.bidi.BackgroundConsumer", autospec=True) + def test__reopen_with_pending_resolves_futures(self, background_consumer, bidi_rpc): + """_reopen_with_pending replays requests and resolves the existing futures.""" + from google.cloud.bigquery_storage_v1.writer import AppendRowsFuture + + type(bidi_rpc.return_value).is_active = mock.PropertyMock(return_value=True) + type(background_consumer.return_value).is_active = mock.PropertyMock( + return_value=True + ) + + mock_client = self._make_mock_client() + mock_stream = self._make_mock_stream() + connection = self._make_one(mock_client, mock_stream) + connection._stream_name = "projects/p/datasets/d/tables/t/streams/s" + + req1 = gapic_types.AppendRowsRequest( + write_stream="projects/p/datasets/d/tables/t/streams/s" + ) + req2 = gapic_types.AppendRowsRequest( + write_stream="projects/p/datasets/d/tables/t/streams/s" + ) + fut1 = AppendRowsFuture(mock_stream) + fut2 = AppendRowsFuture(mock_stream) + + connection._reopen_with_pending([(req1, fut1), (req2, fut2)]) + + # BidiRpc should have been created with the merged initial request. + bidi_rpc.assert_called_once() + # The second request should have been sent via rpc.send(). + bidi_rpc.return_value.send.assert_called_once_with(req2) + # Both futures should be in the queue awaiting responses. + assert connection._queue.qsize() == 2 + + @mock.patch("google.api_core.bidi.BidiRpc", autospec=True) + @mock.patch("google.api_core.bidi.BackgroundConsumer", autospec=True) + def test__reopen_with_pending_fails_futures_on_connection_failure( + self, background_consumer, bidi_rpc + ): + """When the reconnect itself fails, futures are failed (no infinite retry).""" + from google.cloud.bigquery_storage_v1.writer import AppendRowsFuture + from google.api_core import exceptions as core_exceptions + + # Simulate consumer never becoming active (connection failure). + type(background_consumer.return_value).is_active = mock.PropertyMock( + return_value=False + ) + + mock_client = self._make_mock_client() + mock_stream = self._make_mock_stream() + connection = self._make_one(mock_client, mock_stream) + connection._stream_name = "projects/p/datasets/d/tables/t/streams/s" + + req = gapic_types.AppendRowsRequest( + write_stream="projects/p/datasets/d/tables/t/streams/s" + ) + fut = AppendRowsFuture(mock_stream) + + connection._reopen_with_pending([(req, fut)]) + + # Future should be failed with a non-retryable Unknown error. + assert fut._is_done is True + with pytest.raises(core_exceptions.Unknown): + fut.result() + # Connection should be closed to prevent further use. + assert connection._closed is True + def test__process_request_template(self): from google.cloud.bigquery_storage_v1.writer import _process_request_template