Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ def __init__(
self._global_connection = None # type: typing.Optional[Connection]
self._global_transaction = None # type: typing.Optional[Transaction]

if self._force_rollback:
self._global_connection = Connection(self._backend)
self._global_transaction = self._global_connection.transaction(
force_rollback=True
)

async def connect(self) -> None:
"""
Establish the connection pool.
Expand All @@ -91,7 +85,14 @@ async def connect(self) -> None:
self.is_connected = True

if self._force_rollback:
assert self._global_transaction is not None
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self._backend)
self._global_transaction = self._global_connection.transaction(
force_rollback=True
)

await self._global_transaction.__aenter__()

async def disconnect(self) -> None:
Expand All @@ -101,9 +102,14 @@ async def disconnect(self) -> None:
assert self.is_connected, "Already disconnected."

if self._force_rollback:
assert self._global_connection is not None
assert self._global_transaction is not None

await self._global_transaction.__aexit__()

self._global_transaction = None
self._global_connection = None

await self._backend.disconnect()
logger.info(
"Disconnected from database %s",
Expand Down
30 changes: 29 additions & 1 deletion tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def async_adapter(wrapped_func):

@functools.wraps(wrapped_func)
def run_sync(*args, **kwargs):
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
task = wrapped_func(*args, **kwargs)
return loop.run_until_complete(task)

Expand Down Expand Up @@ -752,6 +752,34 @@ async def db_lookup():
await asyncio.gather(db_lookup(), db_lookup())


@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_global_connection_is_initialized_lazily(database_url):
"""
Ensure that global connection is initialized at latest possible time
so it's _query_lock will belong to same event loop that async_adapter has
initialized.

See https://github.com/encode/databases/issues/157 for more context.
"""

database_url = DatabaseURL(database_url)
if database_url.dialect != "postgresql":
pytest.skip("Test requires `pg_sleep()`")

database = Database(database_url, force_rollback=True)

@async_adapter
async def run_database_queries():
async with database:

async def db_lookup():
await database.fetch_one("SELECT pg_sleep(1)")

await asyncio.gather(db_lookup(), db_lookup())

run_database_queries()


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_iterate_outside_transaction_with_values(database_url):
Expand Down