Skip to content

Commit

Permalink
Wrap types in typing.Optional where applicable (#510)
Browse files Browse the repository at this point in the history
Co-authored-by: tsunyoku <mbruhyo@gmail.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
3 people committed Dec 17, 2022
1 parent 8ec9168 commit 7aa1326
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
56 changes: 36 additions & 20 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,41 @@ async def __aenter__(self) -> "Database":

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
await self.disconnect()

async def fetch_all(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.List[Record]:
async with self.connection() as connection:
return await connection.fetch_all(query, values)

async def fetch_one(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.Optional[Record]:
async with self.connection() as connection:
return await connection.fetch_one(query, values)

async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
values: typing.Optional[dict] = None,
column: typing.Any = 0,
) -> typing.Any:
async with self.connection() as connection:
return await connection.fetch_val(query, values, column=column)

async def execute(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.Any:
async with self.connection() as connection:
return await connection.execute(query, values)
Expand All @@ -169,7 +175,9 @@ async def execute_many(
return await connection.execute_many(query, values)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Mapping, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
Expand Down Expand Up @@ -232,9 +240,9 @@ async def __aenter__(self) -> "Connection":

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
async with self._connection_lock:
assert self._connection is not None
Expand All @@ -243,14 +251,18 @@ async def __aexit__(
await self._connection.release()

async def fetch_all(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.List[Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
return await self._connection.fetch_all(built_query)

async def fetch_one(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.Optional[Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
Expand All @@ -259,15 +271,17 @@ async def fetch_one(
async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
values: typing.Optional[dict] = None,
column: typing.Any = 0,
) -> typing.Any:
built_query = self._build_query(query, values)
async with self._query_lock:
return await self._connection.fetch_val(built_query, column)

async def execute(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.Any:
built_query = self._build_query(query, values)
async with self._query_lock:
Expand All @@ -281,7 +295,9 @@ async def execute_many(
await self._connection.execute_many(queries)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Any, None]:
built_query = self._build_query(query, values)
async with self.transaction():
Expand All @@ -303,7 +319,7 @@ def raw_connection(self) -> typing.Any:

@staticmethod
def _build_query(
query: typing.Union[ClauseElement, str], values: dict = None
query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None
) -> ClauseElement:
if isinstance(query, str):
query = text(query)
Expand Down Expand Up @@ -338,9 +354,9 @@ async def __aenter__(self) -> "Transaction":

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
"""
Called when exiting `async with database.transaction()`
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
disallow_untyped_defs = True
ignore_missing_imports = True
no_implicit_optional = True

[tool:isort]
profile = black
Expand Down

0 comments on commit 7aa1326

Please sign in to comment.