Skip to content

Commit

Permalink
feat(spanner): add implementation and integration tests for max commi…
Browse files Browse the repository at this point in the history
…t delay
  • Loading branch information
harshachinta committed Jan 24, 2024
1 parent f3b23b2 commit 0b51124
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self, return_commit_stats=False, request_options=None):
def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None):
"""Commit mutations to the database.
:type return_commit_stats: bool
Expand Down Expand Up @@ -189,6 +189,7 @@ def commit(self, return_commit_stats=False, request_options=None):
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
request_options=request_options,
max_commit_delay=max_commit_delay,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
method = functools.partial(
Expand Down
8 changes: 5 additions & 3 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self, request_options=None):
def batch(self, request_options=None, max_commit_delay=None):
"""Return an object which wraps a batch.
The wrapper *must* be used as a context manager, with the batch
Expand All @@ -737,7 +737,7 @@ def batch(self, request_options=None):
:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self, request_options)
return BatchCheckout(self, request_options, max_commit_delay)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
Expand Down Expand Up @@ -1037,7 +1037,7 @@ class BatchCheckout(object):
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
"""

def __init__(self, database, request_options=None):
def __init__(self, database, request_options=None, max_commit_delay=None):
self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1046,6 +1046,7 @@ def __init__(self, database, request_options=None):
self._request_options = RequestOptions(request_options)
else:
self._request_options = request_options
self._max_commit_delay = max_commit_delay

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1062,6 +1063,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._batch.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def run_in_transaction(self, func, *args, **kw):
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
commit_request_options = kw.pop("commit_request_options", None)
max_commit_delay = kw.pop("max_commit_delay", None)
transaction_tag = kw.pop("transaction_tag", None)
attempts = 0

Expand Down Expand Up @@ -400,6 +401,7 @@ def run_in_transaction(self, func, *args, **kw):
txn.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=commit_request_options,
max_commit_delay=max_commit_delay,
)
except Aborted as exc:
del self._transaction
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self, return_commit_stats=False, request_options=None):
def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None):
"""Commit mutations to the database.
:type return_commit_stats: bool
Expand Down Expand Up @@ -229,6 +229,7 @@ def commit(self, return_commit_stats=False, request_options=None):
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
request_options=request_options,
max_commit_delay=max_commit_delay,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
method = functools.partial(
Expand Down
37 changes: 36 additions & 1 deletion tests/system/test_database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import time
import uuid

Expand Down Expand Up @@ -819,3 +819,38 @@ def _transaction_read(transaction):

with pytest.raises(exceptions.InvalidArgument):
shared_database.run_in_transaction(_transaction_read)


def test_db_batch_insert_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch(max_commit_delay=datetime.timedelta(milliseconds=100)) as batch:
batch.delete(sd.TABLE, sd.ALL)
batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA)

with shared_database.snapshot(read_timestamp=batch.committed) as snapshot:
from_snap = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL))

sd._check_rows_data(from_snap)


def test_db_run_in_transaction_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch() as batch:
batch.delete(sd.TABLE, sd.ALL)

def _unit_of_work(transaction, test):
rows = list(transaction.read(test.TABLE, test.COLUMNS, sd.ALL))
assert rows == []

transaction.insert_or_update(test.TABLE, test.COLUMNS, test.ROW_DATA)

shared_database.run_in_transaction(_unit_of_work, test=sd, max_commit_delay=datetime.timedelta(milliseconds=100))

with shared_database.snapshot() as after:
rows = list(after.execute_sql(sd.SQL))

sd._check_rows_data(rows)

0 comments on commit 0b51124

Please sign in to comment.