Skip to content

Commit

Permalink
feat: add new_transaction support (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Apr 4, 2024
1 parent f4f3bc7 commit 43855dd
Show file tree
Hide file tree
Showing 12 changed files with 580 additions and 59 deletions.
10 changes: 4 additions & 6 deletions google/cloud/datastore/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,11 @@ def _next_page(self):
return None

query_pb = self._build_protobuf()
transaction = self.client.current_transaction
if transaction is None:
transaction_id = None
else:
transaction_id = transaction.id
transaction_id, new_transaction_options = helpers.get_transaction_options(
self.client.current_transaction
)
read_options = helpers.get_read_options(
self._eventual, transaction_id, self._read_time
self._eventual, transaction_id, self._read_time, new_transaction_options
)

partition_id = entity_pb2.PartitionId(
Expand Down
27 changes: 21 additions & 6 deletions google/cloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ def mutations(self):
"""
return self._mutations

def _allow_mutations(self) -> bool:
"""
This method is called to see if the batch is in a proper state to allow
`put` and `delete` operations.
the Transaction subclass overrides this method to support
the `begin_later` flag.
:rtype: bool
:returns: True if the batch is in a state to allow mutations.
"""
return self._status == self._IN_PROGRESS

def put(self, entity):
"""Remember an entity's state to be saved during :meth:`commit`.
Expand All @@ -218,7 +231,7 @@ def put(self, entity):
progress, if entity has no key assigned, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
if not self._allow_mutations():
raise ValueError("Batch must be in progress to put()")

if entity.key is None:
Expand Down Expand Up @@ -248,7 +261,7 @@ def delete(self, key):
progress, if key is not complete, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
if not self._allow_mutations():
raise ValueError("Batch must be in progress to delete()")

if key.is_partial:
Expand Down Expand Up @@ -370,10 +383,12 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None:
self.commit()
else:
self.rollback()
# commit or rollback if not in terminal state
if self._status not in (self._ABORTED, self._FINISHED):
if exc_type is None:
self.commit()
else:
self.rollback()
finally:
self._client._pop_batch()

Expand Down
26 changes: 18 additions & 8 deletions google/cloud/datastore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _extended_lookup(
missing=None,
deferred=None,
eventual=False,
transaction_id=None,
transaction=None,
retry=None,
timeout=None,
read_time=None,
Expand Down Expand Up @@ -158,10 +158,10 @@ def _extended_lookup(
consistency. If True, request ``EVENTUAL`` read
consistency.
:type transaction_id: str
:param transaction_id: If passed, make the request in the scope of
the given transaction. Incompatible with
``eventual==True`` or ``read_time``.
:type transaction: Transaction
:param transaction: If passed, make the request in the scope of
the given transaction. Incompatible with
``eventual==True`` or ``read_time``.
:type retry: :class:`google.api_core.retry.Retry`
:param retry:
Expand All @@ -177,7 +177,7 @@ def _extended_lookup(
:type read_time: datetime
:param read_time:
(Optional) Read time to use for read consistency. Incompatible with
``eventual==True`` or ``transaction_id``.
``eventual==True`` or ``transaction``.
This feature is in private preview.
:type database: str
Expand All @@ -199,8 +199,14 @@ def _extended_lookup(

results = []

transaction_id = None
transaction_id, new_transaction_options = helpers.get_transaction_options(
transaction
)
read_options = helpers.get_read_options(
eventual, transaction_id, read_time, new_transaction_options
)
loop_num = 0
read_options = helpers.get_read_options(eventual, transaction_id, read_time)
while loop_num < _MAX_LOOPS: # loop against possible deferred.
loop_num += 1
request = {
Expand All @@ -214,6 +220,10 @@ def _extended_lookup(
**kwargs,
)

# set new transaction id if we just started a transaction
if transaction and lookup_response.transaction:
transaction._begin_with_id(lookup_response.transaction)

# Accumulate the new results.
results.extend(result.entity for result in lookup_response.found)

Expand Down Expand Up @@ -570,7 +580,7 @@ def get_multi(
eventual=eventual,
missing=missing,
deferred=deferred,
transaction_id=transaction and transaction.id,
transaction=transaction,
retry=retry,
timeout=timeout,
read_time=read_time,
Expand Down
72 changes: 48 additions & 24 deletions google/cloud/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def entity_to_protobuf(entity):
return entity_pb


def get_read_options(eventual, transaction_id, read_time=None):
def get_read_options(
eventual, transaction_id, read_time=None, new_transaction_options=None
):
"""Validate rules for read options, and assign to the request.
Helper method for ``lookup()`` and ``run_query``.
Expand All @@ -245,33 +247,55 @@ def get_read_options(eventual, transaction_id, read_time=None):
:type read_time: datetime
:param read_time: Read data from the specified time (may be null). This feature is in private preview.
:type new_transaction_options: :class:`google.cloud.datastore_v1.types.TransactionOptions`
:param new_transaction_options: Options for a new transaction.
:rtype: :class:`.datastore_pb2.ReadOptions`
:returns: The read options corresponding to the inputs.
:raises: :class:`ValueError` if more than one of ``eventual==True``,
``transaction``, and ``read_time`` is specified.
``transaction_id``, ``read_time``, and ``new_transaction_options`` is specified.
"""
if transaction_id is None:
if eventual:
if read_time is not None:
raise ValueError("eventual must be False when read_time is specified")
else:
return datastore_pb2.ReadOptions(
read_consistency=datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL
)
else:
if read_time is None:
return datastore_pb2.ReadOptions()
else:
read_time_pb = timestamp_pb2.Timestamp()
read_time_pb.FromDatetime(read_time)
return datastore_pb2.ReadOptions(read_time=read_time_pb)
else:
if eventual:
raise ValueError("eventual must be False when in a transaction")
elif read_time is not None:
raise ValueError("transaction and read_time are mutual exclusive")
else:
return datastore_pb2.ReadOptions(transaction=transaction_id)
is_set = [
bool(x) for x in (eventual, transaction_id, read_time, new_transaction_options)
]
if sum(is_set) > 1:
raise ValueError(
"At most one of eventual, transaction, or read_time is allowed."
)
new_options = datastore_pb2.ReadOptions()
if transaction_id is not None:
new_options.transaction = transaction_id
if read_time is not None:
read_time_pb = timestamp_pb2.Timestamp()
read_time_pb.FromDatetime(read_time)
new_options.read_time = read_time_pb
if new_transaction_options is not None:
new_options.new_transaction = new_transaction_options
if eventual:
new_options.read_consistency = (
datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL
)
return new_options


def get_transaction_options(transaction):
"""
Get the transaction_id or new_transaction_options field from an active transaction object,
for use in get_read_options
These are mutually-exclusive fields, so one or both will be None.
:rtype: Tuple[Optional[bytes], Optional[google.cloud.datastore_v1.types.TransactionOptions]]
:returns: The transaction_id and new_transaction_options fields from the transaction object.
"""
transaction_id, new_transaction_options = None, None
if transaction is not None:
if transaction.id is not None:
transaction_id = transaction.id
elif transaction._begin_later and transaction._status == transaction._INITIAL:
# If the transaction has not yet been begun, we can use the new_transaction_options field.
new_transaction_options = transaction._options
return transaction_id, new_transaction_options


def key_from_protobuf(pb):
Expand Down
11 changes: 5 additions & 6 deletions google/cloud/datastore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,12 @@ def _next_page(self):
return None

query_pb = self._build_protobuf()
transaction = self.client.current_transaction
if transaction is None:
transaction_id = None
else:
transaction_id = transaction.id
new_transaction_options = None
transaction_id, new_transaction_options = helpers.get_transaction_options(
self.client.current_transaction
)
read_options = helpers.get_read_options(
self._eventual, transaction_id, self._read_time
self._eventual, transaction_id, self._read_time, new_transaction_options
)

partition_id = entity_pb2.PartitionId(
Expand Down
60 changes: 56 additions & 4 deletions google/cloud/datastore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Create / interact with Google Cloud Datastore transactions."""

from google.cloud.datastore.batch import Batch
from google.cloud.datastore_v1.types import TransactionOptions
from google.protobuf import timestamp_pb2
Expand Down Expand Up @@ -149,15 +148,23 @@ class Transaction(Batch):
:param read_time: (Optional) Time at which the transaction reads entities.
Only allowed when ``read_only=True``. This feature is in private preview.
:type begin_later: bool
:param begin_later: (Optional) If True, the transaction will be started
lazily (i.e. when the first RPC is made). If False,
the transaction will be started as soon as the context manager
is entered. `self.begin()` can also be called manually to begin
the transaction at any time. Default is False.
:raises: :class:`ValueError` if read_time is specified when
``read_only=False``.
"""

_status = None

def __init__(self, client, read_only=False, read_time=None):
def __init__(self, client, read_only=False, read_time=None, begin_later=False):
super(Transaction, self).__init__(client)
self._id = None
self._begin_later = begin_later

if read_only:
if read_time is not None:
Expand All @@ -180,8 +187,8 @@ def __init__(self, client, read_only=False, read_time=None):
def id(self):
"""Getter for the transaction ID.
:rtype: str
:returns: The ID of the current transaction.
:rtype: bytes or None
:returns: The ID of the current transaction, or None if not started.
"""
return self._id

Expand Down Expand Up @@ -240,6 +247,21 @@ def begin(self, retry=None, timeout=None):
self._status = self._ABORTED
raise

def _begin_with_id(self, transaction_id):
"""
Attach newly created transaction to an existing transaction ID.
This is used when begin_later is True, when the first lookup request
associated with this transaction creates a new transaction ID.
:type transaction_id: bytes
:param transaction_id: ID of the transaction to attach to.
"""
if self._status is not self._INITIAL:
raise ValueError("Transaction already begun.")
self._id = transaction_id
self._status = self._IN_PROGRESS

def rollback(self, retry=None, timeout=None):
"""Rolls back the current transaction.
Expand All @@ -258,6 +280,12 @@ def rollback(self, retry=None, timeout=None):
Note that if ``retry`` is specified, the timeout applies
to each individual attempt.
"""
# if transaction has not started, abort it
if self._status == self._INITIAL:
self._status = self._ABORTED
self._id = None
return None

kwargs = _make_retry_timeout_kwargs(retry, timeout)

try:
Expand Down Expand Up @@ -296,6 +324,15 @@ def commit(self, retry=None, timeout=None):
Note that if ``retry`` is specified, the timeout applies
to each individual attempt.
"""
# if transaction has not begun, either begin now, or abort if empty
if self._status == self._INITIAL:
if not self._mutations:
self._status = self._ABORTED
self._id = None
return None
else:
self.begin()

kwargs = _make_retry_timeout_kwargs(retry, timeout)

try:
Expand All @@ -321,3 +358,18 @@ def put(self, entity):
raise RuntimeError("Transaction is read only")
else:
super(Transaction, self).put(entity)

def __enter__(self):
if not self._begin_later:
self.begin()
self._client._push_batch(self)
return self

def _allow_mutations(self):
"""
Mutations can be added to a transaction if it is in IN_PROGRESS state,
or if it is in INITIAL state and the begin_later flag is set.
"""
return self._status == self._IN_PROGRESS or (
self._begin_later and self._status == self._INITIAL
)
Loading

0 comments on commit 43855dd

Please sign in to comment.