Skip to content

Commit

Permalink
feat: use DML batches in executemany() method (#412)
Browse files Browse the repository at this point in the history
* feat: use mutations for executemany() inserts

* add unit test and fix parsing

* add use_mutations flag into Connection class

* use three-values flag for use_mutations

* update docstrings

* use batch DMLs for executemany() method

* prepare args before inserting into SQL statement

* erase mutation mentions

* next step

* next step

* next step

* fixes

* add unit tests for UPDATE and DELETE statements

* don't propagate errors to users on retry

* lint fixes

* use run_in_transaction

* refactor the tests code

* fix merge conflict

* fix the unit test

* revert some changes

* use executemany for test data insert

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
Ilya Gurov and larkee authored Aug 9, 2021
1 parent a2b81be commit cbb4ee3
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 27 deletions.
50 changes: 34 additions & 16 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

from google.rpc.code_pb2 import ABORTED


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
MAX_INTERNAL_RETRIES = 50
Expand Down Expand Up @@ -175,25 +177,41 @@ def _rerun_previous_statements(self):
from the last transaction.
"""
for statement in self._statements:
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)

_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
if isinstance(statement, list):
statements, checksum = statement

transaction = self.transaction_checkout()
status, res = transaction.batch_update(statements)

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
retried_checksum.consume_result(res)
retried_checksum.consume_result(status.code)

_compare_checksums(checksum, retried_checksum)
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)
_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Expand Down
58 changes: 55 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

from google.rpc.code_pb2 import ABORTED, OK

_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Expand Down Expand Up @@ -156,6 +158,15 @@ def _do_execute_update(self, transaction, sql, params):

return result

def _do_batch_update(self, transaction, statements, many_result_set):
status, res = transaction.batch_update(statements)
many_result_set.add_iter(res)

if status.code == ABORTED:
raise Aborted(status.details)
elif status.code != OK:
raise OperationalError(status.details)

def execute(self, sql, args=None):
"""Prepares and executes a Spanner database operation.
Expand Down Expand Up @@ -258,9 +269,50 @@ def executemany(self, operation, seq_of_params):

many_result_set = StreamedManyResultSets()

for params in seq_of_params:
self.execute(operation, params)
many_result_set.add_iter(self._itr)
if classification in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
statements = []

for params in seq_of_params:
sql, params = parse_utils.sql_pyformat_args_to_spanner(
operation, params
)
statements.append((sql, params, get_param_types(params)))

if self.connection.autocommit:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
else:
retried = False
while True:
try:
transaction = self.connection.transaction_checkout()

res_checksum = ResultsChecksum()
if not retried:
self.connection._statements.append(
(statements, res_checksum)
)

status, res = transaction.batch_update(statements)
many_result_set.add_iter(res)
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.details)
elif status.code != OK:
raise OperationalError(status.details)
break
except Aborted:
self.connection.retry_transaction()
retried = True

else:
for params in seq_of_params:
self.execute(operation, params)
many_result_set.add_iter(self._itr)

self._result_set = many_result_set
self._itr = many_result_set
Expand Down
16 changes: 8 additions & 8 deletions tests/system/test_system_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,20 @@ def test_execute_many(self):
conn = Connection(Config.INSTANCE, self._db)
cursor = conn.cursor()

cursor.execute(
cursor.executemany(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com'),
(2, 'first-name2', 'last-name2', 'test.email2@example.com')
"""
VALUES (%s, %s, %s, %s)
""",
[
(1, "first-name", "last-name", "test.email@example.com"),
(2, "first-name2", "last-name2", "test.email2@example.com"),
],
)
conn.commit()

cursor.executemany(
"""
SELECT * FROM contacts WHERE contact_id = @a1
""",
({"a1": 1}, {"a1": 2}),
"""SELECT * FROM contacts WHERE contact_id = @a1""", ({"a1": 1}, {"a1": 2}),
)
res = cursor.fetchall()
conn.commit()
Expand Down
Loading

0 comments on commit cbb4ee3

Please sign in to comment.