diff --git a/jsonrpc_async/jsonrpc.py b/jsonrpc_async/jsonrpc.py index a9cfccf..c8ea942 100644 --- a/jsonrpc_async/jsonrpc.py +++ b/jsonrpc_async/jsonrpc.py @@ -9,7 +9,7 @@ class Server(jsonrpc_base.Server): """A connection to a HTTP JSON-RPC server, backed by aiohttp""" - def __init__(self, url, session=None, **post_kwargs): + def __init__(self, url, session=None, *, loads=None, **post_kwargs): super().__init__() object.__setattr__(self, 'session', session or aiohttp.ClientSession()) post_kwargs['headers'] = post_kwargs.get('headers', {}) @@ -19,6 +19,10 @@ def __init__(self, url, session=None, **post_kwargs): 'Accept', 'application/json-rpc') self._request = functools.partial(self.session.post, url, **post_kwargs) + self._json_args = {} + if loads is not None: + self._json_args['loads'] = loads + @asyncio.coroutine def send_message(self, message): """Send the HTTP message to the server and return the message response. @@ -38,7 +42,7 @@ def send_message(self, message): return None try: - response_data = yield from response.json() + response_data = yield from response.json(**self._json_args) except ValueError as value_error: raise TransportError('Cannot deserialize response body', message, value_error) diff --git a/tests.py b/tests.py index ec395a0..93e01bd 100644 --- a/tests.py +++ b/tests.py @@ -1,8 +1,8 @@ import asyncio +from unittest import mock import unittest import random import json -import inspect import os import aiohttp @@ -15,11 +15,6 @@ import jsonrpc_base from jsonrpc_async import Server, ProtocolError, TransportError -try: - # python 3.3 - from unittest.mock import Mock -except ImportError: - from mock import Mock class JsonTestClient(aiohttp.test_utils.TestClient): def __init__(self, app, **kwargs): @@ -31,6 +26,7 @@ def request(self, method, path, *args, **kwargs): self.request_callback(method, path, *args, **kwargs) return super().request(method, path, *args, **kwargs) + class TestCase(unittest.TestCase): def assertSameJSON(self, json1, json2): """Tells whether two json strings, once decoded, are the same dictionary""" @@ -40,8 +36,7 @@ def assertRaisesRegex(self, *args, **kwargs): return super(TestCase, self).assertRaisesRegex(*args, **kwargs) -class TestJSONRPCClient(TestCase): - +class TestJSONRPCClientBase(TestCase): def setUp(self): self.loop = setup_test_loop() self.app = self.get_app() @@ -53,8 +48,11 @@ def create_client(app, loop): self.client = self.loop.run_until_complete( create_client(self.app, self.loop)) self.loop.run_until_complete(self.client.start_server()) - random.randint = Mock(return_value=1) - self.server = Server('/xmlrpc', session=self.client, timeout=0.2) + random.randint = mock.Mock(return_value=1) + self.server = self.get_server() + + def get_server(self): + return Server('/xmlrpc', session=self.client, timeout=0.2) def tearDown(self): self.loop.run_until_complete(self.client.close()) @@ -68,6 +66,8 @@ def response_func(request): app.router.add_post('/xmlrpc', response_func) return app + +class TestJSONRPCClient(TestJSONRPCClientBase): def test_pep8_conformance(self): """Test that we conform to PEP8.""" @@ -249,5 +249,25 @@ def handler(request): self.assertIsNone((yield from self.server.subtract(42, 23, _notification=True))) +class TestJSONRPCClientCustomLoads(TestJSONRPCClientBase): + def get_server(self): + self.loads_mock = mock.Mock(wraps=json.loads) + return Server('/xmlrpc', session=self.client, loads=self.loads_mock, timeout=0.2) + + @unittest_run_loop + @asyncio.coroutine + def test_custom_loads(self): + # rpc call with positional parameters: + @asyncio.coroutine + def handler1(request): + request_message = yield from request.json() + self.assertEqual(request_message["params"], [42, 23]) + return aiohttp.web.Response(text='{"jsonrpc": "2.0", "result": 19, "id": 1}', content_type='application/json') + + self.handler = handler1 + self.assertEqual((yield from self.server.subtract(42, 23)), 19) + self.assertEqual(self.loads_mock.call_count, 1) + + if __name__ == '__main__': unittest.main()