From 810e302a99dd73146dc2b3ee8286bc946523996a Mon Sep 17 00:00:00 2001 From: deptyped Date: Wed, 8 Jul 2020 02:58:31 +0300 Subject: [PATCH] Add ability to use custom serializer and deserializer --- jsonrpcserver/async_dispatcher.py | 36 ++++++--- jsonrpcserver/dispatcher.py | 50 +++++++++++-- jsonrpcserver/response.py | 22 ++++-- tests/test_dispatcher.py | 120 ++++++++++++++++++++++++++---- 4 files changed, 188 insertions(+), 40 deletions(-) diff --git a/jsonrpcserver/async_dispatcher.py b/jsonrpcserver/async_dispatcher.py index df3f53e..0195184 100644 --- a/jsonrpcserver/async_dispatcher.py +++ b/jsonrpcserver/async_dispatcher.py @@ -2,8 +2,8 @@ import asyncio import collections.abc from json import JSONDecodeError -from json import dumps as serialize, loads as deserialize -from typing import Any, Iterable, Optional, Union +from json import dumps as default_serialize, loads as default_deserialize +from typing import Any, Iterable, Optional, Union, Callable from apply_defaults import apply_config # type: ignore from jsonschema import ValidationError # type: ignore @@ -34,7 +34,9 @@ async def call(method: Method, *args: Any, **kwargs: Any) -> Any: return await validate_args(method, *args, **kwargs)(*args, **kwargs) -async def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response: +async def safe_call( + request: Request, methods: Methods, *, debug: bool, serialize: Callable +) -> Response: with handle_exceptions(request, debug) as handler: result = await call( lookup(methods, request.method), *request.args, **request.kwargs @@ -42,17 +44,24 @@ async def safe_call(request: Request, methods: Methods, *, debug: bool) -> Respo # Ensure value returned from the method is JSON-serializable. If not, # handle_exception will set handler.response to an ExceptionResponse serialize(result) - handler.response = SuccessResponse(result=result, id=request.id) + handler.response = SuccessResponse( + result=result, id=request.id, serialize_func=serialize + ) return handler.response async def call_requests( - requests: Union[Request, Iterable[Request]], methods: Methods, debug: bool + requests: Union[Request, Iterable[Request]], + methods: Methods, + debug: bool, + serialize: Callable, ) -> Response: if isinstance(requests, collections.abc.Iterable): - responses = (safe_call(r, methods, debug=debug) for r in requests) - return BatchResponse(await asyncio.gather(*responses)) - return await safe_call(requests, methods, debug=debug) + responses = ( + safe_call(r, methods, debug=debug, serialize=serialize) for r in requests + ) + return BatchResponse(await asyncio.gather(*responses), serialize_func=serialize) + return await safe_call(requests, methods, debug=debug, serialize=serialize) async def dispatch_pure( @@ -61,7 +70,9 @@ async def dispatch_pure( *, context: Any, convert_camel_case: bool, - debug: bool + debug: bool, + serialize: Callable, + deserialize: Callable, ) -> Response: try: deserialized = validate(deserialize(request), schema) @@ -75,6 +86,7 @@ async def dispatch_pure( ), methods, debug=debug, + serialize=serialize, ) @@ -88,7 +100,9 @@ async def dispatch( context: Any = NOCONTEXT, debug: bool = False, trim_log_values: bool = False, - **kwargs: Any + serialize: Callable = default_serialize, + deserialize: Callable = default_deserialize, + **kwargs: Any, ) -> Response: # Use the global methods object if no methods object was passed. methods = global_methods if methods is None else methods @@ -102,6 +116,8 @@ async def dispatch( debug=debug, context=context, convert_camel_case=convert_camel_case, + serialize=serialize, + deserialize=deserialize, ) log_response(str(response), trim_log_values=trim_log_values) # Remove the temporary stream handlers diff --git a/jsonrpcserver/dispatcher.py b/jsonrpcserver/dispatcher.py index 33d5b54..8bebb29 100644 --- a/jsonrpcserver/dispatcher.py +++ b/jsonrpcserver/dispatcher.py @@ -10,9 +10,20 @@ from configparser import ConfigParser from contextlib import contextmanager from json import JSONDecodeError -from json import dumps as serialize, loads as deserialize +from json import dumps as default_serialize, loads as default_deserialize from types import SimpleNamespace -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + Callable, +) from apply_defaults import apply_config # type: ignore from jsonschema import ValidationError # type: ignore @@ -40,7 +51,7 @@ response_logger = logging.getLogger(__name__ + ".response") # Prepare the jsonschema validator -schema = deserialize(resource_string(__name__, "request-schema.json")) +schema = default_deserialize(resource_string(__name__, "request-schema.json")) klass = validator_for(schema) klass.check_schema(schema) validator = klass(schema) @@ -144,7 +155,9 @@ def handle_exceptions(request: Request, debug: bool) -> Generator: handler.response = NotificationResponse() -def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response: +def safe_call( + request: Request, methods: Methods, *, debug: bool, serialize: Callable +) -> Response: """ Call a Request, catching exceptions to ensure we always return a Response. @@ -152,6 +165,7 @@ def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response: request: The Request object. methods: The list of methods that can be called. debug: Include more information in error responses. + serialize: Function that is used to serialize data. Returns: A Response object. @@ -161,12 +175,17 @@ def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response: # Ensure value returned from the method is JSON-serializable. If not, # handle_exception will set handler.response to an ExceptionResponse serialize(result) - handler.response = SuccessResponse(result=result, id=request.id) + handler.response = SuccessResponse( + result=result, id=request.id, serialize_func=serialize + ) return handler.response def call_requests( - requests: Union[Request, Iterable[Request]], methods: Methods, debug: bool + requests: Union[Request, Iterable[Request]], + methods: Methods, + debug: bool, + serialize: Callable, ) -> Response: """ Takes a request or list of Requests and calls them. @@ -175,10 +194,14 @@ def call_requests( requests: Request object, or a collection of them. methods: The list of methods that can be called. debug: Include more information in error responses. + serialize: Function that is used to serialize data. """ if isinstance(requests, Iterable): - return BatchResponse(safe_call(r, methods, debug=debug) for r in requests) - return safe_call(requests, methods, debug=debug) + return BatchResponse( + [safe_call(r, methods, debug=debug, serialize=serialize) for r in requests], + serialize_func=serialize, + ) + return safe_call(requests, methods, debug=debug, serialize=serialize) def create_requests( @@ -211,6 +234,8 @@ def dispatch_pure( context: Any, convert_camel_case: bool, debug: bool, + serialize: Callable, + deserialize: Callable, ) -> Response: """ Pure version of dispatch - no logging, no optional parameters. @@ -225,6 +250,8 @@ def dispatch_pure( context: If specified, will be the first positional argument in all requests. convert_camel_case: Will convert the method name/any named params to snake case. debug: Include more information in error responses. + serialize: Function that is used to serialize data. + deserialize: Function that is used to deserialize data. Returns: A Response. """ @@ -240,6 +267,7 @@ def dispatch_pure( ), methods, debug=debug, + serialize=serialize, ) @@ -253,6 +281,8 @@ def dispatch( context: Any = NOCONTEXT, debug: bool = False, trim_log_values: bool = False, + serialize: Callable = default_serialize, + deserialize: Callable = default_deserialize, **kwargs: Any, ) -> Response: """ @@ -270,6 +300,8 @@ def dispatch( case. debug: Include more information in error responses. trim_log_values: Show abbreviated requests and responses in log. + serialize: Function that is used to serialize data. + deserialize: Function that is used to deserialize data. Returns: A Response. @@ -289,6 +321,8 @@ def dispatch( debug=debug, context=context, convert_camel_case=convert_camel_case, + serialize=serialize, + deserialize=deserialize, ) log_response(str(response), trim_log_values=trim_log_values) # Remove the temporary stream handlers diff --git a/jsonrpcserver/response.py b/jsonrpcserver/response.py index b5ebc5c..d503c5d 100644 --- a/jsonrpcserver/response.py +++ b/jsonrpcserver/response.py @@ -30,10 +30,10 @@ ExceptionResponse BatchResponse - a list of DictResponses """ -import json from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Dict, Iterable, cast +from typing import Any, Dict, Iterable, cast, Callable +from json import dumps as default_serialize from . import status @@ -43,8 +43,11 @@ class Response(ABC): """Base class of all responses.""" - def __init__(self, http_status: int) -> None: + def __init__( + self, http_status: int, serialize_func: Callable = default_serialize + ) -> None: self.http_status = http_status + self._serialize = serialize_func @property @abstractmethod @@ -130,7 +133,7 @@ def deserialized(self) -> dict: def __str__(self) -> str: """Use str() to get the JSON-RPC response string.""" - return json.dumps(sort_dict_response(self.deserialized())) + return self._serialize(sort_dict_response(self.deserialized())) class SuccessResponse(DictResponse): @@ -150,7 +153,7 @@ def __init__( The payload from processing the request. If the request was a JSON-RPC notification (i.e. the request id is `None`), the result must also be `None` because notifications don't require any data returned. - http_status: + http_status: """ super().__init__(http_status=http_status, **kwargs) self.result = result @@ -297,9 +300,12 @@ class BatchResponse(Response): """ def __init__( - self, responses: Iterable[Response], http_status: int = status.HTTP_OK + self, + responses: Iterable[Response], + http_status: int = status.HTTP_OK, + **kwargs: Any, ) -> None: - super().__init__(http_status=http_status) + super().__init__(http_status=http_status, **kwargs) # Remove notifications; these are not allowed in batch responses self.responses = cast( Iterable[DictResponse], {r for r in responses if r.wanted} @@ -317,4 +323,4 @@ def __str__(self) -> str: dicts = self.deserialized() # For an all-notifications response, an empty string should be returned, as per # spec - return json.dumps(dicts) if len(dicts) else "" + return self._serialize(dicts) if len(dicts) else "" diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 360e3c9..e0d1a3e 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -12,6 +12,8 @@ log_response, remove_handlers, safe_call, + default_deserialize, + default_serialize, ) from jsonrpcserver.methods import Methods, global_methods from jsonrpcserver.request import NOCONTEXT, Request @@ -52,14 +54,21 @@ def test_log_response(): def test_safe_call_success_response(): - response = safe_call(Request(method="ping", id=1), Methods(ping), debug=True) + response = safe_call( + Request(method="ping", id=1), + Methods(ping), + debug=True, + serialize=default_serialize, + ) assert isinstance(response, SuccessResponse) assert response.result == "pong" assert response.id == 1 def test_safe_call_notification(): - response = safe_call(Request(method="ping"), Methods(ping), debug=True) + response = safe_call( + Request(method="ping"), Methods(ping), debug=True, serialize=default_serialize + ) assert isinstance(response, NotificationResponse) @@ -67,18 +76,28 @@ def test_safe_call_notification_failure(): def fail(): raise ValueError() - response = safe_call(Request(method="foo"), Methods(fail), debug=True) + response = safe_call( + Request(method="foo"), Methods(fail), debug=True, serialize=default_serialize + ) assert isinstance(response, NotificationResponse) def test_safe_call_method_not_found(): - response = safe_call(Request(method="nonexistant", id=1), Methods(ping), debug=True) + response = safe_call( + Request(method="nonexistant", id=1), + Methods(ping), + debug=True, + serialize=default_serialize, + ) assert isinstance(response, MethodNotFoundResponse) def test_safe_call_invalid_args(): response = safe_call( - Request(method="ping", params=[1], id=1), Methods(ping), debug=True + Request(method="ping", params=[1], id=1), + Methods(ping), + debug=True, + serialize=default_serialize, ) assert isinstance(response, InvalidParamsResponse) @@ -87,7 +106,12 @@ def test_safe_call_api_error(): def error(): raise ApiError("Client Error", code=123, data={"data": 42}) - response = safe_call(Request(method="error", id=1), Methods(error), debug=True) + response = safe_call( + Request(method="error", id=1), + Methods(error), + debug=True, + serialize=default_serialize, + ) assert isinstance(response, ErrorResponse) error_dict = response.deserialized()["error"] assert error_dict["message"] == "Client Error" @@ -99,7 +123,12 @@ def test_safe_call_api_error_minimal(): def error(): raise ApiError("Client Error") - response = safe_call(Request(method="error", id=1), Methods(error), debug=True) + response = safe_call( + Request(method="error", id=1), + Methods(error), + debug=True, + serialize=default_serialize, + ) assert isinstance(response, ErrorResponse) response_dict = response.deserialized() error_dict = response_dict["error"] @@ -112,7 +141,12 @@ def test_non_json_encodable_resonse(): def method(): return b"Hello, World" - response = safe_call(Request(method="method", id=1), Methods(method), debug=False) + response = safe_call( + Request(method="method", id=1), + Methods(method), + debug=False, + serialize=default_serialize, + ) # response must be serializable here str(response) assert isinstance(response, ErrorResponse) @@ -134,6 +168,7 @@ def ping_with_context(context=None): Request("ping_with_context", convert_camel_case=False), Methods(ping_with_context), debug=True, + serialize=default_serialize, ) # Assert is in the method @@ -147,6 +182,7 @@ def test_call_requests_batch_all_notifications(): }, Methods(ping), debug=True, + serialize=default_serialize, ) assert str(response) == "" @@ -179,6 +215,8 @@ def test_dispatch_pure_request(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, SuccessResponse) assert response.result == "pong" @@ -192,6 +230,8 @@ def test_dispatch_pure_notification(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, NotificationResponse) @@ -203,6 +243,8 @@ def test_dispatch_pure_notification_invalid_jsonrpc(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, ErrorResponse) @@ -210,7 +252,13 @@ def test_dispatch_pure_notification_invalid_jsonrpc(): def test_dispatch_pure_invalid_json(): """Unable to parse, must return an error""" response = dispatch_pure( - "{", Methods(ping), convert_camel_case=False, context=NOCONTEXT, debug=True + "{", + Methods(ping), + convert_camel_case=False, + context=NOCONTEXT, + debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, InvalidJSONResponse) @@ -218,7 +266,13 @@ def test_dispatch_pure_invalid_json(): def test_dispatch_pure_invalid_jsonrpc(): """Invalid JSON-RPC, must return an error. (impossible to determine if notification)""" response = dispatch_pure( - "{}", Methods(ping), convert_camel_case=False, context=NOCONTEXT, debug=True + "{}", + Methods(ping), + convert_camel_case=False, + context=NOCONTEXT, + debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, InvalidJSONRPCResponse) @@ -233,6 +287,8 @@ def foo(colour): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, InvalidParamsResponse) @@ -247,9 +303,11 @@ def foo(colour: str, size: str): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, InvalidParamsResponse) - assert response.data == 'missing a required argument: \'size\'' + assert response.data == "missing a required argument: 'size'" # def test_dispatch_pure_invalid_params_notification(): @@ -295,6 +353,8 @@ def subtract(minuend, subtrahend): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, SuccessResponse) assert response.result == 19 @@ -306,6 +366,8 @@ def subtract(minuend, subtrahend): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, SuccessResponse) assert response.result == -19 @@ -321,6 +383,8 @@ def subtract(**kwargs): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, SuccessResponse) assert response.result == 19 @@ -332,6 +396,8 @@ def subtract(**kwargs): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, SuccessResponse) assert response.result == 19 @@ -345,6 +411,8 @@ def test_examples_notification(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, NotificationResponse) @@ -355,6 +423,8 @@ def test_examples_notification(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, NotificationResponse) @@ -366,6 +436,8 @@ def test_examples_invalid_json(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, ErrorResponse) assert ( @@ -377,7 +449,13 @@ def test_examples_invalid_json(): def test_examples_empty_array(): # This is an invalid JSON-RPC request, should return an error. response = dispatch_pure( - "[]", Methods(ping), convert_camel_case=False, context=NOCONTEXT, debug=True + "[]", + Methods(ping), + convert_camel_case=False, + context=NOCONTEXT, + debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, ErrorResponse) assert ( @@ -392,7 +470,13 @@ def test_examples_invalid_jsonrpc_batch(): The examples are expecting a batch response full of error responses. """ response = dispatch_pure( - "[1]", Methods(ping), convert_camel_case=False, context=NOCONTEXT, debug=True + "[1]", + Methods(ping), + convert_camel_case=False, + context=NOCONTEXT, + debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, InvalidJSONRPCResponse) assert ( @@ -412,6 +496,8 @@ def test_examples_multiple_invalid_jsonrpc(): convert_camel_case=False, context=NOCONTEXT, debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) assert isinstance(response, ErrorResponse) assert ( @@ -452,7 +538,13 @@ def test_examples_mixed_requests_and_notifications(): ] ) response = dispatch_pure( - requests, methods, convert_camel_case=False, context=NOCONTEXT, debug=True + requests, + methods, + convert_camel_case=False, + context=NOCONTEXT, + debug=True, + serialize=default_serialize, + deserialize=default_deserialize, ) expected = [ {"jsonrpc": "2.0", "result": 7, "id": "1"},