Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure transactions rollback on failure #767

Merged
merged 16 commits into from Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 36 additions & 56 deletions google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -105,14 +105,18 @@ async def _begin(self, retry_id: bytes = None) -> None:
)
self._id = transaction_response.transaction

async def _rollback(self) -> None:
async def _rollback(self, source_exc=None) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unfortunate that we have to add an additional parameter to this function. I am guessing the underscore prefix means it is private? If so, I guess it's not that big of a deal, but if this argument is only being used for that one case in the call function then could we catch errors there caused by await transaction._rollback(source_exc=exc) and do raise exc from source_exc there instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not worried about the new parameter because 1) this is a private function, and 2) it's an optional paramater, so existing behaviour would remain consistent

But that's a good point that after this refactor, the rollback only happens in one place. In this case, we can rely on built-in exception chaining instead of raise from, since the exception occurs while handling a different exception. I simplified the code to remove the source_exc

"""Roll back the transaction.

Args:
source_exc (Optional[Exception]): The exception that caused the
rollback to occur. If an exception is created while rolling
back, it will be chained to this one.
Raises:
ValueError: If no transaction is in progress.
"""
if not self.in_progress:
raise ValueError(_CANT_ROLLBACK)
raise ValueError(_CANT_ROLLBACK) from source_exc

try:
# NOTE: The response is just ``google.protobuf.Empty``.
Expand All @@ -123,6 +127,9 @@ async def _rollback(self) -> None:
},
metadata=self._client._rpc_metadata,
)
except Exception as exc: # pylint: disable=broad-except
# attach source_exc to the exception raised by rollback
raise exc from source_exc
finally:
self._clean_up()

Expand Down Expand Up @@ -223,10 +230,6 @@ async def _pre_commit(
) -> Coroutine:
"""Begin transaction and call the wrapped coroutine.

If the coroutine raises an exception, the transaction will be rolled
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
it will have staged writes).

Args:
transaction
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
Expand All @@ -250,41 +253,7 @@ async def _pre_commit(
self.current_id = transaction._id
if self.retry_id is None:
self.retry_id = self.current_id
try:
return await self.to_wrap(transaction, *args, **kwargs)
except: # noqa
# NOTE: If ``rollback`` fails this will lose the information
# from the original failure.
await transaction._rollback()
raise

async def _maybe_commit(self, transaction: AsyncTransaction) -> bool:
"""Try to commit the transaction.

If the transaction is read-write and the ``Commit`` fails with the
``ABORTED`` status code, it will be retried. Any other failure will
not be caught.

Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
The transaction to be ``Commit``-ed.

Returns:
bool: Indicating if the commit succeeded.
"""
try:
await transaction._commit()
return True
except exceptions.GoogleAPICallError as exc:
if transaction._read_only:
raise

if isinstance(exc, exceptions.Aborted):
# If a read-write transaction returns ABORTED, retry.
return False
else:
raise
return await self.to_wrap(transaction, *args, **kwargs)

async def __call__(self, transaction, *args, **kwargs):
"""Execute the wrapped callable within a transaction.
Expand All @@ -306,22 +275,33 @@ async def __call__(self, transaction, *args, **kwargs):
``max_attempts``.
"""
self._reset()
retryable_exceptions = (
(exceptions.Aborted) if not transaction._read_only else ()
)
last_exc = None

for attempt in range(transaction._max_attempts):
result = await self._pre_commit(transaction, *args, **kwargs)
succeeded = await self._maybe_commit(transaction)
if succeeded:
return result

# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.

await transaction._rollback()
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg)
try:
for attempt in range(transaction._max_attempts):
result = await self._pre_commit(transaction, *args, **kwargs)
try:
await transaction._commit()
return result
except retryable_exceptions as exc:
last_exc = exc
# Retry attempts that result in retryable exceptions
# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.
# retries exhausted
# wrap the last exception in a ValueError before raising
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg) from last_exc

except BaseException as exc:
await transaction._rollback(source_exc=exc)
raise exc


def async_transactional(
Expand Down
3 changes: 0 additions & 3 deletions google/cloud/firestore_v1/base_transaction.py
Expand Up @@ -185,8 +185,5 @@ def _reset(self) -> None:
def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn:
raise NotImplementedError

def _maybe_commit(self, transaction) -> NoReturn:
raise NotImplementedError

def __call__(self, transaction, *args, **kwargs):
raise NotImplementedError
94 changes: 37 additions & 57 deletions google/cloud/firestore_v1/transaction.py
Expand Up @@ -44,7 +44,7 @@
# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.types import CommitResponse
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Generator


class Transaction(batch.WriteBatch, BaseTransaction):
Expand Down Expand Up @@ -103,14 +103,18 @@ def _begin(self, retry_id: bytes = None) -> None:
)
self._id = transaction_response.transaction

def _rollback(self) -> None:
def _rollback(self, source_exc=None) -> None:
"""Roll back the transaction.

Args:
source_exc (Optional[Exception]): The exception that caused the
rollback to occur. If an exception is created while rolling
back, it will be chained to this one.
Raises:
ValueError: If no transaction is in progress.
"""
if not self.in_progress:
raise ValueError(_CANT_ROLLBACK)
raise ValueError(_CANT_ROLLBACK) from source_exc

try:
# NOTE: The response is just ``google.protobuf.Empty``.
Expand All @@ -121,6 +125,9 @@ def _rollback(self) -> None:
},
metadata=self._client._rpc_metadata,
)
except Exception as exc:
# attach source_exc to the exception raised by rollback
raise exc from source_exc
finally:
self._clean_up()

Expand Down Expand Up @@ -214,10 +221,6 @@ def __init__(self, to_wrap) -> None:
def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
"""Begin transaction and call the wrapped callable.

If the callable raises an exception, the transaction will be rolled
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
it will have staged writes).

Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
Expand All @@ -241,41 +244,7 @@ def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
self.current_id = transaction._id
if self.retry_id is None:
self.retry_id = self.current_id
try:
return self.to_wrap(transaction, *args, **kwargs)
except: # noqa
# NOTE: If ``rollback`` fails this will lose the information
# from the original failure.
transaction._rollback()
raise

def _maybe_commit(self, transaction: Transaction) -> Optional[bool]:
"""Try to commit the transaction.

If the transaction is read-write and the ``Commit`` fails with the
``ABORTED`` status code, it will be retried. Any other failure will
not be caught.

Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
The transaction to be ``Commit``-ed.

Returns:
bool: Indicating if the commit succeeded.
"""
try:
transaction._commit()
return True
except exceptions.GoogleAPICallError as exc:
if transaction._read_only:
raise

if isinstance(exc, exceptions.Aborted):
# If a read-write transaction returns ABORTED, retry.
return False
else:
raise
return self.to_wrap(transaction, *args, **kwargs)

def __call__(self, transaction: Transaction, *args, **kwargs):
"""Execute the wrapped callable within a transaction.
Expand All @@ -297,22 +266,33 @@ def __call__(self, transaction: Transaction, *args, **kwargs):
``max_attempts``.
"""
self._reset()
retryable_exceptions = (
(exceptions.Aborted) if not transaction._read_only else ()
)
last_exc = None

for attempt in range(transaction._max_attempts):
result = self._pre_commit(transaction, *args, **kwargs)
succeeded = self._maybe_commit(transaction)
if succeeded:
return result

# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.

transaction._rollback()
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg)
try:
for attempt in range(transaction._max_attempts):
result = self._pre_commit(transaction, *args, **kwargs)
try:
transaction._commit()
return result
except retryable_exceptions as exc:
last_exc = exc
# Retry attempts that result in retryable exceptions
# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.
# retries exhausted
# wrap the last exception in a ValueError before raising
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg) from last_exc
except BaseException as exc: # noqa: B901
# rollback the transaction on any error
transaction._rollback(source_exc=exc)
raise exc


def transactional(to_wrap: Callable) -> _Transactional:
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/firestore_v1/watch.py
Expand Up @@ -401,7 +401,9 @@ def _on_snapshot_target_change_remove(self, target_change):

error_message = "Error %s: %s" % (code, message)

raise RuntimeError(error_message)
raise RuntimeError(error_message) from exceptions.from_grpc_status(
code, message
)

def _on_snapshot_target_change_reset(self, target_change):
# Whatever changes have happened so far no longer matter.
Expand Down