Skip to content

Commit

Permalink
fix: ensure transactions rollback on failure (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 29, 2023
1 parent 9840d43 commit cdaf25b
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 493 deletions.
85 changes: 31 additions & 54 deletions google/cloud/firestore_v1/async_transaction.py
Expand Up @@ -110,6 +110,7 @@ async def _rollback(self) -> None:
Raises:
ValueError: If no transaction is in progress.
google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
"""
if not self.in_progress:
raise ValueError(_CANT_ROLLBACK)
Expand All @@ -124,6 +125,7 @@ async def _rollback(self) -> None:
metadata=self._client._rpc_metadata,
)
finally:
# clean up, even if rollback fails
self._clean_up()

async def _commit(self) -> list:
Expand Down Expand Up @@ -223,10 +225,6 @@ async def _pre_commit(
) -> Coroutine:
"""Begin transaction and call the wrapped coroutine.
If the coroutine raises an exception, the transaction will be rolled
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
it will have staged writes).
Args:
transaction
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
Expand All @@ -250,41 +248,7 @@ async def _pre_commit(
self.current_id = transaction._id
if self.retry_id is None:
self.retry_id = self.current_id
try:
return await self.to_wrap(transaction, *args, **kwargs)
except: # noqa
# NOTE: If ``rollback`` fails this will lose the information
# from the original failure.
await transaction._rollback()
raise

async def _maybe_commit(self, transaction: AsyncTransaction) -> bool:
"""Try to commit the transaction.
If the transaction is read-write and the ``Commit`` fails with the
``ABORTED`` status code, it will be retried. Any other failure will
not be caught.
Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
The transaction to be ``Commit``-ed.
Returns:
bool: Indicating if the commit succeeded.
"""
try:
await transaction._commit()
return True
except exceptions.GoogleAPICallError as exc:
if transaction._read_only:
raise

if isinstance(exc, exceptions.Aborted):
# If a read-write transaction returns ABORTED, retry.
return False
else:
raise
return await self.to_wrap(transaction, *args, **kwargs)

async def __call__(self, transaction, *args, **kwargs):
"""Execute the wrapped callable within a transaction.
Expand All @@ -306,22 +270,35 @@ async def __call__(self, transaction, *args, **kwargs):
``max_attempts``.
"""
self._reset()
retryable_exceptions = (
(exceptions.Aborted) if not transaction._read_only else ()
)
last_exc = None

for attempt in range(transaction._max_attempts):
result = await self._pre_commit(transaction, *args, **kwargs)
succeeded = await self._maybe_commit(transaction)
if succeeded:
return result

# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.

await transaction._rollback()
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg)
try:
for attempt in range(transaction._max_attempts):
result = await self._pre_commit(transaction, *args, **kwargs)
try:
await transaction._commit()
return result
except retryable_exceptions as exc:
last_exc = exc
# Retry attempts that result in retryable exceptions
# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.
# retries exhausted
# wrap the last exception in a ValueError before raising
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg) from last_exc

except BaseException:
# rollback the transaction on any error
# errors raised during _rollback will be chained to the original error through __context__
await transaction._rollback()
raise


def async_transactional(
Expand Down
3 changes: 0 additions & 3 deletions google/cloud/firestore_v1/base_transaction.py
Expand Up @@ -185,8 +185,5 @@ def _reset(self) -> None:
def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn:
raise NotImplementedError

def _maybe_commit(self, transaction) -> NoReturn:
raise NotImplementedError

def __call__(self, transaction, *args, **kwargs):
raise NotImplementedError
86 changes: 31 additions & 55 deletions google/cloud/firestore_v1/transaction.py
Expand Up @@ -44,7 +44,7 @@
# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.types import CommitResponse
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Generator


class Transaction(batch.WriteBatch, BaseTransaction):
Expand Down Expand Up @@ -108,6 +108,7 @@ def _rollback(self) -> None:
Raises:
ValueError: If no transaction is in progress.
google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
"""
if not self.in_progress:
raise ValueError(_CANT_ROLLBACK)
Expand All @@ -122,6 +123,7 @@ def _rollback(self) -> None:
metadata=self._client._rpc_metadata,
)
finally:
# clean up, even if rollback fails
self._clean_up()

def _commit(self) -> list:
Expand Down Expand Up @@ -214,10 +216,6 @@ def __init__(self, to_wrap) -> None:
def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
"""Begin transaction and call the wrapped callable.
If the callable raises an exception, the transaction will be rolled
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
it will have staged writes).
Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
Expand All @@ -241,41 +239,7 @@ def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any:
self.current_id = transaction._id
if self.retry_id is None:
self.retry_id = self.current_id
try:
return self.to_wrap(transaction, *args, **kwargs)
except: # noqa
# NOTE: If ``rollback`` fails this will lose the information
# from the original failure.
transaction._rollback()
raise

def _maybe_commit(self, transaction: Transaction) -> Optional[bool]:
"""Try to commit the transaction.
If the transaction is read-write and the ``Commit`` fails with the
``ABORTED`` status code, it will be retried. Any other failure will
not be caught.
Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
The transaction to be ``Commit``-ed.
Returns:
bool: Indicating if the commit succeeded.
"""
try:
transaction._commit()
return True
except exceptions.GoogleAPICallError as exc:
if transaction._read_only:
raise

if isinstance(exc, exceptions.Aborted):
# If a read-write transaction returns ABORTED, retry.
return False
else:
raise
return self.to_wrap(transaction, *args, **kwargs)

def __call__(self, transaction: Transaction, *args, **kwargs):
"""Execute the wrapped callable within a transaction.
Expand All @@ -297,22 +261,34 @@ def __call__(self, transaction: Transaction, *args, **kwargs):
``max_attempts``.
"""
self._reset()
retryable_exceptions = (
(exceptions.Aborted) if not transaction._read_only else ()
)
last_exc = None

for attempt in range(transaction._max_attempts):
result = self._pre_commit(transaction, *args, **kwargs)
succeeded = self._maybe_commit(transaction)
if succeeded:
return result

# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.

transaction._rollback()
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg)
try:
for attempt in range(transaction._max_attempts):
result = self._pre_commit(transaction, *args, **kwargs)
try:
transaction._commit()
return result
except retryable_exceptions as exc:
last_exc = exc
# Retry attempts that result in retryable exceptions
# Subsequent requests will use the failed transaction ID as part of
# the ``BeginTransactionRequest`` when restarting this transaction
# (via ``options.retry_transaction``). This preserves the "spot in
# line" of the transaction, so exponential backoff is not required
# in this case.
# retries exhausted
# wrap the last exception in a ValueError before raising
msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
raise ValueError(msg) from last_exc
except BaseException: # noqa: B901
# rollback the transaction on any error
# errors raised during _rollback will be chained to the original error through __context__
transaction._rollback()
raise


def transactional(to_wrap: Callable) -> _Transactional:
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/firestore_v1/watch.py
Expand Up @@ -401,7 +401,9 @@ def _on_snapshot_target_change_remove(self, target_change):

error_message = "Error %s: %s" % (code, message)

raise RuntimeError(error_message)
raise RuntimeError(error_message) from exceptions.from_grpc_status(
code, message
)

def _on_snapshot_target_change_reset(self, target_change):
# Whatever changes have happened so far no longer matter.
Expand Down

0 comments on commit cdaf25b

Please sign in to comment.