Skip to content

Commit

Permalink
feat: Fixing and refactoring transaction retry logic in dbapi. Also a…
Browse files Browse the repository at this point in the history
…dding interceptors support for testing (#1056)

* feat: Fixing and refactoring transaction retry logic in dbapi. Also adding interceptors support for testing

* Comments incorporated and changes for also storing Cursor object with the statements details added for retry

* Some refactoring of transaction_helper.py and maintaining state of rows update count for batch dml in cursor

* Small fix

* Maintaining a map from cursor to last statement added in transaction_helper.py

* Rolling back the transaction when Aborted exception is thrown from interceptor

* Small change

* Disabling a test for emulator run

* Reformatting
  • Loading branch information
ankiaga committed Jan 17, 2024
1 parent 7ada21c commit 6640888
Show file tree
Hide file tree
Showing 16 changed files with 1,812 additions and 988 deletions.
25 changes: 12 additions & 13 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
Expand Up @@ -16,7 +16,6 @@

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Expand Down Expand Up @@ -80,8 +79,10 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
"""
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
many_result_set = StreamedManyResultSets()
if not statements:
return many_result_set
connection = cursor.connection
statements_tuple = []
for statement in statements:
statements_tuple.append(statement.get_tuple())
Expand All @@ -90,28 +91,26 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)

cursor._batch_dml_rows_count = res
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
return many_result_set
except Aborted:
connection.retry_transaction()
retried = True
# We are raising it so it could be handled in transaction_helper.py and is retried
if cursor._in_retry_mode:
raise
else:
connection._transaction_helper.retry_transaction()


def _do_batch_update(transaction, statements):
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/spanner_dbapi/checksum.py
Expand Up @@ -62,6 +62,8 @@ def consume_result(self, result):


def _compare_checksums(original, retried):
from google.cloud.spanner_dbapi.transaction_helper import RETRY_ABORTED_ERROR

"""Compare the given checksums.
Raise an error if the given checksums are not equal.
Expand All @@ -75,6 +77,4 @@ def _compare_checksums(original, retried):
:raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal.
"""
if retried != original:
raise RetryAborted(
"The transaction was aborted and could not be retried due to a concurrent modification."
)
raise RetryAborted(RETRY_ABORTED_ERROR)
127 changes: 19 additions & 108 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""
import time
import warnings

from google.api_core.exceptions import Aborted
Expand All @@ -23,19 +22,16 @@
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
Statement,
StatementType,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
OperationalError,
Expand All @@ -44,13 +40,10 @@
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

from google.rpc.code_pb2 import ABORTED


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as a transaction has not been started."
)
MAX_INTERNAL_RETRIES = 50


def check_not_closed(function):
Expand Down Expand Up @@ -106,9 +99,6 @@ def __init__(self, instance, database=None, read_only=False):
self._transaction = None
self._session = None
self._snapshot = None
# SQL statements, which were executed
# within the current transaction
self._statements = []

self.is_closed = False
self._autocommit = False
Expand All @@ -125,6 +115,7 @@ def __init__(self, instance, database=None, read_only=False):
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionRetryHelper(self)

@property
def autocommit(self):
Expand Down Expand Up @@ -288,76 +279,6 @@ def _release_session(self):
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
"""Retry the aborted transaction.
All the statements executed in the original transaction
will be re-executed in new one. Results checksums of the
original statements and the retried ones will be compared.
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
If results checksum of the retried statement is
not equal to the checksum of the original one.
"""
attempt = 0
while True:
self._spanner_transaction_started = False
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise

try:
self._rerun_previous_statements()
break
except Aborted as exc:
delay = _get_retry_delay(exc.errors[0], attempt)
if delay:
time.sleep(delay)

def _rerun_previous_statements(self):
"""
Helper to run all the remembered statements
from the last transaction.
"""
for statement in self._statements:
if isinstance(statement, list):
statements, checksum = statement

transaction = self.transaction_checkout()
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
retried_checksum.consume_result(res)
retried_checksum.consume_result(status.code)

_compare_checksums(checksum, retried_checksum)
else:
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)

_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Expand Down Expand Up @@ -433,12 +354,10 @@ def begin(self):

def commit(self):
"""Commits any pending transaction to the database.
This is a no-op if there is no active client transaction.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")

if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
Expand All @@ -450,33 +369,31 @@ def commit(self):
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
except Aborted:
self.retry_transaction()
self._transaction_helper.retry_transaction()
self.commit()
finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False
self._reset_post_commit_or_rollback()

def rollback(self):
"""Rolls back any pending transaction.
This is a no-op if there is no active client transaction.
"""
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.rollback()
finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False
self._reset_post_commit_or_rollback()

def _reset_post_commit_or_rollback(self):
self._release_session()
self._transaction_helper.reset()
self._transaction_begin_marked = False
self._spanner_transaction_started = False

@check_not_closed
def cursor(self):
Expand All @@ -493,7 +410,7 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement: Statement, retried=False):
def run_statement(self, statement: Statement):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
Expand All @@ -513,17 +430,11 @@ def run_statement(self, statement: Statement, retried=False):
checksum of this statement results.
"""
transaction = self.transaction_checkout()
if not retried:
self._statements.append(statement)

return (
transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
),
ResultsChecksum() if retried else statement.checksum,
return transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
)

@check_not_closed
Expand Down

0 comments on commit 6640888

Please sign in to comment.