From 90e535027ed1f7bb99e06da1a42e2e5f706a0d9a Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 12:47:14 -0700 Subject: [PATCH] test: add concurrent task tests --- tests/test_databases.py | 55 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..b286a27a 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -961,16 +961,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS)