Skip to content

Commit

Permalink
feat: Add support for max commit delay (#1050)
Browse files Browse the repository at this point in the history
* proto generation

* max commit delay

* Fix some errors

* Unit tests

* regenerate proto changes

* Fix unit tests

* Finish test_transaction.py

* Finish test_batch.py

* Formatting

* Cleanup

* Fix merge conflict

* Add optional=True

* Remove optional=True, try calling HasField.

* Update HasField to be called on the protobuf.

* Update to timedelta.duration instead of an int.

* Cleanup

* Changes from Sri to pipe value to top-level funcitons and to add integration tests. Thanks Sri

* Run nox -s blacken

* feat(spanner): remove unused imports and add line

* feat(spanner): add empty line in python docs

* Update comment with valid values.

* Update comment with valid values.

* feat(spanner): fix lint

* feat(spanner): rever nox file changes

---------

Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com>
Co-authored-by: Sri Harsha CH <sriharshach@google.com>
  • Loading branch information
3 people committed Feb 5, 2024
1 parent 122ab36 commit d5acc26
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 19 deletions.
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/batch.py
Expand Up @@ -146,7 +146,9 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self, return_commit_stats=False, request_options=None):
def commit(
self, return_commit_stats=False, request_options=None, max_commit_delay=None
):
"""Commit mutations to the database.
:type return_commit_stats: bool
Expand All @@ -160,6 +162,11 @@ def commit(self, return_commit_stats=False, request_options=None):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type max_commit_delay: :class:`datetime.timedelta`
:param max_commit_delay:
(Optional) The amount of latency this request is willing to incur
in order to improve throughput.
:rtype: datetime
:returns: timestamp of the committed changes.
"""
Expand Down Expand Up @@ -188,6 +195,7 @@ def commit(self, return_commit_stats=False, request_options=None):
mutations=self._mutations,
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
Expand Down
25 changes: 21 additions & 4 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -721,7 +721,7 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self, request_options=None):
def batch(self, request_options=None, max_commit_delay=None):
"""Return an object which wraps a batch.
The wrapper *must* be used as a context manager, with the batch
Expand All @@ -734,10 +734,16 @@ def batch(self, request_options=None):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type max_commit_delay: :class:`datetime.timedelta`
:param max_commit_delay:
(Optional) The amount of latency this request is willing to incur
in order to improve throughput. Value must be between 0ms and
500ms.
:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self, request_options)
return BatchCheckout(self, request_options, max_commit_delay)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
Expand Down Expand Up @@ -796,9 +802,13 @@ def run_in_transaction(self, func, *args, **kw):
:type kw: dict
:param kw: (Optional) keyword arguments to be passed to ``func``.
If passed, "timeout_secs" will be removed and used to
If passed,
"timeout_secs" will be removed and used to
override the default retry timeout which defines maximum timestamp
to continue retrying the transaction.
"max_commit_delay" will be removed and used to set the
max_commit_delay for the request. Value must be between
0ms and 500ms.
:rtype: Any
:returns: The return value of ``func``.
Expand Down Expand Up @@ -1035,9 +1045,14 @@ class BatchCheckout(object):
(Optional) Common options for the commit request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type max_commit_delay: :class:`datetime.timedelta`
:param max_commit_delay:
(Optional) The amount of latency this request is willing to incur
in order to improve throughput.
"""

def __init__(self, database, request_options=None):
def __init__(self, database, request_options=None, max_commit_delay=None):
self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1046,6 +1061,7 @@ def __init__(self, database, request_options=None):
self._request_options = RequestOptions(request_options)
else:
self._request_options = request_options
self._max_commit_delay = max_commit_delay

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1062,6 +1078,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._batch.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/session.py
Expand Up @@ -363,6 +363,8 @@ def run_in_transaction(self, func, *args, **kw):
to continue retrying the transaction.
"commit_request_options" will be removed and used to set the
request options for the commit request.
"max_commit_delay" will be removed and used to set the max commit delay for the request.
"transaction_tag" will be removed and used to set the transaction tag for the request.
:rtype: Any
:returns: The return value of ``func``.
Expand All @@ -372,6 +374,7 @@ def run_in_transaction(self, func, *args, **kw):
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
commit_request_options = kw.pop("commit_request_options", None)
max_commit_delay = kw.pop("max_commit_delay", None)
transaction_tag = kw.pop("transaction_tag", None)
attempts = 0

Expand Down Expand Up @@ -400,6 +403,7 @@ def run_in_transaction(self, func, *args, **kw):
txn.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=commit_request_options,
max_commit_delay=max_commit_delay,
)
except Aborted as exc:
del self._transaction
Expand Down
11 changes: 10 additions & 1 deletion google/cloud/spanner_v1/transaction.py
Expand Up @@ -180,7 +180,9 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self, return_commit_stats=False, request_options=None):
def commit(
self, return_commit_stats=False, request_options=None, max_commit_delay=None
):
"""Commit mutations to the database.
:type return_commit_stats: bool
Expand All @@ -194,6 +196,12 @@ def commit(self, return_commit_stats=False, request_options=None):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type max_commit_delay: :class:`datetime.timedelta`
:param max_commit_delay:
(Optional) The amount of latency this request is willing to incur
in order to improve throughput.
:class:`~google.cloud.spanner_v1.types.MaxCommitDelay`.
:rtype: datetime
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
Expand Down Expand Up @@ -228,6 +236,7 @@ def commit(self, return_commit_stats=False, request_options=None):
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
Expand Down
40 changes: 40 additions & 0 deletions tests/system/test_database_api.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import time
import uuid

Expand Down Expand Up @@ -819,3 +820,42 @@ def _transaction_read(transaction):

with pytest.raises(exceptions.InvalidArgument):
shared_database.run_in_transaction(_transaction_read)


def test_db_batch_insert_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch(
max_commit_delay=datetime.timedelta(milliseconds=100)
) as batch:
batch.delete(sd.TABLE, sd.ALL)
batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA)

with shared_database.snapshot(read_timestamp=batch.committed) as snapshot:
from_snap = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL))

sd._check_rows_data(from_snap)


def test_db_run_in_transaction_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch() as batch:
batch.delete(sd.TABLE, sd.ALL)

def _unit_of_work(transaction, test):
rows = list(transaction.read(test.TABLE, test.COLUMNS, sd.ALL))
assert rows == []

transaction.insert_or_update(test.TABLE, test.COLUMNS, test.ROW_DATA)

shared_database.run_in_transaction(
_unit_of_work, test=sd, max_commit_delay=datetime.timedelta(milliseconds=100)
)

with shared_database.snapshot() as after:
rows = list(after.execute_sql(sd.SQL))

sd._check_rows_data(rows)
54 changes: 45 additions & 9 deletions tests/unit/test_batch.py
Expand Up @@ -233,7 +233,14 @@ def test_commit_ok(self):
self.assertEqual(committed, now)
self.assertEqual(batch.committed, committed)

(session, mutations, single_use_txn, request_options, metadata) = api._committed
(
session,
mutations,
single_use_txn,
request_options,
max_commit_delay,
metadata,
) = api._committed
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
Expand All @@ -246,12 +253,13 @@ def test_commit_ok(self):
],
)
self.assertEqual(request_options, RequestOptions())
self.assertEqual(max_commit_delay, None)

self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

def _test_commit_with_request_options(self, request_options=None):
def _test_commit_with_options(self, request_options=None, max_commit_delay_in=None):
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
Expand All @@ -267,7 +275,9 @@ def _test_commit_with_request_options(self, request_options=None):
batch = self._make_one(session)
batch.transaction_tag = self.TRANSACTION_TAG
batch.insert(TABLE_NAME, COLUMNS, VALUES)
committed = batch.commit(request_options=request_options)
committed = batch.commit(
request_options=request_options, max_commit_delay=max_commit_delay_in
)

self.assertEqual(committed, now)
self.assertEqual(batch.committed, committed)
Expand All @@ -284,6 +294,7 @@ def _test_commit_with_request_options(self, request_options=None):
mutations,
single_use_txn,
actual_request_options,
max_commit_delay,
metadata,
) = api._committed
self.assertEqual(session, self.SESSION_NAME)
Expand All @@ -303,33 +314,46 @@ def _test_commit_with_request_options(self, request_options=None):
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

self.assertEqual(max_commit_delay_in, max_commit_delay)

def test_commit_w_request_tag_success(self):
request_options = RequestOptions(
request_tag="tag-1",
)
self._test_commit_with_request_options(request_options=request_options)
self._test_commit_with_options(request_options=request_options)

def test_commit_w_transaction_tag_success(self):
request_options = RequestOptions(
transaction_tag="tag-1-1",
)
self._test_commit_with_request_options(request_options=request_options)
self._test_commit_with_options(request_options=request_options)

def test_commit_w_request_and_transaction_tag_success(self):
request_options = RequestOptions(
request_tag="tag-1",
transaction_tag="tag-1-1",
)
self._test_commit_with_request_options(request_options=request_options)
self._test_commit_with_options(request_options=request_options)

def test_commit_w_request_and_transaction_tag_dictionary_success(self):
request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"}
self._test_commit_with_request_options(request_options=request_options)
self._test_commit_with_options(request_options=request_options)

def test_commit_w_incorrect_tag_dictionary_error(self):
request_options = {"incorrect_tag": "tag-1-1"}
with self.assertRaises(ValueError):
self._test_commit_with_request_options(request_options=request_options)
self._test_commit_with_options(request_options=request_options)

def test_commit_w_max_commit_delay(self):
import datetime

request_options = RequestOptions(
request_tag="tag-1",
)
self._test_commit_with_options(
request_options=request_options,
max_commit_delay_in=datetime.timedelta(milliseconds=100),
)

def test_context_mgr_already_committed(self):
import datetime
Expand Down Expand Up @@ -368,7 +392,14 @@ def test_context_mgr_success(self):

self.assertEqual(batch.committed, now)

(session, mutations, single_use_txn, request_options, metadata) = api._committed
(
session,
mutations,
single_use_txn,
request_options,
_,
metadata,
) = api._committed
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
Expand Down Expand Up @@ -565,12 +596,17 @@ def commit(
):
from google.api_core.exceptions import Unknown

max_commit_delay = None
if type(request).pb(request).HasField("max_commit_delay"):
max_commit_delay = request.max_commit_delay

assert request.transaction_id == b""
self._committed = (
request.session,
request.mutations,
request.single_use_transaction,
request.request_options,
max_commit_delay,
metadata,
)
if self._rpc_error:
Expand Down

0 comments on commit d5acc26

Please sign in to comment.