diff --git a/databases/core.py b/databases/core.py index efa59471..8415b836 100644 --- a/databases/core.py +++ b/databases/core.py @@ -129,20 +129,24 @@ async def __aenter__(self) -> "Database": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: await self.disconnect() async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: async with self.connection() as connection: return await connection.fetch_all(query, values) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: async with self.connection() as connection: return await connection.fetch_one(query, values) @@ -150,14 +154,16 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: async with self.connection() as connection: return await connection.fetch_val(query, values, column=column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: async with self.connection() as connection: return await connection.execute(query, values) @@ -169,7 +175,9 @@ async def execute_many( return await connection.execute_many(query, values) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: async for record in connection.iterate(query, values): @@ -232,9 +240,9 @@ async def __aenter__(self) -> "Connection": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: async with self._connection_lock: assert self._connection is not None @@ -243,14 +251,18 @@ async def __aexit__( await self._connection.release() async def fetch_all( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.List[Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_all(built_query) async def fetch_one( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Optional[Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -259,7 +271,7 @@ async def fetch_one( async def fetch_val( self, query: typing.Union[ClauseElement, str], - values: dict = None, + values: typing.Optional[dict] = None, column: typing.Any = 0, ) -> typing.Any: built_query = self._build_query(query, values) @@ -267,7 +279,9 @@ async def fetch_val( return await self._connection.fetch_val(built_query, column) async def execute( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.Any: built_query = self._build_query(query, values) async with self._query_lock: @@ -281,7 +295,9 @@ async def execute_many( await self._connection.execute_many(queries) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Optional[dict] = None, ) -> typing.AsyncGenerator[typing.Any, None]: built_query = self._build_query(query, values) async with self.transaction(): @@ -303,7 +319,7 @@ def raw_connection(self) -> typing.Any: @staticmethod def _build_query( - query: typing.Union[ClauseElement, str], values: dict = None + query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None ) -> ClauseElement: if isinstance(query, str): query = text(query) @@ -338,9 +354,9 @@ async def __aenter__(self) -> "Transaction": async def __aexit__( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, ) -> None: """ Called when exiting `async with database.transaction()` diff --git a/setup.cfg b/setup.cfg index 77c8c58d..da1831fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,7 @@ [mypy] disallow_untyped_defs = True ignore_missing_imports = True +no_implicit_optional = True [tool:isort] profile = black