diff --git a/databases/core.py b/databases/core.py index 8394ab5c..795609ea 100644 --- a/databases/core.py +++ b/databases/core.py @@ -3,6 +3,7 @@ import functools import logging import typing +import weakref from contextvars import ContextVar from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit @@ -11,7 +12,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -35,6 +36,11 @@ logger = logging.getLogger("databases") +_ACTIVE_TRANSACTIONS: ContextVar[ + typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] +] = ContextVar("databases:active_transactions", default=None) + + class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", @@ -45,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -55,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -63,14 +72,35 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context: ContextVar = ContextVar("connection_context") - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None self._global_transaction: typing.Optional[Transaction] = None + @property + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task + + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) + + @_connection.setter + def _connection( + self, connection: typing.Optional["Connection"] + ) -> typing.Optional["Connection"]: + task = self._current_task + + if connection is None: + self._connection_map.pop(task, None) + else: + self._connection_map[task] = connection + + return self._connection + async def connect(self) -> None: """ Establish the connection pool. @@ -89,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -113,7 +143,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connection = None await self._backend.disconnect() logger.info( @@ -187,12 +217,10 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + if not self._connection: + self._connection = Connection(self, self._backend) + + return self._connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -215,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -249,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -345,6 +375,37 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs + @property + def _connection(self) -> "Connection": + # Returns the same connection if called multiple times + return self._connection_callable() + + @property + def _transaction(self) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + return None + + return transactions.get(self, None) + + @_transaction.setter + def _transaction( + self, transaction: typing.Optional["TransactionBackend"] + ) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + transactions = weakref.WeakKeyDictionary() + else: + transactions = transactions.copy() + + if transaction is None: + transactions.pop(self, None) + else: + transactions[self] = transaction + + _ACTIVE_TRANSACTIONS.set(transactions) + return transactions.get(self, None) + async def __aenter__(self) -> "Transaction": """ Called when entering `async with database.transaction()` @@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() self._transaction = self._connection._connection.transaction() async with self._connection._transaction_lock: @@ -401,15 +461,19 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.commit() await self._connection.__aexit__() + self._transaction = None async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.rollback() await self._connection.__aexit__() + self._transaction = None class _EmptyNetloc(str): diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index aa45537d..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool: async with database.transaction(): ... ``` + It can also be acquired from a specific database connection: ```python @@ -95,8 +98,51 @@ async def create_users(request): ... ``` -Transaction blocks are managed as task-local state. Nested transactions -are fully supported, and are implemented using database savepoints. +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: + +```python +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( + "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", + {"id": id} + ) + + +async with Database(database_url) as database: + async with database.transaction(): + # This note won't exist until the transaction closes... + await database.execute( + "INSERT INTO notes(id, text) values (1, 'databases is cool')" + ) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) + + await database.fetch_val("SELECT text FROM notes WHERE id=1") + # ^ returns: "databases is cool!!!" +``` + +Nested transactions are fully supported, and are implemented using database savepoints: + +```python +async with databases.Database(database_url) as db: + async with db.transaction() as outer: + # Do something in the outer transaction + ... + + # Suppress to prevent influence on the outer transaction + with contextlib.suppress(ValueError): + async with db.transaction(): + # Do something in the inner transaction + ... + + raise ValueError('Abort the inner transaction') + + # Observe the results of the outer transaction, + # without effects from the inner transaction. + await db.fetch_all('SELECT * FROM ...') +``` Transaction isolation-level can be specified if the driver backend supports that: diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..4d737261 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,8 +2,11 @@ import datetime import decimal import functools +import gc +import itertools import os import re +from typing import MutableMapping from unittest.mock import MagicMock, patch import pytest @@ -477,6 +480,254 @@ async def test_transaction_commit(database_url): assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): + """ + Ensure that child tasks may influence inherited transactions. + """ + # This is an practical example of the above test. + async with Database(database_url) as database: + async with database.transaction(): + # Create a note + await database.execute( + notes.insert().values(id=1, text="setup", completed=True) + ) + + # Change the note from the same task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="prior") + ) + + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" + + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( + notes.update().where(notes.c.id == 1).values(text="test") + ) + + await asyncio.create_task(run_update_from_child_task(database.connection())) + + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation(database_url): + """ + Ensure that transactions are isolated between sibling tasks. + """ + start = asyncio.Event() + end = asyncio.Event() + + async with Database(database_url) as database: + + async def check_transaction(transaction): + await start.wait() + # Parent task is now in a transaction, we should not + # see its transaction backend since this task was + # _started_ in a context where no transaction was active. + assert transaction._transaction is None + end.set() + + transaction = database.transaction() + assert transaction._transaction is None + task = asyncio.create_task(check_transaction(transaction)) + + async with transaction: + start.set() + assert transaction._transaction is not None + await end.wait() + + # Cleanup for "Task not awaited" warning + await task + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): + """ + Ensure that task connections are not persisted unecessarily. + """ + + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() + + async with Database(database_url) as database: + # Should have a connection in this task + # .connect is lazy, it doesn't create a Connection, but .connection does + connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection + + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection + + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_garbagecollector(database_url): + """ + Ensure that connections for tasks are not persisted unecessarily, even + if exit handlers are not called. + """ + database = Database(database_url) + await database.connect() + + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() + + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task + gc.collect() + + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + async with database.transaction() as transaction: + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # Context manager closes, open_transactions is cleaned up + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily, even + if exit handlers are not called. + + This test should be an XFAIL, but cannot be due to the way that is hangs + during teardown. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + transaction = database.transaction() + await transaction.start() + + # Should be tracking the transaction + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # neither .commit, .rollback, nor .__aexit__ are called + del transaction + gc.collect() + + # TODO(zevisert,review): Could skip instead of using the logic below + # A strong reference to the transaction is kept alive by the connection's + # ._transaction_stack, so it is still be tracked at this point. + assert len(open_transactions) == 1 + + # If that were magically cleared, the transaction would be cleaned up, + # but as it stands this always causes a hang during teardown at + # `Database(...).disconnect()` if the transaction is not closed. + transaction = database.connection()._transaction_stack[-1] + await transaction.rollback() + del transaction + + # Now with the transaction rolled-back, it should be cleaned up. + assert len(open_transactions) == 0 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_serializable(database_url): @@ -609,17 +860,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -789,15 +1067,16 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context(database_url): - """ - Test connection contexts are task-local. - """ +async def test_connection_context_same_task(database_url): async with Database(database_url) as database: async with database.connection() as connection_1: async with database.connection() as connection_2: assert connection_1 is connection_2 + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -817,9 +1096,8 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) while connection_1 is None or connection_2 is None: await asyncio.sleep(0.000001) assert connection_1 is not connection_2 @@ -828,6 +1106,61 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + +@pytest.mark.parametrize( + "database_url1,database_url2", + ( + pytest.param(db1, db2, id=f"{db1} | {db2}") + for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) + ), +) +@async_adapter +async def test_connection_context_multiple_databases(database_url1, database_url2): + async with Database(database_url1) as database1: + async with Database(database_url2) as database2: + assert database1.connection() is not database2.connection() + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connection_context_with_raw_connection(database_url): @@ -961,16 +1294,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS)