From a86edfaae53792d7eccceb0d860d5f0200faebb0 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Sat, 2 Mar 2024 23:17:09 +0500 Subject: [PATCH] S01E01 --- databases/backends/asyncpg.py | 45 ++++------------------- databases/backends/dialects/psycopg.py | 30 ++++++++++++++- databases/backends/psycopg.py | 51 +++++++++++++++++--------- 3 files changed, 70 insertions(+), 56 deletions(-) diff --git a/databases/backends/asyncpg.py b/databases/backends/asyncpg.py index 98ac44ea..ff61fe26 100644 --- a/databases/backends/asyncpg.py +++ b/databases/backends/asyncpg.py @@ -4,11 +4,10 @@ import asyncpg from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement -from sqlalchemy.sql.ddl import DDLElement from databases.backends.common.records import Record, create_column_maps -from databases.backends.dialects.psycopg import get_dialect -from databases.core import LOG_EXTRA, DatabaseURL +from databases.backends.dialects.psycopg import compile_query, get_dialect +from databases.core import DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, @@ -88,7 +87,7 @@ async def release(self) -> None: async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, result_columns = compile_query(query, self._dialect) rows = await self._connection.fetch(query_str, *args) dialect = self._dialect column_maps = create_column_maps(result_columns) @@ -96,7 +95,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, result_columns = compile_query(query, self._dialect) row = await self._connection.fetchrow(query_str, *args) if row is None: return None @@ -124,7 +123,7 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, _ = self._compile(query) + query_str, args, _ = compile_query(query, self._dialect) return await self._connection.fetchval(query_str, *args) async def execute_many(self, queries: typing.List[ClauseElement]) -> None: @@ -133,14 +132,14 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: # loop through multiple executes here, which should all end up # using the same prepared statement. for single_query in queries: - single_query, args, _ = self._compile(single_query) + single_query, args, _ = compile_query(single_query, self._dialect) await self._connection.execute(single_query, *args) async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, result_columns = compile_query(query, self._dialect) column_maps = create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) @@ -148,36 +147,6 @@ async def iterate( def transaction(self) -> TransactionBackend: return AsyncpgTransaction(connection=self) - def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: - compiled = query.compile( - dialect=self._dialect, compile_kwargs={"render_postcompile": True} - ) - - if not isinstance(query, DDLElement): - compiled_params = sorted(compiled.params.items()) - - mapping = { - key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) - } - compiled_query = compiled.string % mapping - - processors = compiled._bind_processors - args = [ - processors[key](val) if key in processors else val - for key, val in compiled_params - ] - result_map = compiled._result_columns - else: - compiled_query = compiled.string - args = [] - result_map = None - - query_message = compiled_query.replace(" \n", " ").replace("\n", " ") - logger.debug( - "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA - ) - return compiled_query, args, result_map - @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py index 1caf49fe..cfce052e 100644 --- a/databases/backends/dialects/psycopg.py +++ b/databases/backends/dialects/psycopg.py @@ -10,6 +10,9 @@ from sqlalchemy import types, util from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext from sqlalchemy.engine import processors +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.ddl import DDLElement from sqlalchemy.types import Float, Numeric @@ -43,7 +46,7 @@ class PGDialect_psycopg(PGDialect): execution_ctx_cls = PGExecutionContext_psycopg -def get_dialect() -> PGDialect_psycopg: +def get_dialect() -> Dialect: dialect = PGDialect_psycopg(paramstyle="pyformat") dialect.implicit_returning = True dialect.supports_native_enum = True @@ -53,3 +56,28 @@ def get_dialect() -> PGDialect_psycopg: dialect._has_native_hstore = True dialect.supports_native_decimal = True return dialect + + +def compile_query(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list, tuple]: + compiled = query.compile(dialect=dialect, compile_kwargs={"render_postcompile": True}) + + if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + + processors = compiled._bind_processors + args = [ + processors[key](val) if key in processors else val + for key, val in compiled_params + ] + result_map = compiled._result_columns + else: + compiled_query = compiled.string + args = [] + result_map = None + + return compiled_query, args, result_map diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py index 981742ce..a047a85c 100644 --- a/databases/backends/psycopg.py +++ b/databases/backends/psycopg.py @@ -1,26 +1,36 @@ import typing -from collections.abc import Sequence +import psycopg import psycopg_pool +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement -from databases.backends.dialects.psycopg import get_dialect +from databases.backends.common.records import Record, create_column_maps +from databases.backends.dialects.psycopg import compile_query, get_dialect from databases.core import DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, + Record as RecordInterface, TransactionBackend, ) class PsycopgBackend(DatabaseBackend): + _database_url: DatabaseURL + _options: typing.Dict[str, typing.Any] + _dialect: Dialect + _pool: typing.Optional[psycopg_pool.AsyncConnectionPool] + def __init__( - self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any + self, + database_url: typing.Union[DatabaseURL, str], + **options: typing.Dict[str, typing.Any], ) -> None: self._database_url = DatabaseURL(database_url) self._options = options self._dialect = get_dialect() - self._pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None + self._pool = None async def connect(self) -> None: if self._pool is not None: @@ -28,22 +38,31 @@ async def connect(self) -> None: self._pool = psycopg_pool.AsyncConnectionPool( self._database_url.url, open=False, **self._options) + + # TODO: Add configurable timeouts await self._pool.open() async def disconnect(self) -> None: if self._pool is None: return + # TODO: Add configurable timeouts await self._pool.close() self._pool = None def connection(self) -> "PsycopgConnection": - return PsycopgConnection(self) + return PsycopgConnection(self, self._dialect) class PsycopgConnection(ConnectionBackend): - def __init__(self, database: PsycopgBackend) -> None: + _database: PsycopgBackend + _dialect: Dialect + _connection: typing.Optional[psycopg.AsyncConnection] + + def __init__(self, database: PsycopgBackend, dialect: Dialect) -> None: self._database = database + self._dialect = dialect + self._connection = None async def acquire(self) -> None: if self._connection is not None: @@ -62,10 +81,17 @@ async def release(self) -> None: await self._database._pool.putconn(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + if self._connection is None: + raise RuntimeError("Connection is not acquired") + + query_str, args, result_columns = compile_query(query, self._dialect) + rows = await self._connection.fetch(query_str, *args) + column_maps = create_column_maps(result_columns) + return [Record(row, result_columns, self._dialect, column_maps) for row in rows] raise NotImplementedError() # pragma: no cover - async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: raise NotImplementedError() # pragma: no cover async def fetch_val( @@ -107,12 +133,3 @@ async def commit(self) -> None: async def rollback(self) -> None: raise NotImplementedError() # pragma: no cover - - -class Record(Sequence): - @property - def _mapping(self) -> typing.Mapping: - raise NotImplementedError() # pragma: no cover - - def __getitem__(self, key: typing.Any) -> typing.Any: - raise NotImplementedError() # pragma: no cover