-
Notifications
You must be signed in to change notification settings - Fork 39
/
async_dispatcher.py
113 lines (101 loc) · 3.24 KB
/
async_dispatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Async version of dispatcher.py"""
from functools import partial
from itertools import starmap
from typing import Any, Callable, Iterable, Tuple, Union
import asyncio
import logging
from oslash.either import Left # type: ignore
from .dispatcher import (
Deserialized,
create_request,
deserialize_request,
extract_args,
extract_kwargs,
extract_list,
get_method,
not_notification,
to_response,
validate_args,
validate_request,
validate_result,
)
from .exceptions import JsonRpcError
from .methods import Method, Methods
from .request import Request
from .result import Result, InternalErrorResult, ErrorResult
from .response import Response, ServerErrorResponse
from .utils import make_list
logger = logging.getLogger(__name__)
# pylint: disable=missing-function-docstring,duplicate-code
async def call(request: Request, context: Any, method: Method) -> Result:
try:
result = await method(
*extract_args(request, context), **extract_kwargs(request)
)
validate_result(result)
except JsonRpcError as exc:
return Left(ErrorResult(code=exc.code, message=exc.message, data=exc.data))
except Exception as exc: # pylint: disable=broad-except
# Other error inside method - Internal error
logger.exception(exc)
return Left(InternalErrorResult(str(exc)))
return result
async def dispatch_request(
methods: Methods, context: Any, request: Request
) -> Tuple[Request, Result]:
method = get_method(methods, request.method).bind(
partial(validate_args, request, context)
)
return (
request,
method
if isinstance(method, Left)
else await call(
request, context, method._value # pylint: disable=protected-access
),
)
async def dispatch_deserialized(
methods: Methods,
context: Any,
post_process: Callable[[Response], Iterable[Any]],
deserialized: Deserialized,
) -> Union[Response, Iterable[Response], None]:
results = await asyncio.gather(
*(
dispatch_request(methods, context, r)
for r in map(create_request, make_list(deserialized))
)
)
return extract_list(
isinstance(deserialized, list),
map(
post_process,
starmap(to_response, filter(not_notification, results)),
),
)
async def dispatch_to_response_pure(
*,
deserializer: Callable[[str], Deserialized],
validator: Callable[[Deserialized], Deserialized],
methods: Methods,
context: Any,
post_process: Callable[[Response], Iterable[Any]],
request: str,
) -> Union[Response, Iterable[Response], None]:
try:
result = deserialize_request(deserializer, request).bind(
partial(validate_request, validator)
)
return (
post_process(result)
if isinstance(result, Left)
else await dispatch_deserialized(
methods,
context,
post_process,
result._value, # pylint: disable=protected-access
)
)
except Exception as exc: # pylint: disable=broad-except
logger.exception(exc)
return post_process(Left(ServerErrorResponse(str(exc), None)))