From 6427be4873b3b6bb893ada340e8af93c46a4dc76 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Thu, 28 Jul 2022 19:30:17 +0200 Subject: [PATCH 01/18] Pin dependencies (#502) --- .github/dependbot.yml | 10 ++++++++ .github/workflows/publish.yml | 4 ++-- .github/workflows/test-suite.yml | 4 ++-- databases/backends/aiopg.py | 2 +- requirements.txt | 41 ++++++++++++++++---------------- 5 files changed, 35 insertions(+), 26 deletions(-) create mode 100644 .github/dependbot.yml diff --git a/.github/dependbot.yml b/.github/dependbot.yml new file mode 100644 index 00000000..b9038ca1 --- /dev/null +++ b/.github/dependbot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: monthly diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a41fd2bf..170e9558 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,8 +12,8 @@ jobs: runs-on: "ubuntu-latest" steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v1" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: 3.7 - name: "Install dependencies" diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 0690b4d1..bc271a65 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -39,8 +39,8 @@ jobs: options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v1" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: "${{ matrix.python-version }}" - name: "Install dependencies" diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 9ad12f63..60c741a7 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -31,7 +31,7 @@ def __init__( self._database_url = DatabaseURL(database_url) self._options = options self._dialect = self._get_dialect() - self._pool = None + self._pool: typing.Union[aiopg.Pool, None] = None def _get_dialect(self) -> Dialect: dialect = PGDialect_psycopg2( diff --git a/requirements.txt b/requirements.txt index 0d1d5b76..3d988585 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,31 @@ -e . # Async database drivers -asyncmy -aiomysql -aiopg -aiosqlite -asyncpg +asyncmy==0.2.5 +aiomysql==0.1.1 +aiopg==1.3.4 +aiosqlite==0.17.0 +asyncpg==0.26.0 # Sync database drivers for standard tooling around setup/teardown/migrations. -psycopg2-binary -pymysql +psycopg2-binary==2.9.3 +pymysql==1.0.2 # Testing -autoflake -black -codecov -isort -mypy -pytest -pytest-cov -starlette -requests +autoflake==1.4 +black==22.6.0 +isort==5.10.1 +mypy==0.971 +pytest==7.1.2 +pytest-cov==3.0.0 +starlette==0.20.4 +requests==2.28.1 # Documentation -mkdocs -mkdocs-material -mkautodoc +mkdocs==1.3.1 +mkdocs-material==8.3.9 +mkautodoc==0.1.0 # Packaging -twine -wheel +twine==4.0.1 +wheel==0.37.1 From 6a52a4a37e69fcb25f309ff5d38e4d43eb80b32c Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Thu, 28 Jul 2022 19:49:45 +0200 Subject: [PATCH 02/18] Simplify mysql tests (#459) --- tests/test_databases.py | 51 --------------------------------------- tests/test_integration.py | 3 +-- 2 files changed, 1 insertion(+), 53 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index e6313e94..a7545e31 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -4,7 +4,6 @@ import functools import os import re -import sys from unittest.mock import MagicMock, patch import pytest @@ -17,23 +16,6 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -def mysql_versions(wrapped_func): - """ - Decorator used to handle multiple versions of Python for mysql drivers - """ - - @functools.wraps(wrapped_func) - def check(*args, **kwargs): # pragma: no cover - url = DatabaseURL(kwargs["database_url"]) - if url.scheme in ["mysql", "mysql+aiomysql"] and sys.version_info >= (3, 10): - pytest.skip("aiomysql supports python 3.9 and lower") - if url.scheme == "mysql+asyncmy" and sys.version_info < (3, 7): - pytest.skip("asyncmy supports python 3.7 and higher") - return wrapped_func(*args, **kwargs) - - return check - - class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -145,7 +127,6 @@ def run_sync(*args, **kwargs): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries(database_url): """ @@ -223,7 +204,6 @@ async def test_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_raw(database_url): """ @@ -285,7 +265,6 @@ async def test_queries_raw(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_ddl_queries(database_url): """ @@ -305,7 +284,6 @@ async def test_ddl_queries(database_url): @pytest.mark.parametrize("exception", [Exception, asyncio.CancelledError]) @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_after_error(database_url, exception): """ @@ -327,7 +305,6 @@ async def test_queries_after_error(database_url, exception): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_results_support_mapping_interface(database_url): """ @@ -356,7 +333,6 @@ async def test_results_support_mapping_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_results_support_column_reference(database_url): """ @@ -388,7 +364,6 @@ async def test_results_support_column_reference(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_result_values_allow_duplicate_names(database_url): """ @@ -405,7 +380,6 @@ async def test_result_values_allow_duplicate_names(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_fetch_one_returning_no_results(database_url): """ @@ -420,7 +394,6 @@ async def test_fetch_one_returning_no_results(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_execute_return_val(database_url): """ @@ -447,7 +420,6 @@ async def test_execute_return_val(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_rollback_isolation(database_url): """ @@ -467,7 +439,6 @@ async def test_rollback_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_rollback_isolation_with_contextmanager(database_url): """ @@ -490,7 +461,6 @@ async def test_rollback_isolation_with_contextmanager(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_commit(database_url): """ @@ -508,7 +478,6 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_commit_serializable(database_url): """ @@ -553,7 +522,6 @@ def delete_independently(): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_rollback(database_url): """ @@ -576,7 +544,6 @@ async def test_transaction_rollback(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_commit_low_level(database_url): """ @@ -600,7 +567,6 @@ async def test_transaction_commit_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_rollback_low_level(database_url): """ @@ -625,7 +591,6 @@ async def test_transaction_rollback_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_transaction_decorator(database_url): """ @@ -656,7 +621,6 @@ async def insert_data(raise_exception): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_datetime_field(database_url): """ @@ -681,7 +645,6 @@ async def test_datetime_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_decimal_field(database_url): """ @@ -709,7 +672,6 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_json_field(database_url): """ @@ -732,7 +694,6 @@ async def test_json_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_custom_field(database_url): """ @@ -758,7 +719,6 @@ async def test_custom_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connections_isolation(database_url): """ @@ -781,7 +741,6 @@ async def test_connections_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_commit_on_root_transaction(database_url): """ @@ -806,7 +765,6 @@ async def test_commit_on_root_transaction(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connect_and_disconnect(database_url): """ @@ -830,7 +788,6 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connection_context(database_url): """ @@ -872,7 +829,6 @@ async def get_connection_2(): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_connection_context_with_raw_connection(database_url): """ @@ -886,7 +842,6 @@ async def test_connection_context_with_raw_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_queries_with_expose_backend_connection(database_url): """ @@ -993,7 +948,6 @@ async def test_queries_with_expose_backend_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_database_url_interface(database_url): """ @@ -1072,7 +1026,6 @@ async def test_iterate_outside_transaction_with_values(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_iterate_outside_transaction_with_temp_table(database_url): """ @@ -1102,7 +1055,6 @@ async def test_iterate_outside_transaction_with_temp_table(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"]) -@mysql_versions @async_adapter async def test_column_names(database_url, select_query): """ @@ -1170,7 +1122,6 @@ async def test_posgres_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_postcompile_queries(database_url): """ @@ -1188,7 +1139,6 @@ async def test_postcompile_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_result_named_access(database_url): async with Database(database_url) as database: @@ -1204,7 +1154,6 @@ async def test_result_named_access(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions @async_adapter async def test_mapping_property_interface(database_url): """ diff --git a/tests/test_integration.py b/tests/test_integration.py index c3e585b4..139f8ffe 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -5,7 +5,7 @@ from starlette.testclient import TestClient from databases import Database, DatabaseURL -from tests.test_databases import DATABASE_URLS, mysql_versions +from tests.test_databases import DATABASE_URLS metadata = sqlalchemy.MetaData() @@ -84,7 +84,6 @@ async def add_note(request): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions def test_integration(database_url): app = get_app(database_url) From 77270d82bf1ae186d08e51c48ae3c96fb0d8d5af Mon Sep 17 00:00:00 2001 From: Rickert Mulder Date: Wed, 3 Aug 2022 14:11:20 +0200 Subject: [PATCH 03/18] Allow string indexing into Record (#501) * Allow string indexing into Record The Record interface inherits from Sequence which only supports integer indexing. The Postgres backend supports string indexing into Records (https://github.com/encode/databases/blob/master/databases/backends/postgres.py#L135). This PR updates the interface to reflect that. At least on the Postgres backend __getitem__ deserializes some data types, so it's not equivalent to the _mapping method. --- databases/interfaces.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/databases/interfaces.py b/databases/interfaces.py index c2109a23..fd6a24ee 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -73,3 +73,6 @@ class Record(Sequence): @property def _mapping(self) -> typing.Mapping: raise NotImplementedError() # pragma: no cover + + def __getitem__(self, key: typing.Any) -> typing.Any: + raise NotImplementedError() # pragma: no cover From 385f3fd788f7a513ca1ee6a8efe5316b3d93f6cd Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 9 Aug 2022 10:30:04 +0200 Subject: [PATCH 04/18] Version 0.6.1 (#505) --- CHANGELOG.md | 7 +++++++ databases/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abe7da92..a0f30af5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.6.1 (Aug 9th, 2022) + +### Fixed + +* Improve typing for `Transaction` (#493) +* Allow string indexing into Record (#501) + ## 0.6.0 (May 29th, 2022) * Dropped Python 3.6 support (#458) diff --git a/databases/__init__.py b/databases/__init__.py index 8dd420b2..1a4a091c 100644 --- a/databases/__init__.py +++ b/databases/__init__.py @@ -1,4 +1,4 @@ from databases.core import Database, DatabaseURL -__version__ = "0.6.0" +__version__ = "0.6.1" __all__ = ["Database", "DatabaseURL"] From ff8e8a26a54cbb775535cb395df93543f3884eb1 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Mon, 7 Nov 2022 12:07:51 +0100 Subject: [PATCH 05/18] Pin SQLAlchemy 1.4.41 (#520) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index decbf7e5..c2b1aa64 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4,<1.5"], + install_requires=["sqlalchemy>=1.4,<=1.4.41"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], From b38cc4f74733f29508aed42f208c2867ee80c98a Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Mon, 7 Nov 2022 12:15:27 +0100 Subject: [PATCH 06/18] Version 0.6.2 (#521) Co-authored-by: Marcelo Trylesinski --- CHANGELOG.md | 6 ++++++ databases/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0f30af5..d5170ee0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.6.2 (Nov 7th, 2022) + +### Changes + +* Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes (#520). + ## 0.6.1 (Aug 9th, 2022) ### Fixed diff --git a/databases/__init__.py b/databases/__init__.py index 1a4a091c..9c88e786 100644 --- a/databases/__init__.py +++ b/databases/__init__.py @@ -1,4 +1,4 @@ from databases.core import Database, DatabaseURL -__version__ = "0.6.1" +__version__ = "0.6.2" __all__ = ["Database", "DatabaseURL"] From b78e519da33114e9d2ff7dec6ca4f03c7437408c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20B=C3=B6hm?= Date: Thu, 1 Dec 2022 14:43:37 +0100 Subject: [PATCH 07/18] Add documentation about creating tables using sqlalchemy schemas; Close #234 (#515) * Add docs about creating tables using sqlalchemy schemas; Close #234 * Add note about preferring alembic outside of experiments * Improve wording --- docs/database_queries.md | 57 +++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/docs/database_queries.md b/docs/database_queries.md index 898e7343..aeb67eb6 100644 --- a/docs/database_queries.md +++ b/docs/database_queries.md @@ -24,9 +24,48 @@ notes = sqlalchemy.Table( ) ``` -You can use any of the sqlalchemy column types such as `sqlalchemy.JSON`, or +You can use any of the SQLAlchemy column types such as `sqlalchemy.JSON`, or custom column types. +## Creating tables + +Databases doesn't use SQLAlchemy's engine for database access internally. [The usual SQLAlchemy core way to create tables with `create_all`](https://docs.sqlalchemy.org/en/20/core/metadata.html#sqlalchemy.schema.MetaData.create_all) is therefore not available. To work around this you can use SQLAlchemy to [compile the query to SQL](https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined) and then execute it with databases: + +```python +from databases import Database +import sqlalchemy + +database = Database("postgresql+asyncpg://localhost/example") + +# Establish the connection pool +await database.connect() + +metadata = sqlalchemy.MetaData() +dialect = sqlalchemy.dialects.postgresql.dialect() + +# Define your table(s) +notes = sqlalchemy.Table( + "notes", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("text", sqlalchemy.String(length=100)), + sqlalchemy.Column("completed", sqlalchemy.Boolean), +) + +# Create tables +for table in metadata.tables.values(): + # Set `if_not_exists=False` if you want the query to throw an + # exception when the table already exists + schema = sqlalchemy.schema.CreateTable(table, if_not_exists=True) + query = str(schema.compile(dialect=dialect)) + await database.execute(query=query) + +# Close all connections in the connection pool +await database.disconnect() +``` + +Note that this way of creating tables is only useful for local experimentation. For serious projects, we recommend using a proper database schema migrations solution like [Alembic](https://alembic.sqlalchemy.org/en/latest/). + ## Queries You can now use any [SQLAlchemy core][sqlalchemy-core] queries ([official tutorial][sqlalchemy-core-tutorial]). @@ -70,11 +109,11 @@ query = notes.select() async for row in database.iterate(query=query): ... -# Close all connection in the connection pool +# Close all connections in the connection pool await database.disconnect() ``` -Connections are managed as task-local state, with driver implementations +Connections are managed as a task-local state, with driver implementations transparently using connection pooling behind the scenes. ## Raw queries @@ -111,17 +150,17 @@ Note that query arguments should follow the `:query_arg` style. ## Query result -To keep in line with [SQLAlchemy 1.4 changes][sqlalchemy-mapping-changes] -query result object no longer implements a mapping interface. -To access query result as a mapping you should use the `_mapping` property. -That way you can process both SQLAlchemy Rows and databases Records from raw queries +To keep in line with [SQLAlchemy 1.4 changes][sqlalchemy-mapping-changes] +query result object no longer implements a mapping interface. +To access query result as a mapping you should use the `_mapping` property. +That way you can process both SQLAlchemy Rows and databases Records from raw queries with the same function without any instance checks. ```python query = "SELECT * FROM notes WHERE id = :id" result = await database.fetch_one(query=query, values={"id": 1}) -result.id # access field via attribute -result._mapping['id'] # access field via mapping +result.id # Access field via attribute +result._mapping['id'] # Access field via mapping ``` [sqlalchemy-mapping-changes]: https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#rowproxy-is-no-longer-a-proxy-is-now-called-row-and-behaves-like-an-enhanced-named-tuple From 8ec9168775d889b7d259c5fb9a2fdff408e40972 Mon Sep 17 00:00:00 2001 From: Nathan Janke Date: Thu, 1 Dec 2022 09:07:21 -0700 Subject: [PATCH 08/18] docs: Update sqlalchemy core tutorial link (#517) The current link in the docs (specifying the version "latest") now points to the docs for the SQLAlchemy 2.0 beta. This PR changes the link to strictly specify version 1.4. It would also be possible to change the link to point to "stable" rather than "latest", however 1.4 seemed more appropriate as "stable" will likely point to 2.0 before this library migrates. Cheers --- docs/database_queries.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/database_queries.md b/docs/database_queries.md index aeb67eb6..66201089 100644 --- a/docs/database_queries.md +++ b/docs/database_queries.md @@ -146,7 +146,7 @@ result = await database.fetch_one(query=query, values={"id": 1}) Note that query arguments should follow the `:query_arg` style. [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ -[sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html +[sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/14/core/tutorial.html ## Query result From 7aa13262e3a1ac2fc402e60a4fde5d6d13925903 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 17 Dec 2022 11:21:44 +0000 Subject: [PATCH 09/18] Wrap types in `typing.Optional` where applicable (#510) Co-authored-by: tsunyoku Co-authored-by: Marcelo Trylesinski --- databases/core.py | 56 ++++++++++++++++++++++++++++++----------------- setup.cfg | 1 + 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/databases/core.py b/databases/core.py index efa59471..8415b836 100644 --- a/databases/core.py +++ b/databases/core.py @@ -129,20 +129,24 @@ async def __aenter__(self) -> "Database": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: await self.disconnect() async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: async with self.connection() as connection: return await connection.fetch_all(query, values) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: async with self.connection() as connection: return await connection.fetch_one(query, values) @@ -150,14 +154,16 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: async with self.connection() as connection: return await connection.fetch_val(query, values, column=column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: async with self.connection() as connection: return await connection.execute(query, values) @@ -169,7 +175,9 @@ async def execute_many( return await connection.execute_many(query, values) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: async for record in connection.iterate(query, values): @@ -232,9 +240,9 @@ async def __aenter__(self) -> "Connection": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: async with self._connection_lock: assert self._connection is not None @@ -243,14 +251,18 @@ async def __aexit__( await self._connection.release() async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_all(built_query) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -259,7 +271,7 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: built_query = self._build_query(query, values) @@ -267,7 +279,9 @@ async def fetch_val( return await self._connection.fetch_val(built_query, column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: built_query = self._build_query(query, values) async with self._query_lock: @@ -281,7 +295,9 @@ async def execute_many( await self._connection.execute_many(queries) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Any, None]: built_query = self._build_query(query, values) async with self.transaction(): @@ -303,7 +319,7 @@ def raw_connection(self) -> typing.Any: @staticmethod def _build_query( - query: typing.Union[ClauseElement, str], values: dict = None + query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None ) -> ClauseElement: if isinstance(query, str): query = text(query) @@ -338,9 +354,9 @@ async def __aenter__(self) -> "Transaction": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: """ Called when exiting `async with database.transaction()` diff --git a/setup.cfg b/setup.cfg index 77c8c58d..da1831fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,7 @@ [mypy] disallow_untyped_defs = True ignore_missing_imports = True +no_implicit_optional = True [tool:isort] profile = black From 81cc6fdb1ce4e78875960a8a262a4b134745946e Mon Sep 17 00:00:00 2001 From: jonium <52005121+joniumGit@users.noreply.github.com> Date: Sat, 17 Dec 2022 13:27:06 +0200 Subject: [PATCH 10/18] Fixes breaking changes in SQLAlchemy cursor (#513) Co-authored-by: Marcelo Trylesinski fixes undefined --- databases/backends/aiopg.py | 1 + databases/backends/asyncmy.py | 1 + databases/backends/mysql.py | 1 + databases/backends/sqlite.py | 1 + setup.py | 2 +- 5 files changed, 5 insertions(+), 1 deletion(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 60c741a7..1d35749e 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -221,6 +221,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index e15dfa45..233d2e0e 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -211,6 +211,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 2a0a8425..c7ac9f4e 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -211,6 +211,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) else: diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 9626dcf8..69ef5b51 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -185,6 +185,7 @@ def _compile( compiled._result_columns, compiled._ordered_columns, compiled._textual_ordered_columns, + compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) diff --git a/setup.py b/setup.py index c2b1aa64..3725cab9 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4,<=1.4.41"], + install_requires=["sqlalchemy>=1.4.42,<1.5"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], From 6b0c767588f501d5edaabb1bdf665fcf1ded88ea Mon Sep 17 00:00:00 2001 From: Ben Beasley Date: Sun, 18 Dec 2022 04:14:45 -0500 Subject: [PATCH 11/18] Version 0.7.0 (#522) Co-authored-by: Marcelo Trylesinski --- CHANGELOG.md | 9 ++++++++- databases/__init__.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5170ee0..4816bc16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.7.0 (Dec 18th, 2022) + +### Fixed + +* Fixed breaking changes in SQLAlchemy cursor; supports `>=1.4.42,<1.5` (#513). +* Wrapped types in `typing.Optional` where applicable (#510). + ## 0.6.2 (Nov 7th, 2022) -### Changes +### Changed * Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes (#520). diff --git a/databases/__init__.py b/databases/__init__.py index 9c88e786..cfb75242 100644 --- a/databases/__init__.py +++ b/databases/__init__.py @@ -1,4 +1,4 @@ from databases.core import Database, DatabaseURL -__version__ = "0.6.2" +__version__ = "0.7.0" __all__ = ["Database", "DatabaseURL"] From 77d9b8aa7dc3871133b02adf0c498c583b89a6fd Mon Sep 17 00:00:00 2001 From: chaojie Date: Tue, 27 Dec 2022 14:34:19 +0800 Subject: [PATCH 12/18] Fix the type-hints using more standard mode (#526) --- databases/backends/aiopg.py | 2 +- databases/backends/asyncmy.py | 2 +- databases/backends/mysql.py | 2 +- databases/backends/postgres.py | 8 +++----- databases/backends/sqlite.py | 2 +- databases/core.py | 8 ++++---- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 1d35749e..8668b2b9 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -104,7 +104,7 @@ class AiopgConnection(ConnectionBackend): def __init__(self, database: AiopgBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[aiopg.Connection] + self._connection: typing.Optional[aiopg.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 233d2e0e..749e5afe 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -92,7 +92,7 @@ class AsyncMyConnection(ConnectionBackend): def __init__(self, database: AsyncMyBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[asyncmy.Connection] + self._connection: typing.Optional[asyncmy.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index c7ac9f4e..6b86042f 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -92,7 +92,7 @@ class MySQLConnection(ConnectionBackend): def __init__(self, database: MySQLBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[aiomysql.Connection] + self._connection: typing.Optional[aiomysql.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 3e1a6fff..e30c12d7 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -45,7 +45,7 @@ def _get_dialect(self) -> Dialect: def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options - kwargs = {} # type: typing.Dict[str, typing.Any] + kwargs: typing.Dict[str, typing.Any] = {} min_size = url_options.get("min_size") max_size = url_options.get("max_size") ssl = url_options.get("ssl") @@ -162,7 +162,7 @@ class PostgresConnection(ConnectionBackend): def __init__(self, database: PostgresBackend, dialect: Dialect): self._database = database self._dialect = dialect - self._connection = None # type: typing.Optional[asyncpg.connection.Connection] + self._connection: typing.Optional[asyncpg.connection.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" @@ -305,9 +305,7 @@ def raw_connection(self) -> asyncpg.connection.Connection: class PostgresTransaction(TransactionBackend): def __init__(self, connection: PostgresConnection): self._connection = connection - self._transaction = ( - None - ) # type: typing.Optional[asyncpg.transaction.Transaction] + self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None async def start( self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 69ef5b51..19464627 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -80,7 +80,7 @@ class SQLiteConnection(ConnectionBackend): def __init__(self, pool: SQLitePool, dialect: Dialect): self._pool = pool self._dialect = dialect - self._connection = None # type: typing.Optional[aiosqlite.Connection] + self._connection: typing.Optional[aiosqlite.Connection] = None async def acquire(self) -> None: assert self._connection is None, "Connection is already acquired" diff --git a/databases/core.py b/databases/core.py index 8415b836..8394ab5c 100644 --- a/databases/core.py +++ b/databases/core.py @@ -64,12 +64,12 @@ def __init__( self._backend = backend_cls(self.url, **self.options) # Connections are stored as task-local state. - self._connection_context = ContextVar("connection_context") # type: ContextVar + self._connection_context: ContextVar = ContextVar("connection_context") # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. - self._global_connection = None # type: typing.Optional[Connection] - self._global_transaction = None # type: typing.Optional[Transaction] + self._global_connection: typing.Optional[Connection] = None + self._global_transaction: typing.Optional[Transaction] = None async def connect(self) -> None: """ @@ -223,7 +223,7 @@ def __init__(self, backend: DatabaseBackend) -> None: self._connection_counter = 0 self._transaction_lock = asyncio.Lock() - self._transaction_stack = [] # type: typing.List[Transaction] + self._transaction_stack: typing.List[Transaction] = [] self._query_lock = asyncio.Lock() From b6eba5f7a19eaf8966e3821f44fe00f4770cb822 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Jan 2023 13:56:18 +0100 Subject: [PATCH 13/18] Bump wheel from 0.37.1 to 0.38.1 (#524) Bumps [wheel](https://github.com/pypa/wheel) from 0.37.1 to 0.38.1. - [Release notes](https://github.com/pypa/wheel/releases) - [Changelog](https://github.com/pypa/wheel/blob/main/docs/news.rst) - [Commits](https://github.com/pypa/wheel/compare/0.37.1...0.38.1) --- updated-dependencies: - dependency-name: wheel dependency-type: direct:production ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Marcelo Trylesinski --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3d988585..0699d3cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ mkautodoc==0.1.0 # Packaging twine==4.0.1 -wheel==0.37.1 +wheel==0.38.1 From ab5eb718a78a27afe18775754e9c0fa2ad9cd211 Mon Sep 17 00:00:00 2001 From: wojtasiq <35078282+wojtasiq@users.noreply.github.com> Date: Wed, 24 May 2023 09:33:50 +0200 Subject: [PATCH 14/18] Bump up asyncmy version to fix `No module named 'asyncmy.connection'` (#553) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0699d3cc..5d98fb2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -e . # Async database drivers -asyncmy==0.2.5 +asyncmy==0.2.7 aiomysql==0.1.1 aiopg==1.3.4 aiosqlite==0.17.0 From 71ea4adfeafc397e0fa54067c012edcbc5a62f7a Mon Sep 17 00:00:00 2001 From: wojtasiq <35078282+wojtasiq@users.noreply.github.com> Date: Wed, 12 Jul 2023 03:12:08 +0200 Subject: [PATCH 15/18] Support for unix socket for aiomysql and asyncmy (#551) --- databases/backends/asyncmy.py | 3 +++ databases/backends/mysql.py | 3 +++ tests/test_connection_options.py | 18 ++++++++++++++++++ tests/test_database_url.py | 5 +++++ 4 files changed, 29 insertions(+) diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 749e5afe..0811ef21 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 6b86042f..630f7cd3 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index e6fe6849..9e4435ad 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -77,6 +77,15 @@ def test_mysql_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mysql_unix_socket(): + backend = MySQLBackend( + "mysql+aiomysql://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") def test_mysql_explicit_pool_size(): backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20) @@ -114,6 +123,15 @@ def test_asyncmy_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") +def test_asyncmy_unix_socket(): + backend = AsyncMyBackend( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") def test_asyncmy_explicit_pool_size(): backend = AsyncMyBackend("mysql://localhost/database", min_size=1, max_size=20) diff --git a/tests/test_database_url.py b/tests/test_database_url.py index 9eea4fa6..7aa15926 100644 --- a/tests/test_database_url.py +++ b/tests/test_database_url.py @@ -69,6 +69,11 @@ def test_database_url_options(): u = DatabaseURL("postgresql://localhost/mydatabase?pool_size=20&ssl=true") assert u.options == {"pool_size": "20", "ssl": "true"} + u = DatabaseURL( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + assert u.options == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + def test_replace_database_url_components(): u = DatabaseURL("postgresql://localhost/mydatabase") From f3f0c6f0ba1b2af7d1716d5786d584410c745ed3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Jul 2023 20:32:27 -0500 Subject: [PATCH 16/18] Bump starlette from 0.20.4 to 0.27.0 (#560) * Bump starlette from 0.20.4 to 0.27.0 Bumps [starlette](https://github.com/encode/starlette) from 0.20.4 to 0.27.0. - [Release notes](https://github.com/encode/starlette/releases) - [Changelog](https://github.com/encode/starlette/blob/master/docs/release-notes.md) - [Commits](https://github.com/encode/starlette/compare/0.20.4...0.27.0) --- updated-dependencies: - dependency-name: starlette dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Add httpx as dependency --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Zanie --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5d98fb2e..87971a6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,11 +14,12 @@ pymysql==1.0.2 # Testing autoflake==1.4 black==22.6.0 +httpx==0.24.1 isort==5.10.1 mypy==0.971 pytest==7.1.2 pytest-cov==3.0.0 -starlette==0.20.4 +starlette==0.27.0 requests==2.28.1 # Documentation From c09542802afdb1e4fbdde9803a136e9254461e5d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Jul 2023 20:39:40 -0500 Subject: [PATCH 17/18] Bump requests from 2.28.1 to 2.31.0 (#562) Bumps [requests](https://github.com/psf/requests) from 2.28.1 to 2.31.0. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.28.1...v2.31.0) --- updated-dependencies: - dependency-name: requests dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 87971a6f..46ed998b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ mypy==0.971 pytest==7.1.2 pytest-cov==3.0.0 starlette==0.27.0 -requests==2.28.1 +requests==2.31.0 # Documentation mkdocs==1.3.1 From 25fa29515d4b6387db482734c564152a1034fbe6 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 25 Jul 2023 10:38:38 -0700 Subject: [PATCH 18/18] fix: incorrect concurrent usage of connection and transaction (#546) * fix: incorrect concurrent usage of connection and transaction * refactor: rename contextvar class attributes, add some explaination comments * fix: contextvar.get takes no keyword arguments * test: add concurrent task tests * feat: use ContextVar[dict] to track connections and transactions per task * test: check multiple databases in the same task use independant connections * chore: changes for linting and typechecking * chore: use typing.Tuple for lower python version compatibility * docs: update comment on _connection_contextmap * Update `Connection` and `Transaction` to be robust to concurrent use * chore: remove optional annotation on asyncio.Task * test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup * feat: reimplement concurrency system with contextvar and weakmap * chore: apply corrections from linters * fix: quote WeakKeyDictionary typing for python<=3.7 * docs: add examples for async transaction context and nested transactions * fix: remove connection inheritance, add more tests, update docs Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited. --------- Co-authored-by: Zanie --- databases/core.py | 92 +++++- docs/connections_and_transactions.md | 54 +++- tests/test_databases.py | 410 +++++++++++++++++++++++++-- 3 files changed, 521 insertions(+), 35 deletions(-) diff --git a/databases/core.py b/databases/core.py index 8394ab5c..795609ea 100644 --- a/databases/core.py +++ b/databases/core.py @@ -3,6 +3,7 @@ import functools import logging import typing +import weakref from contextvars import ContextVar from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit @@ -11,7 +12,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -35,6 +36,11 @@ logger = logging.getLogger("databases") +_ACTIVE_TRANSACTIONS: ContextVar[ + typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] +] = ContextVar("databases:active_transactions", default=None) + + class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", @@ -45,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -55,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -63,14 +72,35 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context: ContextVar = ContextVar("connection_context") - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None self._global_transaction: typing.Optional[Transaction] = None + @property + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task + + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) + + @_connection.setter + def _connection( + self, connection: typing.Optional["Connection"] + ) -> typing.Optional["Connection"]: + task = self._current_task + + if connection is None: + self._connection_map.pop(task, None) + else: + self._connection_map[task] = connection + + return self._connection + async def connect(self) -> None: """ Establish the connection pool. @@ -89,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -113,7 +143,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connection = None await self._backend.disconnect() logger.info( @@ -187,12 +217,10 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + if not self._connection: + self._connection = Connection(self, self._backend) + + return self._connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -215,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -249,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -345,6 +375,37 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs + @property + def _connection(self) -> "Connection": + # Returns the same connection if called multiple times + return self._connection_callable() + + @property + def _transaction(self) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + return None + + return transactions.get(self, None) + + @_transaction.setter + def _transaction( + self, transaction: typing.Optional["TransactionBackend"] + ) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + transactions = weakref.WeakKeyDictionary() + else: + transactions = transactions.copy() + + if transaction is None: + transactions.pop(self, None) + else: + transactions[self] = transaction + + _ACTIVE_TRANSACTIONS.set(transactions) + return transactions.get(self, None) + async def __aenter__(self) -> "Transaction": """ Called when entering `async with database.transaction()` @@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() self._transaction = self._connection._connection.transaction() async with self._connection._transaction_lock: @@ -401,15 +461,19 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.commit() await self._connection.__aexit__() + self._transaction = None async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.rollback() await self._connection.__aexit__() + self._transaction = None class _EmptyNetloc(str): diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index aa45537d..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool: async with database.transaction(): ... ``` + It can also be acquired from a specific database connection: ```python @@ -95,8 +98,51 @@ async def create_users(request): ... ``` -Transaction blocks are managed as task-local state. Nested transactions -are fully supported, and are implemented using database savepoints. +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: + +```python +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( + "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", + {"id": id} + ) + + +async with Database(database_url) as database: + async with database.transaction(): + # This note won't exist until the transaction closes... + await database.execute( + "INSERT INTO notes(id, text) values (1, 'databases is cool')" + ) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) + + await database.fetch_val("SELECT text FROM notes WHERE id=1") + # ^ returns: "databases is cool!!!" +``` + +Nested transactions are fully supported, and are implemented using database savepoints: + +```python +async with databases.Database(database_url) as db: + async with db.transaction() as outer: + # Do something in the outer transaction + ... + + # Suppress to prevent influence on the outer transaction + with contextlib.suppress(ValueError): + async with db.transaction(): + # Do something in the inner transaction + ... + + raise ValueError('Abort the inner transaction') + + # Observe the results of the outer transaction, + # without effects from the inner transaction. + await db.fetch_all('SELECT * FROM ...') +``` Transaction isolation-level can be specified if the driver backend supports that: diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..4d737261 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,8 +2,11 @@ import datetime import decimal import functools +import gc +import itertools import os import re +from typing import MutableMapping from unittest.mock import MagicMock, patch import pytest @@ -477,6 +480,254 @@ async def test_transaction_commit(database_url): assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): + """ + Ensure that child tasks may influence inherited transactions. + """ + # This is an practical example of the above test. + async with Database(database_url) as database: + async with database.transaction(): + # Create a note + await database.execute( + notes.insert().values(id=1, text="setup", completed=True) + ) + + # Change the note from the same task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="prior") + ) + + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" + + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( + notes.update().where(notes.c.id == 1).values(text="test") + ) + + await asyncio.create_task(run_update_from_child_task(database.connection())) + + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation(database_url): + """ + Ensure that transactions are isolated between sibling tasks. + """ + start = asyncio.Event() + end = asyncio.Event() + + async with Database(database_url) as database: + + async def check_transaction(transaction): + await start.wait() + # Parent task is now in a transaction, we should not + # see its transaction backend since this task was + # _started_ in a context where no transaction was active. + assert transaction._transaction is None + end.set() + + transaction = database.transaction() + assert transaction._transaction is None + task = asyncio.create_task(check_transaction(transaction)) + + async with transaction: + start.set() + assert transaction._transaction is not None + await end.wait() + + # Cleanup for "Task not awaited" warning + await task + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): + """ + Ensure that task connections are not persisted unecessarily. + """ + + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() + + async with Database(database_url) as database: + # Should have a connection in this task + # .connect is lazy, it doesn't create a Connection, but .connection does + connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection + + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection + + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_garbagecollector(database_url): + """ + Ensure that connections for tasks are not persisted unecessarily, even + if exit handlers are not called. + """ + database = Database(database_url) + await database.connect() + + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() + + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task + gc.collect() + + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + async with database.transaction() as transaction: + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # Context manager closes, open_transactions is cleaned up + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily, even + if exit handlers are not called. + + This test should be an XFAIL, but cannot be due to the way that is hangs + during teardown. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + transaction = database.transaction() + await transaction.start() + + # Should be tracking the transaction + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # neither .commit, .rollback, nor .__aexit__ are called + del transaction + gc.collect() + + # TODO(zevisert,review): Could skip instead of using the logic below + # A strong reference to the transaction is kept alive by the connection's + # ._transaction_stack, so it is still be tracked at this point. + assert len(open_transactions) == 1 + + # If that were magically cleared, the transaction would be cleaned up, + # but as it stands this always causes a hang during teardown at + # `Database(...).disconnect()` if the transaction is not closed. + transaction = database.connection()._transaction_stack[-1] + await transaction.rollback() + del transaction + + # Now with the transaction rolled-back, it should be cleaned up. + assert len(open_transactions) == 0 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_serializable(database_url): @@ -609,17 +860,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -789,15 +1067,16 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context(database_url): - """ - Test connection contexts are task-local. - """ +async def test_connection_context_same_task(database_url): async with Database(database_url) as database: async with database.connection() as connection_1: async with database.connection() as connection_2: assert connection_1 is connection_2 + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -817,9 +1096,8 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) while connection_1 is None or connection_2 is None: await asyncio.sleep(0.000001) assert connection_1 is not connection_2 @@ -828,6 +1106,61 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + +@pytest.mark.parametrize( + "database_url1,database_url2", + ( + pytest.param(db1, db2, id=f"{db1} | {db2}") + for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) + ), +) +@async_adapter +async def test_connection_context_multiple_databases(database_url1, database_url2): + async with Database(database_url1) as database1: + async with Database(database_url2) as database2: + assert database1.connection() is not database2.connection() + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connection_context_with_raw_connection(database_url): @@ -961,16 +1294,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS)