From 465167115f6f7ad83d5bddd35a28472fd5207f75 Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Mon, 1 Jul 2019 02:40:24 +0930 Subject: [PATCH 1/8] [WIP] `aiopg` support #39 --- databases/backends/aiopg.py | 260 ++++++++++++++++++++++++++++++++++++ databases/core.py | 15 ++- tests/conftest.py | 39 ++++++ tests/test_databases.py | 3 +- tests/test_integration.py | 1 + 5 files changed, 316 insertions(+), 2 deletions(-) create mode 100644 databases/backends/aiopg.py create mode 100644 tests/conftest.py diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py new file mode 100644 index 00000000..7121d8f6 --- /dev/null +++ b/databases/backends/aiopg.py @@ -0,0 +1,260 @@ +import logging +import typing +import uuid + +import aiopg + +from sqlalchemy.dialects.postgresql import pypostgresql +from sqlalchemy.engine.interfaces import Dialect, ExecutionContext +from sqlalchemy.engine.result import ResultMetaData, RowProxy +from sqlalchemy.sql import ClauseElement +from sqlalchemy.types import TypeEngine + +from databases.core import DatabaseURL +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend + +logger = logging.getLogger("databases") + + +class AiopgBackend(DatabaseBackend): + def __init__( + self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any + ) -> None: + self._database_url = DatabaseURL(database_url) + self._options = options + self._dialect = self._get_dialect() + self._pool = None + + def _get_dialect(self) -> Dialect: + dialect = pypostgresql.dialect(paramstyle="pyformat") + + dialect.implicit_returning = True + dialect.supports_native_enum = True + dialect.supports_smallserial = True # 9.2+ + dialect._backslash_escapes = False + dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ + dialect._has_native_hstore = True + dialect.supports_native_decimal = True + + return dialect + + def _get_connection_kwargs(self) -> dict: # TODO move to `core.py` + url_options = self._database_url.options + + kwargs = {} + min_size = url_options.get("min_size") + max_size = url_options.get("max_size") + ssl = url_options.get("ssl") + + if min_size is not None: + kwargs["minsize"] = int(min_size) + if max_size is not None: + kwargs["maxsize"] = int(max_size) + if ssl is not None: + kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + + for key, value in self._options.items(): + # Coerce 'min_size' and 'max_size' for consistency. + if key == "min_size": + key = "minsize" + elif key == "max_size": + key = "maxsize" + kwargs[key] = value + + return kwargs + + async def connect(self) -> None: # TODO as MySQL one? + assert self._pool is None, "DatabaseBackend is already running" + kwargs = self._get_connection_kwargs() + self._pool = await aiopg.create_pool( + host=self._database_url.hostname, + port=self._database_url.port, + user=self._database_url.username or getpass.getuser(), + password=self._database_url.password, + database=self._database_url.database, + # autocommit=True, + **kwargs, + ) + + async def disconnect(self) -> None: + assert self._pool is not None, "DatabaseBackend is not running" + self._pool.close() + await self._pool.wait_closed() + self._pool = None + + def connection(self) -> "AiopgConnection": + return AiopgConnection(self, self._dialect) + + +class CompilationContext: + def __init__(self, context: ExecutionContext): + self.context = context + + +class AiopgConnection(ConnectionBackend): + def __init__(self, database: AiopgBackend, dialect: Dialect): + self._database = database + self._dialect = dialect + self._connection = None # type: typing.Optional[aiopg.Connection] + + async def acquire(self) -> None: + assert self._connection is None, "Connection is already acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + self._connection = await self._database._pool.acquire() + + async def release(self) -> None: + assert self._connection is not None, "Connection is not acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + await self._database._pool.release(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + # TODO + import pdb; pdb.set_trace() + try: + await cursor.execute(query, args) + rows = await cursor.fetchall() + metadata = ResultMetaData(context, cursor.description) + return [ + RowProxy(metadata, row, metadata._processors, metadata._keymap) + for row in rows + ] + finally: + cursor.close() + + async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + # TODO + import pdb; pdb.set_trace() + try: + await cursor.execute(query, args) + row = await cursor.fetchone() + if row is None: + return None + metadata = ResultMetaData(context, cursor.description) + return RowProxy(metadata, row, metadata._processors, metadata._keymap) + finally: + cursor.close() + + async def execute(self, query: ClauseElement) -> typing.Any: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + # TODO + import pdb; pdb.set_trace() + try: + await cursor.execute(query, args) + return cursor.lastrowid + finally: + cursor.close() + + async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + assert self._connection is not None, "Connection is not acquired" + cursor = await self._connection.cursor() + # TODO + import pdb; pdb.set_trace() + try: + for single_query in queries: + single_query, args, context = self._compile(single_query) + await cursor.execute(single_query, args) + finally: + cursor.close() + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Any, None]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + # TODO + import pdb; pdb.set_trace() + try: + await cursor.execute(query, args) + metadata = ResultMetaData(context, cursor.description) + async for row in cursor: + yield RowProxy(metadata, row, metadata._processors, metadata._keymap) + finally: + cursor.close() + + def transaction(self) -> TransactionBackend: + return AiopgTransaction(self) + + def _compile( + self, query: ClauseElement + ) -> typing.Tuple[str, dict, CompilationContext]: + compiled = query.compile(dialect=self._dialect) + args = compiled.construct_params() + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + + execution_context = self._dialect.execution_ctx_cls() + execution_context.dialect = self._dialect + execution_context.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) + + logger.debug("Query: %s\nArgs: %s", compiled.string, args) + return compiled.string, args, CompilationContext(execution_context) + + @property + def raw_connection(self) -> aiopg.connection.Connection: + assert self._connection is not None, "Connection is not acquired" + return self._connection + + +class AiopgTransaction(TransactionBackend): + def __init__(self, connection: AiopgConnection): + self._connection = connection + self._is_root = False + self._savepoint_name = "" + + async def start(self, is_root: bool) -> None: + import pdb; pdb.set_trace() + assert self._connection._connection is not None, "Connection is not acquired" + self._is_root = is_root + cursor = await self._connection._connection.cursor() + if self._is_root: + # await self._connection._connection.begin() + await cursor.execute("BEGIN") + else: + id = str(uuid.uuid4()).replace("-", "_") + self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" + # cursor = await self._connection._connection.cursor() + try: + await cursor.execute(f"SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() + + async def commit(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + # await self._connection._connection.commit() + await cursor.execute("COMMIT") + else: + # cursor = await self._connection._connection.cursor() + try: + await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() + + async def rollback(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + # await self._connection._connection.rollback() + await cursor.execute("ROLLBACK") + else: + # cursor = await self._connection._connection.cursor() + try: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() diff --git a/databases/core.py b/databases/core.py index 9034e5fe..fc3c8838 100644 --- a/databases/core.py +++ b/databases/core.py @@ -18,8 +18,17 @@ class Database: + # TODO Nested schema? + # { + # "postgresql": { + # "asyncpg": "...", # Default + # "aiopg": "..." + # } + # } SUPPORTED_BACKENDS = { + # TODO `postgresql+asyncpg`? "postgresql": "databases.backends.postgres:PostgresBackend", + "postgresql+psycopg2": "databases.backends.aiopg:AiopgBackend", "mysql": "databases.backends.mysql:MySQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", } @@ -37,7 +46,7 @@ def __init__( self._force_rollback = force_rollback - backend_str = self.SUPPORTED_BACKENDS[self.url.dialect] + backend_str = self.SUPPORTED_BACKENDS[self.url.scheme] backend_cls = import_from_string(backend_str) assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) @@ -330,6 +339,10 @@ def components(self) -> SplitResult: self._components = urlsplit(self._url) return self._components + @property + def scheme(self) -> str: + return self.components.scheme + @property def dialect(self) -> str: return self.components.scheme.split("+")[0] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..10d6b774 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +import os + +import pytest +import sqlalchemy + +from databases import DatabaseURL + +assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." + +DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] + + +# @pytest.fixture(autouse=True, scope="module") +# def metadata(): +# yield sqlalchemy.MetaData() + + + +# @pytest.fixture(autouse=True, scope="module") +# def create_test_database(): +# # Create test databases +# import pdb; pdb.set_trace() +# for url in DATABASE_URLS: +# database_url = DatabaseURL(url) +# if database_url.dialect == "mysql": +# url = str(database_url.replace(driver="pymysql")) +# engine = sqlalchemy.create_engine(url) +# metadata.create_all(engine) + +# # Run the test suite +# yield + +# # Drop test databases +# for url in DATABASE_URLS: +# database_url = DatabaseURL(url) +# if database_url.dialect == "mysql": +# url = str(database_url.replace(driver="pymysql")) +# engine = sqlalchemy.create_engine(url) +# metadata.drop_all(engine) diff --git a/tests/test_databases.py b/tests/test_databases.py index b7123abc..3692b25e 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -71,9 +71,10 @@ def process_result_value(self, value, dialect): ) +# TODO Move to `conftest.py` @pytest.fixture(autouse=True, scope="module") def create_test_database(): - # Create test databases + # Create test databases with tables creation for url in DATABASE_URLS: database_url = DatabaseURL(url) if database_url.dialect == "mysql": diff --git a/tests/test_integration.py b/tests/test_integration.py index 54fd14c2..6b19eb4a 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,6 +23,7 @@ ) +# TODO Move to `conftest.py` with tables creation @pytest.fixture(autouse=True, scope="module") def create_test_database(): # Create test databases From 66ec568b35aba65e0c0f299a757afad7ac60e095 Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Thu, 11 Jul 2019 02:26:45 +0930 Subject: [PATCH 2/8] Checks for failing unittests --- databases/backends/aiopg.py | 20 +------------------- tests/test_databases.py | 37 +++++++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 7121d8f6..7dbd7585 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -63,7 +63,7 @@ def _get_connection_kwargs(self) -> dict: # TODO move to `core.py` return kwargs - async def connect(self) -> None: # TODO as MySQL one? + async def connect(self) -> None: assert self._pool is None, "DatabaseBackend is already running" kwargs = self._get_connection_kwargs() self._pool = await aiopg.create_pool( @@ -72,7 +72,6 @@ async def connect(self) -> None: # TODO as MySQL one? user=self._database_url.username or getpass.getuser(), password=self._database_url.password, database=self._database_url.database, - # autocommit=True, **kwargs, ) @@ -112,8 +111,6 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: assert self._connection is not None, "Connection is not acquired" query, args, context = self._compile(query) cursor = await self._connection.cursor() - # TODO - import pdb; pdb.set_trace() try: await cursor.execute(query, args) rows = await cursor.fetchall() @@ -129,8 +126,6 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin assert self._connection is not None, "Connection is not acquired" query, args, context = self._compile(query) cursor = await self._connection.cursor() - # TODO - import pdb; pdb.set_trace() try: await cursor.execute(query, args) row = await cursor.fetchone() @@ -145,8 +140,6 @@ async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" query, args, context = self._compile(query) cursor = await self._connection.cursor() - # TODO - import pdb; pdb.set_trace() try: await cursor.execute(query, args) return cursor.lastrowid @@ -156,8 +149,6 @@ async def execute(self, query: ClauseElement) -> typing.Any: async def execute_many(self, queries: typing.List[ClauseElement]) -> None: assert self._connection is not None, "Connection is not acquired" cursor = await self._connection.cursor() - # TODO - import pdb; pdb.set_trace() try: for single_query in queries: single_query, args, context = self._compile(single_query) @@ -171,8 +162,6 @@ async def iterate( assert self._connection is not None, "Connection is not acquired" query, args, context = self._compile(query) cursor = await self._connection.cursor() - # TODO - import pdb; pdb.set_trace() try: await cursor.execute(query, args) metadata = ResultMetaData(context, cursor.description) @@ -217,17 +206,14 @@ def __init__(self, connection: AiopgConnection): self._savepoint_name = "" async def start(self, is_root: bool) -> None: - import pdb; pdb.set_trace() assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root cursor = await self._connection._connection.cursor() if self._is_root: - # await self._connection._connection.begin() await cursor.execute("BEGIN") else: id = str(uuid.uuid4()).replace("-", "_") self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" - # cursor = await self._connection._connection.cursor() try: await cursor.execute(f"SAVEPOINT {self._savepoint_name}") finally: @@ -237,10 +223,8 @@ async def commit(self) -> None: assert self._connection._connection is not None, "Connection is not acquired" cursor = await self._connection._connection.cursor() if self._is_root: - # await self._connection._connection.commit() await cursor.execute("COMMIT") else: - # cursor = await self._connection._connection.cursor() try: await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") finally: @@ -250,10 +234,8 @@ async def rollback(self) -> None: assert self._connection._connection is not None, "Connection is not acquired" cursor = await self._connection._connection.cursor() if self._is_root: - # await self._connection._connection.rollback() await cursor.execute("ROLLBACK") else: - # cursor = await self._connection._connection.cursor() try: await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}") finally: diff --git a/tests/test_databases.py b/tests/test_databases.py index 3692b25e..622e5f52 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,6 +2,7 @@ import datetime import decimal import functools +import json import os import pytest @@ -300,6 +301,9 @@ async def test_execute_return_val(database_url): query = notes.insert() values = {"text": "example1", "completed": True} pk = await database.execute(query, values) + # Apparently for `aiopg` it's OID that will always 0 in this case + # As it's only one action within this cursor life cycle + # Something to triple check assert isinstance(pk, int) query = notes.select().where(notes.c.id == pk) @@ -505,7 +509,15 @@ async def test_json_field(database_url): # execute() query = session.insert() values = {"data": {"text": "hello", "boolean": True, "int": 1}} - await database.execute(query, values) + if str(database_url).startswith("postgresql+psycopg2"): + await database.execute( + query.values( + # or wrapped with `psycopg2.extras.Json` + data=json.dumps({"text": "hello", "boolean": True, "int": 1}) + ) + ) + else: + await database.execute(query, values) # fetch_all() query = session.select() @@ -667,7 +679,7 @@ async def test_queries_with_expose_backend_connection(database_url): raw_connection = connection.raw_connection # Insert query - if str(database_url).startswith("mysql"): + if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" @@ -675,11 +687,11 @@ async def test_queries_with_expose_backend_connection(database_url): # execute() values = ("example1", True) - if str(database_url).startswith("postgresql"): - await raw_connection.execute(insert_query, *values) - elif str(database_url).startswith("mysql"): + if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) + elif str(database_url).startswith("postgresql"): + await raw_connection.execute(insert_query, *values) elif str(database_url).startswith("sqlite"): await raw_connection.execute(insert_query, values) @@ -689,6 +701,11 @@ async def test_queries_with_expose_backend_connection(database_url): if str(database_url).startswith("mysql"): cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) + elif str(database_url).startswith("postgresql+psycopg2"): + cursor = await raw_connection.cursor() + # No async support for `executemany` + for value in values: + await cursor.execute(insert_query, value) else: await raw_connection.executemany(insert_query, values) @@ -696,12 +713,12 @@ async def test_queries_with_expose_backend_connection(database_url): select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" # fetch_all() - if str(database_url).startswith("postgresql"): - results = await raw_connection.fetch(select_query) - elif str(database_url).startswith("mysql"): + if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() + elif str(database_url).startswith("postgresql"): + results = await raw_connection.fetch(select_query) elif str(database_url).startswith("sqlite"): results = await raw_connection.execute_fetchall(select_query) @@ -713,9 +730,9 @@ async def test_queries_with_expose_backend_connection(database_url): assert results[1][2] == False assert results[2][1] == "example3" assert results[2][2] == True - + # fetch_one() - if str(database_url).startswith("postgresql"): + if str(database_url).startswith("postgresql") and not str(database_url).startswith("postgresql+psycopg2"): result = await raw_connection.fetchrow(select_query) else: cursor = await raw_connection.cursor() From 83c14557cc1fe5a523e8deabf773ee69e1b7719c Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Sat, 13 Jul 2019 02:28:09 +0930 Subject: [PATCH 3/8] Sticking with `postgresql+aiopg` and some clean up --- databases/backends/aiopg.py | 4 +-- databases/core.py | 10 +----- tests/conftest.py | 39 --------------------- tests/test_connection_options.py | 27 +++++++++++++++ tests/test_databases.py | 58 +++++++++++++++++--------------- tests/test_integration.py | 5 ++- 6 files changed, 65 insertions(+), 78 deletions(-) delete mode 100644 tests/conftest.py diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 7dbd7585..2482f73c 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -1,9 +1,9 @@ +import getpass import logging import typing import uuid import aiopg - from sqlalchemy.dialects.postgresql import pypostgresql from sqlalchemy.engine.interfaces import Dialect, ExecutionContext from sqlalchemy.engine.result import ResultMetaData, RowProxy @@ -38,7 +38,7 @@ def _get_dialect(self) -> Dialect: return dialect - def _get_connection_kwargs(self) -> dict: # TODO move to `core.py` + def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options kwargs = {} diff --git a/databases/core.py b/databases/core.py index fc3c8838..8401de22 100644 --- a/databases/core.py +++ b/databases/core.py @@ -18,17 +18,9 @@ class Database: - # TODO Nested schema? - # { - # "postgresql": { - # "asyncpg": "...", # Default - # "aiopg": "..." - # } - # } SUPPORTED_BACKENDS = { - # TODO `postgresql+asyncpg`? "postgresql": "databases.backends.postgres:PostgresBackend", - "postgresql+psycopg2": "databases.backends.aiopg:AiopgBackend", + "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", "mysql": "databases.backends.mysql:MySQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", } diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 10d6b774..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,39 +0,0 @@ -import os - -import pytest -import sqlalchemy - -from databases import DatabaseURL - -assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." - -DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] - - -# @pytest.fixture(autouse=True, scope="module") -# def metadata(): -# yield sqlalchemy.MetaData() - - - -# @pytest.fixture(autouse=True, scope="module") -# def create_test_database(): -# # Create test databases -# import pdb; pdb.set_trace() -# for url in DATABASE_URLS: -# database_url = DatabaseURL(url) -# if database_url.dialect == "mysql": -# url = str(database_url.replace(driver="pymysql")) -# engine = sqlalchemy.create_engine(url) -# metadata.create_all(engine) - -# # Run the test suite -# yield - -# # Drop test databases -# for url in DATABASE_URLS: -# database_url = DatabaseURL(url) -# if database_url.dialect == "mysql": -# url = str(database_url.replace(driver="pymysql")) -# engine = sqlalchemy.create_engine(url) -# metadata.drop_all(engine) diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index 0cdc33bd..d850a1c4 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -2,6 +2,7 @@ Unit tests for the backend connection arguments. """ +from databases.backends.aiopg import AiopgBackend from databases.backends.mysql import MySQLBackend from databases.backends.postgres import PostgresBackend @@ -52,3 +53,29 @@ def test_mysql_explicit_ssl(): backend = MySQLBackend("mysql://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} + + +def test_aiopg_pool_size(): + backend = AiopgBackend("postgres+aiopg://localhost/database?min_size=1&max_size=20") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +def test_aiopg_explicit_pool_size(): + backend = AiopgBackend( + "postgres+aiopg://localhost/database", min_size=1, max_size=20 + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +def test_aiopg_ssl(): + backend = AiopgBackend("postgres+aiopg://localhost/database?ssl=true") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} + + +def test_aiopg_explicit_ssl(): + backend = AiopgBackend("postgres+aiopg://localhost/database", ssl=True) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} diff --git a/tests/test_databases.py b/tests/test_databases.py index 622e5f52..cea0366d 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -72,7 +72,6 @@ def process_result_value(self, value, dialect): ) -# TODO Move to `conftest.py` @pytest.fixture(autouse=True, scope="module") def create_test_database(): # Create test databases with tables creation @@ -80,6 +79,8 @@ def create_test_database(): database_url = DatabaseURL(url) if database_url.dialect == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.driver == "aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -91,6 +92,8 @@ def create_test_database(): database_url = DatabaseURL(url) if database_url.dialect == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.driver == "aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) @@ -301,15 +304,19 @@ async def test_execute_return_val(database_url): query = notes.insert() values = {"text": "example1", "completed": True} pk = await database.execute(query, values) - # Apparently for `aiopg` it's OID that will always 0 in this case - # As it's only one action within this cursor life cycle - # Something to triple check assert isinstance(pk, int) - query = notes.select().where(notes.c.id == pk) - result = await database.fetch_one(query) - assert result["text"] == "example1" - assert result["completed"] == True + # Apparently for `aiopg` it's OID that will always 0 in this case + # As it's only one action within this cursor life cycle + # It's recommended to use the `RETURNING` clause + # For obtaining the record id + if database.url.scheme == "postgresql+aiopg": + assert pk == 0 + else: + query = notes.select().where(notes.c.id == pk) + result = await database.fetch_one(query) + assert result["text"] == "example1" + assert result["completed"] == True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -508,14 +515,11 @@ async def test_json_field(database_url): async with database.transaction(force_rollback=True): # execute() query = session.insert() - values = {"data": {"text": "hello", "boolean": True, "int": 1}} - if str(database_url).startswith("postgresql+psycopg2"): - await database.execute( - query.values( - # or wrapped with `psycopg2.extras.Json` - data=json.dumps({"text": "hello", "boolean": True, "int": 1}) - ) - ) + data = {"text": "hello", "boolean": True, "int": 1} + values = {"data": data} + + if database.url.scheme == "postgresql+aiopg": + await database.execute(query, {"data": json.dumps(data)}) else: await database.execute(query, values) @@ -679,7 +683,7 @@ async def test_queries_with_expose_backend_connection(database_url): raw_connection = connection.raw_connection # Insert query - if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" @@ -687,21 +691,21 @@ async def test_queries_with_expose_backend_connection(database_url): # execute() values = ("example1", True) - if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) - elif str(database_url).startswith("postgresql"): + elif database.url.scheme == "postgresql": await raw_connection.execute(insert_query, *values) - elif str(database_url).startswith("sqlite"): + elif database.url.scheme == "sqlite": await raw_connection.execute(insert_query, values) # execute_many() values = [("example2", False), ("example3", True)] - if str(database_url).startswith("mysql"): + if database.url.scheme == "mysql": cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) - elif str(database_url).startswith("postgresql+psycopg2"): + elif database.url.scheme == "postgresql+aiopg": cursor = await raw_connection.cursor() # No async support for `executemany` for value in values: @@ -713,13 +717,13 @@ async def test_queries_with_expose_backend_connection(database_url): select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" # fetch_all() - if str(database_url).startswith("mysql") or str(database_url).startswith("postgresql+psycopg2"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() - elif str(database_url).startswith("postgresql"): + elif database.url.scheme == "postgresql": results = await raw_connection.fetch(select_query) - elif str(database_url).startswith("sqlite"): + elif database.url.scheme == "sqlite": results = await raw_connection.execute_fetchall(select_query) assert len(results) == 3 @@ -730,9 +734,9 @@ async def test_queries_with_expose_backend_connection(database_url): assert results[1][2] == False assert results[2][1] == "example3" assert results[2][2] == True - + # fetch_one() - if str(database_url).startswith("postgresql") and not str(database_url).startswith("postgresql+psycopg2"): + if database.url.scheme == "postgresql": result = await raw_connection.fetchrow(select_query) else: cursor = await raw_connection.cursor() diff --git a/tests/test_integration.py b/tests/test_integration.py index 6b19eb4a..4b4e3781 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,7 +23,6 @@ ) -# TODO Move to `conftest.py` with tables creation @pytest.fixture(autouse=True, scope="module") def create_test_database(): # Create test databases @@ -31,6 +30,8 @@ def create_test_database(): database_url = DatabaseURL(url) if database_url.dialect == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.driver == "aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -42,6 +43,8 @@ def create_test_database(): database_url = DatabaseURL(url) if database_url.dialect == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.driver == "aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) From bf4dfe6538fd999b247545c8b1514ee911da8e3f Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Sat, 13 Jul 2019 02:35:30 +0930 Subject: [PATCH 4/8] Add `aiopg` to the travis build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1506b4ec..914dd9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: - "3.7" env: - - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db" + - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db, postgresql+aiopg://localhost/test_database" services: - postgresql From cefc3433a19c748fe109176289e358a39da83fed Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Sat, 13 Jul 2019 02:40:34 +0930 Subject: [PATCH 5/8] Adjustments for `requirements.txt` and `setup.py` --- requirements.txt | 1 + setup.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cae52dc1..1b06551a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ aiocontextvars;python_version<"3.7" # Async database drivers aiomysql +aiopg aiosqlite asyncpg diff --git a/setup.py b/setup.py index 292e6379..d0cbf4a1 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,8 @@ def get_packages(package): extras_require={ "postgresql": ["asyncpg", "psycopg2-binary"], "mysql": ["aiomysql", "pymysql"], - "sqlite": ["aiosqlite"] + "sqlite": ["aiosqlite"], + "postgresql+aiopg": ["aiopg"] }, classifiers=[ "Development Status :: 3 - Alpha", From 7e37d41bf67e918f6b93aad8c42b90c52539a6c8 Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Mon, 30 Sep 2019 20:01:07 +0930 Subject: [PATCH 6/8] Fix the dsn for some tests --- tests/test_connection_options.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index d850a1c4..437ee360 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -56,26 +56,26 @@ def test_mysql_explicit_ssl(): def test_aiopg_pool_size(): - backend = AiopgBackend("postgres+aiopg://localhost/database?min_size=1&max_size=20") + backend = AiopgBackend("postgresql+aiopg://localhost/database?min_size=1&max_size=20") kwargs = backend._get_connection_kwargs() assert kwargs == {"minsize": 1, "maxsize": 20} def test_aiopg_explicit_pool_size(): backend = AiopgBackend( - "postgres+aiopg://localhost/database", min_size=1, max_size=20 + "postgresql+aiopg://localhost/database", min_size=1, max_size=20 ) kwargs = backend._get_connection_kwargs() assert kwargs == {"minsize": 1, "maxsize": 20} def test_aiopg_ssl(): - backend = AiopgBackend("postgres+aiopg://localhost/database?ssl=true") + backend = AiopgBackend("postgresql+aiopg://localhost/database?ssl=true") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} def test_aiopg_explicit_ssl(): - backend = AiopgBackend("postgres+aiopg://localhost/database", ssl=True) + backend = AiopgBackend("postgresql+aiopg://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} From 30522119c0b1e2b73a252312c7869c2705c8732b Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Wed, 2 Oct 2019 22:27:07 +0930 Subject: [PATCH 7/8] Minor fixes and linter --- tests/test_connection_options.py | 4 +++- tests/test_databases.py | 8 ++++---- tests/test_integration.py | 8 ++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index 437ee360..d6d254ed 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -56,7 +56,9 @@ def test_mysql_explicit_ssl(): def test_aiopg_pool_size(): - backend = AiopgBackend("postgresql+aiopg://localhost/database?min_size=1&max_size=20") + backend = AiopgBackend( + "postgresql+aiopg://localhost/database?min_size=1&max_size=20" + ) kwargs = backend._get_connection_kwargs() assert kwargs == {"minsize": 1, "maxsize": 20} diff --git a/tests/test_databases.py b/tests/test_databases.py index cea0366d..0cb318d5 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -77,9 +77,9 @@ def create_test_database(): # Create test databases with tables creation for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) - elif database_url.driver == "aiopg": + elif database_url.scheme == "postgresql+aiopg": url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -90,9 +90,9 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) - elif database_url.driver == "aiopg": + elif database_url.scheme == "postgresql+aiopg": url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) diff --git a/tests/test_integration.py b/tests/test_integration.py index 4b4e3781..c0cef2db 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -28,9 +28,9 @@ def create_test_database(): # Create test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) - elif database_url.driver == "aiopg": + elif database_url.scheme == "postgresql+aiopg": url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -41,9 +41,9 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) - elif database_url.driver == "aiopg": + elif database_url.scheme == "postgresql+aiopg": url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) From 9e58341c10052fbe3215d61db3b63ce8c07d5e60 Mon Sep 17 00:00:00 2001 From: George Bogodukhov Date: Thu, 3 Oct 2019 19:07:00 +0930 Subject: [PATCH 8/8] Change the dialect for `aiopg` for better support. --- databases/backends/aiopg.py | 10 +++++++--- tests/test_databases.py | 10 +++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 2482f73c..3c2f1bef 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -1,10 +1,12 @@ import getpass +import json import logging import typing import uuid import aiopg -from sqlalchemy.dialects.postgresql import pypostgresql +from aiopg.sa.engine import APGCompiler_psycopg2 +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.interfaces import Dialect, ExecutionContext from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement @@ -26,8 +28,10 @@ def __init__( self._pool = None def _get_dialect(self) -> Dialect: - dialect = pypostgresql.dialect(paramstyle="pyformat") - + dialect = PGDialect_psycopg2( + json_serializer=json.dumps, json_deserializer=lambda x: x + ) + dialect.statement_compiler = APGCompiler_psycopg2 dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ diff --git a/tests/test_databases.py b/tests/test_databases.py index 0cb318d5..a9728745 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,7 +2,6 @@ import datetime import decimal import functools -import json import os import pytest @@ -514,14 +513,10 @@ async def test_json_field(database_url): async with Database(database_url) as database: async with database.transaction(force_rollback=True): # execute() - query = session.insert() data = {"text": "hello", "boolean": True, "int": 1} values = {"data": data} - - if database.url.scheme == "postgresql+aiopg": - await database.execute(query, {"data": json.dumps(data)}) - else: - await database.execute(query, values) + query = session.insert() + await database.execute(query, values) # fetch_all() query = session.select() @@ -544,6 +539,7 @@ async def test_custom_field(database_url): # execute() query = custom_date.insert() values = {"title": "Hello, world", "published": today} + await database.execute(query, values) # fetch_all()