diff --git a/.github/dependbot.yml b/.github/dependbot.yml new file mode 100644 index 00000000..b9038ca1 --- /dev/null +++ b/.github/dependbot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: monthly diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a41fd2bf..170e9558 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,8 +12,8 @@ jobs: runs-on: "ubuntu-latest" steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v1" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: 3.7 - name: "Install dependencies" diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 0690b4d1..bc271a65 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -39,8 +39,8 @@ jobs: options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v1" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: "${{ matrix.python-version }}" - name: "Install dependencies" diff --git a/CHANGELOG.md b/CHANGELOG.md index abe7da92..4816bc16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.7.0 (Dec 18th, 2022) + +### Fixed + +* Fixed breaking changes in SQLAlchemy cursor; supports `>=1.4.42,<1.5` (#513). +* Wrapped types in `typing.Optional` where applicable (#510). + +## 0.6.2 (Nov 7th, 2022) + +### Changed + +* Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes (#520). + +## 0.6.1 (Aug 9th, 2022) + +### Fixed + +* Improve typing for `Transaction` (#493) +* Allow string indexing into Record (#501) + ## 0.6.0 (May 29th, 2022) * Dropped Python 3.6 support (#458) diff --git a/databases/__init__.py b/databases/__init__.py index 8dd420b2..cfb75242 100644 --- a/databases/__init__.py +++ b/databases/__init__.py @@ -1,4 +1,4 @@ from databases.core import Database, DatabaseURL -__version__ = "0.6.0" +__version__ = "0.7.0" __all__ = ["Database", "DatabaseURL"] diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 9ad12f63..8668b2b9 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -31,7 +31,7 @@ def __init__( self._database_url = DatabaseURL(database_url) self._options = options self._dialect = self._get_dialect() - self._pool = None + self._pool: typing.Union[aiopg.Pool, None] = None def _get_dialect(self) -> Dialect: dialect = PGDialect_psycopg2( @@ -104,7 +104,7 @@ class AiopgConnection(ConnectionBackend): def __init__(self, database: AiopgBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[aiopg.Connection] + self._connection: typing.Optional[aiopg.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -221,6 +221,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index e15dfa45..0811ef21 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. @@ -92,7 +95,7 @@ class AsyncMyConnection(ConnectionBackend): def __init__(self, database: AsyncMyBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[asyncmy.Connection] + self._connection: typing.Optional[asyncmy.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -211,6 +214,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 2a0a8425..630f7cd3 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. @@ -92,7 +95,7 @@ class MySQLConnection(ConnectionBackend): def __init__(self, database: MySQLBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[aiomysql.Connection] + self._connection: typing.Optional[aiomysql.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -211,6 +214,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 3e8c9f8b..a2468bad 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -45,7 +45,7 @@ def _get_dialect(self) -> Dialect: def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options - kwargs = {} # type: typing.Dict[str, typing.Any] + kwargs: typing.Dict[str, typing.Any] = {} min_size = url_options.get("min_size") max_size = url_options.get("max_size") ssl = url_options.get("ssl") @@ -165,7 +165,7 @@ class PostgresConnection(ConnectionBackend): def __init__(self, database: PostgresBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[asyncpg.connection.Connection] + self._connection: typing.Optional[asyncpg.connection.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -308,9 +308,7 @@ def raw_connection(self) -> asyncpg.connection.Connection: class PostgresTransaction(TransactionBackend): def __init__(self, connection: PostgresConnection): self._connection = connection - self._transaction = ( - None - ) # type: typing.Optional[asyncpg.transaction.Transaction] + self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None async def start( self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 9626dcf8..19464627 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -80,7 +80,7 @@ class SQLiteConnection(ConnectionBackend): def __init__(self, pool: SQLitePool, dialect: Dialect): self._pool = pool self._dialect = dialect - self._connection = None # type: typing.Optional[aiosqlite.Connection] + self._connection: typing.Optional[aiosqlite.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -185,6 +185,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) diff --git a/databases/core.py b/databases/core.py index efa59471..795609ea 100644 --- a/databases/core.py +++ b/databases/core.py @@ -3,6 +3,7 @@ import functools import logging import typing +import weakref from contextvars import ContextVar from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit @@ -11,7 +12,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -35,6 +36,11 @@ logger = logging.getLogger("databases") +_ACTIVE_TRANSACTIONS: ContextVar[ + typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] +] = ContextVar("databases:active_transactions", default=None) + + class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", @@ -45,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -55,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -63,13 +72,34 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context = ContextVar("connection_context") # type: ContextVar - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. - self._global_connection = None # type: typing.Optional[Connection] - self._global_transaction = None # type: typing.Optional[Transaction] + self._global_connection: typing.Optional[Connection] = None + self._global_transaction: typing.Optional[Transaction] = None + + @property + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task + + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) + + @_connection.setter + def _connection( + self, connection: typing.Optional["Connection"] + ) -> typing.Optional["Connection"]: + task = self._current_task + + if connection is None: + self._connection_map.pop(task, None) + else: + self._connection_map[task] = connection + + return self._connection async def connect(self) -> None: """ @@ -89,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -113,7 +143,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connection = None await self._backend.disconnect() logger.info( @@ -129,20 +159,24 @@ async def __aenter__(self) -> "Database": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: await self.disconnect() async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: async with self.connection() as connection: return await connection.fetch_all(query, values) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: async with self.connection() as connection: return await connection.fetch_one(query, values) @@ -150,14 +184,16 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: async with self.connection() as connection: return await connection.fetch_val(query, values, column=column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: async with self.connection() as connection: return await connection.execute(query, values) @@ -169,7 +205,9 @@ async def execute_many( return await connection.execute_many(query, values) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: async for record in connection.iterate(query, values): @@ -179,12 +217,10 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + if not self._connection: + self._connection = Connection(self, self._backend) + + return self._connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -207,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -215,7 +252,7 @@ def __init__(self, backend: DatabaseBackend) -> None: self._connection_counter = 0 self._transaction_lock = asyncio.Lock() - self._transaction_stack = [] # type: typing.List[Transaction] + self._transaction_stack: typing.List[Transaction] = [] self._query_lock = asyncio.Lock() @@ -232,25 +269,30 @@ async def __aenter__(self) -> "Connection": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: async with self._connection_lock: assert self._connection is not None self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_all(built_query) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -259,7 +301,7 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: built_query = self._build_query(query, values) @@ -267,7 +309,9 @@ async def fetch_val( return await self._connection.fetch_val(built_query, column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: built_query = self._build_query(query, values) async with self._query_lock: @@ -281,7 +325,9 @@ async def execute_many( await self._connection.execute_many(queries) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Any, None]: built_query = self._build_query(query, values) async with self.transaction(): @@ -303,7 +349,7 @@ def raw_connection(self) -> typing.Any: @staticmethod def _build_query( - query: typing.Union[ClauseElement, str], values: dict = None + query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None ) -> ClauseElement: if isinstance(query, str): query = text(query) @@ -329,6 +375,37 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs + @property + def _connection(self) -> "Connection": + # Returns the same connection if called multiple times + return self._connection_callable() + + @property + def _transaction(self) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + return None + + return transactions.get(self, None) + + @_transaction.setter + def _transaction( + self, transaction: typing.Optional["TransactionBackend"] + ) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + transactions = weakref.WeakKeyDictionary() + else: + transactions = transactions.copy() + + if transaction is None: + transactions.pop(self, None) + else: + transactions[self] = transaction + + _ACTIVE_TRANSACTIONS.set(transactions) + return transactions.get(self, None) + async def __aenter__(self) -> "Transaction": """ Called when entering `async with database.transaction()` @@ -338,9 +415,9 @@ async def __aenter__(self) -> "Transaction": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: """ Called when exiting `async with database.transaction()` @@ -369,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() self._transaction = self._connection._connection.transaction() async with self._connection._transaction_lock: @@ -385,15 +461,19 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.commit() await self._connection.__aexit__() + self._transaction = None async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.rollback() await self._connection.__aexit__() + self._transaction = None class _EmptyNetloc(str): diff --git a/databases/interfaces.py b/databases/interfaces.py index c2109a23..fd6a24ee 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -73,3 +73,6 @@ class Record(Sequence): @property def _mapping(self) -> typing.Mapping: raise NotImplementedError() # pragma: no cover + + def __getitem__(self, key: typing.Any) -> typing.Any: + raise NotImplementedError() # pragma: no cover diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index aa45537d..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool: async with database.transaction(): ... ``` + It can also be acquired from a specific database connection: ```python @@ -95,8 +98,51 @@ async def create_users(request): ... ``` -Transaction blocks are managed as task-local state. Nested transactions -are fully supported, and are implemented using database savepoints. +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: + +```python +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( + "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", + {"id": id} + ) + + +async with Database(database_url) as database: + async with database.transaction(): + # This note won't exist until the transaction closes... + await database.execute( + "INSERT INTO notes(id, text) values (1, 'databases is cool')" + ) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) + + await database.fetch_val("SELECT text FROM notes WHERE id=1") + # ^ returns: "databases is cool!!!" +``` + +Nested transactions are fully supported, and are implemented using database savepoints: + +```python +async with databases.Database(database_url) as db: + async with db.transaction() as outer: + # Do something in the outer transaction + ... + + # Suppress to prevent influence on the outer transaction + with contextlib.suppress(ValueError): + async with db.transaction(): + # Do something in the inner transaction + ... + + raise ValueError('Abort the inner transaction') + + # Observe the results of the outer transaction, + # without effects from the inner transaction. + await db.fetch_all('SELECT * FROM ...') +``` Transaction isolation-level can be specified if the driver backend supports that: diff --git a/docs/database_queries.md b/docs/database_queries.md index 898e7343..66201089 100644 --- a/docs/database_queries.md +++ b/docs/database_queries.md @@ -24,9 +24,48 @@ notes = sqlalchemy.Table( ) ``` -You can use any of the sqlalchemy column types such as `sqlalchemy.JSON`, or +You can use any of the SQLAlchemy column types such as `sqlalchemy.JSON`, or custom column types. +## Creating tables + +Databases doesn't use SQLAlchemy's engine for database access internally. [The usual SQLAlchemy core way to create tables with `create_all`](https://docs.sqlalchemy.org/en/20/core/metadata.html#sqlalchemy.schema.MetaData.create_all) is therefore not available. To work around this you can use SQLAlchemy to [compile the query to SQL](https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined) and then execute it with databases: + +```python +from databases import Database +import sqlalchemy + +database = Database("postgresql+asyncpg://localhost/example") + +# Establish the connection pool +await database.connect() + +metadata = sqlalchemy.MetaData() +dialect = sqlalchemy.dialects.postgresql.dialect() + +# Define your table(s) +notes = sqlalchemy.Table( + "notes", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("text", sqlalchemy.String(length=100)), + sqlalchemy.Column("completed", sqlalchemy.Boolean), +) + +# Create tables +for table in metadata.tables.values(): + # Set `if_not_exists=False` if you want the query to throw an + # exception when the table already exists + schema = sqlalchemy.schema.CreateTable(table, if_not_exists=True) + query = str(schema.compile(dialect=dialect)) + await database.execute(query=query) + +# Close all connections in the connection pool +await database.disconnect() +``` + +Note that this way of creating tables is only useful for local experimentation. For serious projects, we recommend using a proper database schema migrations solution like [Alembic](https://alembic.sqlalchemy.org/en/latest/). + ## Queries You can now use any [SQLAlchemy core][sqlalchemy-core] queries ([official tutorial][sqlalchemy-core-tutorial]). @@ -70,11 +109,11 @@ query = notes.select() async for row in database.iterate(query=query): ... -# Close all connection in the connection pool +# Close all connections in the connection pool await database.disconnect() ``` -Connections are managed as task-local state, with driver implementations +Connections are managed as a task-local state, with driver implementations transparently using connection pooling behind the scenes. ## Raw queries @@ -107,21 +146,21 @@ result = await database.fetch_one(query=query, values={"id": 1}) Note that query arguments should follow the `:query_arg` style. [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ -[sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html +[sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/14/core/tutorial.html ## Query result -To keep in line with [SQLAlchemy 1.4 changes][sqlalchemy-mapping-changes] -query result object no longer implements a mapping interface. -To access query result as a mapping you should use the `_mapping` property. -That way you can process both SQLAlchemy Rows and databases Records from raw queries +To keep in line with [SQLAlchemy 1.4 changes][sqlalchemy-mapping-changes] +query result object no longer implements a mapping interface. +To access query result as a mapping you should use the `_mapping` property. +That way you can process both SQLAlchemy Rows and databases Records from raw queries with the same function without any instance checks. ```python query = "SELECT * FROM notes WHERE id = :id" result = await database.fetch_one(query=query, values={"id": 1}) -result.id # access field via attribute -result._mapping['id'] # access field via mapping +result.id # Access field via attribute +result._mapping['id'] # Access field via mapping ``` [sqlalchemy-mapping-changes]: https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#rowproxy-is-no-longer-a-proxy-is-now-called-row-and-behaves-like-an-enhanced-named-tuple diff --git a/requirements.txt b/requirements.txt index 0d1d5b76..46ed998b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,32 @@ -e . # Async database drivers -asyncmy -aiomysql -aiopg -aiosqlite -asyncpg +asyncmy==0.2.7 +aiomysql==0.1.1 +aiopg==1.3.4 +aiosqlite==0.17.0 +asyncpg==0.26.0 # Sync database drivers for standard tooling around setup/teardown/migrations. -psycopg2-binary -pymysql +psycopg2-binary==2.9.3 +pymysql==1.0.2 # Testing -autoflake -black -codecov -isort -mypy -pytest -pytest-cov -starlette -requests +autoflake==1.4 +black==22.6.0 +httpx==0.24.1 +isort==5.10.1 +mypy==0.971 +pytest==7.1.2 +pytest-cov==3.0.0 +starlette==0.27.0 +requests==2.31.0 # Documentation -mkdocs -mkdocs-material -mkautodoc +mkdocs==1.3.1 +mkdocs-material==8.3.9 +mkautodoc==0.1.0 # Packaging -twine -wheel +twine==4.0.1 +wheel==0.38.1 diff --git a/setup.cfg b/setup.cfg index 77c8c58d..da1831fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,7 @@ [mypy] disallow_untyped_defs = True ignore_missing_imports = True +no_implicit_optional = True [tool:isort] profile = black diff --git a/setup.py b/setup.py index decbf7e5..3725cab9 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4,<1.5"], + install_requires=["sqlalchemy>=1.4.42,<1.5"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index e6fe6849..9e4435ad 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -77,6 +77,15 @@ def test_mysql_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mysql_unix_socket(): + backend = MySQLBackend( + "mysql+aiomysql://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") def test_mysql_explicit_pool_size(): backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20) @@ -114,6 +123,15 @@ def test_asyncmy_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") +def test_asyncmy_unix_socket(): + backend = AsyncMyBackend( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") def test_asyncmy_explicit_pool_size(): backend = AsyncMyBackend("mysql://localhost/database", min_size=1, max_size=20) diff --git a/tests/test_database_url.py b/tests/test_database_url.py index 9eea4fa6..7aa15926 100644 --- a/tests/test_database_url.py +++ b/tests/test_database_url.py @@ -69,6 +69,11 @@ def test_database_url_options(): u = DatabaseURL("postgresql://localhost/mydatabase?pool_size=20&ssl=true") assert u.options == {"pool_size": "20", "ssl": "true"} + u = DatabaseURL( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + assert u.options == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + def test_replace_database_url_components(): u = DatabaseURL("postgresql://localhost/mydatabase") diff --git a/tests/test_databases.py b/tests/test_databases.py index 6683d4d0..a17d7f00 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,9 +2,11 @@ import datetime import decimal import functools +import gc +import itertools import os import re -import sys +from typing import MutableMapping from unittest.mock import MagicMock, patch import pytest @@ -17,23 +19,6 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -def mysql_versions(wrapped_func): - """ - Decorator used to handle multiple versions of Python for mysql drivers - """ - - @functools.wraps(wrapped_func) - def check(*args, **kwargs): # pragma: no cover - url = DatabaseURL(kwargs["database_url"]) - if url.scheme in ["mysql", "mysql+aiomysql"] and sys.version_info >= (3, 10): - pytest.skip("aiomysql supports python 3.9 and lower") - if url.scheme == "mysql+asyncmy" and sys.version_info < (3, 7): - pytest.skip("asyncmy supports python 3.7 and higher") - return wrapped_func(*args, **kwargs) - - return check - - class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -145,7 +130,6 @@ def run_sync(*args, **kwargs): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries(database_url): """ @@ -223,7 +207,6 @@ async def test_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_manual(database_url): async with Database(database_url) as database: @@ -303,7 +286,6 @@ async def test_queries_raw(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_ddl_queries(database_url): """ @@ -323,7 +305,6 @@ async def test_ddl_queries(database_url): @pytest.mark.parametrize("exception", [Exception, asyncio.CancelledError]) @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_after_error(database_url, exception): """ @@ -345,7 +326,6 @@ async def test_queries_after_error(database_url, exception): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_results_support_mapping_interface(database_url): """ @@ -374,7 +354,6 @@ async def test_results_support_mapping_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_results_support_column_reference(database_url): """ @@ -406,7 +385,6 @@ async def test_results_support_column_reference(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_result_values_allow_duplicate_names(database_url): """ @@ -423,7 +401,6 @@ async def test_result_values_allow_duplicate_names(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_fetch_one_returning_no_results(database_url): """ @@ -438,7 +415,6 @@ async def test_fetch_one_returning_no_results(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_execute_return_val(database_url): """ @@ -465,7 +441,6 @@ async def test_execute_return_val(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_rollback_isolation(database_url): """ @@ -485,7 +460,6 @@ async def test_rollback_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_rollback_isolation_with_contextmanager(database_url): """ @@ -508,7 +482,6 @@ async def test_rollback_isolation_with_contextmanager(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_commit(database_url): """ @@ -526,7 +499,254 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions +@async_adapter +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): + """ + Ensure that child tasks may influence inherited transactions. + """ + # This is an practical example of the above test. + async with Database(database_url) as database: + async with database.transaction(): + # Create a note + await database.execute( + notes.insert().values(id=1, text="setup", completed=True) + ) + + # Change the note from the same task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="prior") + ) + + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" + + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( + notes.update().where(notes.c.id == 1).values(text="test") + ) + + await asyncio.create_task(run_update_from_child_task(database.connection())) + + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation(database_url): + """ + Ensure that transactions are isolated between sibling tasks. + """ + start = asyncio.Event() + end = asyncio.Event() + + async with Database(database_url) as database: + + async def check_transaction(transaction): + await start.wait() + # Parent task is now in a transaction, we should not + # see its transaction backend since this task was + # _started_ in a context where no transaction was active. + assert transaction._transaction is None + end.set() + + transaction = database.transaction() + assert transaction._transaction is None + task = asyncio.create_task(check_transaction(transaction)) + + async with transaction: + start.set() + assert transaction._transaction is not None + await end.wait() + + # Cleanup for "Task not awaited" warning + await task + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): + """ + Ensure that task connections are not persisted unecessarily. + """ + + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() + + async with Database(database_url) as database: + # Should have a connection in this task + # .connect is lazy, it doesn't create a Connection, but .connection does + connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection + + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection + + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_garbagecollector(database_url): + """ + Ensure that connections for tasks are not persisted unecessarily, even + if exit handlers are not called. + """ + database = Database(database_url) + await database.connect() + + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() + + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task + gc.collect() + + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + async with database.transaction() as transaction: + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # Context manager closes, open_transactions is cleaned up + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily, even + if exit handlers are not called. + + This test should be an XFAIL, but cannot be due to the way that is hangs + during teardown. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + transaction = database.transaction() + await transaction.start() + + # Should be tracking the transaction + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # neither .commit, .rollback, nor .__aexit__ are called + del transaction + gc.collect() + + # TODO(zevisert,review): Could skip instead of using the logic below + # A strong reference to the transaction is kept alive by the connection's + # ._transaction_stack, so it is still be tracked at this point. + assert len(open_transactions) == 1 + + # If that were magically cleared, the transaction would be cleaned up, + # but as it stands this always causes a hang during teardown at + # `Database(...).disconnect()` if the transaction is not closed. + transaction = database.connection()._transaction_stack[-1] + await transaction.rollback() + del transaction + + # Now with the transaction rolled-back, it should be cleaned up. + assert len(open_transactions) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_serializable(database_url): """ @@ -571,7 +791,6 @@ def delete_independently(): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_rollback(database_url): """ @@ -594,7 +813,6 @@ async def test_transaction_rollback(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_commit_low_level(database_url): """ @@ -618,7 +836,6 @@ async def test_transaction_commit_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_rollback_low_level(database_url): """ @@ -643,7 +860,6 @@ async def test_transaction_rollback_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_decorator(database_url): """ @@ -662,19 +878,45 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): """ @@ -699,7 +941,6 @@ async def test_datetime_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_decimal_field(database_url): """ @@ -727,7 +968,6 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_json_field(database_url): """ @@ -750,7 +990,6 @@ async def test_json_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_custom_field(database_url): """ @@ -776,7 +1015,6 @@ async def test_custom_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connections_isolation(database_url): """ @@ -799,7 +1037,6 @@ async def test_connections_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_commit_on_root_transaction(database_url): """ @@ -824,7 +1061,6 @@ async def test_commit_on_root_transaction(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connect_and_disconnect(database_url): """ @@ -848,17 +1084,17 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter -async def test_connection_context(database_url): - """ - Test connection contexts are task-local. - """ +async def test_connection_context_same_task(database_url): async with Database(database_url) as database: async with database.connection() as connection_1: async with database.connection() as connection_2: assert connection_1 is connection_2 + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -878,9 +1114,8 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) while connection_1 is None or connection_2 is None: await asyncio.sleep(0.000001) assert connection_1 is not connection_2 @@ -890,7 +1125,61 @@ async def get_connection_2(): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + +@pytest.mark.parametrize( + "database_url1,database_url2", + ( + pytest.param(db1, db2, id=f"{db1} | {db2}") + for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) + ), +) +@async_adapter +async def test_connection_context_multiple_databases(database_url1, database_url2): + async with Database(database_url1) as database1: + async with Database(database_url2) as database2: + assert database1.connection() is not database2.connection() + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connection_context_with_raw_connection(database_url): """ @@ -904,7 +1193,6 @@ async def test_connection_context_with_raw_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_with_expose_backend_connection(database_url): """ @@ -1011,7 +1299,6 @@ async def test_queries_with_expose_backend_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_database_url_interface(database_url): """ @@ -1025,16 +1312,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1090,7 +1420,6 @@ async def test_iterate_outside_transaction_with_values(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_iterate_outside_transaction_with_temp_table(database_url): """ @@ -1120,7 +1449,6 @@ async def test_iterate_outside_transaction_with_temp_table(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"]) -@mysql_versions @async_adapter async def test_column_names(database_url, select_query): """ @@ -1188,7 +1516,6 @@ async def test_posgres_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_postcompile_queries(database_url): """ @@ -1206,7 +1533,6 @@ async def test_postcompile_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_result_named_access(database_url): async with Database(database_url) as database: @@ -1222,7 +1548,6 @@ async def test_result_named_access(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_mapping_property_interface(database_url): """ diff --git a/tests/test_integration.py b/tests/test_integration.py index c3e585b4..139f8ffe 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -5,7 +5,7 @@ from starlette.testclient import TestClient from databases import Database, DatabaseURL -from tests.test_databases import DATABASE_URLS, mysql_versions +from tests.test_databases import DATABASE_URLS metadata = sqlalchemy.MetaData() @@ -84,7 +84,6 @@ async def add_note(request): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions def test_integration(database_url): app = get_app(database_url)