diff --git a/examples/10-cleaning-demo/cleaner.py b/examples/10-cleaning-demo/cleaner.py index 544e8489..2f94798f 100644 --- a/examples/10-cleaning-demo/cleaner.py +++ b/examples/10-cleaning-demo/cleaner.py @@ -44,8 +44,8 @@ async def startup(): await Availability.create( provider=provider, - time_start=utc.localize(datetime.fromisoformat("2022-12-31 10:00:00")), - time_end=utc.localize(datetime.fromisoformat("2022-12-31 22:00:00")), + time_start=utc.localize(datetime.fromisoformat("2022-01-31 00:00:00")), + time_end=utc.localize(datetime.fromisoformat("2023-02-01 00:00:00")), max_distance=10, min_hourly_price=5, ) diff --git a/examples/10-cleaning-demo/protocols/cleaning/__init__.py b/examples/10-cleaning-demo/protocols/cleaning/__init__.py index a7d3ed25..60a232c8 100644 --- a/examples/10-cleaning-demo/protocols/cleaning/__init__.py +++ b/examples/10-cleaning-demo/protocols/cleaning/__init__.py @@ -78,7 +78,7 @@ async def handle_query_request(ctx: Context, sender: str, msg: ServiceRequest): @cleaning_proto.on_message(model=ServiceBooking, replies=BookingResponse) async def handle_book_request(ctx: Context, sender: str, msg: ServiceBooking): - provider = await Provider.get(name=ctx.name) + provider = await Provider.filter(name=ctx.name).first() availability = await Availability.get(provider=provider) services = [int(service.type) for service in await provider.services] diff --git a/src/nexus/agent.py b/src/nexus/agent.py index 18f3a8c5..cdae0472 100644 --- a/src/nexus/agent.py +++ b/src/nexus/agent.py @@ -14,7 +14,7 @@ MsgDigest, ) from nexus.crypto import Identity, derive_key_from_seed, is_user_address -from nexus.dispatch import Sink, dispatcher +from nexus.dispatch import Sink, dispatcher, JsonStr from nexus.models import Model, ErrorMessage from nexus.protocol import Protocol from nexus.resolver import Resolver, AlmanacResolver @@ -57,7 +57,7 @@ def __init__( self._name = name self._intervals: List[Tuple[float, Any]] = [] self._port = port if port is not None else 8000 - self._background_tasks = set() + self._background_tasks: Set[asyncio.Task] = set() self._resolver = resolve if resolve is not None else AlmanacResolver() self._loop = asyncio.get_event_loop_policy().get_event_loop() if seed is None: @@ -111,11 +111,6 @@ def __init__( # register with the dispatcher self._dispatcher.register(self.address, self) - # start the background message queue processor - task = self._loop.create_task(self._process_message_queue()) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - self._server = ASGIServer(self._port, self._loop, self._queries) @property @@ -287,7 +282,7 @@ def include(self, protocol: Protocol): if protocol.digest is not None: self.protocols[protocol.canonical_name] = protocol.digest - async def handle_message(self, sender, schema_digest: str, message: Any): + async def handle_message(self, sender, schema_digest: str, message: JsonStr): await self._message_queue.put((schema_digest, sender, message)) async def startup(self): @@ -302,6 +297,11 @@ def setup(self): # register the internal agent protocol self.include(self._protocol) + # start the background message queue processor + task = self._loop.create_task(self._process_message_queue()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + # start the contract registration update loop self._loop.create_task( _run_interval(self.register, self._ctx, REG_UPDATE_INTERVAL_SECONDS) @@ -321,7 +321,7 @@ async def _process_message_queue(self): schema_digest, sender, message = await self._message_queue.get() # lookup the model definition - model_class = self._models.get(schema_digest) + model_class: Model = self._models.get(schema_digest) if model_class is None: continue diff --git a/src/nexus/asgi.py b/src/nexus/asgi.py index c1a3b017..c63a4c01 100644 --- a/src/nexus/asgi.py +++ b/src/nexus/asgi.py @@ -63,7 +63,7 @@ async def __call__(self, scope, receive, send): return headers = dict(scope.get("headers", {})) - if headers[b"content-type"] != b"application/json": + if b"application/json" not in headers[b"content-type"]: await send( { "type": "http.response.start", @@ -143,14 +143,15 @@ async def __call__(self, scope, receive, send): return await dispatcher.dispatch( - env.sender, env.target, env.protocol, env.decode_payload() + env.sender, env.target, env.protocol, json.dumps(env.decode_payload()) ) # wait for any queries to be resolved if expects_response: response_msg: Model = await self._queries[env.sender] - if datetime.now() > datetime.fromtimestamp(env.expires): - response_msg = ErrorMessage("Query envelope expired") + if env.expires is not None: + if datetime.now() > datetime.fromtimestamp(env.expires): + response_msg = ErrorMessage("Query envelope expired") sender = env.target response = enclose_response(response_msg, sender, env.session) else: diff --git a/src/nexus/context.py b/src/nexus/context.py index 386966a8..d14cb8ee 100644 --- a/src/nexus/context.py +++ b/src/nexus/context.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import uuid from dataclasses import dataclass @@ -10,7 +11,7 @@ from cosmpy.aerial.wallet import LocalWallet from nexus.config import DEFAULT_ENVELOPE_TIMEOUT_SECONDS -from nexus.crypto import Identity, is_user_address +from nexus.crypto import Identity from nexus.dispatch import dispatcher from nexus.envelope import Envelope from nexus.models import Model, ErrorMessage @@ -105,15 +106,12 @@ async def send( # handle local dispatch of messages if dispatcher.contains(destination): await dispatcher.dispatch( - self.address, destination, schema_digest, json_message + self.address, destination, schema_digest, json.dumps(json_message) ) return # handle queries waiting for a response - if is_user_address(destination): - if destination not in self._queries: - logging.exception(f"Unable to resolve query to user {destination}") - return + if destination in self._queries: self._queries[destination].set_result(message) del self._queries[destination] return diff --git a/src/nexus/dispatch.py b/src/nexus/dispatch.py index 072c82e9..50bf145d 100644 --- a/src/nexus/dispatch.py +++ b/src/nexus/dispatch.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod -from typing import Dict, Set, Any +from typing import Dict, Set + +JsonStr = str class Sink(ABC): @abstractmethod - async def handle_message(self, sender: str, schema_digest: str, message: Any): + async def handle_message(self, sender: str, schema_digest: str, message: JsonStr): pass @@ -26,7 +28,7 @@ def contains(self, address: str) -> bool: return address in self._sinks async def dispatch( - self, sender: str, destination: str, schema_digest: str, message: Any + self, sender: str, destination: str, schema_digest: str, message: JsonStr ): for handler in self._sinks.get(destination, set()): await handler.handle_message(sender, schema_digest, message) diff --git a/src/nexus/query.py b/src/nexus/query.py index 7de3b99d..614e7411 100644 --- a/src/nexus/query.py +++ b/src/nexus/query.py @@ -73,5 +73,5 @@ def enclose_response(message: Model, sender: str, session: str) -> dict: session=session, protocol=Model.build_schema_digest(message), ) - response_env.encode_payload(message.json()) + response_env.encode_payload(message.dict()) return response_env.json() diff --git a/tests/test_agent_registration.py b/tests/test_agent_registration.py new file mode 100644 index 00000000..dbbeaa01 --- /dev/null +++ b/tests/test_agent_registration.py @@ -0,0 +1,46 @@ +# pylint: disable=protected-access +import unittest + +from nexus import Agent +from nexus.setup import fund_agent_if_low + + +class TestVerify(unittest.TestCase): + def test_agent_registration(self): + + agent = Agent(name="alice") + + reg_fee = "500000000000000000atestfet" + + fund_agent_if_low(agent.wallet.address()) + + sequence = agent.get_registration_sequence() + + signature = agent._identity.sign_registration( + agent._reg_contract.address, agent.get_registration_sequence() + ) + + msg = { + "register": { + "record": { + "service": { + "protocols": [], + "endpoints": [{"url": agent._endpoint, "weight": 1}], + } + }, + "signature": signature, + "sequence": sequence, + "agent_address": agent.address, + } + } + + transaction = agent._reg_contract.execute(msg, agent.wallet, funds=reg_fee) + transaction.wait_to_complete() + + is_registered = agent.registration_status() + + self.assertEqual(is_registered, True, "Registration failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 00000000..59e379f0 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,464 @@ +# pylint: disable=protected-access +import asyncio +import unittest +import uuid +from unittest.mock import patch, AsyncMock, call + +from nexus import Agent, Model +from nexus.envelope import Envelope +from nexus.crypto import generate_user_address +from nexus.query import enclose_response + + +class Message(Model): + message: str + + +class TestServer(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.agent = Agent(name="alice", seed="alice recovery password") + self.bob = Agent(name="bob", seed="bob recovery password") + self.loop: asyncio.BaseEventLoop = self._asyncioTestLoop + return super().setUp() + + async def mock_process_sync_message(self, sender: str, msg: Model): + while True: + if sender in self.agent._server._queries: + self.agent._server._queries[sender].set_result(msg) + return + + async def test_message_success(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.bob._identity) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b"{}", + } + ), + ] + ) + + async def test_message_success_unsigned(self): + message = Message(message="hello") + user = generate_user_address() + session = uuid.uuid4() + env = Envelope( + version=1, + sender=user, + target=self.agent.address, + session=session, + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b"{}", + } + ), + ] + ) + + async def test_message_success_sync(self): + message = Message(message="hello") + reply = Message(message="hey") + user = generate_user_address() + session = uuid.uuid4() + env = Envelope( + version=1, + sender=user, + target=self.agent.address, + session=session, + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await asyncio.gather( + self.loop.create_task( + self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={ + b"content-type": b"application/json", + b"x-uagents-connection": b"sync", + }, + ), + receive=None, + send=mock_send, + ) + ), + self.loop.create_task(self.mock_process_sync_message(user, reply)), + ) + response = enclose_response(reply, self.agent.address, session) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": response.encode(), + } + ), + ] + ) + + async def test_message_success_sync_unsigned(self): + message = Message(message="hello") + reply = Message(message="hey") + session = uuid.uuid4() + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=session, + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.bob._identity) + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await asyncio.gather( + self.loop.create_task( + self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={ + b"content-type": b"application/json", + b"x-uagents-connection": b"sync", + }, + ), + receive=None, + send=mock_send, + ) + ), + self.loop.create_task( + self.mock_process_sync_message(self.bob.address, reply) + ), + ) + response = enclose_response(reply, self.agent.address, session) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": response.encode(), + } + ), + ] + ) + + async def test_message_fail_wrong_path(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.bob._identity) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/bad/path", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "not found"}', + } + ), + ] + ) + + async def test_message_fail_wrong_headers(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.bob._identity) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/badapp"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 400, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "invalid format"}', + } + ), + ] + ) + + async def test_message_fail_bad_data(self): + message = Message(message="hello") + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = message.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 400, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "invalid format"}', + } + ), + ] + ) + + async def test_message_fail_unsigned(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 400, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "unable to verify payload"}', + } + ), + ] + ) + + async def test_message_fail_verify(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=self.agent.address, + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.agent._identity) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 400, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "unable to verify payload"}', + } + ), + ] + ) + + async def test_message_fail_dispatch(self): + message = Message(message="hello") + env = Envelope( + version=1, + sender=self.bob.address, + target=generate_user_address(), + session=uuid.uuid4(), + protocol=Model.build_schema_digest(message), + ) + env.encode_payload(message.json()) + env.sign(self.bob._identity) + + mock_send = AsyncMock() + with patch("nexus.asgi._read_asgi_body") as mock_receive: + mock_receive.return_value = env.json().encode() + await self.agent._server( + scope=dict( + type="http", + path="/submit", + headers={b"content-type": b"application/json"}, + ), + receive=None, + send=mock_send, + ) + mock_send.assert_has_calls( + [ + call( + { + "type": "http.response.start", + "status": 400, + "headers": [[b"content-type", b"application/json"]], + } + ), + call( + { + "type": "http.response.body", + "body": b'{"error": "unable to route envelope"}', + } + ), + ] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_verify_msg.py b/tests/test_verify_msg.py index a078d060..555f0413 100644 --- a/tests/test_verify_msg.py +++ b/tests/test_verify_msg.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import unittest @@ -12,7 +13,8 @@ def encode(message: str) -> bytes: class TestVerify(unittest.TestCase): - def test_verify_message(self): + def test_sign_and_verify_message(self): + asyncio.set_event_loop(asyncio.new_event_loop()) alice = Agent(name="alice", seed="alice recovery password") alice_msg = "hello there bob" @@ -25,6 +27,22 @@ def test_verify_message(self): self.assertEqual(result, True, "Verification failed") + def test_verify_dart_digest(self): + + # Generate public key + address = "agent1qf5gfqm48k9acegez3sg82ney2aa6l5fvpwh3n3z0ajh0nam3ssgwnn5me7" + + # Signature + signature = "sig1qyvn5fjzrhjzqcmj2gfg4us6xj00gvscs4u9uqxy6wpvp9agxjf723eh5l6w878p67lycgd3fz77zr3h0q6mrheg48e35zsvv0rm2tsuvyn3l" # pylint: disable=line-too-long + + # Message + dart_digest = "a29af8b704077d394a9756dc04f0bb5f1424fc391b3de91144d683c5893ca234" + bytes_dart_digest = bytes.fromhex(dart_digest) + + result = Identity.verify_digest(address, bytes_dart_digest, signature) + + self.assertEqual(result, True, "Verification failed") + if __name__ == "__main__": unittest.main()