From d5acc263d86fcbde7d5f972930255119e2f60e76 Mon Sep 17 00:00:00 2001 From: nginsberg-google <131713109+nginsberg-google@users.noreply.github.com> Date: Sun, 4 Feb 2024 20:17:21 -0800 Subject: [PATCH] feat: Add support for max commit delay (#1050) * proto generation * max commit delay * Fix some errors * Unit tests * regenerate proto changes * Fix unit tests * Finish test_transaction.py * Finish test_batch.py * Formatting * Cleanup * Fix merge conflict * Add optional=True * Remove optional=True, try calling HasField. * Update HasField to be called on the protobuf. * Update to timedelta.duration instead of an int. * Cleanup * Changes from Sri to pipe value to top-level funcitons and to add integration tests. Thanks Sri * Run nox -s blacken * feat(spanner): remove unused imports and add line * feat(spanner): add empty line in python docs * Update comment with valid values. * Update comment with valid values. * feat(spanner): fix lint * feat(spanner): rever nox file changes --------- Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> Co-authored-by: Sri Harsha CH --- google/cloud/spanner_v1/batch.py | 10 ++++- google/cloud/spanner_v1/database.py | 25 ++++++++++-- google/cloud/spanner_v1/session.py | 4 ++ google/cloud/spanner_v1/transaction.py | 11 +++++- tests/system/test_database_api.py | 40 +++++++++++++++++++ tests/unit/test_batch.py | 54 +++++++++++++++++++++----- tests/unit/test_transaction.py | 34 ++++++++++++++-- 7 files changed, 159 insertions(+), 19 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index da74bf35f0..9cb2afbc2c 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -146,7 +146,9 @@ def _check_state(self): if self.committed is not None: raise ValueError("Batch already committed") - def commit(self, return_commit_stats=False, request_options=None): + def commit( + self, return_commit_stats=False, request_options=None, max_commit_delay=None + ): """Commit mutations to the database. :type return_commit_stats: bool @@ -160,6 +162,11 @@ def commit(self, return_commit_stats=False, request_options=None): If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type max_commit_delay: :class:`datetime.timedelta` + :param max_commit_delay: + (Optional) The amount of latency this request is willing to incur + in order to improve throughput. + :rtype: datetime :returns: timestamp of the committed changes. """ @@ -188,6 +195,7 @@ def commit(self, return_commit_stats=False, request_options=None): mutations=self._mutations, single_use_transaction=txn_options, return_commit_stats=return_commit_stats, + max_commit_delay=max_commit_delay, request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1a651a66f5..b23db95284 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -721,7 +721,7 @@ def snapshot(self, **kw): """ return SnapshotCheckout(self, **kw) - def batch(self, request_options=None): + def batch(self, request_options=None, max_commit_delay=None): """Return an object which wraps a batch. The wrapper *must* be used as a context manager, with the batch @@ -734,10 +734,16 @@ def batch(self, request_options=None): If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type max_commit_delay: :class:`datetime.timedelta` + :param max_commit_delay: + (Optional) The amount of latency this request is willing to incur + in order to improve throughput. Value must be between 0ms and + 500ms. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout` :returns: new wrapper """ - return BatchCheckout(self, request_options) + return BatchCheckout(self, request_options, max_commit_delay) def mutation_groups(self): """Return an object which wraps a mutation_group. @@ -796,9 +802,13 @@ def run_in_transaction(self, func, *args, **kw): :type kw: dict :param kw: (Optional) keyword arguments to be passed to ``func``. - If passed, "timeout_secs" will be removed and used to + If passed, + "timeout_secs" will be removed and used to override the default retry timeout which defines maximum timestamp to continue retrying the transaction. + "max_commit_delay" will be removed and used to set the + max_commit_delay for the request. Value must be between + 0ms and 500ms. :rtype: Any :returns: The return value of ``func``. @@ -1035,9 +1045,14 @@ class BatchCheckout(object): (Optional) Common options for the commit request. If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + + :type max_commit_delay: :class:`datetime.timedelta` + :param max_commit_delay: + (Optional) The amount of latency this request is willing to incur + in order to improve throughput. """ - def __init__(self, database, request_options=None): + def __init__(self, database, request_options=None, max_commit_delay=None): self._database = database self._session = self._batch = None if request_options is None: @@ -1046,6 +1061,7 @@ def __init__(self, database, request_options=None): self._request_options = RequestOptions(request_options) else: self._request_options = request_options + self._max_commit_delay = max_commit_delay def __enter__(self): """Begin ``with`` block.""" @@ -1062,6 +1078,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._batch.commit( return_commit_stats=self._database.log_commit_stats, request_options=self._request_options, + max_commit_delay=self._max_commit_delay, ) finally: if self._database.log_commit_stats and self._batch.commit_stats: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index b25af53805..d0a44f6856 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -363,6 +363,8 @@ def run_in_transaction(self, func, *args, **kw): to continue retrying the transaction. "commit_request_options" will be removed and used to set the request options for the commit request. + "max_commit_delay" will be removed and used to set the max commit delay for the request. + "transaction_tag" will be removed and used to set the transaction tag for the request. :rtype: Any :returns: The return value of ``func``. @@ -372,6 +374,7 @@ def run_in_transaction(self, func, *args, **kw): """ deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS) commit_request_options = kw.pop("commit_request_options", None) + max_commit_delay = kw.pop("max_commit_delay", None) transaction_tag = kw.pop("transaction_tag", None) attempts = 0 @@ -400,6 +403,7 @@ def run_in_transaction(self, func, *args, **kw): txn.commit( return_commit_stats=self._database.log_commit_stats, request_options=commit_request_options, + max_commit_delay=max_commit_delay, ) except Aborted as exc: del self._transaction diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index d564d0d488..3c950401ac 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -180,7 +180,9 @@ def rollback(self): self.rolled_back = True del self._session._transaction - def commit(self, return_commit_stats=False, request_options=None): + def commit( + self, return_commit_stats=False, request_options=None, max_commit_delay=None + ): """Commit mutations to the database. :type return_commit_stats: bool @@ -194,6 +196,12 @@ def commit(self, return_commit_stats=False, request_options=None): If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type max_commit_delay: :class:`datetime.timedelta` + :param max_commit_delay: + (Optional) The amount of latency this request is willing to incur + in order to improve throughput. + :class:`~google.cloud.spanner_v1.types.MaxCommitDelay`. + :rtype: datetime :returns: timestamp of the committed changes. :raises ValueError: if there are no mutations to commit. @@ -228,6 +236,7 @@ def commit(self, return_commit_stats=False, request_options=None): mutations=self._mutations, transaction_id=self._transaction_id, return_commit_stats=return_commit_stats, + max_commit_delay=max_commit_delay, request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 052e628188..fbaee7476d 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import time import uuid @@ -819,3 +820,42 @@ def _transaction_read(transaction): with pytest.raises(exceptions.InvalidArgument): shared_database.run_in_transaction(_transaction_read) + + +def test_db_batch_insert_w_max_commit_delay(shared_database): + _helpers.retry_has_all_dll(shared_database.reload)() + sd = _sample_data + + with shared_database.batch( + max_commit_delay=datetime.timedelta(milliseconds=100) + ) as batch: + batch.delete(sd.TABLE, sd.ALL) + batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA) + + with shared_database.snapshot(read_timestamp=batch.committed) as snapshot: + from_snap = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL)) + + sd._check_rows_data(from_snap) + + +def test_db_run_in_transaction_w_max_commit_delay(shared_database): + _helpers.retry_has_all_dll(shared_database.reload)() + sd = _sample_data + + with shared_database.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + + def _unit_of_work(transaction, test): + rows = list(transaction.read(test.TABLE, test.COLUMNS, sd.ALL)) + assert rows == [] + + transaction.insert_or_update(test.TABLE, test.COLUMNS, test.ROW_DATA) + + shared_database.run_in_transaction( + _unit_of_work, test=sd, max_commit_delay=datetime.timedelta(milliseconds=100) + ) + + with shared_database.snapshot() as after: + rows = list(after.execute_sql(sd.SQL)) + + sd._check_rows_data(rows) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 203c8a0cb5..1c02e93f1d 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -233,7 +233,14 @@ def test_commit_ok(self): self.assertEqual(committed, now) self.assertEqual(batch.committed, committed) - (session, mutations, single_use_txn, request_options, metadata) = api._committed + ( + session, + mutations, + single_use_txn, + request_options, + max_commit_delay, + metadata, + ) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) @@ -246,12 +253,13 @@ def test_commit_ok(self): ], ) self.assertEqual(request_options, RequestOptions()) + self.assertEqual(max_commit_delay, None) self.assertSpanAttributes( "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) ) - def _test_commit_with_request_options(self, request_options=None): + def _test_commit_with_options(self, request_options=None, max_commit_delay_in=None): import datetime from google.cloud.spanner_v1 import CommitResponse from google.cloud.spanner_v1 import TransactionOptions @@ -267,7 +275,9 @@ def _test_commit_with_request_options(self, request_options=None): batch = self._make_one(session) batch.transaction_tag = self.TRANSACTION_TAG batch.insert(TABLE_NAME, COLUMNS, VALUES) - committed = batch.commit(request_options=request_options) + committed = batch.commit( + request_options=request_options, max_commit_delay=max_commit_delay_in + ) self.assertEqual(committed, now) self.assertEqual(batch.committed, committed) @@ -284,6 +294,7 @@ def _test_commit_with_request_options(self, request_options=None): mutations, single_use_txn, actual_request_options, + max_commit_delay, metadata, ) = api._committed self.assertEqual(session, self.SESSION_NAME) @@ -303,33 +314,46 @@ def _test_commit_with_request_options(self, request_options=None): "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) ) + self.assertEqual(max_commit_delay_in, max_commit_delay) + def test_commit_w_request_tag_success(self): request_options = RequestOptions( request_tag="tag-1", ) - self._test_commit_with_request_options(request_options=request_options) + self._test_commit_with_options(request_options=request_options) def test_commit_w_transaction_tag_success(self): request_options = RequestOptions( transaction_tag="tag-1-1", ) - self._test_commit_with_request_options(request_options=request_options) + self._test_commit_with_options(request_options=request_options) def test_commit_w_request_and_transaction_tag_success(self): request_options = RequestOptions( request_tag="tag-1", transaction_tag="tag-1-1", ) - self._test_commit_with_request_options(request_options=request_options) + self._test_commit_with_options(request_options=request_options) def test_commit_w_request_and_transaction_tag_dictionary_success(self): request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} - self._test_commit_with_request_options(request_options=request_options) + self._test_commit_with_options(request_options=request_options) def test_commit_w_incorrect_tag_dictionary_error(self): request_options = {"incorrect_tag": "tag-1-1"} with self.assertRaises(ValueError): - self._test_commit_with_request_options(request_options=request_options) + self._test_commit_with_options(request_options=request_options) + + def test_commit_w_max_commit_delay(self): + import datetime + + request_options = RequestOptions( + request_tag="tag-1", + ) + self._test_commit_with_options( + request_options=request_options, + max_commit_delay_in=datetime.timedelta(milliseconds=100), + ) def test_context_mgr_already_committed(self): import datetime @@ -368,7 +392,14 @@ def test_context_mgr_success(self): self.assertEqual(batch.committed, now) - (session, mutations, single_use_txn, request_options, metadata) = api._committed + ( + session, + mutations, + single_use_txn, + request_options, + _, + metadata, + ) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) @@ -565,12 +596,17 @@ def commit( ): from google.api_core.exceptions import Unknown + max_commit_delay = None + if type(request).pb(request).HasField("max_commit_delay"): + max_commit_delay = request.max_commit_delay + assert request.transaction_id == b"" self._committed = ( request.session, request.mutations, request.single_use_transaction, request.request_options, + max_commit_delay, metadata, ) if self._rpc_error: diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 2d2f208424..d391fe4c13 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -346,9 +346,14 @@ def test_commit_w_other_error(self): ) def _commit_helper( - self, mutate=True, return_commit_stats=False, request_options=None + self, + mutate=True, + return_commit_stats=False, + request_options=None, + max_commit_delay_in=None, ): import datetime + from google.cloud.spanner_v1 import CommitResponse from google.cloud.spanner_v1.keyset import KeySet from google.cloud._helpers import UTC @@ -370,13 +375,22 @@ def _commit_helper( transaction.delete(TABLE_NAME, keyset) transaction.commit( - return_commit_stats=return_commit_stats, request_options=request_options + return_commit_stats=return_commit_stats, + request_options=request_options, + max_commit_delay=max_commit_delay_in, ) self.assertEqual(transaction.committed, now) self.assertIsNone(session._transaction) - session_id, mutations, txn_id, actual_request_options, metadata = api._committed + ( + session_id, + mutations, + txn_id, + actual_request_options, + max_commit_delay, + metadata, + ) = api._committed if request_options is None: expected_request_options = RequestOptions( @@ -391,6 +405,7 @@ def _commit_helper( expected_request_options.transaction_tag = self.TRANSACTION_TAG expected_request_options.request_tag = None + self.assertEqual(max_commit_delay_in, max_commit_delay) self.assertEqual(session_id, session.name) self.assertEqual(txn_id, self.TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) @@ -423,6 +438,11 @@ def test_commit_w_mutations(self): def test_commit_w_return_commit_stats(self): self._commit_helper(return_commit_stats=True) + def test_commit_w_max_commit_delay(self): + import datetime + + self._commit_helper(max_commit_delay_in=datetime.timedelta(milliseconds=100)) + def test_commit_w_request_tag_success(self): request_options = RequestOptions( request_tag="tag-1", @@ -851,7 +871,7 @@ def test_context_mgr_success(self): self.assertEqual(transaction.committed, now) - session_id, mutations, txn_id, _, metadata = api._committed + session_id, mutations, txn_id, _, _, metadata = api._committed self.assertEqual(session_id, self.SESSION_NAME) self.assertEqual(txn_id, self.TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) @@ -938,11 +958,17 @@ def commit( metadata=None, ): assert not request.single_use_transaction + + max_commit_delay = None + if type(request).pb(request).HasField("max_commit_delay"): + max_commit_delay = request.max_commit_delay + self._committed = ( request.session, request.mutations, request.transaction_id, request.request_options, + max_commit_delay, metadata, ) return self._commit_response