diff --git a/.travis.yml b/.travis.yml index 1506b4ec..914dd9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: - "3.7" env: - - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db" + - TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db, postgresql+aiopg://localhost/test_database" services: - postgresql diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py new file mode 100644 index 00000000..3c2f1bef --- /dev/null +++ b/databases/backends/aiopg.py @@ -0,0 +1,246 @@ +import getpass +import json +import logging +import typing +import uuid + +import aiopg +from aiopg.sa.engine import APGCompiler_psycopg2 +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.engine.interfaces import Dialect, ExecutionContext +from sqlalchemy.engine.result import ResultMetaData, RowProxy +from sqlalchemy.sql import ClauseElement +from sqlalchemy.types import TypeEngine + +from databases.core import DatabaseURL +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend + +logger = logging.getLogger("databases") + + +class AiopgBackend(DatabaseBackend): + def __init__( + self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any + ) -> None: + self._database_url = DatabaseURL(database_url) + self._options = options + self._dialect = self._get_dialect() + self._pool = None + + def _get_dialect(self) -> Dialect: + dialect = PGDialect_psycopg2( + json_serializer=json.dumps, json_deserializer=lambda x: x + ) + dialect.statement_compiler = APGCompiler_psycopg2 + dialect.implicit_returning = True + dialect.supports_native_enum = True + dialect.supports_smallserial = True # 9.2+ + dialect._backslash_escapes = False + dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ + dialect._has_native_hstore = True + dialect.supports_native_decimal = True + + return dialect + + def _get_connection_kwargs(self) -> dict: + url_options = self._database_url.options + + kwargs = {} + min_size = url_options.get("min_size") + max_size = url_options.get("max_size") + ssl = url_options.get("ssl") + + if min_size is not None: + kwargs["minsize"] = int(min_size) + if max_size is not None: + kwargs["maxsize"] = int(max_size) + if ssl is not None: + kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + + for key, value in self._options.items(): + # Coerce 'min_size' and 'max_size' for consistency. + if key == "min_size": + key = "minsize" + elif key == "max_size": + key = "maxsize" + kwargs[key] = value + + return kwargs + + async def connect(self) -> None: + assert self._pool is None, "DatabaseBackend is already running" + kwargs = self._get_connection_kwargs() + self._pool = await aiopg.create_pool( + host=self._database_url.hostname, + port=self._database_url.port, + user=self._database_url.username or getpass.getuser(), + password=self._database_url.password, + database=self._database_url.database, + **kwargs, + ) + + async def disconnect(self) -> None: + assert self._pool is not None, "DatabaseBackend is not running" + self._pool.close() + await self._pool.wait_closed() + self._pool = None + + def connection(self) -> "AiopgConnection": + return AiopgConnection(self, self._dialect) + + +class CompilationContext: + def __init__(self, context: ExecutionContext): + self.context = context + + +class AiopgConnection(ConnectionBackend): + def __init__(self, database: AiopgBackend, dialect: Dialect): + self._database = database + self._dialect = dialect + self._connection = None # type: typing.Optional[aiopg.Connection] + + async def acquire(self) -> None: + assert self._connection is None, "Connection is already acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + self._connection = await self._database._pool.acquire() + + async def release(self) -> None: + assert self._connection is not None, "Connection is not acquired" + assert self._database._pool is not None, "DatabaseBackend is not running" + await self._database._pool.release(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + try: + await cursor.execute(query, args) + rows = await cursor.fetchall() + metadata = ResultMetaData(context, cursor.description) + return [ + RowProxy(metadata, row, metadata._processors, metadata._keymap) + for row in rows + ] + finally: + cursor.close() + + async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + try: + await cursor.execute(query, args) + row = await cursor.fetchone() + if row is None: + return None + metadata = ResultMetaData(context, cursor.description) + return RowProxy(metadata, row, metadata._processors, metadata._keymap) + finally: + cursor.close() + + async def execute(self, query: ClauseElement) -> typing.Any: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + try: + await cursor.execute(query, args) + return cursor.lastrowid + finally: + cursor.close() + + async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + assert self._connection is not None, "Connection is not acquired" + cursor = await self._connection.cursor() + try: + for single_query in queries: + single_query, args, context = self._compile(single_query) + await cursor.execute(single_query, args) + finally: + cursor.close() + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Any, None]: + assert self._connection is not None, "Connection is not acquired" + query, args, context = self._compile(query) + cursor = await self._connection.cursor() + try: + await cursor.execute(query, args) + metadata = ResultMetaData(context, cursor.description) + async for row in cursor: + yield RowProxy(metadata, row, metadata._processors, metadata._keymap) + finally: + cursor.close() + + def transaction(self) -> TransactionBackend: + return AiopgTransaction(self) + + def _compile( + self, query: ClauseElement + ) -> typing.Tuple[str, dict, CompilationContext]: + compiled = query.compile(dialect=self._dialect) + args = compiled.construct_params() + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + + execution_context = self._dialect.execution_ctx_cls() + execution_context.dialect = self._dialect + execution_context.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) + + logger.debug("Query: %s\nArgs: %s", compiled.string, args) + return compiled.string, args, CompilationContext(execution_context) + + @property + def raw_connection(self) -> aiopg.connection.Connection: + assert self._connection is not None, "Connection is not acquired" + return self._connection + + +class AiopgTransaction(TransactionBackend): + def __init__(self, connection: AiopgConnection): + self._connection = connection + self._is_root = False + self._savepoint_name = "" + + async def start(self, is_root: bool) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + self._is_root = is_root + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("BEGIN") + else: + id = str(uuid.uuid4()).replace("-", "_") + self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" + try: + await cursor.execute(f"SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() + + async def commit(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("COMMIT") + else: + try: + await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() + + async def rollback(self) -> None: + assert self._connection._connection is not None, "Connection is not acquired" + cursor = await self._connection._connection.cursor() + if self._is_root: + await cursor.execute("ROLLBACK") + else: + try: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}") + finally: + cursor.close() diff --git a/databases/core.py b/databases/core.py index c2f61da4..513021a4 100644 --- a/databases/core.py +++ b/databases/core.py @@ -42,6 +42,7 @@ class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", + "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", "postgres": "databases.backends.postgres:PostgresBackend", "mysql": "databases.backends.mysql:MySQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", @@ -60,7 +61,7 @@ def __init__( self._force_rollback = force_rollback - backend_str = self.SUPPORTED_BACKENDS[self.url.dialect] + backend_str = self.SUPPORTED_BACKENDS[self.url.scheme] backend_cls = import_from_string(backend_str) assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) @@ -367,6 +368,10 @@ def components(self) -> SplitResult: self._components = urlsplit(self._url) return self._components + @property + def scheme(self) -> str: + return self.components.scheme + @property def dialect(self) -> str: return self.components.scheme.split("+")[0] diff --git a/requirements.txt b/requirements.txt index cae52dc1..1b06551a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ aiocontextvars;python_version<"3.7" # Async database drivers aiomysql +aiopg aiosqlite asyncpg diff --git a/setup.py b/setup.py index 292e6379..d0cbf4a1 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,8 @@ def get_packages(package): extras_require={ "postgresql": ["asyncpg", "psycopg2-binary"], "mysql": ["aiomysql", "pymysql"], - "sqlite": ["aiosqlite"] + "sqlite": ["aiosqlite"], + "postgresql+aiopg": ["aiopg"] }, classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index 0cdc33bd..d6d254ed 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -2,6 +2,7 @@ Unit tests for the backend connection arguments. """ +from databases.backends.aiopg import AiopgBackend from databases.backends.mysql import MySQLBackend from databases.backends.postgres import PostgresBackend @@ -52,3 +53,31 @@ def test_mysql_explicit_ssl(): backend = MySQLBackend("mysql://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} + + +def test_aiopg_pool_size(): + backend = AiopgBackend( + "postgresql+aiopg://localhost/database?min_size=1&max_size=20" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +def test_aiopg_explicit_pool_size(): + backend = AiopgBackend( + "postgresql+aiopg://localhost/database", min_size=1, max_size=20 + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"minsize": 1, "maxsize": 20} + + +def test_aiopg_ssl(): + backend = AiopgBackend("postgresql+aiopg://localhost/database?ssl=true") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} + + +def test_aiopg_explicit_ssl(): + backend = AiopgBackend("postgresql+aiopg://localhost/database", ssl=True) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": True} diff --git a/tests/test_databases.py b/tests/test_databases.py index 5b2eca57..510a1a6f 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -73,11 +73,13 @@ def process_result_value(self, value, dialect): @pytest.fixture(autouse=True, scope="module") def create_test_database(): - # Create test databases + # Create test databases with tables creation for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.scheme == "postgresql+aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -87,8 +89,10 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.scheme == "postgresql+aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) @@ -301,10 +305,17 @@ async def test_execute_return_val(database_url): pk = await database.execute(query, values) assert isinstance(pk, int) - query = notes.select().where(notes.c.id == pk) - result = await database.fetch_one(query) - assert result["text"] == "example1" - assert result["completed"] == True + # Apparently for `aiopg` it's OID that will always 0 in this case + # As it's only one action within this cursor life cycle + # It's recommended to use the `RETURNING` clause + # For obtaining the record id + if database.url.scheme == "postgresql+aiopg": + assert pk == 0 + else: + query = notes.select().where(notes.c.id == pk) + result = await database.fetch_one(query) + assert result["text"] == "example1" + assert result["completed"] == True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -502,8 +513,9 @@ async def test_json_field(database_url): async with Database(database_url) as database: async with database.transaction(force_rollback=True): # execute() + data = {"text": "hello", "boolean": True, "int": 1} + values = {"data": data} query = session.insert() - values = {"data": {"text": "hello", "boolean": True, "int": 1}} await database.execute(query, values) # fetch_all() @@ -527,6 +539,7 @@ async def test_custom_field(database_url): # execute() query = custom_date.insert() values = {"title": "Hello, world", "published": today} + await database.execute(query, values) # fetch_all() @@ -666,7 +679,7 @@ async def test_queries_with_expose_backend_connection(database_url): raw_connection = connection.raw_connection # Insert query - if str(database_url).startswith("mysql"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)" @@ -674,20 +687,25 @@ async def test_queries_with_expose_backend_connection(database_url): # execute() values = ("example1", True) - if str(database_url).startswith("postgresql"): - await raw_connection.execute(insert_query, *values) - elif str(database_url).startswith("mysql"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) - elif str(database_url).startswith("sqlite"): + elif database.url.scheme == "postgresql": + await raw_connection.execute(insert_query, *values) + elif database.url.scheme == "sqlite": await raw_connection.execute(insert_query, values) # execute_many() values = [("example2", False), ("example3", True)] - if str(database_url).startswith("mysql"): + if database.url.scheme == "mysql": cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) + elif database.url.scheme == "postgresql+aiopg": + cursor = await raw_connection.cursor() + # No async support for `executemany` + for value in values: + await cursor.execute(insert_query, value) else: await raw_connection.executemany(insert_query, values) @@ -695,13 +713,13 @@ async def test_queries_with_expose_backend_connection(database_url): select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" # fetch_all() - if str(database_url).startswith("postgresql"): - results = await raw_connection.fetch(select_query) - elif str(database_url).startswith("mysql"): + if database.url.scheme in ["mysql", "postgresql+aiopg"]: cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() - elif str(database_url).startswith("sqlite"): + elif database.url.scheme == "postgresql": + results = await raw_connection.fetch(select_query) + elif database.url.scheme == "sqlite": results = await raw_connection.execute_fetchall(select_query) assert len(results) == 3 @@ -714,7 +732,7 @@ async def test_queries_with_expose_backend_connection(database_url): assert results[2][2] == True # fetch_one() - if str(database_url).startswith("postgresql"): + if database.url.scheme == "postgresql": result = await raw_connection.fetchrow(select_query) else: cursor = await raw_connection.cursor() diff --git a/tests/test_integration.py b/tests/test_integration.py index 54fd14c2..c0cef2db 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -28,8 +28,10 @@ def create_test_database(): # Create test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.scheme == "postgresql+aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.create_all(engine) @@ -39,8 +41,10 @@ def create_test_database(): # Drop test databases for url in DATABASE_URLS: database_url = DatabaseURL(url) - if database_url.dialect == "mysql": + if database_url.scheme == "mysql": url = str(database_url.replace(driver="pymysql")) + elif database_url.scheme == "postgresql+aiopg": + url = str(database_url.replace(driver=None)) engine = sqlalchemy.create_engine(url) metadata.drop_all(engine)