Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Catch rst stream error for all transactions #934

Merged
merged 8 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import decimal
import math
import time

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -292,3 +293,57 @@ def _metadata_with_prefix(prefix, **kw):
List[Tuple[str, str]]: RPC metadata with supplied prefix
"""
return [("google-cloud-resource-prefix", prefix)]


def _retry(
func,
retry_count=5,
delay=2,
allowed_exceptions=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.

Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The result of the function if it is successful, or raises the last exception if all retries fail.
"""
retries = 0
while retries <= retry_count:
try:
result = func()
except Exception as exc:
if (
allowed_exceptions is None or exc.__class__ in allowed_exceptions
) and retries < retry_count:
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
if (
allowed_exceptions is not None
and allowed_exceptions[exc.__class__] is not None
):
allowed_exceptions[exc.__class__](exc)
time.sleep(delay)
delay = delay * 2
retries = retries + 1
else:
raise exc
else:
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
return result


def _check_rst_stream_error(exc):
resumable_error = (
any(
resumable_message in exc.message
for resumable_message in (
"RST_STREAM",
"Received unexpected EOS on DATA frame from server",
)
),
)
if not resumable_error:
raise
11 changes: 10 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Context manager for Cloud Spanner batched writes."""
import functools

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
Expand All @@ -23,6 +24,9 @@
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -179,10 +183,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed
Expand Down
23 changes: 20 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _retry
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -545,12 +547,17 @@ def partition_read(
with trace_call(
"CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
):
response = api.partition_read(
method = functools.partial(
api.partition_read,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -640,12 +647,17 @@ def partition_query(
self._session,
trace_attributes,
):
response = api.partition_query(
method = functools.partial(
api.partition_query,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -768,10 +780,15 @@ def begin(self):
metadata = _metadata_with_prefix(database.name)
txn_selector = self._make_txn_selector()
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
34 changes: 29 additions & 5 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _retry
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError


class Transaction(_SnapshotBase, _BatchBase):
Expand Down Expand Up @@ -100,7 +103,11 @@ def _execute_request(
transaction = self._make_txn_selector()
request.transaction = transaction
with trace_call(trace_name, session, attributes):
response = method(request=request)
method = functools.partial(method, request=request)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return response

Expand All @@ -126,8 +133,15 @@ def begin(self):
metadata = _metadata_with_prefix(database.name)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
session=self._session.name, options=txn_options, metadata=metadata
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_options,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
Expand All @@ -141,11 +155,16 @@ def rollback(self):
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
with trace_call("CloudSpanner.Rollback", self._session):
api.rollback(
method = functools.partial(
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=metadata,
)
_retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.rolled_back = True
del self._session._transaction

Expand Down Expand Up @@ -196,10 +215,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test__session_checkout(self, mock_database):
connection._session_checkout()
self.assertEqual(connection._session, "db_session")

def test__session_checkout_database_error(self):
def test_session_checkout_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand All @@ -190,7 +190,7 @@ def test__release_session(self, mock_database):
pool.put.assert_called_once_with("session")
self.assertIsNone(connection._session)

def test__release_session_database_error(self):
def test_release_session_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand Down
74 changes: 74 additions & 0 deletions tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import unittest
import mock


class Test_merge_query_options(unittest.TestCase):
Expand Down Expand Up @@ -669,3 +670,76 @@ def test(self):
prefix = "prefix"
metadata = self._call_fut(prefix)
self.assertEqual(metadata, [("google-cloud-resource-prefix", prefix)])


class Test_retry(unittest.TestCase):
class test_class:
def test_fxn(self):
return True

def test_retry_on_error(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
NotFound("testing"),
True,
]

_retry(functools.partial(test_api.test_fxn))

self.assertEqual(test_api.test_fxn.call_count, 3)

def test_retry_allowed_exceptions(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
InternalServerError("testing"),
True,
]

with self.assertRaises(InternalServerError):
_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={NotFound: None},
)
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

def test_retry_count(self):
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
InternalServerError("testing"),
]

with self.assertRaises(InternalServerError):
_retry(functools.partial(test_api.test_fxn), retry_count=1)
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

def test_check_rst_stream_error(self):
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry, _check_rst_stream_error
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("Received unexpected EOS on DATA frame from server"),
InternalServerError("RST_STREAM"),
True,
]

_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

self.assertEqual(test_api.test_fxn.call_count, 3)
19 changes: 19 additions & 0 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,25 @@ def test_begin_w_other_error(self):
attributes=BASE_ATTRIBUTES,
)

def test_begin_w_retry(self):
from google.cloud.spanner_v1 import (
Transaction as TransactionPB,
)
from google.api_core.exceptions import InternalServerError

database = _Database()
api = database.spanner_api = self._make_spanner_api()
database.spanner_api.begin_transaction.side_effect = [
InternalServerError("Received unexpected EOS on DATA frame from server"),
TransactionPB(id=TXN_ID),
]
timestamp = self._makeTimestamp()
session = _Session(database)
snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True)

snapshot.begin()
self.assertEqual(api.begin_transaction.call_count, 2)

def test_begin_ok_exact_staleness(self):
from google.protobuf.duration_pb2 import Duration
from google.cloud.spanner_v1 import (
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ def test_begin_ok(self):
"CloudSpanner.BeginTransaction", attributes=TestTransaction.BASE_ATTRIBUTES
)

def test_begin_w_retry(self):
asthamohta marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.spanner_v1 import (
Transaction as TransactionPB,
)
from google.api_core.exceptions import InternalServerError

database = _Database()
api = database.spanner_api = self._make_spanner_api()
database.spanner_api.begin_transaction.side_effect = [
InternalServerError("Received unexpected EOS on DATA frame from server"),
TransactionPB(id=self.TRANSACTION_ID),
]

session = _Session(database)
transaction = self._make_one(session)
transaction.begin()

self.assertEqual(api.begin_transaction.call_count, 2)

def test_rollback_not_begun(self):
database = _Database()
api = database.spanner_api = self._make_spanner_api()
Expand Down