Skip to content

Commit

Permalink
fix: pass transaction's options to API in 'begin' (#143)
Browse files Browse the repository at this point in the history
Closes #135.
  • Loading branch information
tseaver committed May 3, 2021
1 parent 4f90d04 commit 924b10b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 47 deletions.
8 changes: 7 additions & 1 deletion google/cloud/datastore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def Entity(*args, **kwargs):
def __init__(self, client, read_only=False):
super(Transaction, self).__init__(client)
self._id = None

if read_only:
options = TransactionOptions(read_only=TransactionOptions.ReadOnly())
else:
options = TransactionOptions()

self._options = options

@property
Expand Down Expand Up @@ -231,9 +233,13 @@ def begin(self, retry=None, timeout=None):

kwargs = _make_retry_timeout_kwargs(retry, timeout)

request = {
"project_id": self.project,
"transaction_options": self._options,
}
try:
response_pb = self._client._datastore_api.begin_transaction(
request={"project_id": self.project}, **kwargs
request=request, **kwargs
)
self._id = response_pb.transaction
except: # noqa: E722 do not use bare except, specify exception instead
Expand Down
135 changes: 89 additions & 46 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,50 @@ def _get_target_class():

return Transaction

def _get_options_class(self, **kw):
def _make_one(self, client, **kw):
return self._get_target_class()(client, **kw)

def _make_options(self, read_only=False, previous_transaction=None):
from google.cloud.datastore_v1.types import TransactionOptions

return TransactionOptions
kw = {}

def _make_one(self, client, **kw):
return self._get_target_class()(client, **kw)
if read_only:
kw["read_only"] = TransactionOptions.ReadOnly()

def _make_options(self, **kw):
return self._get_options_class()(**kw)
return TransactionOptions(**kw)

def test_ctor_defaults(self):
project = "PROJECT"
client = _Client(project)

xact = self._make_one(client)

self.assertEqual(xact.project, project)
self.assertIs(xact._client, client)
self.assertIsNone(xact.id)
self.assertEqual(xact._status, self._get_target_class()._INITIAL)
self.assertEqual(xact._mutations, [])
self.assertEqual(len(xact._partial_key_entities), 0)

def test_constructor_read_only(self):
project = "PROJECT"
id_ = 850302
ds_api = _make_datastore_api(xact=id_)
client = _Client(project, datastore_api=ds_api)
options = self._make_options(read_only=True)

xact = self._make_one(client, read_only=True)

self.assertEqual(xact._options, options)

def _make_begin_request(self, project, read_only=False):
expected_options = self._make_options(read_only=read_only)
return {
"project_id": project,
"transaction_options": expected_options,
}

def test_current(self):
from google.cloud.datastore_v1.types import datastore as datastore_pb2

Expand All @@ -57,24 +79,34 @@ def test_current(self):
xact2 = self._make_one(client)
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

with xact1:
self.assertIs(xact1.current(), xact1)
self.assertIs(xact2.current(), xact1)

with _NoCommitBatch(client):
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

with xact2:
self.assertIs(xact1.current(), xact2)
self.assertIs(xact2.current(), xact2)

with _NoCommitBatch(client):
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

self.assertIs(xact1.current(), xact1)
self.assertIs(xact2.current(), xact1)

self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

ds_api.rollback.assert_not_called()
begin_txn = ds_api.begin_transaction
self.assertEqual(begin_txn.call_count, 2)
expected_request = self._make_begin_request(project)
begin_txn.assert_called_with(request=expected_request)

commit_method = ds_api.commit
self.assertEqual(commit_method.call_count, 2)
mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL
Expand All @@ -87,21 +119,35 @@ def test_current(self):
}
)

begin_txn = ds_api.begin_transaction
self.assertEqual(begin_txn.call_count, 2)
begin_txn.assert_called_with(request={"project_id": project})
ds_api.rollback.assert_not_called()

def test_begin(self):
project = "PROJECT"
id_ = 889
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

xact.begin()

self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_begin_w_readonly(self):
project = "PROJECT"
id_ = 889
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client, read_only=True)

xact.begin()

self.assertEqual(xact.id, id_)

expected_request = self._make_begin_request(project, read_only=True)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_begin_w_retry_w_timeout(self):
project = "PROJECT"
Expand All @@ -116,8 +162,10 @@ def test_begin_w_retry_w_timeout(self):
xact.begin(retry=retry, timeout=timeout)

self.assertEqual(xact.id, id_)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}, retry=retry, timeout=timeout
request=expected_request, retry=retry, timeout=timeout,
)

def test_begin_tombstoned(self):
Expand All @@ -126,19 +174,23 @@ def test_begin_tombstoned(self):
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

xact.begin()

self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

xact.rollback()

client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)
self.assertIsNone(xact.id)

self.assertRaises(ValueError, xact.begin)
with self.assertRaises(ValueError):
xact.begin()

def test_begin_w_begin_transaction_failure(self):
project = "PROJECT"
Expand All @@ -152,9 +204,9 @@ def test_begin_w_begin_transaction_failure(self):
xact.begin()

self.assertIsNone(xact.id)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_rollback(self):
project = "PROJECT"
Expand Down Expand Up @@ -256,11 +308,14 @@ def test_context_manager_no_raise(self):
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

with xact:
self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)
self.assertEqual(xact.id, id_) # only set between begin / commit

self.assertIsNone(xact.id)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL
client._datastore_api.commit.assert_called_once_with(
Expand All @@ -272,9 +327,6 @@ def test_context_manager_no_raise(self):
},
)

self.assertIsNone(xact.id)
self.assertEqual(ds_api.begin_transaction.call_count, 1)

def test_context_manager_w_raise(self):
class Foo(Exception):
pass
Expand All @@ -288,29 +340,20 @@ class Foo(Exception):
try:
with xact:
self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)
raise Foo()
except Foo:
self.assertIsNone(xact.id)
client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)
pass

client._datastore_api.commit.assert_not_called()
self.assertIsNone(xact.id)
self.assertEqual(ds_api.begin_transaction.call_count, 1)

def test_constructor_read_only(self):
project = "PROJECT"
id_ = 850302
ds_api = _make_datastore_api(xact=id_)
client = _Client(project, datastore_api=ds_api)
read_only = self._get_options_class().ReadOnly()
options = self._make_options(read_only=read_only)
xact = self._make_one(client, read_only=True)
self.assertEqual(xact._options, options)
expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

client._datastore_api.commit.assert_not_called()

client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)

def test_put_read_only(self):
project = "PROJECT"
Expand Down

0 comments on commit 924b10b

Please sign in to comment.