Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mypy to 0.930 #105

Merged
merged 5 commits into from
Dec 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/105.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update mypy to 0.930 and fix newly discovered type errors
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ lint =
flake8>=4.0.1
flake8-commas>=2.1
typecheck =
mypy>=0.920
mypy>=0.930
types-python-dateutil
types-toml
types-setuptools
Expand Down
212 changes: 143 additions & 69 deletions src/ai/backend/common/etcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using callbacks in separate threads.
'''

from __future__ import annotations

import asyncio
from collections import namedtuple, ChainMap
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -15,10 +17,22 @@
import logging
import time
from typing import (
Any, Awaitable, Callable, Iterable, Optional, Union,
AsyncGenerator,
Dict, Mapping,
Awaitable,
Callable,
Dict,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from typing_extensions import ( # FIXME: move to typing when we migrate to Python 3.10
Concatenate,
ParamSpec,
)
from urllib.parse import quote as _quote, unquote

Expand Down Expand Up @@ -101,9 +115,18 @@ async def reauthenticate(etcd_sync, creds, executor):
EtcdTokenCallCredentials(resp.token))


def reconn_reauth_adaptor(meth: Callable[..., Awaitable[Any]]):
P = ParamSpec("P")
R = TypeVar("R")


# FIXME: when mypy begins to support typing.Concatenate, remove "type: ignore" comments
# (ref: https://github.com/python/mypy/issues/8645)
def reconn_reauth_adaptor(
meth: Callable[Concatenate[AsyncEtcd, P], Awaitable[R]], # type: ignore
) -> Callable[Concatenate[AsyncEtcd, P], Awaitable[R]]: # type: ignore

@functools.wraps(meth)
async def wrapped(self, *args, **kwargs):
async def wrapped(self: AsyncEtcd, *args: P.args, **kwargs: P.kwargs) -> R:
num_reauth_tries = 0
num_reconn_tries = 0
while True:
Expand Down Expand Up @@ -131,14 +154,21 @@ async def wrapped(self, *args, **kwargs):
continue
else:
raise

return wrapped


class AsyncEtcd:

def __init__(self, addr: HostPortPair, namespace: str,
scope_prefix_map: Mapping[ConfigScopes, str], *,
credentials=None, encoding='utf8'):
def __init__(
self,
addr: HostPortPair,
namespace: str,
scope_prefix_map: Mapping[ConfigScopes, str],
*,
credentials=None,
encoding='utf8',
) -> None:
self.scope_prefix_map = t.Dict({
t.Key(ConfigScopes.GLOBAL): t.String(allow_blank=True),
t.Key(ConfigScopes.SGROUP, optional=True): t.String,
Expand Down Expand Up @@ -188,10 +218,26 @@ def _demangle_key(self, k: Union[bytes, str]) -> str:
k = k[len(prefix):]
return k

def _merge_scope_prefix_map(
self,
override: Mapping[ConfigScopes, str] = None,
) -> Mapping[ConfigScopes, str]:
"""
This stub ensures immutable usage of the ChainMap because ChainMap does *not*
have the immutable version in typeshed.
(ref: https://github.com/python/typeshed/issues/6042)
"""
return ChainMap(cast(MutableMapping, override) or {}, self.scope_prefix_map)

@reconn_reauth_adaptor
async def put(self, key: str, val: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put(
self,
key: str,
val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a single key-value pair to the etcd.

Expand All @@ -201,17 +247,21 @@ async def put(self, key: str, val: str, *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.put(mangled_key, str(val).encode(self.encoding)))

@reconn_reauth_adaptor
async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put_prefix(
self,
key: str,
dict_obj: Mapping[str, str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a nested dict object under the given key prefix.
All keys in the dict object are automatically quoted to avoid conflicts with the path separator.
Expand All @@ -222,8 +272,7 @@ async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
flattened_dict: Dict[str, str] = {}

def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None:
Expand All @@ -250,9 +299,13 @@ def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None:
))

@reconn_reauth_adaptor
async def put_dict(self, dict_obj: Mapping[str, str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put_dict(
self,
dict_obj: Mapping[str, str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a flattened key-value pairs into the etcd.
Since the given dict must be a flattened one, its keys must be quoted as needed by the caller.
Expand All @@ -263,8 +316,7 @@ async def put_dict(self, dict_obj: Mapping[str, str], *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.transaction(
Expand All @@ -276,10 +328,13 @@ async def put_dict(self, dict_obj: Mapping[str, str], *,
))

@reconn_reauth_adaptor
async def get(self, key: str, *,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None) \
-> Optional[str]:
async def get(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> Optional[str]:
"""
Get a single key from the etcd.
Returns ``None`` if the key does not exist.
Expand All @@ -298,22 +353,22 @@ async def get_impl(key: str) -> Optional[str]:
lambda: self.etcd_sync.get(mangled_key))
return val.decode(self.encoding) if val is not None else None

scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
_scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map)
if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
p = scope_prefix_map.get(ConfigScopes.NODE)
p = _scope_prefix_map.get(ConfigScopes.NODE)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.SGROUP:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.GLOBAL:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
else:
raise ValueError('Invalid scope prefix value')
values = await asyncio.gather(*[
Expand All @@ -328,10 +383,13 @@ async def get_impl(key: str) -> Optional[str]:
return value

@reconn_reauth_adaptor
async def get_prefix(self, key_prefix: str,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None) \
-> Mapping[str, Optional[str]]:
async def get_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> Mapping[str, Optional[str]]:
"""
Retrieves all key-value pairs under the given key prefix as a nested dictionary.
All dictionary keys are automatically unquoted.
Expand Down Expand Up @@ -376,22 +434,22 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]:
t[0].decode(self.encoding))
for t in results)

scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
_scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map)
if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
p = scope_prefix_map.get(ConfigScopes.NODE)
p = _scope_prefix_map.get(ConfigScopes.NODE)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.SGROUP:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.GLOBAL:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
else:
raise ValueError('Invalid scope prefix value')
pair_sets = await asyncio.gather(*[
Expand All @@ -408,34 +466,45 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]:
get_prefix_dict = get_prefix

@reconn_reauth_adaptor
async def replace(self, key: str, initial_val: str, new_val: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None) -> bool:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def replace(
self,
key: str,
initial_val: str,
new_val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> bool:
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
success = await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.replace(mangled_key, initial_val, new_val))
return success

@reconn_reauth_adaptor
async def delete(self, key: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.delete(mangled_key))

@reconn_reauth_adaptor
async def delete_multi(self, keys: Iterable[str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete_multi(
self,
keys: Iterable[str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.transaction(
Expand All @@ -446,17 +515,24 @@ async def delete_multi(self, keys: Iterable[str], *,
))

@reconn_reauth_adaptor
async def delete_prefix(self, key_prefix: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.delete_prefix(mangled_key_prefix))

def _watch_cb(self, queue: asyncio.Queue, resp: etcd3.watch.WatchResponse) -> None:
def _watch_cb(
self,
queue: asyncio.Queue,
resp: etcd3.watch.WatchResponse,
) -> None:
if isinstance(resp, grpc.RpcError):
if (
resp.code() == grpc.StatusCode.UNAVAILABLE or
Expand Down Expand Up @@ -543,8 +619,7 @@ async def watch(
cleanup_event: asyncio.Event = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
scope_prefix_len = len(f'{_slash(scope_prefix)}')
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
# NOTE: yield from in async-generator is not supported.
Expand Down Expand Up @@ -582,8 +657,7 @@ async def watch_prefix(
cleanup_event: asyncio.Event = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
scope_prefix_len = len(f'{_slash(scope_prefix)}')
mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}')
while True:
Expand Down