diff --git a/tests/test_databases.py b/tests/test_databases.py index e744c8bb..36d86564 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -360,3 +360,44 @@ async def test_connect_and_disconnect(database_url): assert database.is_connected await database.disconnect() assert not database.is_connected + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context(database_url): + """ + Test connection contexts are task-local. + """ + async with Database(database_url) as database: + async with database.connection() as connection_1: + async with database.connection() as connection_2: + assert connection_1 is connection_2 + + async with Database(database_url) as database: + connection_1 = None + connection_2 = None + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + await test_complete.wait() + + loop = asyncio.get_event_loop() + task_1 = loop.create_task(get_connection_1()) + task_2 = loop.create_task(get_connection_2()) + while connection_1 is None or connection_2 is None: + await asyncio.sleep(0.000001) + assert connection_1 is not connection_2 + test_complete.set() + await task_1 + await task_2