From d998f30ebc10879f84b6ff78fa79e5cf45ed7308 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 6 Dec 2022 11:47:25 +0100 Subject: [PATCH 1/2] 8th and final round of migrating integration tests to TestKit --- tests/integration/test_readme.py | 8 +- tests/integration/test_tx_functions.py | 159 --------------------- tests/unit/async_/work/test_session.py | 13 +- tests/unit/async_/work/test_transaction.py | 29 ++++ tests/unit/sync/work/test_session.py | 13 +- tests/unit/sync/work/test_transaction.py | 29 ++++ 6 files changed, 86 insertions(+), 165 deletions(-) delete mode 100644 tests/integration/test_tx_functions.py diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py index f47f1ff84..7bebde9ac 100644 --- a/tests/integration/test_readme.py +++ b/tests/integration/test_readme.py @@ -22,6 +22,10 @@ # python -m pytest tests/integration/test_readme.py -s -v +def _work(tx, query, **params): + tx.run(query, **params).consume() + + def test_should_run_readme(uri, auth): names = set() print = names.add @@ -49,14 +53,14 @@ def print_friends(tx, name): with driver.session(database="neo4j") as session: # === END: README === - session.run("MATCH (a) DETACH DELETE a") + session.execute_write(_work, "MATCH (a) DETACH DELETE a") # === START: README === session.execute_write(add_friend, "Arthur", "Guinevere") session.execute_write(add_friend, "Arthur", "Lancelot") session.execute_write(add_friend, "Arthur", "Merlin") session.execute_read(print_friends, "Arthur") # === END: README === - session.run("MATCH (a) DETACH DELETE a") + session.execute_write(_work, "MATCH (a) DETACH DELETE a") # === START: README === driver.close() diff --git a/tests/integration/test_tx_functions.py b/tests/integration/test_tx_functions.py deleted file mode 100644 index 2528feb33..000000000 --- a/tests/integration/test_tx_functions.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from uuid import uuid4 - -import pytest - -from neo4j import unit_of_work -from neo4j.exceptions import ( - ClientError, - Neo4jError, -) - - -# python -m pytest tests/integration/test_tx_functions.py -s -v - - -@pytest.fixture(params=["read_transaction", "execute_read"]) -def read_transaction(request): - def executor(session, *args, **kwargs): - if request.param == "read_transaction": - with pytest.warns( - DeprecationWarning, - match="^read_transaction has been renamed to execute_read$" - ): - return session.read_transaction(*args, **kwargs) - elif request.param == "execute_read": - return session.execute_read(*args, **kwargs) - raise ValueError(request.param) - - return executor - - -@pytest.fixture(params=["write_transaction", "execute_write"]) -def write_transaction(request): - def executor(session, *args, **kwargs): - if request.param == "write_transaction": - with pytest.warns( - DeprecationWarning, - match="^write_transaction has been renamed to execute_write$" - ): - return session.write_transaction(*args, **kwargs) - elif request.param == "execute_write": - return session.execute_write(*args, **kwargs) - raise ValueError(request.param) - - return executor - - -def test_simple_read(session, read_transaction): - - def work(tx): - return tx.run("RETURN 1").single().value() - - value = read_transaction(session, work) - assert value == 1 - - -def test_read_with_arg(session, read_transaction): - - def work(tx, x): - return tx.run("RETURN $x", x=x).single().value() - - value = read_transaction(session, work, x=1) - assert value == 1 - - -def test_read_with_arg_and_metadata(session, read_transaction): - - @unit_of_work(timeout=25, metadata={"foo": "bar"}) - def work(tx): - return tx.run("CALL dbms.getTXMetaData").single().value() - - try: - value = read_transaction(session, work) - except ClientError: - pytest.skip("Transaction metadata and timeout only supported in Neo4j EE 3.5+") - else: - assert value == {"foo": "bar"} - - -def test_simple_write(session, write_transaction): - - def work(tx): - return tx.run("CREATE (a {x: 1}) RETURN a.x").single().value() - - value = write_transaction(session, work) - assert value == 1 - - -def test_write_with_arg(session, write_transaction): - - def work(tx, x): - return tx.run("CREATE (a {x: $x}) RETURN a.x", x=x).single().value() - - value = write_transaction(session, work, x=1) - assert value == 1 - - -def test_write_with_arg_and_metadata(session, write_transaction): - - @unit_of_work(timeout=25, metadata={"foo": "bar"}) - def work(tx, x): - return tx.run("CREATE (a {x: $x}) RETURN a.x", x=x).single().value() - - try: - value = write_transaction(session, work, x=1) - except ClientError: - pytest.skip("Transaction metadata and timeout only supported in Neo4j EE 3.5+") - else: - assert value == 1 - - -def test_error_on_write_transaction(session, write_transaction): - - def f(tx, uuid): - tx.run("CREATE (a:Thing {uuid:$uuid})", uuid=uuid), uuid4() - - with pytest.raises(TypeError): - write_transaction(session, f) - - -def test_retry_logic(driver, read_transaction): - # python -m pytest tests/integration/test_tx_functions.py -s -v -k test_retry_logic - - pytest.global_counter = 0 - - def get_one(tx): - result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") - records = list(result) - pytest.global_counter += 1 - - if pytest.global_counter < 3: - database_unavailable = Neo4jError.hydrate(message="The database is not currently available to serve your request, refer to the database logs for more details. Retrying your request at a later time may succeed.", code="Neo.TransientError.Database.DatabaseUnavailable") - raise database_unavailable - - return records - - with driver.session() as session: - records = read_transaction(session, get_one) - - assert pytest.global_counter == 3 - - del pytest.global_counter diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 235ade560..5a0d0ad2b 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -40,8 +40,11 @@ @contextmanager def assert_warns_tx_func_deprecation(tx_func_name): if tx_func_name.endswith("_transaction"): - with pytest.warns(DeprecationWarning, - match=f"{tx_func_name}.*execute_"): + mode = tx_func_name.split("_")[0] + with pytest.warns( + DeprecationWarning, + match=f"^{mode}_transaction has been renamed to execute_{mode}$" + ): yield else: yield @@ -289,6 +292,12 @@ async def work(tx): with assert_warns_tx_func_deprecation(tx_type): await getattr(session, tx_type)(work) assert called + assert len(fake_pool.acquired_connection_mocks) == 1 + cx = fake_pool.acquired_connection_mocks[0] + cx.begin.assert_called_once() + for key in ("timeout", "metadata"): + value = decorator_kwargs.get(key) + assert cx.begin.call_args[1][key] == value @mark_async_test diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 24b79a81c..14ff98abc 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -130,6 +130,35 @@ async def test_transaction_run_takes_no_query_object(async_fake_connection): await tx.run(Query("RETURN 1")) +@mark_async_test +@pytest.mark.parametrize("params", ( + {"x": 1}, + {"x": "1"}, + {"x": "1", "y": 2}, + {"parameters": {"nested": "parameters"}}, +)) +@pytest.mark.parametrize("as_kwargs", (True, False)) +async def test_transaction_run_parameters( + async_fake_connection, params, as_kwargs +): + on_closed = MagicMock() + on_error = MagicMock() + on_cancel = MagicMock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error, + on_cancel) + if not as_kwargs: + params = {"parameters": params} + await tx.run("RETURN $x", **params) + calls = [call for call in async_fake_connection.method_calls + if call[0] in ("run", "send_all", "fetch_message")] + assert [call[0] for call in calls] == ["run", "send_all", "fetch_message"] + run = calls[0] + assert run[1][0] == "RETURN $x" + if "parameters" in params: + params = params["parameters"] + assert run[2]["parameters"] == params + + @mark_async_test async def test_transaction_rollbacks_on_open_connections( async_fake_connection diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 2116ec001..9ee9def04 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -40,8 +40,11 @@ @contextmanager def assert_warns_tx_func_deprecation(tx_func_name): if tx_func_name.endswith("_transaction"): - with pytest.warns(DeprecationWarning, - match=f"{tx_func_name}.*execute_"): + mode = tx_func_name.split("_")[0] + with pytest.warns( + DeprecationWarning, + match=f"^{mode}_transaction has been renamed to execute_{mode}$" + ): yield else: yield @@ -289,6 +292,12 @@ def work(tx): with assert_warns_tx_func_deprecation(tx_type): getattr(session, tx_type)(work) assert called + assert len(fake_pool.acquired_connection_mocks) == 1 + cx = fake_pool.acquired_connection_mocks[0] + cx.begin.assert_called_once() + for key in ("timeout", "metadata"): + value = decorator_kwargs.get(key) + assert cx.begin.call_args[1][key] == value @mark_sync_test diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 34cd384d7..a0bdbec28 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -130,6 +130,35 @@ def test_transaction_run_takes_no_query_object(fake_connection): tx.run(Query("RETURN 1")) +@mark_sync_test +@pytest.mark.parametrize("params", ( + {"x": 1}, + {"x": "1"}, + {"x": "1", "y": 2}, + {"parameters": {"nested": "parameters"}}, +)) +@pytest.mark.parametrize("as_kwargs", (True, False)) +def test_transaction_run_parameters( + fake_connection, params, as_kwargs +): + on_closed = MagicMock() + on_error = MagicMock() + on_cancel = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error, + on_cancel) + if not as_kwargs: + params = {"parameters": params} + tx.run("RETURN $x", **params) + calls = [call for call in fake_connection.method_calls + if call[0] in ("run", "send_all", "fetch_message")] + assert [call[0] for call in calls] == ["run", "send_all", "fetch_message"] + run = calls[0] + assert run[1][0] == "RETURN $x" + if "parameters" in params: + params = params["parameters"] + assert run[2]["parameters"] == params + + @mark_sync_test def test_transaction_rollbacks_on_open_connections( fake_connection From efe74b16db2cc8d8127667bbd325b6461cbadb79 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 7 Dec 2022 16:21:22 +0100 Subject: [PATCH 2/2] TestKit backend: except txMeta as Cypher types --- testkitbackend/fromtestkit.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index a320ea780..e28a63079 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -40,7 +40,9 @@ def to_cypher_and_params(data): if params is None: return data["cypher"], None # Transform the params to Python native - params_dict = {p: to_param(params[p]) for p in params} + params_dict = {k: to_param(v) for k, v in params.items()} + if isinstance(params, Request): + params.mark_all_as_read() return data["cypher"], params_dict @@ -48,9 +50,12 @@ def to_tx_kwargs(data): from .backend import Request kwargs = {} if "txMeta" in data: - kwargs["metadata"] = data["txMeta"] - if isinstance(kwargs["metadata"], Request): - kwargs["metadata"].mark_all_as_read() + metadata = data["txMeta"] + kwargs["metadata"] = metadata + if metadata is not None: + kwargs["metadata"] = {k: to_param(v) for k, v in metadata.items()} + if isinstance(metadata, Request): + metadata.mark_all_as_read() if "timeout" in data: kwargs["timeout"] = data["timeout"] if kwargs["timeout"] is not None: