diff --git a/tests/unit/test_client_v1.py b/tests/unit/test_read_client_v1.py similarity index 71% rename from tests/unit/test_client_v1.py rename to tests/unit/test_read_client_v1.py index f214b877..0c3e1dc1 100644 --- a/tests/unit/test_client_v1.py +++ b/tests/unit/test_read_client_v1.py @@ -15,11 +15,12 @@ import importlib from unittest import mock +import google.api_core.exceptions from google.api_core.gapic_v1 import client_info from google.auth import credentials import pytest -from google.cloud.bigquery_storage import types +from google.cloud.bigquery_storage_v1 import types PROJECT = "my-project" SERVICE_ACCOUNT_PROJECT = "project-from-credentials" @@ -29,20 +30,15 @@ def mock_transport(monkeypatch): from google.cloud.bigquery_storage_v1.services.big_query_read import transports - fake_create_session_rpc = mock.Mock(name="create_read_session_rpc") - fake_read_rows_rpc = mock.Mock(name="read_rows_rpc") - transport = mock.create_autospec( transports.grpc.BigQueryReadGrpcTransport, instance=True ) transport.create_read_session = mock.Mock(name="fake_create_read_session") transport.read_rows = mock.Mock(name="fake_read_rows") - - transport._wrapped_methods = { - transport.create_read_session: fake_create_session_rpc, - transport.read_rows: fake_read_rows_rpc, - } + transports.grpc.BigQueryReadGrpcTransport._prep_wrapped_messages( + transport, client_info.ClientInfo() + ) # _credentials property for TPC support transport._credentials = "" @@ -85,8 +81,11 @@ def __init__(self, *args, **kwargs): def test_create_read_session(mock_transport, client_under_test): - assert client_under_test._transport is mock_transport # sanity check + # validate test assumptions + assert client_under_test._transport is mock_transport + rpc_callable = mock.Mock() + mock_transport._wrapped_methods[mock_transport.create_read_session] = rpc_callable table = "projects/{}/datasets/{}/tables/{}".format( "data-project-id", "dataset_id", "table_id" ) @@ -101,12 +100,47 @@ def test_create_read_session(mock_transport, client_under_test): expected_session_arg = types.CreateReadSessionRequest( parent="projects/other-project", read_session=read_session ) - rpc_callable = mock_transport._wrapped_methods[mock_transport.create_read_session] rpc_callable.assert_called_once_with( expected_session_arg, metadata=mock.ANY, retry=mock.ANY, timeout=mock.ANY ) +def test_create_read_session_retries_serviceunavailable( + mock_transport, client_under_test +): + """Regression test for https://github.com/googleapis/python-bigquery-storage/issues/969.""" + # validate test assumptions + assert client_under_test._transport is mock_transport + + mock_transport.create_read_session.side_effect = [ + google.api_core.exceptions.ServiceUnavailable("connection reset"), + google.api_core.exceptions.ServiceUnavailable("connection reset"), + types.ReadSession(), + ] + table = "projects/{}/datasets/{}/tables/{}".format( + "data-project-id", "dataset_id", "table_id" + ) + read_session = types.ReadSession() + read_session.table = table + + # with pytest.raises(google.api_core.exceptions.ServiceUnavailable): + client_under_test.create_read_session( + parent="projects/other-project", read_session=read_session + ) + + expected_session_arg = types.CreateReadSessionRequest( + parent="projects/other-project", read_session=read_session + ) + expected_call = mock.call(expected_session_arg, metadata=mock.ANY, timeout=mock.ANY) + mock_transport.create_read_session.assert_has_calls( + [ + expected_call, + expected_call, + expected_call, + ] + ) + + def test_read_rows(mock_transport, client_under_test): stream_name = "teststream" offset = 0