Skip to content

Commit

Permalink
➕ Code cleaning for mysql, postgres and async
Browse files Browse the repository at this point in the history
* Moved common functions to a common package
* Created common Record for the DB supported
  • Loading branch information
tarsil committed Mar 22, 2023
1 parent c73f55e commit b3f6d51
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 195 deletions.
62 changes: 43 additions & 19 deletions databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record,
Record as RecordInterface,
TransactionBackend,
)

Expand Down Expand Up @@ -105,15 +105,18 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect

async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
return [
rows = [
Row(
metadata,
metadata._processors,
Expand All @@ -123,32 +126,38 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
)
for row in rows
]
return [
Record(row, result_columns, dialect, column_maps) for row in rows
]
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
return Row(
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
await cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, results_map, context = self._compile(query)
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
Expand All @@ -163,7 +172,9 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
async with self._connection.cursor() as cursor:
try:
for single_query in queries:
single_query, args, context = self._compile(single_query)
single_query, args, results_map, context = self._compile(
single_query
)
await cursor.execute(single_query, args)
finally:
await cursor.close()
Expand All @@ -172,36 +183,38 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
await cursor.close()

def transaction(self) -> TransactionBackend:
return AsyncMyTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

args = compiled.construct_params()
for key, val in args.items():
if key in compiled._bind_processors:
Expand All @@ -214,12 +227,23 @@ def _compile(
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
result_map = compiled._result_columns

else:
args = {}
result_map = None
compiled_query = compiled.string

query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA)
return compiled.string, args, CompilationContext(execution_context)
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled.string, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> asyncmy.connection.Connection:
Expand Down
Empty file.
139 changes: 139 additions & 0 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import json
import typing
from datetime import date, datetime

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
from sqlalchemy.sql.compiler import _CompileLabel
from sqlalchemy.sql.schema import Column
from sqlalchemy.types import TypeEngine

from databases.interfaces import Record as RecordInterface


class Record(RecordInterface):
__slots__ = (
"_row",
"_result_columns",
"_dialect",
"_column_map",
"_column_map_int",
"_column_map_full",
)

def __init__(
self,
row: typing.Any,
result_columns: tuple,
dialect: Dialect,
column_maps: typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
],
) -> None:
self._row = row
self._result_columns = result_columns
self._dialect = dialect
self._column_map, self._column_map_int, self._column_map_full = column_maps

@property
def _mapping(self) -> typing.Mapping:
return self._row

def keys(self) -> typing.KeysView:
return self._mapping.keys()

def values(self) -> typing.ValuesView:
return self._mapping.values()

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0:
return self._row[key]
elif isinstance(key, Column):
idx, datatype = self._column_map_full[str(key)]
elif isinstance(key, int):
idx, datatype = self._column_map_int[key]
else:
idx, datatype = self._column_map[key]

raw = self._row[idx]
processor = datatype._cached_result_processor(self._dialect, None)

if processor is not None and (not isinstance(raw, (datetime, date))):
return processor(raw)
return raw

def __iter__(self) -> typing.Iterator:
return iter(self._row.keys())

def __len__(self) -> int:
return len(self._row)

def __getattr__(self, name: str) -> typing.Any:
try:
return self.__getitem__(name)
except KeyError as e:
raise AttributeError(e.args[0]) from e


class Row(SQLRow):
def __getitem__(self, key: typing.Any) -> typing.Any:
"""
An instance of a Row in SQLAlchemy allows the access
to the Row._fields as tuple and the Row._mapping for
the values.
"""
if isinstance(key, int):
field = self._fields[key]
data = self._mapping[field]
if isinstance(data, dict):
return json.dumps(data)
return data
return self._mapping[key]

def keys(self):
return self._mapping.keys()

def values(self):
return self._mapping.values()

def __getattr__(self, name: str) -> typing.Any:
try:
return self.__getitem__(name)
except KeyError as e:
raise AttributeError(e.args[0]) from e


def create_column_maps(
result_columns: typing.Any,
) -> typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
]:
"""
Generate column -> datatype mappings from the column definitions.
These mappings are used throughout PostgresConnection methods
to initialize Record-s. The underlying DB driver does not do type
conversion for us so we have wrap the returned asyncpg.Record-s.
:return: Three mappings from different ways to address a column to \
corresponding column indexes and datatypes: \
1. by column identifier; \
2. by column index; \
3. by column name in Column sqlalchemy objects.
"""
column_map, column_map_int, column_map_full = {}, {}, {}
for idx, (column_name, _, column, datatype) in enumerate(result_columns):
column_map[column_name] = (idx, datatype)
column_map_int[idx] = (idx, datatype)

# Added in SQLA 2.0 and _CompileLabels do not have _annotations
# When this happens, the mapping is on the second position
if isinstance(column[0], _CompileLabel):
column_map_full[str(column[2])] = (idx, datatype)
else:
column_map_full[str(column[0])] = (idx, datatype)
return column_map, column_map_int, column_map_full
Loading

0 comments on commit b3f6d51

Please sign in to comment.