diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py index 5c7ae25c..5f6ba263 100644 --- a/databases/backends/common/records.py +++ b/databases/backends/common/records.py @@ -1,5 +1,4 @@ import typing -from collections import namedtuple from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.row import Row as SQLRow @@ -53,7 +52,10 @@ def values(self) -> typing.ValuesView: def __getitem__(self, key: typing.Any) -> typing.Any: if len(self._column_map) == 0: - return self._row[key] + try: + return self._row[key] + except TypeError: + return self._mapping[key] elif isinstance(key, Column): idx, datatype = self._column_map_full[str(key)] elif isinstance(key, int): diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py index 302f94e5..bb623cfa 100644 --- a/databases/backends/psycopg.py +++ b/databases/backends/psycopg.py @@ -3,9 +3,10 @@ import psycopg import psycopg_pool from psycopg.rows import namedtuple_row -from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.schema import Column from databases.backends.common.records import Record, create_column_maps from databases.core import DatabaseURL @@ -31,6 +32,7 @@ def __init__( self._database_url = DatabaseURL(database_url) self._options = options self._dialect = PGDialect_psycopg() + self._dialect.implicit_returning = True self._pool = None async def connect(self) -> None: @@ -94,7 +96,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: rows = await cursor.fetchall() column_maps = create_column_maps(result_columns) - return [Record(row, result_columns, self._dialect, column_maps) for row in rows] + return [PsycopgRecord(row, result_columns, self._dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: if self._connection is None: @@ -109,7 +111,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa if row is None: return None - return Record( + return PsycopgRecord( row, result_columns, self._dialect, @@ -154,7 +156,7 @@ async def iterate( if row is None: break - yield Record(row, result_columns, self._dialect, column_maps) + yield PsycopgRecord(row, result_columns, self._dialect, column_maps) def transaction(self) -> "TransactionBackend": return PsycopgTransaction(connection=self) @@ -214,3 +216,21 @@ async def rollback(self) -> None: async with self._transaction._conn.lock: await self._transaction._conn.wait(self._transaction._rollback_gen(None)) + + +class PsycopgRecord(Record): + @property + def _mapping(self) -> typing.Mapping: + return self._row._asdict() + + def __getitem__(self, key: typing.Any) -> typing.Any: + if len(self._column_map) == 0: + return self._mapping[key] + elif isinstance(key, Column): + idx, datatype = self._column_map_full[str(key)] + elif isinstance(key, int): + idx, datatype = self._column_map_int[key] + else: + idx, datatype = self._column_map[key] + + return self._row[idx]