Skip to content

Remove DispatchResult, use plain tuple #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions jsonrpcserver/async_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
"""Asynchronous dispatch"""

from functools import partial
from typing import Any, Callable, Iterable, Union
from itertools import starmap
from typing import Any, Callable, Iterable, Tuple, Union
import asyncio
import logging

from oslash.either import Left # type: ignore

from .dispatcher import (
Deserialized,
DispatchResult,
create_request,
deserialize,
extract_args,
extract_kwargs,
extract_list,
get_method,
not_notification,
to_response,
validate_args,
validate_request,
validate_result,
)
from .exceptions import JsonRpcError
from .methods import Method, Methods
from .request import Request, NOID
from .request import Request
from .result import Result, InternalErrorResult, ErrorResult
from .response import Response, ServerErrorResponse
from .utils import compose, make_list
from .utils import make_list


async def call(request: Request, context: Any, method: Method) -> Result:
Expand All @@ -45,36 +46,35 @@ async def call(request: Request, context: Any, method: Method) -> Result:

async def dispatch_request(
methods: Methods, context: Any, request: Request
) -> DispatchResult:
) -> Tuple[Request, Result]:
method = get_method(methods, request.method).bind(
partial(validate_args, request, context)
)
return DispatchResult(
request=request,
result=(
method
if isinstance(method, Left)
else await call(request, context, method._value)
),
return (
request,
method
if isinstance(method, Left)
else await call(request, context, method._value),
)


async def dispatch_deserialized(
methods: Methods,
context: Any,
post_process: Callable[[Deserialized], Iterable[Any]],
post_process: Callable[[Response], Iterable[Any]],
deserialized: Deserialized,
) -> Union[Response, Iterable[Response], None]:
coroutines = (
dispatch_request(methods, context, r)
for r in map(create_request, make_list(deserialized))
results = await asyncio.gather(
*(
dispatch_request(methods, context, r)
for r in map(create_request, make_list(deserialized))
)
)
results = await asyncio.gather(*coroutines)
return extract_list(
isinstance(deserialized, list),
map(
compose(post_process, to_response),
filter(lambda dr: dr.request.id is not NOID, results),
post_process,
starmap(to_response, filter(not_notification, results)),
),
)

Expand All @@ -85,7 +85,7 @@ async def dispatch_to_response_pure(
schema_validator: Callable[[Deserialized], Deserialized],
methods: Methods,
context: Any,
post_process: Callable[[Deserialized], Iterable[Any]],
post_process: Callable[[Response], Iterable[Any]],
request: str,
) -> Union[Response, Iterable[Response], None]:
try:
Expand Down
60 changes: 24 additions & 36 deletions jsonrpcserver/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Union
from itertools import starmap
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import logging

from oslash.either import Either, Left, Right # type: ignore
Expand Down Expand Up @@ -30,11 +31,6 @@
Deserialized = Union[Dict[str, Any], List[Dict[str, Any]]]


class DispatchResult(NamedTuple):
request: Request
result: Result


def extract_list(
is_batch: bool, responses: Iterable[Response]
) -> Union[Response, List[Response], None]:
Expand All @@ -58,22 +54,14 @@ def extract_list(
)


def to_response(dispatch_result: DispatchResult) -> Response:
"""Maps DispatchResults to Responses."""
def to_response(request: Request, result: Result) -> Response:
"""Maps Requests & Results to Responses."""
# Don't pass a notification to this function - should return a Server Error.
assert dispatch_result.request.id is not NOID
assert request.id is not NOID
return (
Left(
ErrorResponse(
**dispatch_result.result._error._asdict(), id=dispatch_result.request.id
)
)
if isinstance(dispatch_result.result, Left)
else Right(
SuccessResponse(
**dispatch_result.result._value._asdict(), id=dispatch_result.request.id
)
)
Left(ErrorResponse(**result._error._asdict(), id=request.id))
if isinstance(result, Left)
else Right(SuccessResponse(**result._value._asdict(), id=request.id))
)


Expand Down Expand Up @@ -124,10 +112,10 @@ def get_method(methods: Methods, method_name: str) -> Either[ErrorResult, Method

def dispatch_request(
methods: Methods, context: Any, request: Request
) -> DispatchResult:
return DispatchResult(
request=request,
result=get_method(methods, request.method)
) -> Tuple[Request, Result]:
return (
request,
get_method(methods, request.method)
.bind(partial(validate_args, request, context))
.bind(partial(call, request, context)),
)
Expand All @@ -139,25 +127,25 @@ def create_request(request: Dict[str, Any]) -> Request:
)


def not_notification(request_result: Any) -> bool:
return request_result[0].id is not NOID


def dispatch_deserialized(
methods: Methods,
context: Any,
post_process: Callable[[Deserialized], Iterable[Any]],
post_process: Callable[[Response], Iterable[Any]],
deserialized: Deserialized,
) -> Union[Response, Iterable[Response], None]:
results = map(
compose(partial(dispatch_request, methods, context), create_request),
make_list(deserialized),
)
return extract_list(
isinstance(deserialized, list),
map(
compose(post_process, to_response),
filter(
lambda dr: dr.request.id is not NOID,
map(
compose(
partial(dispatch_request, methods, context), create_request
),
make_list(deserialized),
),
),
post_process,
starmap(to_response, filter(not_notification, results)),
),
)

Expand Down Expand Up @@ -193,7 +181,7 @@ def dispatch_to_response_pure(
schema_validator: Callable[[Deserialized], Deserialized],
methods: Methods,
context: Any,
post_process: Callable[[Deserialized], Iterable[Any]],
post_process: Callable[[Response], Iterable[Any]],
request: str,
) -> Union[Response, Iterable[Response], None]:
try:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_async_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
dispatch_to_response_pure,
)
from jsonrpcserver.async_main import default_deserializer, default_schema_validator
from jsonrpcserver.dispatcher import DispatchResult
from jsonrpcserver.methods import Methods
from jsonrpcserver.request import Request
from jsonrpcserver.response import SuccessResponse
Expand All @@ -31,8 +30,9 @@ async def test_call():
@pytest.mark.asyncio
async def test_dispatch_request():
request = Request("ping", [], 1)
assert await dispatch_request(Methods(ping), None, request) == DispatchResult(
request, Right(SuccessResult("pong"))
assert await dispatch_request(Methods(ping), None, request) == (
request,
Right(SuccessResult("pong")),
)


Expand Down
40 changes: 15 additions & 25 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ERROR_SERVER_ERROR,
)
from jsonrpcserver.dispatcher import (
DispatchResult,
create_request,
dispatch_request,
dispatch_to_response_pure,
Expand Down Expand Up @@ -51,22 +50,18 @@ def ping() -> Result:

def test_to_response_SuccessResult():
response = to_response(
DispatchResult(
Request("ping", [], sentinel.id), Right(SuccessResult(sentinel.result))
)
Request("ping", [], sentinel.id), Right(SuccessResult(sentinel.result))
)
assert response == Right(SuccessResponse(sentinel.result, sentinel.id))


def test_to_response_ErrorResult():
response = to_response(
DispatchResult(
Request("ping", [], sentinel.id),
Left(
ErrorResult(
code=sentinel.code, message=sentinel.message, data=sentinel.data
)
),
Request("ping", [], sentinel.id),
Left(
ErrorResult(
code=sentinel.code, message=sentinel.message, data=sentinel.data
)
),
)
assert response == Left(
Expand All @@ -76,29 +71,23 @@ def test_to_response_ErrorResult():

def test_to_response_InvalidParams():
response = to_response(
DispatchResult(Request("ping", [], sentinel.id), InvalidParams(sentinel.data))
Request("ping", [], sentinel.id), InvalidParams(sentinel.data)
)
assert response == Left(
ErrorResponse(-32602, "Invalid params", sentinel.data, sentinel.id)
)


def test_to_response_InvalidParams_no_data():
response = to_response(
DispatchResult(Request("ping", [], sentinel.id), InvalidParams())
)
response = to_response(Request("ping", [], sentinel.id), InvalidParams())
assert response == Left(
ErrorResponse(-32602, "Invalid params", NODATA, sentinel.id)
)


def test_to_response_notification():
with pytest.raises(AssertionError):
to_response(
DispatchResult(
Request("ping", [], NOID), SuccessResult(result=sentinel.result)
)
)
to_response(Request("ping", [], NOID), SuccessResult(result=sentinel.result))


# validate_args
Expand Down Expand Up @@ -151,10 +140,11 @@ def foo(self, one, two):
# dispatch_request


def test_dispatch_request_success():
def test_dispatch_request():
request = Request("ping", [], 1)
assert dispatch_request(Methods(ping), None, request) == DispatchResult(
request, Right(SuccessResult("pong"))
assert dispatch_request(Methods(ping), None, request) == (
request,
Right(SuccessResult("pong")),
)


Expand Down Expand Up @@ -182,7 +172,7 @@ def test_create_request():
# dispatch_to_response_pure


def test_dispatch_to_response_pure_success():
def test_dispatch_to_response_pure():
assert (
dispatch_to_response_pure(
deserializer=default_deserializer,
Expand Down Expand Up @@ -368,7 +358,7 @@ def raise_exception():
# dispatch_to_response_pure -- Notifications


def test_dispatch_to_response_pure_notification_success():
def test_dispatch_to_response_pure_notification():
assert (
dispatch_to_response_pure(
deserializer=default_deserializer,
Expand Down