From a4f34473b813a5b305f1d616a03907c4a2c84108 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Tue, 5 Mar 2024 22:48:08 +0500 Subject: [PATCH] Use correct type hints for query methods --- databases/backends/aiopg.py | 14 ++++---------- databases/backends/asyncmy.py | 13 ++++--------- databases/backends/common/records.py | 2 -- databases/backends/mysql.py | 13 ++++--------- databases/backends/postgres.py | 13 ++++--------- databases/backends/sqlite.py | 2 +- databases/core.py | 4 ++-- databases/interfaces.py | 2 +- 8 files changed, 20 insertions(+), 43 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 0b4d95a3..413c7d85 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -7,7 +7,6 @@ import aiopg from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement @@ -15,12 +14,7 @@ from databases.backends.compilers.psycopg import PGCompiler_psycopg from databases.backends.dialects.psycopg import PGDialect_psycopg from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record as RecordInterface, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -118,7 +112,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -142,7 +136,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: finally: cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -186,7 +180,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 040a4346..951e338d 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -12,12 +12,7 @@ from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record as RecordInterface, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -108,7 +103,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -134,7 +129,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -180,7 +175,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py index e963af50..f57de839 100644 --- a/databases/backends/common/records.py +++ b/databases/backends/common/records.py @@ -1,6 +1,4 @@ -import enum import typing -from datetime import date, datetime, time from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.row import Row as SQLRow diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 792f3685..3a9960ba 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -12,12 +12,7 @@ from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record as RecordInterface, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -108,7 +103,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -131,7 +126,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) @@ -177,7 +172,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index c42688e1..a0f13aab 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -9,12 +9,7 @@ from databases.backends.common.records import Record, create_column_maps from databases.backends.dialects.psycopg import dialect as psycopg_dialect from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record as RecordInterface, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -99,7 +94,7 @@ async def release(self) -> None: self._connection = await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) @@ -107,7 +102,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: column_maps = create_column_maps(result_columns) return [Record(row, result_columns, dialect, column_maps) for row in rows] - async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) row = await self._connection.fetchrow(query_str, *args) @@ -151,7 +146,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) column_maps = create_column_maps(result_columns) diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 16e17e9e..3f01752c 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -144,7 +144,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns, context = self._compile(query) column_maps = create_column_maps(result_columns) diff --git a/databases/core.py b/databases/core.py index d55dd3c8..6c06e4ba 100644 --- a/databases/core.py +++ b/databases/core.py @@ -208,7 +208,7 @@ async def iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, - ) -> typing.AsyncGenerator[typing.Mapping, None]: + ) -> typing.AsyncGenerator[Record, None]: async with self.connection() as connection: async for record in connection.iterate(query, values): yield record @@ -328,7 +328,7 @@ async def iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, - ) -> typing.AsyncGenerator[typing.Any, None]: + ) -> typing.AsyncGenerator[Record, None]: built_query = self._build_query(query, values) async with self.transaction(): async with self._query_lock: diff --git a/databases/interfaces.py b/databases/interfaces.py index fd6a24ee..d97500eb 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -42,7 +42,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async def iterate( self, query: ClauseElement - ) -> typing.AsyncGenerator[typing.Mapping, None]: + ) -> typing.AsyncGenerator["Record", None]: raise NotImplementedError() # pragma: no cover # mypy needs async iterators to contain a `yield` # https://github.com/python/mypy/issues/5385#issuecomment-407281656