Skip to content

Commit

Permalink
fix: increment seqno before execute calls to prevent InvalidArgument … (
Browse files Browse the repository at this point in the history
#19)

* fix: increment seqno before execute calls to prevent InvalidArgument errors after a previous error

* make assignments atomic

* add and update tests

* revert snapshot.py change

* formatting

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee committed Mar 24, 2020
1 parent 13a9027 commit adeacee
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
18 changes: 12 additions & 6 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def execute_update(
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = database._instance._client._query_options
Expand All @@ -214,11 +219,9 @@ def execute_update(
param_types=param_types,
query_mode=query_mode,
query_options=query_options,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
return response.stats.row_count_exact

def batch_update(self, statements):
Expand Down Expand Up @@ -259,15 +262,18 @@ def batch_update(self, statements):
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

response = api.execute_batch_dml(
session=self._session.name,
transaction=transaction,
statements=parsed,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
row_counts = [
result_set.stats.row_count_exact for result_set in response.result_sets
]
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def test_execute_sql_other_error(self):
with self.assertRaises(RuntimeError):
list(derived.execute_sql(SQL_QUERY))

self.assertEqual(derived._execute_sql_count, 1)

def test_execute_sql_w_params_wo_param_types(self):
database = _Database()
session = _Session(database)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,19 @@ def test_execute_update_new_transaction(self):
def test_execute_update_w_count(self):
self._execute_update_helper(count=1)

def test_execute_update_error(self):
database = _Database()
database.spanner_api = self._make_spanner_api()
database.spanner_api.execute_sql.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

with self.assertRaises(RuntimeError):
transaction.execute_update(DML_QUERY)

self.assertEqual(transaction._execute_sql_count, 1)

def test_execute_update_w_query_options(self):
from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest

Expand Down Expand Up @@ -513,6 +526,31 @@ def test_batch_update_wo_errors(self):
def test_batch_update_w_errors(self):
self._batch_update_helper(error_after=2, count=1)

def test_batch_update_error(self):
database = _Database()
api = database.spanner_api = self._make_spanner_api()
api.execute_batch_dml.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
insert_param_types = {"pkey": "INT64", "desc": "STRING"}
update_dml = 'UPDATE table SET desc = desc + "-amended"'
delete_dml = "DELETE FROM table WHERE desc IS NULL"

dml_statements = [
(insert_dml, insert_params, insert_param_types),
update_dml,
delete_dml,
]

with self.assertRaises(RuntimeError):
transaction.batch_update(dml_statements)

self.assertEqual(transaction._execute_sql_count, 1)

def test_context_mgr_success(self):
import datetime
from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse
Expand Down

0 comments on commit adeacee

Please sign in to comment.