Skip to content

Commit

Permalink
feat: adding support for spanner request options tags (#276)
Browse files Browse the repository at this point in the history
* feat: added support for request options with request tag and transaction tag in supported classes

* feat: corrected import for RequestOptions

* feat: request options added lint corrections

* feat: added system test for request tagging

* feat: added annotation to skip request tags validation test while using emulator

* feat: lint fix

* fix: remove request_option from batch

* lint: lint fixes

* refactor: undo changes

* refactor: undo changes

* refactor: remove test_system file, as it has been removed in master

* refactor: update code to latest changes

* feat: added support for request options with request tag and transaction tag in supported classes

* feat: corrected import for RequestOptions

* fix: add transaction_tag test for transaction_tag set in transaction class

* fix: lint fixes

* refactor: lint fixes

* fix: change request_options dictionary to RequestOptions object

* refactor: fix lint issues

* refactor: lint fixes

* refactor: move write txn properties to BatchBase

* fix: use transaction tag on all write methods

* feat: add support for batch commit

* feat: add support for setting a transaction tag on batch checkout

* refactor: update checks for readability

* test: use separate expectation object for readability

* test: add run_in_transaction test

* test: remove test for unsupported behaviour

* style: lint fixes

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
vi3k6i5 and larkee authored Sep 29, 2021
1 parent f59d08b commit e16f376
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 25 deletions.
14 changes: 11 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class _BatchBase(_SessionWrapper):
:param session: the session used to perform the commit
"""

transaction_tag = None
_read_only = False

def __init__(self, session):
super(_BatchBase, self).__init__(session)
self._mutations = []
Expand Down Expand Up @@ -118,8 +121,7 @@ def delete(self, table, keyset):


class Batch(_BatchBase):
"""Accumulate mutations for transmission during :meth:`commit`.
"""
"""Accumulate mutations for transmission during :meth:`commit`."""

committed = None
commit_stats = None
Expand Down Expand Up @@ -160,8 +162,14 @@ def commit(self, return_commit_stats=False, request_options=None):
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
Expand Down
16 changes: 14 additions & 2 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,20 @@ def execute_partitioned_dml(
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
Please note, the `transactionTag` setting will be ignored as it is
not supported for partitioned DML.
:rtype: int
:returns: Count of rows affected by the DML statement.
"""
query_options = _merge_query_options(
self._instance._client._query_options, query_options
)
if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = None

if params is not None:
from google.cloud.spanner_v1.transaction import Transaction
Expand Down Expand Up @@ -796,12 +801,19 @@ class BatchCheckout(object):
def __init__(self, database, request_options=None):
self._database = database
self._session = self._batch = None
self._request_options = request_options
if request_options is None:
self._request_options = RequestOptions()
elif type(request_options) == dict:
self._request_options = RequestOptions(request_options)
else:
self._request_options = request_options

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
batch = self._batch = Batch(session)
if self._request_options.transaction_tag:
batch.transaction_tag = self._request_options.transaction_tag
return batch

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,13 @@ def run_in_transaction(self, func, *args, **kw):
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
commit_request_options = kw.pop("commit_request_options", None)
transaction_tag = kw.pop("transaction_tag", None)
attempts = 0

while True:
if self._transaction is None:
txn = self.transaction()
txn.transaction_tag = transaction_tag
else:
txn = self._transaction
if txn._transaction_id is None:
Expand Down
22 changes: 20 additions & 2 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class _SnapshotBase(_SessionWrapper):
"""

_multi_use = False
_read_only = True
_transaction_id = None
_read_request_count = 0
_execute_sql_count = 0
Expand Down Expand Up @@ -160,6 +161,8 @@ def read(
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
Please note, the `transactionTag` setting will be ignored for
snapshot as it's not supported for read-only transactions.
:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.
Expand All @@ -185,9 +188,17 @@ def read(
metadata = _metadata_with_prefix(database.name)
transaction = self._make_txn_selector()

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)

if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
request_options.transaction_tag = self.transaction_tag

request = ReadRequest(
session=self._session.name,
table=table,
Expand Down Expand Up @@ -312,8 +323,15 @@ def execute_sql(
default_query_options = database._instance._client._query_options
query_options = _merge_query_options(default_query_options, query_options)

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
request_options.transaction_tag = self.transaction_tag

request = ExecuteSqlRequest(
session=self._session.name,
Expand Down
19 changes: 16 additions & 3 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,15 @@ def commit(self, return_commit_stats=False, request_options=None):
metadata = _metadata_with_prefix(database.name)
trace_attributes = {"num_mutations": len(self._mutations)}

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
Expand Down Expand Up @@ -267,8 +274,11 @@ def execute_update(
default_query_options = database._instance._client._query_options
query_options = _merge_query_options(default_query_options, query_options)

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

trace_attributes = {"db.statement": dml}

Expand Down Expand Up @@ -343,8 +353,11 @@ def batch_update(self, statements, request_options=None):
self._execute_sql_count + 1,
)

if type(request_options) == dict:
if request_options is None:
request_options = RequestOptions()
elif type(request_options) == dict:
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

trace_attributes = {
# Get just the queries from the DML statement batch
Expand Down
83 changes: 77 additions & 6 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import unittest
from tests._helpers import OpenTelemetryBase, StatusCode
from google.cloud.spanner_v1 import RequestOptions

TABLE_NAME = "citizens"
COLUMNS = ["email", "first_name", "last_name", "age"]
Expand All @@ -39,6 +40,7 @@ class _BaseTest(unittest.TestCase):
DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID
SESSION_ID = "session-id"
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
TRANSACTION_TAG = "transaction-tag"

def _make_one(self, *args, **kwargs):
return self._getTargetClass()(*args, **kwargs)
Expand Down Expand Up @@ -232,18 +234,87 @@ def test_commit_ok(self):
self.assertEqual(committed, now)
self.assertEqual(batch.committed, committed)

(session, mutations, single_use_txn, metadata, request_options) = api._committed
(session, mutations, single_use_txn, request_options, metadata) = api._committed
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
self.assertEqual(request_options, None)
self.assertEqual(request_options, RequestOptions())

self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

def _test_commit_with_request_options(self, request_options=None):
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
response = CommitResponse(commit_timestamp=now_pb)
database = _Database()
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response)
session = _Session(database)
batch = self._make_one(session)
batch.transaction_tag = self.TRANSACTION_TAG
batch.insert(TABLE_NAME, COLUMNS, VALUES)
committed = batch.commit(request_options=request_options)

self.assertEqual(committed, now)
self.assertEqual(batch.committed, committed)

if type(request_options) == dict:
expected_request_options = RequestOptions(request_options)
else:
expected_request_options = request_options
expected_request_options.transaction_tag = self.TRANSACTION_TAG
expected_request_options.request_tag = None

(
session,
mutations,
single_use_txn,
actual_request_options,
metadata,
) = api._committed
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
self.assertEqual(actual_request_options, expected_request_options)

self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

def test_commit_w_request_tag_success(self):
request_options = RequestOptions(request_tag="tag-1",)
self._test_commit_with_request_options(request_options=request_options)

def test_commit_w_transaction_tag_success(self):
request_options = RequestOptions(transaction_tag="tag-1-1",)
self._test_commit_with_request_options(request_options=request_options)

def test_commit_w_request_and_transaction_tag_success(self):
request_options = RequestOptions(
request_tag="tag-1", transaction_tag="tag-1-1",
)
self._test_commit_with_request_options(request_options=request_options)

def test_commit_w_request_and_transaction_tag_dictionary_success(self):
request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"}
self._test_commit_with_request_options(request_options=request_options)

def test_commit_w_incorrect_tag_dictionary_error(self):
request_options = {"incorrect_tag": "tag-1-1"}
with self.assertRaises(ValueError):
self._test_commit_with_request_options(request_options=request_options)

def test_context_mgr_already_committed(self):
import datetime
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -281,13 +352,13 @@ def test_context_mgr_success(self):

self.assertEqual(batch.committed, now)

(session, mutations, single_use_txn, metadata, request_options) = api._committed
(session, mutations, single_use_txn, request_options, metadata) = api._committed
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
self.assertEqual(request_options, None)
self.assertEqual(request_options, RequestOptions())

self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
Expand Down Expand Up @@ -341,7 +412,7 @@ def __init__(self, **kwargs):
self.__dict__.update(**kwargs)

def commit(
self, request=None, metadata=None, request_options=None,
self, request=None, metadata=None,
):
from google.api_core.exceptions import Unknown

Expand All @@ -350,8 +421,8 @@ def commit(
request.session,
request.mutations,
request.single_use_transaction,
request.request_options,
metadata,
request_options,
)
if self._rpc_error:
raise Unknown("error")
Expand Down
Loading

0 comments on commit e16f376

Please sign in to comment.