diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 0e5c68dca8..f69648853f 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -6,7 +6,12 @@ import uuid from aiohttp import web -from aiohttp_apispec import docs, response_schema, setup_aiohttp_apispec +from aiohttp_apispec import ( + docs, + response_schema, + setup_aiohttp_apispec, + validation_middleware, +) import aiohttp_cors from marshmallow import fields, Schema @@ -149,7 +154,7 @@ def __init__( async def make_application(self) -> web.Application: """Get the aiohttp application instance.""" - middlewares = [] + middlewares = [validation_middleware] admin_api_key = self.context.settings.get("admin.admin_api_key") admin_insecure_mode = self.context.settings.get("admin.admin_insecure_mode") @@ -203,11 +208,11 @@ async def collect_stats(request, handler): app.add_routes( [ - web.get("/", self.redirect_handler), - web.get("/plugins", self.plugins_handler), - web.get("/status", self.status_handler), + web.get("/", self.redirect_handler, allow_head=False), + web.get("/plugins", self.plugins_handler, allow_head=False), + web.get("/status", self.status_handler, allow_head=False), web.post("/status/reset", self.status_reset_handler), - web.get("/ws", self.websocket_handler), + web.get("/ws", self.websocket_handler, allow_head=False), ] ) diff --git a/aries_cloudagent/admin/tests/test_admin_server.py b/aries_cloudagent/admin/tests/test_admin_server.py index 981e6dfacd..d6eccee917 100644 --- a/aries_cloudagent/admin/tests/test_admin_server.py +++ b/aries_cloudagent/admin/tests/test_admin_server.py @@ -4,6 +4,7 @@ from aiohttp import web from asynctest import TestCase as AsyncTestCase from asynctest.mock import patch +from asynctest.mock import CoroutineMock, patch from ...config.default_context import DefaultContextBuilder from ...config.injection_context import InjectionContext @@ -12,7 +13,7 @@ from ...core.protocol_registry import ProtocolRegistry from ...transport.outbound.message import OutboundMessage -from ..server import AdminServer +from ..server import AdminServer, AdminSetupError class TestAdminServerBasic(AsyncTestCase): @@ -64,6 +65,11 @@ async def test_start_stop(self): await server.start() await server.stop() + with patch.object(web.TCPSite, "start", CoroutineMock()) as mock_start: + mock_start.side_effect = OSError("Failure to launch") + with self.assertRaises(AdminSetupError): + await self.get_admin_server(settings).start() + async def test_responder_send(self): message = OutboundMessage(payload="{}") admin_server = self.get_admin_server() @@ -75,7 +81,11 @@ async def test_responder_webhook(self): admin_server = self.get_admin_server() test_url = "target_url" test_attempts = 99 - admin_server.add_webhook_target(test_url, max_attempts=test_attempts) + admin_server.add_webhook_target( + target_url=test_url, + topic_filter=["*"], # cover vacuous filter + max_attempts=test_attempts, + ) test_topic = "test_topic" test_payload = {"test": "TEST"} await admin_server.responder.send_webhook(test_topic, test_payload) @@ -83,6 +93,9 @@ async def test_responder_webhook(self): (test_topic, test_payload, test_url, test_attempts) ] + admin_server.remove_webhook_target(target_url=test_url) + assert test_url not in admin_server.webhook_targets + async def test_import_routes(self): # this test just imports all default admin routes # for routes with associated tests, this shouldn't make a difference in coverage diff --git a/aries_cloudagent/connections/models/connection_record.py b/aries_cloudagent/connections/models/connection_record.py index 882d293320..6161f55c42 100644 --- a/aries_cloudagent/connections/models/connection_record.py +++ b/aries_cloudagent/connections/models/connection_record.py @@ -1,7 +1,6 @@ """Handle connection information interface with non-secrets storage.""" -from marshmallow import fields -from marshmallow.validate import OneOf +from marshmallow import fields, validate from ...config.injection_context import InjectionContext from ...messaging.models.base_record import BaseRecord, BaseRecordSchema @@ -298,7 +297,7 @@ class Meta: required=False, description="Connection initiator: self, external, or multiuse", example=ConnectionRecord.INITIATOR_SELF, - validate=OneOf(["self", "external", "multiuse"]), + validate=validate.OneOf(["self", "external", "multiuse"]), ) invitation_key = fields.Str( required=False, description="Public key for connection", **INDY_RAW_PUBLIC_KEY @@ -317,7 +316,7 @@ class Meta: required=False, description="Connection acceptance: manual or auto", example=ConnectionRecord.ACCEPT_AUTO, - validate=OneOf(["manual", "auto"]), + validate=validate.OneOf(["manual", "auto"]), ) error_msg = fields.Str( required=False, @@ -328,7 +327,7 @@ class Meta: required=False, description="Invitation mode: once, multi, or static", example=ConnectionRecord.INVITATION_MODE_ONCE, - validate=OneOf(["once", "multi", "static"]), + validate=validate.OneOf(["once", "multi", "static"]), ) alias = fields.Str( required=False, diff --git a/aries_cloudagent/holder/routes.py b/aries_cloudagent/holder/routes.py index 2df0fec8d9..695d16f9f5 100644 --- a/aries_cloudagent/holder/routes.py +++ b/aries_cloudagent/holder/routes.py @@ -3,11 +3,19 @@ import json from aiohttp import web -from aiohttp_apispec import docs, response_schema +from aiohttp_apispec import docs, match_info_schema, querystring_schema, response_schema from marshmallow import fields, Schema -from .base import BaseHolder -from ..messaging.valid import INDY_CRED_DEF_ID, INDY_REV_REG_ID, INDY_SCHEMA_ID +from .base import BaseHolder, HolderError +from ..messaging.valid import ( + INDY_CRED_DEF_ID, + INDY_REV_REG_ID, + INDY_SCHEMA_ID, + INDY_WQL, + NATURAL_NUM, + WHOLE_NUM, + UUIDFour, +) from ..wallet.error import WalletNotFoundError @@ -60,13 +68,32 @@ class CredentialSchema(Schema): witness = fields.Nested(WitnessSchema) -class CredentialListSchema(Schema): +class CredentialsListSchema(Schema): """Result schema for a credential query.""" results = fields.List(fields.Nested(CredentialSchema())) +class CredentialsListQueryStringSchema(Schema): + """Parameters and validators for query string with DID only.""" + + start = fields.Int(description="Start index", required=False, **WHOLE_NUM,) + count = fields.Int( + description="Maximum number to retrieve", required=False, **NATURAL_NUM, + ) + wql = fields.Str(description="(JSON) WQL query", required=False, **INDY_WQL,) + + +class CredIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking credential id.""" + + credential_id = fields.Str( + description="Credential identifier", required=True, example=UUIDFour.EXAMPLE + ) + + @docs(tags=["credentials"], summary="Fetch a credential from wallet by id") +@match_info_schema(CredIdMatchInfoSchema()) @response_schema(CredentialSchema(), 200) async def credentials_get(request: web.BaseRequest): """ @@ -81,7 +108,7 @@ async def credentials_get(request: web.BaseRequest): """ context = request.app["request_context"] - credential_id = request.match_info["id"] + credential_id = request.match_info["credential_id"] holder: BaseHolder = await context.inject(BaseHolder) try: @@ -94,6 +121,7 @@ async def credentials_get(request: web.BaseRequest): @docs(tags=["credentials"], summary="Remove a credential from the wallet by id") +@match_info_schema(CredIdMatchInfoSchema()) async def credentials_remove(request: web.BaseRequest): """ Request handler for searching connection records. @@ -107,7 +135,7 @@ async def credentials_remove(request: web.BaseRequest): """ context = request.app["request_context"] - credential_id = request.match_info["id"] + credential_id = request.match_info["credential_id"] holder: BaseHolder = await context.inject(BaseHolder) try: @@ -119,25 +147,10 @@ async def credentials_remove(request: web.BaseRequest): @docs( - tags=["credentials"], - parameters=[ - { - "name": "start", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "count", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - {"name": "wql", "in": "query", "schema": {"type": "string"}, "required": False}, - ], - summary="Fetch credentials from wallet", + tags=["credentials"], summary="Fetch credentials from wallet", ) -@response_schema(CredentialListSchema(), 200) +@querystring_schema(CredentialsListQueryStringSchema()) +@response_schema(CredentialsListSchema(), 200) async def credentials_list(request: web.BaseRequest): """ Request handler for searching credential records. @@ -163,7 +176,10 @@ async def credentials_list(request: web.BaseRequest): count = int(count) if isinstance(count, str) else 10 holder: BaseHolder = await context.inject(BaseHolder) - credentials = await holder.get_credentials(start, count, wql) + try: + credentials = await holder.get_credentials(start, count, wql) + except HolderError as x_holder: + raise web.HTTPBadRequest(reason=x_holder.message) return web.json_response({"results": credentials}) @@ -173,8 +189,8 @@ async def register(app: web.Application): app.add_routes( [ - web.get("/credential/{id}", credentials_get), - web.post("/credential/{id}/remove", credentials_remove), - web.get("/credentials", credentials_list), + web.get("/credential/{credential_id}", credentials_get, allow_head=False), + web.post("/credential/{credential_id}/remove", credentials_remove), + web.get("/credentials", credentials_list, allow_head=False), ] ) diff --git a/aries_cloudagent/ledger/routes.py b/aries_cloudagent/ledger/routes.py index 07121f5f45..fec9069be5 100644 --- a/aries_cloudagent/ledger/routes.py +++ b/aries_cloudagent/ledger/routes.py @@ -1,10 +1,11 @@ """Ledger admin routes.""" from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import docs, querystring_schema, request_schema, response_schema -from marshmallow import fields, Schema +from marshmallow import fields, Schema, validate +from ..messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY from .base import BaseLedger from .error import LedgerTransactionError @@ -55,31 +56,33 @@ class TAAAcceptSchema(Schema): mechanism = fields.Str() +class RegisterLedgerNymQueryStringSchema(Schema): + """Query string parameters and validators for register ledger nym request.""" + + did = fields.Str(description="DID to register", required=True, **INDY_DID,) + verkey = fields.Str( + description="Verification key", required=True, **INDY_RAW_PUBLIC_KEY + ) + alias = fields.Str(description="Alias", required=False, example="Barry",) + role = fields.Str( + description="Role", + required=False, + validate=validate.OneOf( + ["TRUSTEE", "STEWARD", "ENDORSER", "NETWORK_MONITOR", "reset"] + ), + ) + + +class QueryStringDIDSchema(Schema): + """Parameters and validators for query string with DID only.""" + + did = fields.Str(description="DID of interest", required=True, **INDY_DID) + + @docs( - tags=["ledger"], - summary="Send a NYM registration to the ledger.", - parameters=[ - {"name": "did", "in": "query", "schema": {"type": "string"}, "required": True}, - { - "name": "verkey", - "in": "query", - "schema": {"type": "string"}, - "required": True, - }, - { - "name": "alias", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "role", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - ], + tags=["ledger"], summary="Send a NYM registration to the ledger.", ) +@querystring_schema(RegisterLedgerNymQueryStringSchema()) async def register_ledger_nym(request: web.BaseRequest): """ Request handler for registering a NYM with the ledger. @@ -97,7 +100,11 @@ async def register_ledger_nym(request: web.BaseRequest): if not did or not verkey: raise web.HTTPBadRequest() - alias, role = request.query.get("alias"), request.query.get("role") + alias = request.query.get("alias") + role = request.query.get("role") + if role == "reset": # indy: empty to reset, null for regular user + role = "" # visually: confusing - correct 'reset' to empty string here + success = False async with ledger: try: @@ -109,12 +116,9 @@ async def register_ledger_nym(request: web.BaseRequest): @docs( - tags=["ledger"], - summary="Get the verkey for a DID from the ledger.", - parameters=[ - {"name": "did", "in": "query", "schema": {"type": "string"}, "required": True} - ], + tags=["ledger"], summary="Get the verkey for a DID from the ledger.", ) +@querystring_schema(QueryStringDIDSchema()) async def get_did_verkey(request: web.BaseRequest): """ Request handler for getting a verkey for a DID from the ledger. @@ -137,12 +141,9 @@ async def get_did_verkey(request: web.BaseRequest): @docs( - tags=["ledger"], - summary="Get the endpoint for a DID from the ledger.", - parameters=[ - {"name": "did", "in": "query", "schema": {"type": "string"}, "required": True} - ], + tags=["ledger"], summary="Get the endpoint for a DID from the ledger.", ) +@querystring_schema(QueryStringDIDSchema()) async def get_did_endpoint(request: web.BaseRequest): """ Request handler for getting a verkey for a DID from the ledger. @@ -233,9 +234,9 @@ async def register(app: web.Application): app.add_routes( [ web.post("/ledger/register-nym", register_ledger_nym), - web.get("/ledger/did-verkey", get_did_verkey), - web.get("/ledger/did-endpoint", get_did_endpoint), - web.get("/ledger/taa", ledger_get_taa), + web.get("/ledger/did-verkey", get_did_verkey, allow_head=False), + web.get("/ledger/did-endpoint", get_did_endpoint, allow_head=False), + web.get("/ledger/taa", ledger_get_taa, allow_head=False), web.post("/ledger/taa/accept", ledger_accept_taa), ] ) diff --git a/aries_cloudagent/ledger/tests/test_routes.py b/aries_cloudagent/ledger/tests/test_routes.py index ddca8a3c73..2f41745704 100644 --- a/aries_cloudagent/ledger/tests/test_routes.py +++ b/aries_cloudagent/ledger/tests/test_routes.py @@ -24,8 +24,7 @@ def setUp(self): self.test_endpoint = "http://localhost:8021" async def test_missing_ledger(self): - request = async_mock.MagicMock() - request.app = self.app + request = async_mock.MagicMock(app=self.app,) self.context.injector.clear_binding(BaseLedger) with self.assertRaises(HTTPForbidden): @@ -80,9 +79,10 @@ async def test_get_endpoint_no_did(self): await test_module.get_did_endpoint(request) async def test_register_nym(self): - request = async_mock.MagicMock() - request.app = self.app - request.query = {"did": self.test_did, "verkey": self.test_verkey} + request = async_mock.MagicMock( + app=self.app, + query={"did": self.test_did, "verkey": self.test_verkey, "role": "reset",}, + ) with async_mock.patch.object( test_module.web, "json_response", async_mock.Mock() ) as json_response: diff --git a/aries_cloudagent/messaging/credential_definitions/routes.py b/aries_cloudagent/messaging/credential_definitions/routes.py index 3ca1712c7b..85ee4d9597 100644 --- a/aries_cloudagent/messaging/credential_definitions/routes.py +++ b/aries_cloudagent/messaging/credential_definitions/routes.py @@ -3,7 +3,13 @@ from asyncio import shield from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) from marshmallow import fields, Schema @@ -13,7 +19,7 @@ from ..valid import INDY_CRED_DEF_ID, INDY_SCHEMA_ID, INDY_VERSION -from .util import CRED_DEF_TAGS, CRED_DEF_SENT_RECORD_TYPE +from .util import CredDefQueryStringSchema, CRED_DEF_TAGS, CRED_DEF_SENT_RECORD_TYPE class CredentialDefinitionSendRequestSchema(Schema): @@ -81,6 +87,16 @@ class CredentialDefinitionsCreatedResultsSchema(Schema): ) +class CredDefIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking cred def id.""" + + cred_def_id = fields.Str( + description="Credential definition identifier", + required=True, + **INDY_CRED_DEF_ID + ) + + @docs( tags=["credential-definition"], summary="Sends a credential definition to the ledger", @@ -124,17 +140,9 @@ async def credential_definitions_send_credential_definition(request: web.BaseReq @docs( tags=["credential-definition"], - parameters=[ - { - "name": tag, - "in": "query", - "schema": {"type": "string", "pattern": pat}, - "required": False, - } - for (tag, pat) in CRED_DEF_TAGS.items() - ], summary="Search for matching credential definitions that agent originated", ) +@querystring_schema(CredDefQueryStringSchema()) @response_schema(CredentialDefinitionsCreatedResultsSchema(), 200) async def credential_definitions_created(request: web.BaseRequest): """ @@ -166,6 +174,7 @@ async def credential_definitions_created(request: web.BaseRequest): tags=["credential-definition"], summary="Gets a credential definition from the ledger", ) +@match_info_schema(CredDefIdMatchInfoSchema()) @response_schema(CredentialDefinitionGetResultsSchema(), 200) async def credential_definitions_get_credential_definition(request: web.BaseRequest): """ @@ -180,7 +189,7 @@ async def credential_definitions_get_credential_definition(request: web.BaseRequ """ context = request.app["request_context"] - credential_definition_id = request.match_info["id"] + credential_definition_id = request.match_info["cred_def_id"] ledger: BaseLedger = await context.inject(BaseLedger) async with ledger: @@ -198,17 +207,16 @@ async def register(app: web.Application): web.post( "/credential-definitions", credential_definitions_send_credential_definition, - ) - ] - ) - app.add_routes( - [web.get("/credential-definitions/created", credential_definitions_created,)] - ) - app.add_routes( - [ + ), web.get( - "/credential-definitions/{id}", + "/credential-definitions/created", + credential_definitions_created, + allow_head=False, + ), + web.get( + "/credential-definitions/{cred_def_id}", credential_definitions_get_credential_definition, - ) + allow_head=False, + ), ] ) diff --git a/aries_cloudagent/messaging/credential_definitions/util.py b/aries_cloudagent/messaging/credential_definitions/util.py index 921c608225..e9e7a133bc 100644 --- a/aries_cloudagent/messaging/credential_definitions/util.py +++ b/aries_cloudagent/messaging/credential_definitions/util.py @@ -1,13 +1,37 @@ """Credential definition utilities.""" -from ..valid import IndySchemaId, IndyDID, IndyVersion, IndyCredDefId - -CRED_DEF_TAGS = { - "schema_id": IndySchemaId.PATTERN, - "schema_issuer_did": IndyDID.PATTERN, - "schema_name": "^.+$", - "schema_version": IndyVersion.PATTERN, - "issuer_did": IndyDID.PATTERN, - "cred_def_id": IndyCredDefId.PATTERN, -} +from marshmallow import fields, Schema + +from ..valid import ( + INDY_CRED_DEF_ID, + INDY_DID, + INDY_SCHEMA_ID, + INDY_VERSION, +) + + CRED_DEF_SENT_RECORD_TYPE = "cred_def_sent" + + +class CredDefQueryStringSchema(Schema): + """Query string parameters for credential definition searches.""" + + schema_id = fields.Str( + description="Schema identifier", required=False, **INDY_SCHEMA_ID, + ) + schema_issuer_did = fields.Str( + description="Schema issuer DID", required=False, **INDY_DID, + ) + schema_name = fields.Str( + description="Schema name", required=False, example="membership", + ) + schema_version = fields.Str( + description="Schema version", required=False, **INDY_VERSION + ) + issuer_did = fields.Str(description="Issuer DID", required=False, **INDY_DID,) + cred_def_id = fields.Str( + description="Credential definition id", required=False, **INDY_CRED_DEF_ID, + ) + + +CRED_DEF_TAGS = [tag for tag in vars(CredDefQueryStringSchema)["_declared_fields"]] diff --git a/aries_cloudagent/messaging/decorators/transport_decorator.py b/aries_cloudagent/messaging/decorators/transport_decorator.py index 08fc48bac3..0f9bb5c2de 100644 --- a/aries_cloudagent/messaging/decorators/transport_decorator.py +++ b/aries_cloudagent/messaging/decorators/transport_decorator.py @@ -4,8 +4,7 @@ This decorator allows changes to agent response behaviour and queue status updates. """ -from marshmallow import fields -from marshmallow.validate import OneOf +from marshmallow import fields, validate from ..models.base import BaseModel, BaseModelSchema from ..valid import UUIDFour @@ -52,7 +51,7 @@ class Meta: required=False, description="Return routing mode: none, all, or thread", example="all", - validate=OneOf(["none", "all", "thread"]), + validate=validate.OneOf(["none", "all", "thread"]), ) return_route_thread = fields.Str( required=False, diff --git a/aries_cloudagent/messaging/schemas/routes.py b/aries_cloudagent/messaging/schemas/routes.py index f200453bbc..d66e8cb278 100644 --- a/aries_cloudagent/messaging/schemas/routes.py +++ b/aries_cloudagent/messaging/schemas/routes.py @@ -3,15 +3,21 @@ from asyncio import shield from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) from marshmallow import fields, Schema from ...issuer.base import BaseIssuer from ...ledger.base import BaseLedger from ...storage.base import BaseStorage -from ..valid import INDY_SCHEMA_ID, INDY_VERSION -from .util import SCHEMA_SENT_RECORD_TYPE, SCHEMA_TAGS +from ..valid import NATURAL_NUM, INDY_SCHEMA_ID, INDY_VERSION +from .util import SchemaQueryStringSchema, SCHEMA_SENT_RECORD_TYPE, SCHEMA_TAGS class SchemaSendRequestSchema(Schema): @@ -49,7 +55,7 @@ class SchemaSchema(Schema): description="Schema attribute names", data_key="attrNames", ) - seqNo = fields.Integer(description="Schema sequence number", example=999) + seqNo = fields.Int(description="Schema sequence number", **NATURAL_NUM) class SchemaGetResultsSchema(Schema): @@ -66,6 +72,14 @@ class SchemasCreatedResultsSchema(Schema): ) +class SchemaIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking schema id.""" + + schema_id = fields.Str( + description="Schema identifier", required=True, **INDY_SCHEMA_ID, + ) + + @docs(tags=["schema"], summary="Sends a schema to the ledger") @request_schema(SchemaSendRequestSchema()) @response_schema(SchemaSendResultsSchema(), 200) @@ -101,18 +115,9 @@ async def schemas_send_schema(request: web.BaseRequest): @docs( - tags=["schema"], - parameters=[ - { - "name": tag, - "in": "query", - "schema": {"type": "string", "pattern": pat}, - "required": False, - } - for (tag, pat) in SCHEMA_TAGS.items() - ], - summary="Search for matching schema that agent originated", + tags=["schema"], summary="Search for matching schema that agent originated", ) +@querystring_schema(SchemaQueryStringSchema()) @response_schema(SchemasCreatedResultsSchema(), 200) async def schemas_created(request: web.BaseRequest): """ @@ -139,6 +144,7 @@ async def schemas_created(request: web.BaseRequest): @docs(tags=["schema"], summary="Gets a schema from the ledger") +@match_info_schema(SchemaIdMatchInfoSchema()) @response_schema(SchemaGetResultsSchema(), 200) async def schemas_get_schema(request: web.BaseRequest): """ @@ -153,7 +159,7 @@ async def schemas_get_schema(request: web.BaseRequest): """ context = request.app["request_context"] - schema_id = request.match_info["id"] + schema_id = request.match_info["schema_id"] ledger: BaseLedger = await context.inject(BaseLedger) async with ledger: @@ -165,5 +171,7 @@ async def schemas_get_schema(request: web.BaseRequest): async def register(app: web.Application): """Register routes.""" app.add_routes([web.post("/schemas", schemas_send_schema)]) - app.add_routes([web.get("/schemas/created", schemas_created)]) - app.add_routes([web.get("/schemas/{id}", schemas_get_schema)]) + app.add_routes([web.get("/schemas/created", schemas_created, allow_head=False)]) + app.add_routes( + [web.get("/schemas/{schema_id}", schemas_get_schema, allow_head=False)] + ) diff --git a/aries_cloudagent/messaging/schemas/util.py b/aries_cloudagent/messaging/schemas/util.py index 32377cce23..48d190cb6b 100644 --- a/aries_cloudagent/messaging/schemas/util.py +++ b/aries_cloudagent/messaging/schemas/util.py @@ -1,11 +1,30 @@ """Schema utilities.""" -from ..valid import IndySchemaId, IndyDID, IndyVersion - -SCHEMA_TAGS = { - "schema_id": IndySchemaId.PATTERN, - "schema_issuer_did": IndyDID.PATTERN, - "schema_name": "^.+$", - "schema_version": IndyVersion.PATTERN, -} +from marshmallow import fields, Schema + +from ..valid import ( + INDY_DID, + INDY_SCHEMA_ID, + INDY_VERSION, +) + + +class SchemaQueryStringSchema(Schema): + """Query string parameters for schema searches.""" + + schema_id = fields.Str( + description="Schema identifier", required=False, **INDY_SCHEMA_ID, + ) + schema_issuer_did = fields.Str( + description="Schema issuer DID", required=False, **INDY_DID, + ) + schema_name = fields.Str( + description="Schema name", required=False, example="membership", + ) + schema_version = fields.Str( + description="Schema version", required=False, **INDY_VERSION + ) + + +SCHEMA_TAGS = [tag for tag in vars(SchemaQueryStringSchema)["_declared_fields"]] SCHEMA_SENT_RECORD_TYPE = "schema_sent" diff --git a/aries_cloudagent/messaging/tests/test_valid.py b/aries_cloudagent/messaging/tests/test_valid.py index 27f9382774..9cfcf61c26 100644 --- a/aries_cloudagent/messaging/tests/test_valid.py +++ b/aries_cloudagent/messaging/tests/test_valid.py @@ -1,25 +1,33 @@ -from marshmallow import ValidationError +import json +import pytest + from unittest import TestCase +from marshmallow import ValidationError + from ..valid import ( BASE58_SHA256_HASH, BASE64, BASE64URL, BASE64URL_NO_PAD, DID_KEY, + ENDPOINT, INDY_CRED_DEF_ID, INDY_DID, + INT_EPOCH, INDY_ISO8601_DATETIME, INDY_PREDICATE, INDY_RAW_PUBLIC_KEY, INDY_REV_REG_ID, INDY_SCHEMA_ID, INDY_VERSION, - INT_EPOCH, + INDY_WQL, + NATURAL_NUM, JWS_HEADER_KID, JWT, SHA256, UUID4, + WHOLE_NUM, ) @@ -35,6 +43,26 @@ def test_epoch(self): INT_EPOCH["validate"](-9223372036854775808) INT_EPOCH["validate"](9223372036854775807) + def test_whole(self): + non_wholes = [-9223372036854775809, 2.3, "Hello", None] + for non_whole in non_wholes: + with self.assertRaises(ValidationError): + WHOLE_NUM["validate"](non_whole) + + WHOLE_NUM["validate"](0) + WHOLE_NUM["validate"](1) + WHOLE_NUM["validate"](12345678901234567890) + + def test_natural(self): + non_naturals = [-9223372036854775809, 2.3, "Hello", 0, None] + for non_natural in non_naturals: + with self.assertRaises(ValidationError): + NATURAL_NUM["validate"](non_natural) + + NATURAL_NUM["validate"](1) + NATURAL_NUM["validate"](2) + NATURAL_NUM["validate"](12345678901234567890) + def test_indy_did(self): non_indy_dids = [ "Q4zqM7aXqm7gDQkUVLng9I", # 'I' not a base58 char @@ -233,6 +261,27 @@ def test_indy_date(self): INDY_ISO8601_DATETIME["validate"]("2020-01-01 00:00:00.1-00:00") INDY_ISO8601_DATETIME["validate"]("2020-01-01 00:00:00.123456-00:00") + def test_indy_wql(self): + non_wqls = [ + "nope", + "[a, b, c]", + "{1, 2, 3}", + set(), + '"Hello World"', + None, + "null", + "true", + False, + ] + for non_wql in non_wqls: + with self.assertRaises(ValidationError): + INDY_WQL["validate"](non_wql) + + INDY_WQL["validate"](json.dumps({})) + INDY_WQL["validate"](json.dumps({"a": "1234"})) + INDY_WQL["validate"](json.dumps({"a": "1234", "b": {"$not": "0"}})) + INDY_WQL["validate"](json.dumps({"$or": {"a": "1234", "b": "0"}})) + def test_base64(self): non_base64s = [ "####", @@ -313,3 +362,24 @@ def test_uuid4(self): UUID4["validate"]("3fa85f64-5717-4562-b3fc-2c963f66afa6") UUID4["validate"]("3FA85F64-5717-4562-B3FC-2C963F66AFA6") # upper case OK + + def test_endpoint(self): + non_endpoints = [ + "123", + "", + "/path/only", + "https://1.2.3.4?query=true&url=false", + "http://no_tld/bad", + "no-proto:8080/my/path", + "smtp:8080/my/path#fragment", + ] + + for non_endpoint in non_endpoints: + with self.assertRaises(ValidationError): + ENDPOINT["validate"](non_endpoint) + + ENDPOINT["validate"]("http://github.com") + ENDPOINT["validate"]("https://localhost:8080") + ENDPOINT["validate"]("newproto://myhost.ca:8080/path") + ENDPOINT["validate"]("ftp://10.10.100.90:8021") + ENDPOINT["validate"]("zzzp://someplace.ca:9999/path") diff --git a/aries_cloudagent/messaging/valid.py b/aries_cloudagent/messaging/valid.py index e7d1fe42a6..d6da19d1a6 100644 --- a/aries_cloudagent/messaging/valid.py +++ b/aries_cloudagent/messaging/valid.py @@ -1,9 +1,12 @@ """Validators for schema fields.""" +import json + from datetime import datetime from base58 import alphabet from marshmallow.validate import OneOf, Range, Regexp +from marshmallow.exceptions import ValidationError from .util import epoch_to_str @@ -25,6 +28,42 @@ def __init__(self): ) +class WholeNumber(Range): + """Validate value as non-negative integer.""" + + EXAMPLE = 0 + + def __init__(self): + """Initializer.""" + + super().__init__(min=0, error="Value {input} is not a non-negative integer") + + def __call__(self, value): + """Validate input value.""" + + if type(value) != int: + raise ValidationError("Value {input} is not a valid whole number") + super().__call__(value) + + +class NaturalNumber(Range): + """Validate value as positive integer.""" + + EXAMPLE = 10 + + def __init__(self): + """Initializer.""" + + super().__init__(min=1, error="Value {input} is not a positive integer") + + def __call__(self, value): + """Validate input value.""" + + if type(value) != int: + raise ValidationError("Value {input} is not a valid natural number") + super().__call__(value) + + class JWSHeaderKid(Regexp): """Validate value against JWS header kid.""" @@ -205,6 +244,33 @@ def __init__(self): ) +class IndyWQL(Regexp): # using Regexp brings in nice visual validator cue + """Validate value as potential WQL query.""" + + EXAMPLE = json.dumps({"name": "Alex"}) + PATTERN = r"^{.*}$" + + def __init__(self): + """Initializer.""" + + super().__init__( + IndyWQL.PATTERN, error="Value {input} is not a valid WQL query", + ) + + def __call__(self, value): + """Validate input value.""" + + super().__call__(value or "") + message = "Value {input} is not a valid WQL query".format(input=value) + + try: + json.loads(value) + except Exception: + raise ValidationError(message) + + return value + + class Base64(Regexp): """Validate base64 value.""" @@ -299,8 +365,29 @@ def __init__(self): ) +class Endpoint(Regexp): # using Regexp brings in nice visual validator cue + """Validate value against endpoint URL on any scheme.""" + + EXAMPLE = "https://myhost:8021" + PATTERN = ( + r"^[A-Za-z0-9\.\-\+]+:" # scheme + r"//([A-Za-z0-9][.A-Za-z0-9-]+[A-Za-z0-9])+" # host + r"(:[1-9][0-9]*)?" # port + r"(/[^?&#]+)?$" # path + ) + + def __init__(self): + """Initializer.""" + + super().__init__( + Endpoint.PATTERN, error="Value {input} is not a valid endpoint", + ) + + # Instances for marshmallow schema specification INT_EPOCH = {"validate": IntEpoch(), "example": IntEpoch.EXAMPLE} +WHOLE_NUM = {"validate": WholeNumber(), "example": WholeNumber.EXAMPLE} +NATURAL_NUM = {"validate": NaturalNumber(), "example": NaturalNumber.EXAMPLE} JWS_HEADER_KID = {"validate": JWSHeaderKid(), "example": JWSHeaderKid.EXAMPLE} JWT = {"validate": JSONWebToken(), "example": JSONWebToken.EXAMPLE} DID_KEY = {"validate": DIDKey(), "example": DIDKey.EXAMPLE} @@ -318,13 +405,14 @@ def __init__(self): "validate": IndyISO8601DateTime(), "example": IndyISO8601DateTime.EXAMPLE, } +INDY_WQL = {"validate": IndyWQL(), "example": IndyWQL.EXAMPLE} BASE64 = {"validate": Base64(), "example": Base64.EXAMPLE} BASE64URL = {"validate": Base64URL(), "example": Base64URL.EXAMPLE} BASE64URL_NO_PAD = {"validate": Base64URLNoPad(), "example": Base64URLNoPad.EXAMPLE} - SHA256 = {"validate": SHA256Hash(), "example": SHA256Hash.EXAMPLE} BASE58_SHA256_HASH = { "validate": Base58SHA256Hash(), "example": Base58SHA256Hash.EXAMPLE, } UUID4 = {"validate": UUIDFour(), "example": UUIDFour.EXAMPLE} +ENDPOINT = {"validate": Endpoint(), "example": Endpoint.EXAMPLE} diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py index fad1a86e1f..18273b9882 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py @@ -60,7 +60,6 @@ class Meta: fields.Nested(MenuFormParamSchema()), required=False, description="List of form parameters", - example="[alpha, x_offset, y_offset, height, width, bgcolor, fgcolor]", ) submit_label = fields.Str( required=False, diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/routes.py b/aries_cloudagent/protocols/actionmenu/v1_0/routes.py index 374f114c4e..257a6e42fc 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/routes.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/routes.py @@ -3,7 +3,7 @@ import logging from aiohttp import web -from aiohttp_apispec import docs, request_schema +from aiohttp_apispec import docs, match_info_schema, request_schema from marshmallow import fields, Schema @@ -61,9 +61,18 @@ class SendMenuSchema(Schema): ) +class ConnIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + @docs( tags=["action-menu"], summary="Close the active menu associated with a connection" ) +@match_info_schema(ConnIdMatchInfoSchema()) async def actionmenu_close(request: web.BaseRequest): """ Request handler for closing the menu associated with a connection. @@ -73,7 +82,7 @@ async def actionmenu_close(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] menu = await retrieve_connection_menu(connection_id, context) if not menu: @@ -84,6 +93,7 @@ async def actionmenu_close(request: web.BaseRequest): @docs(tags=["action-menu"], summary="Fetch the active menu") +@match_info_schema(ConnIdMatchInfoSchema()) async def actionmenu_fetch(request: web.BaseRequest): """ Request handler for fetching the previously-received menu for a connection. @@ -93,7 +103,7 @@ async def actionmenu_fetch(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] menu = await retrieve_connection_menu(connection_id, context) result = {"result": menu.serialize() if menu else None} @@ -101,6 +111,7 @@ async def actionmenu_fetch(request: web.BaseRequest): @docs(tags=["action-menu"], summary="Perform an action associated with the active menu") +@match_info_schema(ConnIdMatchInfoSchema()) @request_schema(PerformRequestSchema()) async def actionmenu_perform(request: web.BaseRequest): """ @@ -111,7 +122,7 @@ async def actionmenu_perform(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] params = await request.json() @@ -129,6 +140,7 @@ async def actionmenu_perform(request: web.BaseRequest): @docs(tags=["action-menu"], summary="Request the active menu") +@match_info_schema(ConnIdMatchInfoSchema()) async def actionmenu_request(request: web.BaseRequest): """ Request handler for requesting a menu from the connection target. @@ -138,7 +150,7 @@ async def actionmenu_request(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] try: @@ -156,6 +168,7 @@ async def actionmenu_request(request: web.BaseRequest): @docs(tags=["action-menu"], summary="Send an action menu to a connection") +@match_info_schema(ConnIdMatchInfoSchema()) @request_schema(SendMenuSchema()) async def actionmenu_send(request: web.BaseRequest): """ @@ -166,7 +179,7 @@ async def actionmenu_send(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] menu_json = await request.json() LOGGER.debug("Received send-menu request: %s %s", connection_id, menu_json) @@ -196,10 +209,10 @@ async def register(app: web.Application): app.add_routes( [ - web.post("/action-menu/{id}/close", actionmenu_close), - web.post("/action-menu/{id}/fetch", actionmenu_fetch), - web.post("/action-menu/{id}/perform", actionmenu_perform), - web.post("/action-menu/{id}/request", actionmenu_request), - web.post("/connections/{id}/send-menu", actionmenu_send), + web.post("/action-menu/{conn_id}/close", actionmenu_close), + web.post("/action-menu/{conn_id}/fetch", actionmenu_fetch), + web.post("/action-menu/{conn_id}/perform", actionmenu_perform), + web.post("/action-menu/{conn_id}/request", actionmenu_request), + web.post("/connections/{conn_id}/send-menu", actionmenu_send), ] ) diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/actionmenu/v1_0/tests/test_routes.py index 12fdefeb76..652ccd7c99 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/tests/test_routes.py @@ -56,7 +56,7 @@ async def test_actionmenu_fetch(self): mock_response.assert_called_once_with({"result": None}) async def test_actionmenu_perform(self): - mock_request = async_mock.MagicMock() + mock_request = async_mock.MagicMock(match_info={"conn_id": "dummy"}) mock_request.json = async_mock.CoroutineMock() mock_request.app = { @@ -77,7 +77,8 @@ async def test_actionmenu_perform(self): res = await test_module.actionmenu_perform(mock_request) mock_response.assert_called_once_with({}) mock_request.app["outbound_message_router"].assert_called_once_with( - mock_perform.return_value, connection_id=mock_request.match_info["id"] + mock_perform.return_value, + connection_id=mock_request.match_info["conn_id"], ) async def test_actionmenu_perform_no_conn_record(self): @@ -126,7 +127,7 @@ async def test_actionmenu_perform_conn_not_ready(self): await test_module.actionmenu_perform(mock_request) async def test_actionmenu_request(self): - mock_request = async_mock.MagicMock() + mock_request = async_mock.MagicMock(match_info={"conn_id": "dummy"}) mock_request.json = async_mock.CoroutineMock() mock_request.app = { @@ -147,7 +148,8 @@ async def test_actionmenu_request(self): res = await test_module.actionmenu_request(mock_request) mock_response.assert_called_once_with({}) mock_request.app["outbound_message_router"].assert_called_once_with( - menu_request.return_value, connection_id=mock_request.match_info["id"] + menu_request.return_value, + connection_id=mock_request.match_info["conn_id"], ) async def test_actionmenu_request_no_conn_record(self): @@ -196,7 +198,7 @@ async def test_actionmenu_request_conn_not_ready(self): await test_module.actionmenu_request(mock_request) async def test_actionmenu_send(self): - mock_request = async_mock.MagicMock() + mock_request = async_mock.MagicMock(match_info={"conn_id": "dummy"}) mock_request.json = async_mock.CoroutineMock() mock_request.app = { @@ -219,7 +221,7 @@ async def test_actionmenu_send(self): mock_response.assert_called_once_with({}) mock_request.app["outbound_message_router"].assert_called_once_with( mock_menu.deserialize.return_value, - connection_id=mock_request.match_info["id"], + connection_id=mock_request.match_info["conn_id"], ) async def test_actionmenu_send_deserialize_x(self): diff --git a/aries_cloudagent/protocols/basicmessage/v1_0/routes.py b/aries_cloudagent/protocols/basicmessage/v1_0/routes.py index 112196b9da..3484413f51 100644 --- a/aries_cloudagent/protocols/basicmessage/v1_0/routes.py +++ b/aries_cloudagent/protocols/basicmessage/v1_0/routes.py @@ -1,11 +1,12 @@ """Basic message admin routes.""" from aiohttp import web -from aiohttp_apispec import docs, request_schema +from aiohttp_apispec import docs, match_info_schema, request_schema from marshmallow import fields, Schema from aries_cloudagent.connections.models.connection_record import ConnectionRecord +from aries_cloudagent.messaging.valid import UUIDFour from aries_cloudagent.storage.error import StorageNotFoundError from .messages.basicmessage import BasicMessage @@ -17,7 +18,16 @@ class SendMessageSchema(Schema): content = fields.Str(description="Message content", example="Hello") +class ConnIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + @docs(tags=["basicmessage"], summary="Send a basic message to a connection") +@match_info_schema(ConnIdMatchInfoSchema()) @request_schema(SendMessageSchema()) async def connections_send_message(request: web.BaseRequest): """ @@ -28,7 +38,7 @@ async def connections_send_message(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] params = await request.json() @@ -48,5 +58,5 @@ async def register(app: web.Application): """Register routes.""" app.add_routes( - [web.post("/connections/{id}/send-message", connections_send_message)] + [web.post("/connections/{conn_id}/send-message", connections_send_message)] ) diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index 6c3287aba3..7afb30e136 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -73,7 +73,7 @@ async def create_invitation( my_label: str = None, my_endpoint: str = None, their_role: str = None, - accept: str = None, + auto_accept: bool = None, public: bool = False, multi_use: bool = False, alias: str = None, @@ -115,8 +115,9 @@ async def create_invitation( my_label: label for this connection my_endpoint: endpoint where other party can reach me their_role: a role to assign the connection - accept: set to 'auto' to auto-accept a corresponding connection request - public: set to True to create an invitation from the public DID + auto_accept: auto-accept a corresponding connection request + (None to use config) + public: set to create an invitation from the public DID multi_use: set to True to create an invitation for multiple use alias: optional alias to apply to connection for later use @@ -155,8 +156,17 @@ async def create_invitation( if not my_endpoint: my_endpoint = self.context.settings.get("default_endpoint") - if not accept and self.context.settings.get("debug.auto_accept_requests"): - accept = ConnectionRecord.ACCEPT_AUTO + accept = ( + ConnectionRecord.ACCEPT_AUTO + if ( + auto_accept + or ( + auto_accept is None + and self.context.settings.get("debug.auto_accept_requests") + ) + ) + else ConnectionRecord.ACCEPT_MANUAL + ) # Create and store new invitation key connection_key = await wallet.create_signing_key() @@ -188,7 +198,7 @@ async def receive_invitation( self, invitation: ConnectionInvitation, their_role: str = None, - accept: str = None, + auto_accept: bool = None, alias: str = None, ) -> ConnectionRecord: """ @@ -197,7 +207,7 @@ async def receive_invitation( Args: invitation: The `ConnectionInvitation` to store their_role: The role assigned to this connection - accept: set to 'auto' to auto-accept the invitation + auto_accept: set to auto-accept the invitation (None to use config) alias: optional alias to set on the record Returns: @@ -210,8 +220,17 @@ async def receive_invitation( if not invitation.endpoint: raise ConnectionManagerError("Invitation must contain an endpoint") - if accept is None and self.context.settings.get("debug.auto_accept_invites"): - accept = ConnectionRecord.ACCEPT_AUTO + accept = ( + ConnectionRecord.ACCEPT_AUTO + if ( + auto_accept + or ( + auto_accept is None + and self.context.settings.get("debug.auto_accept_invites") + ) + ) + else ConnectionRecord.ACCEPT_MANUAL + ) # Create connection record connection = ConnectionRecord( diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py index 556c510c98..5059dee06f 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py @@ -11,6 +11,7 @@ from ..message_types import CONNECTION_INVITATION, PROTOCOL_PACKAGE + HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers" ".connection_invitation_handler.ConnectionInvitationHandler" diff --git a/aries_cloudagent/protocols/connections/v1_0/routes.py b/aries_cloudagent/protocols/connections/v1_0/routes.py index 721565a0b7..c2fb341793 100644 --- a/aries_cloudagent/protocols/connections/v1_0/routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/routes.py @@ -1,15 +1,28 @@ """Connection handling admin routes.""" +import json + from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) -from marshmallow import fields, Schema +from marshmallow import fields, Schema, validate, validates_schema from aries_cloudagent.connections.models.connection_record import ( ConnectionRecord, ConnectionRecordSchema, ) -from aries_cloudagent.messaging.valid import IndyDID, UUIDFour +from aries_cloudagent.messaging.valid import ( + ENDPOINT, + INDY_DID, + INDY_RAW_PUBLIC_KEY, + UUIDFour, +) from aries_cloudagent.storage.error import StorageNotFoundError from .manager import ConnectionManager @@ -28,6 +41,14 @@ class ConnectionListSchema(Schema): ) +class ReceiveInvitationRequestSchema(ConnectionInvitationSchema): + """Request schema for receive invitation request.""" + + @validates_schema + def validate_fields(self, data, **kwargs): + """Bypass middleware field validation.""" + + class InvitationResultSchema(Schema): """Result schema for a new connection invitation.""" @@ -45,20 +66,14 @@ class ConnectionStaticRequestSchema(Schema): """Request schema for a new static connection.""" my_seed = fields.Str(description="Seed to use for the local DID", required=False) - my_did = fields.Str( - description="Local DID", required=False, example=IndyDID.EXAMPLE - ) + my_did = fields.Str(description="Local DID", required=False, **INDY_DID) their_seed = fields.Str( description="Seed to use for the remote DID", required=False ) - their_did = fields.Str( - description="Remote DID", required=False, example=IndyDID.EXAMPLE - ) + their_did = fields.Str(description="Remote DID", required=False, **INDY_DID) their_verkey = fields.Str(description="Remote verification key", required=False) their_endpoint = fields.Str( - description="URL endpoint for the other party", - required=False, - example="http://192.168.56.101:5000", + description="URL endpoint for the other party", required=False, **ENDPOINT ) their_role = fields.Str( description="Role to assign to this connection", required=False @@ -72,16 +87,113 @@ class ConnectionStaticRequestSchema(Schema): class ConnectionStaticResultSchema(Schema): """Result schema for new static connection.""" - my_did = fields.Str(description="Local DID", required=True, example=IndyDID.EXAMPLE) - mv_verkey = fields.Str(description="My verification key", required=True) - my_endpoint = fields.Str(description="My endpoint", required=True) - their_did = fields.Str( - description="Remote DID", required=True, example=IndyDID.EXAMPLE + my_did = fields.Str(description="Local DID", required=True, **INDY_DID) + mv_verkey = fields.Str( + description="My verification key", required=True, **INDY_RAW_PUBLIC_KEY + ) + my_endpoint = fields.Str(description="My URL endpoint", required=True, **ENDPOINT) + their_did = fields.Str(description="Remote DID", required=True, **INDY_DID) + their_verkey = fields.Str( + description="Remote verification key", required=True, **INDY_RAW_PUBLIC_KEY ) - their_verkey = fields.Str(description="Remote verification key", required=True) record = fields.Nested(ConnectionRecordSchema, required=True) +class ConnectionsListQueryStringSchema(Schema): + """Parameters and validators for connections list request query string.""" + + alias = fields.Str(description="Alias", required=False, example="Barry",) + initiator = fields.Str( + description="Connection initiator", + required=False, + validate=validate.OneOf(["self", "external"]), + ) + invitation_key = fields.Str( + description="invitation key", required=False, **INDY_RAW_PUBLIC_KEY + ) + my_did = fields.Str(description="My DID", required=False, **INDY_DID) + state = fields.Str( + description="Connection state", + required=False, + validate=validate.OneOf( + [ + getattr(ConnectionRecord, m) + for m in vars(ConnectionRecord) + if m.startswith("STATE_") + ] + ), + ) + their_did = fields.Str(description="Their DID", required=False, **INDY_DID) + their_role = fields.Str( + description="Their assigned connection role", + required=False, + example="Interlocutor", + ) + + +class CreateInvitationQueryStringSchema(Schema): + """Parameters and validators for create invitation request query string.""" + + alias = fields.Str(description="Alias", required=False, example="Barry",) + auto_accept = fields.Boolean( + description="Auto-accept connection (default as per configuration)", + required=False, + ) + public = fields.Boolean( + description="Create invitation from public DID (default false)", required=False + ) + multi_use = fields.Boolean( + description="Create invitation for multiple use (default false)", required=False + ) + + +class ReceiveInvitationQueryStringSchema(Schema): + """Parameters and validators for receive invitation request query string.""" + + alias = fields.Str(description="Alias", required=False, example="Barry",) + auto_accept = fields.Boolean( + description="Auto-accept connection (defaults to configuration)", + required=False, + ) + + +class AcceptInvitationQueryStringSchema(Schema): + """Parameters and validators for accept invitation request query string.""" + + my_endpoint = fields.Str(description="My URL endpoint", required=False, **ENDPOINT) + my_label = fields.Str( + description="Label for connection", required=False, example="Broker" + ) + + +class AcceptRequestQueryStringSchema(Schema): + """Parameters and validators for accept conn-request web-request query string.""" + + my_endpoint = fields.Str(description="My URL endpoint", required=False, **ENDPOINT) + + +class ConnIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + +class ConnIdRefIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection and ref ids.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + ref_id = fields.Str( + description="Inbound connection identifier", + required=True, + example=UUIDFour.EXAMPLE, + ) + + def connection_sort_key(conn): """Get the sorting key for a particular connection.""" if conn["state"] == ConnectionRecord.STATE_INACTIVE: @@ -94,64 +206,9 @@ def connection_sort_key(conn): @docs( - tags=["connection"], - summary="Query agent-to-agent connections", - parameters=[ - { - "name": "alias", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "initiator", - "in": "query", - "schema": {"type": "string", "enum": ["self", "external"]}, - "required": False, - }, - { - "name": "invitation_key", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "my_did", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "state", - "in": "query", - "schema": { - "type": "string", - "enum": [ - "init", - "invitation", - "request", - "response", - "active", - "error", - "inactive", - ], - }, - "required": False, - }, - { - "name": "their_did", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "their_role", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - ], + tags=["connection"], summary="Query agent-to-agent connections", ) +@querystring_schema(ConnectionsListQueryStringSchema()) @response_schema(ConnectionListSchema(), 200) async def connections_list(request: web.BaseRequest): """ @@ -190,6 +247,7 @@ async def connections_list(request: web.BaseRequest): @docs(tags=["connection"], summary="Fetch a single connection record") +@match_info_schema(ConnIdMatchInfoSchema()) @response_schema(ConnectionRecordSchema(), 200) async def connections_retrieve(request: web.BaseRequest): """ @@ -203,7 +261,7 @@ async def connections_retrieve(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] try: record = await ConnectionRecord.retrieve_by_id(context, connection_id) except StorageNotFoundError: @@ -212,30 +270,9 @@ async def connections_retrieve(request: web.BaseRequest): @docs( - tags=["connection"], - summary="Create a new connection invitation", - parameters=[ - { - "name": "alias", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "accept", - "in": "query", - "schema": {"type": "string", "enum": ["none", "auto"]}, - "required": False, - }, - {"name": "public", "in": "query", "schema": {"type": "int"}, "required": False}, - { - "name": "multi_use", - "in": "query", - "schema": {"type": "int"}, - "required": False, - }, - ], + tags=["connection"], summary="Create a new connection invitation", ) +@querystring_schema(CreateInvitationQueryStringSchema()) @response_schema(InvitationResultSchema(), 200) async def connections_create_invitation(request: web.BaseRequest): """ @@ -249,18 +286,18 @@ async def connections_create_invitation(request: web.BaseRequest): """ context = request.app["request_context"] - accept = request.query.get("accept") + auto_accept = json.loads(request.query.get("auto_accept", "null")) alias = request.query.get("alias") - public = request.query.get("public") - multi_use = request.query.get("multi_use") + public = json.loads(request.query.get("public", "false")) + multi_use = json.loads(request.query.get("multi_use", "false")) if public and not context.settings.get("public_invites"): raise web.HTTPForbidden() base_url = context.settings.get("invite_base_url") connection_mgr = ConnectionManager(context) - connection, invitation = await connection_mgr.create_invitation( - accept=accept, public=bool(public), multi_use=bool(multi_use), alias=alias + (connection, invitation) = await connection_mgr.create_invitation( + auto_accept=auto_accept, public=public, multi_use=multi_use, alias=alias ) result = { "connection_id": connection and connection.connection_id, @@ -275,24 +312,10 @@ async def connections_create_invitation(request: web.BaseRequest): @docs( - tags=["connection"], - summary="Receive a new connection invitation", - parameters=[ - { - "name": "alias", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "accept", - "in": "query", - "schema": {"type": "string", "enum": ["none", "auto"]}, - "required": False, - }, - ], + tags=["connection"], summary="Receive a new connection invitation", ) -@request_schema(ConnectionInvitationSchema()) +@querystring_schema(ReceiveInvitationQueryStringSchema()) +@request_schema(ReceiveInvitationRequestSchema()) @response_schema(ConnectionRecordSchema(), 200) async def connections_receive_invitation(request: web.BaseRequest): """ @@ -311,32 +334,20 @@ async def connections_receive_invitation(request: web.BaseRequest): connection_mgr = ConnectionManager(context) invitation_json = await request.json() invitation = ConnectionInvitation.deserialize(invitation_json) - accept = request.query.get("accept") + auto_accept = json.loads(request.query.get("auto_accept", "null")) alias = request.query.get("alias") + connection = await connection_mgr.receive_invitation( - invitation, accept=accept, alias=alias + invitation, auto_accept=auto_accept, alias=alias ) return web.json_response(connection.serialize()) @docs( - tags=["connection"], - summary="Accept a stored connection invitation", - parameters=[ - { - "name": "my_endpoint", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "my_label", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - ], + tags=["connection"], summary="Accept a stored connection invitation", ) +@match_info_schema(ConnIdMatchInfoSchema()) +@querystring_schema(AcceptInvitationQueryStringSchema()) @response_schema(ConnectionRecordSchema(), 200) async def connections_accept_invitation(request: web.BaseRequest): """ @@ -351,7 +362,7 @@ async def connections_accept_invitation(request: web.BaseRequest): """ context = request.app["request_context"] outbound_handler = request.app["outbound_message_router"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] try: connection = await ConnectionRecord.retrieve_by_id(context, connection_id) except StorageNotFoundError: @@ -365,17 +376,10 @@ async def connections_accept_invitation(request: web.BaseRequest): @docs( - tags=["connection"], - summary="Accept a stored connection request", - parameters=[ - { - "name": "my_endpoint", - "in": "query", - "schema": {"type": "string"}, - "required": False, - } - ], + tags=["connection"], summary="Accept a stored connection request", ) +@match_info_schema(ConnIdMatchInfoSchema()) +@querystring_schema(AcceptRequestQueryStringSchema()) @response_schema(ConnectionRecordSchema(), 200) async def connections_accept_request(request: web.BaseRequest): """ @@ -390,7 +394,7 @@ async def connections_accept_request(request: web.BaseRequest): """ context = request.app["request_context"] outbound_handler = request.app["outbound_message_router"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] try: connection = await ConnectionRecord.retrieve_by_id(context, connection_id) except StorageNotFoundError: @@ -405,6 +409,7 @@ async def connections_accept_request(request: web.BaseRequest): @docs( tags=["connection"], summary="Assign another connection as the inbound connection" ) +@match_info_schema(ConnIdRefIdMatchInfoSchema()) async def connections_establish_inbound(request: web.BaseRequest): """ Request handler for setting the inbound connection on a connection record. @@ -413,7 +418,7 @@ async def connections_establish_inbound(request: web.BaseRequest): request: aiohttp request object """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] inbound_connection_id = request.match_info["ref_id"] try: @@ -428,6 +433,7 @@ async def connections_establish_inbound(request: web.BaseRequest): @docs(tags=["connection"], summary="Remove an existing connection record") +@match_info_schema(ConnIdMatchInfoSchema()) async def connections_remove(request: web.BaseRequest): """ Request handler for removing a connection record. @@ -436,7 +442,7 @@ async def connections_remove(request: web.BaseRequest): request: aiohttp request object """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] try: connection = await ConnectionRecord.retrieve_by_id(context, connection_id) except StorageNotFoundError: @@ -491,19 +497,22 @@ async def register(app: web.Application): app.add_routes( [ - web.get("/connections", connections_list), - web.get("/connections/{id}", connections_retrieve), + web.get("/connections", connections_list, allow_head=False), + web.get("/connections/{conn_id}", connections_retrieve, allow_head=False), web.post("/connections/create-static", connections_create_static), web.post("/connections/create-invitation", connections_create_invitation), web.post("/connections/receive-invitation", connections_receive_invitation), web.post( - "/connections/{id}/accept-invitation", connections_accept_invitation + "/connections/{conn_id}/accept-invitation", + connections_accept_invitation, + ), + web.post( + "/connections/{conn_id}/accept-request", connections_accept_request ), - web.post("/connections/{id}/accept-request", connections_accept_request), web.post( - "/connections/{id}/establish-inbound/{ref_id}", + "/connections/{conn_id}/establish-inbound/{ref_id}", connections_establish_inbound, ), - web.post("/connections/{id}/remove", connections_remove), + web.post("/connections/{conn_id}/remove", connections_remove), ] ) diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index ab2c6a6905..354d2eaaff 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -210,7 +210,7 @@ async def test_receive_invitation_no_auto_accept(self): ) invitee_record = await self.manager.receive_invitation( - connect_invite, accept=ConnectionRecord.ACCEPT_MANUAL + connect_invite, auto_accept=False ) assert invitee_record.state == ConnectionRecord.STATE_INVITATION diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py index 266fcdf6d9..a55fd33db5 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py @@ -80,7 +80,7 @@ async def test_connections_retrieve(self): mock_req.app = { "request_context": context, } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_conn_rec = async_mock.MagicMock() mock_conn_rec.serialize = async_mock.MagicMock(return_value={"hello": "world"}) @@ -102,7 +102,7 @@ async def test_connections_retrieve_not_found(self): mock_req.app = { "request_context": context, } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} with async_mock.patch.object( test_module.ConnectionRecord, "retrieve_by_id", async_mock.CoroutineMock() @@ -120,10 +120,10 @@ async def test_connections_create_invitation(self): "request_context": context, } mock_req.query = { - "accept": "auto", + "auto_accept": "true", "alias": "alias", - "public": 1, - "multi_use": 1, + "public": "true", + "multi_use": "true", } with async_mock.patch.object( @@ -162,10 +162,10 @@ async def test_connections_create_invitation_public_forbidden(self): "request_context": context, } mock_req.query = { - "accept": "auto", + "auto_accept": "true", "alias": "alias", - "public": 1, - "multi_use": 1, + "public": "true", + "multi_use": "true", } with self.assertRaises(test_module.web.HTTPForbidden): @@ -179,7 +179,7 @@ async def test_connections_receive_invitation(self): } mock_req.json = async_mock.CoroutineMock() mock_req.query = { - "accept": "auto", + "auto_accept": "true", "alias": "alias", } @@ -215,7 +215,7 @@ async def test_connections_accept_invitation(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_req.query = { "my_label": "label", "my_endpoint": "http://endpoint.ca", @@ -245,7 +245,7 @@ async def test_connections_accept_invitation_not_found(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} with async_mock.patch.object( test_module.ConnectionRecord, "retrieve_by_id", async_mock.CoroutineMock() @@ -262,7 +262,7 @@ async def test_connections_accept_request(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_req.query = { "my_endpoint": "http://endpoint.ca", } @@ -290,7 +290,7 @@ async def test_connections_accept_request_not_found(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} with async_mock.patch.object( test_module.ConnectionRecord, "retrieve_by_id", async_mock.CoroutineMock() @@ -307,7 +307,7 @@ async def test_connections_establish_inbound(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy", "ref_id": "ref"} + mock_req.match_info = {"conn_id": "dummy", "ref_id": "ref"} mock_req.query = { "my_endpoint": "http://endpoint.ca", } @@ -335,7 +335,7 @@ async def test_connections_establish_inbound_not_found(self): "request_context": context, "outbound_message_router": async_mock.CoroutineMock(), } - mock_req.match_info = {"id": "dummy", "ref_id": "ref"} + mock_req.match_info = {"conn_id": "dummy", "ref_id": "ref"} with async_mock.patch.object( test_module.ConnectionRecord, "retrieve_by_id", async_mock.CoroutineMock() @@ -351,7 +351,7 @@ async def test_connections_remove(self): mock_req.app = { "request_context": context, } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_conn_rec = async_mock.MagicMock() mock_conn_rec.delete_record = async_mock.CoroutineMock() @@ -372,7 +372,7 @@ async def test_connections_remove_not_found(self): mock_req.app = { "request_context": context, } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_conn_rec = async_mock.MagicMock() @@ -403,10 +403,10 @@ async def test_connections_create_static(self): } ) mock_req.query = { - "accept": "auto", + "auto_accept": "true", "alias": "alias", } - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_conn_rec = async_mock.MagicMock() mock_conn_rec.serialize = async_mock.MagicMock() diff --git a/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py b/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py index ee826c76be..62dffef3ff 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py +++ b/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py @@ -2,8 +2,7 @@ from typing import Mapping, Sequence -from marshmallow import fields, Schema -from marshmallow.validate import OneOf +from marshmallow import fields, Schema, validate from aries_cloudagent.messaging.agent_message import AgentMessage, AgentMessageSchema @@ -41,7 +40,7 @@ class ProtocolDescriptorSchema(Schema): fields.Str( description="Role: requester or responder", example="requester", - validate=OneOf(["requester", "responder"]), + validate=validate.OneOf(["requester", "responder"]), ), required=False, allow_none=True, diff --git a/aries_cloudagent/protocols/discovery/v1_0/routes.py b/aries_cloudagent/protocols/discovery/v1_0/routes.py index ca3ca85048..35465afc65 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/routes.py +++ b/aries_cloudagent/protocols/discovery/v1_0/routes.py @@ -1,7 +1,7 @@ """Feature discovery admin routes.""" from aiohttp import web -from aiohttp_apispec import docs, response_schema +from aiohttp_apispec import docs, querystring_schema, response_schema from marshmallow import fields, Schema @@ -18,18 +18,16 @@ class QueryResultSchema(Schema): ) +class QueryFeaturesQueryStringSchema(Schema): + """Query string parameters for feature query.""" + + query = fields.Str(description="Query", required=False, example="did:sov:*") + + @docs( - tags=["server"], - summary="Query supported features", - parameters=[ - { - "name": "query", - "in": "query", - "schema": {"type": "string"}, - "required": False, - } - ], + tags=["server"], summary="Query supported features", ) +@querystring_schema(QueryFeaturesQueryStringSchema()) @response_schema(QueryResultSchema(), 200) async def query_features(request: web.BaseRequest): """ @@ -52,4 +50,4 @@ async def query_features(request: web.BaseRequest): async def register(app: web.Application): """Register routes.""" - app.add_routes([web.get("/features", query_features)]) + app.add_routes([web.get("/features", query_features, allow_head=False)]) diff --git a/aries_cloudagent/protocols/introduction/v0_1/routes.py b/aries_cloudagent/protocols/introduction/v0_1/routes.py index f527db26a3..aa0325905c 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/routes.py +++ b/aries_cloudagent/protocols/introduction/v0_1/routes.py @@ -3,31 +3,43 @@ import logging from aiohttp import web -from aiohttp_apispec import docs +from aiohttp_apispec import docs, match_info_schema, querystring_schema + +from marshmallow import fields, Schema + +from ....messaging.valid import UUIDFour from .base_service import BaseIntroductionService LOGGER = logging.getLogger(__name__) +class IntroStartQueryStringSchema(Schema): + """Query string parameters for request to start introduction.""" + + target_connection_id = fields.Str( + description="Target connection identifier", + required=True, + example=UUIDFour.EXAMPLE, + ) + message = fields.Str( + description="Message", required=False, example="Allow me to introduce ..." + ) + + +class ConnIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + @docs( - tags=["introduction"], - summary="Start an introduction between two connections", - parameters=[ - { - "name": "target_connection_id", - "in": "query", - "schema": {"type": "string"}, - "required": True, - }, - { - "name": "message", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - ], + tags=["introduction"], summary="Start an introduction between two connections", ) +@match_info_schema(ConnIdMatchInfoSchema()) +@querystring_schema(IntroStartQueryStringSchema()) async def introduction_start(request: web.BaseRequest): """ Request handler for starting an introduction. @@ -39,7 +51,7 @@ async def introduction_start(request: web.BaseRequest): LOGGER.info("Introduction requested") context = request.app["request_context"] outbound_handler = request.app["outbound_message_router"] - init_connection_id = request.match_info["id"] + init_connection_id = request.match_info["conn_id"] target_connection_id = request.query.get("target_connection_id") message = request.query.get("message") @@ -59,5 +71,5 @@ async def register(app: web.Application): """Register routes.""" app.add_routes( - [web.post("/connections/{id}/start-introduction", introduction_start)] + [web.post("/connections/{conn_id}/start-introduction", introduction_start)] ) diff --git a/aries_cloudagent/protocols/introduction/v0_1/tests/test_routes.py b/aries_cloudagent/protocols/introduction/v0_1/tests/test_routes.py index 640ca760c1..c60234abd1 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/tests/test_routes.py +++ b/aries_cloudagent/protocols/introduction/v0_1/tests/test_routes.py @@ -28,7 +28,7 @@ async def test_introduction_start_no_service(self): "alias": "alias", } ) - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_req.query = { "target_connection_id": "dummy", "message": "Hello", @@ -54,7 +54,7 @@ async def test_introduction_start(self): "alias": "alias", } ) - mock_req.match_info = {"id": "dummy"} + mock_req.match_info = {"conn_id": "dummy"} mock_req.query = { "target_connection_id": "dummy", "message": "Hello", @@ -74,7 +74,7 @@ async def test_introduction_start(self): await test_module.introduction_start(mock_req) mock_ctx_inject.return_value.start_introduction.assert_called_once_with( - mock_req.match_info["id"], + mock_req.match_info["conn_id"], mock_req.query["target_connection_id"], mock_req.query["message"], mock_req.app["outbound_message_router"], diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/models/credential_exchange.py b/aries_cloudagent/protocols/issue_credential/v1_0/models/credential_exchange.py index 3d73417e76..9c7082e0e6 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/models/credential_exchange.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/models/credential_exchange.py @@ -2,8 +2,7 @@ from typing import Any -from marshmallow import fields -from marshmallow.validate import OneOf +from marshmallow import fields, validate from .....config.injection_context import InjectionContext from .....messaging.models.base_record import BaseExchangeRecord, BaseExchangeSchema @@ -174,13 +173,13 @@ class Meta: required=False, description="Issue-credential exchange initiator: self or external", example=V10CredentialExchange.INITIATOR_SELF, - validate=OneOf(["self", "external"]), + validate=validate.OneOf(["self", "external"]), ) role = fields.Str( required=False, description="Issue-credential exchange role: holder or issuer", example=V10CredentialExchange.ROLE_ISSUER, - validate=OneOf(["holder", "issuer"]), + validate=validate.OneOf(["holder", "issuer"]), ) state = fields.Str( required=False, diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/routes.py b/aries_cloudagent/protocols/issue_credential/v1_0/routes.py index 0571bed8ea..0095e03f4a 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/routes.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/routes.py @@ -3,7 +3,13 @@ import json from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) from json.decoder import JSONDecodeError from marshmallow import fields, Schema @@ -17,7 +23,9 @@ INDY_REV_REG_ID, INDY_SCHEMA_ID, INDY_VERSION, + NATURAL_NUM, UUIDFour, + UUID4, ) from ....storage.error import StorageNotFoundError @@ -164,7 +172,42 @@ class V10PublishRevocationsResultSchema(Schema): ) +class RevokeQueryStringSchema(Schema): + """Parameters and validators for revocation request.""" + + rev_reg_id = fields.Str( + description="Revocation registry identifier", required=True, **INDY_REV_REG_ID, + ) + cred_rev_id = fields.Int( + description="Credential revocation identifier", required=True, **NATURAL_NUM, + ) + publish = fields.Boolean( + description=( + "(True) publish revocation to ledger immediately, or " + "(False) mark it pending (default value)" + ), + required=False, + ) + + +class CredIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking credential id.""" + + credential_id = fields.Str( + description="Credential identifier", required=True, example=UUIDFour.EXAMPLE + ) + + +class CredExIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking credential exchange id.""" + + cred_ex_id = fields.Str( + description="Credential exchange identifier", required=True, **UUID4 + ) + + @docs(tags=["issue-credential"], summary="Get attribute MIME types from wallet") +@match_info_schema(CredIdMatchInfoSchema()) @response_schema(V10AttributeMimeTypesResultSchema(), 200) async def attribute_mime_types_get(request: web.BaseRequest): """ @@ -210,6 +253,7 @@ async def credential_exchange_list(request: web.BaseRequest): @docs(tags=["issue-credential"], summary="Fetch a single credential exchange record") +@match_info_schema(CredExIdMatchInfoSchema()) @response_schema(V10CredentialExchangeSchema(), 200) async def credential_exchange_retrieve(request: web.BaseRequest): """ @@ -485,6 +529,7 @@ async def credential_exchange_send_free_offer(request: web.BaseRequest): tags=["issue-credential"], summary="Send holder a credential offer in reference to a proposal with preview", ) +@match_info_schema(CredExIdMatchInfoSchema()) @response_schema(V10CredentialExchangeSchema(), 200) async def credential_exchange_send_bound_offer(request: web.BaseRequest): """ @@ -544,6 +589,7 @@ async def credential_exchange_send_bound_offer(request: web.BaseRequest): @docs(tags=["issue-credential"], summary="Send issuer a credential request") +@match_info_schema(CredExIdMatchInfoSchema()) @response_schema(V10CredentialExchangeSchema(), 200) async def credential_exchange_send_request(request: web.BaseRequest): """ @@ -603,6 +649,7 @@ async def credential_exchange_send_request(request: web.BaseRequest): @docs(tags=["issue-credential"], summary="Send holder a credential") +@match_info_schema(CredExIdMatchInfoSchema()) @request_schema(V10CredentialIssueRequestSchema()) @response_schema(V10CredentialExchangeSchema(), 200) async def credential_exchange_issue(request: web.BaseRequest): @@ -675,6 +722,7 @@ async def credential_exchange_issue(request: web.BaseRequest): @docs(tags=["issue-credential"], summary="Store a received credential") +@match_info_schema(CredExIdMatchInfoSchema()) @request_schema(V10CredentialStoreRequestSchema()) @response_schema(V10CredentialExchangeSchema(), 200) async def credential_exchange_store(request: web.BaseRequest): @@ -741,33 +789,9 @@ async def credential_exchange_store(request: web.BaseRequest): @docs( - tags=["issue-credential"], - parameters=[ - { - "name": "rev_reg_id", - "in": "query", - "description": "revocation registry id", - "required": True, - }, - { - "name": "cred_rev_id", - "in": "query", - "description": "credential revocation id", - "required": True, - }, - { - "name": "publish", - "in": "query", - "description": ( - "(true) publish revocation to ledger immediately, or " - "(false) mark it pending" - ), - "schema": {"type": "boolean"}, - "required": False, - }, - ], - summary="Revoke an issued credential", + tags=["issue-credential"], summary="Revoke an issued credential", ) +@querystring_schema(RevokeQueryStringSchema()) async def credential_exchange_revoke(request: web.BaseRequest): """ Request handler for storing a credential request. @@ -782,7 +806,7 @@ async def credential_exchange_revoke(request: web.BaseRequest): context = request.app["request_context"] rev_reg_id = request.query.get("rev_reg_id") - cred_rev_id = request.query.get("cred_rev_id") + cred_rev_id = request.query.get("cred_rev_id") # numeric str here, which indy wants publish = bool(json.loads(request.query.get("publish", json.dumps(False)))) credential_manager = CredentialManager(context) @@ -819,6 +843,7 @@ async def credential_exchange_publish_revocations(request: web.BaseRequest): @docs( tags=["issue-credential"], summary="Remove an existing credential exchange record" ) +@match_info_schema(CredExIdMatchInfoSchema()) async def credential_exchange_remove(request: web.BaseRequest): """ Request handler for removing a credential exchange record. @@ -842,6 +867,7 @@ async def credential_exchange_remove(request: web.BaseRequest): @docs( tags=["issue-credential"], summary="Send a problem report for credential exchange" ) +@match_info_schema(CredExIdMatchInfoSchema()) @request_schema(V10CredentialProblemReportRequestSchema()) async def credential_exchange_problem_report(request: web.BaseRequest): """ @@ -889,11 +915,17 @@ async def register(app: web.Application): app.add_routes( [ web.get( - "/issue-credential/mime-types/{credential_id}", attribute_mime_types_get + "/issue-credential/mime-types/{credential_id}", + attribute_mime_types_get, + allow_head=False, + ), + web.get( + "/issue-credential/records", credential_exchange_list, allow_head=False ), - web.get("/issue-credential/records", credential_exchange_list), web.get( - "/issue-credential/records/{cred_ex_id}", credential_exchange_retrieve + "/issue-credential/records/{cred_ex_id}", + credential_exchange_retrieve, + allow_head=False, ), web.post("/issue-credential/send", credential_exchange_send), web.post( @@ -918,7 +950,7 @@ async def register(app: web.Application): "/issue-credential/records/{cred_ex_id}/store", credential_exchange_store, ), - web.post("/issue-credential/revoke", credential_exchange_revoke,), + web.post("/issue-credential/revoke", credential_exchange_revoke), web.post( "/issue-credential/publish-revocations", credential_exchange_publish_revocations, diff --git a/aries_cloudagent/protocols/present_proof/v1_0/models/presentation_exchange.py b/aries_cloudagent/protocols/present_proof/v1_0/models/presentation_exchange.py index 69632d3d96..c1fba4d436 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/models/presentation_exchange.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/models/presentation_exchange.py @@ -2,8 +2,7 @@ from typing import Any -from marshmallow import fields -from marshmallow.validate import OneOf +from marshmallow import fields, validate from .....messaging.models.base_record import BaseExchangeRecord, BaseExchangeSchema from .....messaging.valid import UUIDFour @@ -127,13 +126,13 @@ class Meta: required=False, description="Present-proof exchange initiator: self or external", example=V10PresentationExchange.INITIATOR_SELF, - validate=OneOf(["self", "external"]), + validate=validate.OneOf(["self", "external"]), ) role = fields.Str( required=False, description="Present-proof exchange role: prover or verifier", example=V10PresentationExchange.ROLE_PROVER, - validate=OneOf(["prover", "verifier"]), + validate=validate.OneOf(["prover", "verifier"]), ) state = fields.Str( required=False, @@ -154,7 +153,7 @@ class Meta: required=False, description="Whether presentation is verified: true or false", example="true", - validate=OneOf(["true", "false"]), + validate=validate.OneOf(["true", "false"]), ) auto_present = fields.Bool( required=False, diff --git a/aries_cloudagent/protocols/present_proof/v1_0/routes.py b/aries_cloudagent/protocols/present_proof/v1_0/routes.py index ab442d6c97..0bc1d78f63 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/routes.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/routes.py @@ -3,8 +3,15 @@ import json from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema -from marshmallow import Schema, fields +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) +from marshmallow import fields, Schema, validates_schema +from marshmallow.exceptions import ValidationError from ....connections.models.connection_record import ConnectionRecord from ....holder.base import BaseHolder @@ -15,8 +22,12 @@ INDY_PREDICATE, INDY_SCHEMA_ID, INDY_VERSION, + INDY_WQL, INT_EPOCH, + NATURAL_NUM, UUIDFour, + UUID4, + WHOLE_NUM, ) from ....storage.error import StorageNotFoundError from ....indy.util import generate_pr_nonce @@ -70,11 +81,6 @@ class V10PresentationProposalRequestSchema(AdminAPIMessageTracingSchema): class IndyProofReqSpecRestrictionsSchema(Schema): """Schema for restrictions in attr or pred specifier indy proof request.""" - credential_definition_id = fields.Str( - description="Credential definition identifier", - required=True, - **INDY_CRED_DEF_ID - ) schema_id = fields.String( description="Schema identifier", required=False, **INDY_SCHEMA_ID ) @@ -93,24 +99,42 @@ class IndyProofReqSpecRestrictionsSchema(Schema): cred_def_id = fields.String( description="Credential definition identifier", required=False, - **INDY_CRED_DEF_ID + **INDY_CRED_DEF_ID, ) -class IndyProofReqNonRevoked(Schema): +class IndyProofReqNonRevokedSchema(Schema): """Non-revocation times specification in indy proof request.""" - from_epoch = fields.Int( + fro = fields.Int( description="Earliest epoch of interest for non-revocation proof", - required=True, - **INT_EPOCH + required=False, + data_key="from", + **INT_EPOCH, ) - to_epoch = fields.Int( + to = fields.Int( description="Latest epoch of interest for non-revocation proof", - required=True, - **INT_EPOCH + required=False, + **INT_EPOCH, ) + @validates_schema + def validate_fields(self, data, **kwargs): + """ + Validate schema fields - must have from, to, or both. + + Args: + data: The data to validate + + Raises: + ValidationError: if data has neither from nor to + + """ + if not (data.get("from") or data.get("to")): + raise ValidationError( + "Non-revocation interval must have at least one end", ("fro", "to") + ) + class IndyProofReqAttrSpecSchema(Schema): """Schema for attribute specification in indy proof request.""" @@ -123,7 +147,7 @@ class IndyProofReqAttrSpecSchema(Schema): description="If present, credential must satisfy one of given restrictions", required=False, ) - non_revoked = fields.Nested(IndyProofReqNonRevoked(), required=False) + non_revoked = fields.Nested(IndyProofReqNonRevokedSchema(), required=False) class IndyProofReqPredSpecSchema(Schema): @@ -133,7 +157,7 @@ class IndyProofReqPredSpecSchema(Schema): p_type = fields.String( description="Predicate type ('<', '<=', '>=', or '>')", required=True, - **INDY_PREDICATE + **INDY_PREDICATE, ) p_value = fields.Integer(description="Threshold value", required=True) restrictions = fields.List( @@ -141,7 +165,7 @@ class IndyProofReqPredSpecSchema(Schema): description="If present, credential must satisfy one of given restrictions", required=False, ) - non_revoked = fields.Nested(IndyProofReqNonRevoked(), required=False) + non_revoked = fields.Nested(IndyProofReqNonRevokedSchema(), required=False) class IndyProofRequestSchema(Schema): @@ -158,7 +182,7 @@ class IndyProofRequestSchema(Schema): description="Proof request version", required=False, default="1.0", - **INDY_VERSION + **INDY_VERSION, ) requested_attributes = fields.Dict( description=("Requested attribute specifications of proof request"), @@ -172,6 +196,7 @@ class IndyProofRequestSchema(Schema): keys=fields.Str(example="0_age_GE_uuid"), # marshmallow/apispec v3.0 ignores values=fields.Nested(IndyProofReqPredSpecSchema()), ) + non_revoked = fields.Nested(IndyProofReqNonRevokedSchema(), required=False) class V10PresentationRequestRequestSchema(AdminAPIMessageTracingSchema): @@ -192,9 +217,15 @@ class IndyRequestedCredsRequestedAttrSchema(Schema): description=( "Wallet credential identifier (typically but not necessarily a UUID)" ), + required=True, ) revealed = fields.Bool( - description="Whether to reveal attribute in proof", default=True + description="Whether to reveal attribute in proof", required=True + ) + timestamp = fields.Int( + description="Epoch timestamp of interest for non-revocation proof", + required=False, + **INT_EPOCH, ) @@ -202,10 +233,16 @@ class IndyRequestedCredsRequestedPredSchema(Schema): """Schema for requested predicates within indy requested credentials structure.""" cred_id = fields.Str( - example="3fa85f64-5717-4562-b3fc-2c963f66afa6", description=( "Wallet credential identifier (typically but not necessarily a UUID)" ), + example="3fa85f64-5717-4562-b3fc-2c963f66afa6", + required=True, + ) + timestamp = fields.Int( + description="Epoch timestamp of interest for non-revocation proof", + required=False, + **INT_EPOCH, ) @@ -244,6 +281,31 @@ class V10PresentationRequestSchema(AdminAPIMessageTracingSchema): ) +class CredentialsFetchQueryStringSchema(Schema): + """Parameters and validators for credentials fetch request query string.""" + + referent = fields.Str( + description="Proof request referents of interest, comma-separated", + required=False, + example="1_name_uuid,2_score_uuid", + ) + start = fields.Int(description="Start index", required=False, **WHOLE_NUM) + count = fields.Int( + description="Maximum number to retrieve", required=False, **NATURAL_NUM + ) + extra_query = fields.Str( + description="(JSON) WQL extra query", required=False, **INDY_WQL, + ) + + +class PresExIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking presentation exchange id.""" + + pres_ex_id = fields.Str( + description="Presentation exchange identifier", required=True, **UUID4 + ) + + @docs(tags=["present-proof"], summary="Fetch all present-proof exchange records") @response_schema(V10PresentationExchangeListSchema(), 200) async def presentation_exchange_list(request: web.BaseRequest): @@ -270,6 +332,7 @@ async def presentation_exchange_list(request: web.BaseRequest): @docs(tags=["present-proof"], summary="Fetch a single presentation exchange record") +@match_info_schema(PresExIdMatchInfoSchema()) @response_schema(V10PresentationExchangeSchema(), 200) async def presentation_exchange_retrieve(request: web.BaseRequest): """ @@ -296,27 +359,9 @@ async def presentation_exchange_retrieve(request: web.BaseRequest): @docs( tags=["present-proof"], summary="Fetch credentials for a presentation request from wallet", - parameters=[ - { - "name": "start", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "count", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "extra_query", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - ], ) +@match_info_schema(PresExIdMatchInfoSchema()) +@querystring_schema(CredentialsFetchQueryStringSchema()) async def presentation_exchange_credentials_list(request: web.BaseRequest): """ Request handler for searching applicable credential records. @@ -331,8 +376,10 @@ async def presentation_exchange_credentials_list(request: web.BaseRequest): context = request.app["request_context"] presentation_exchange_id = request.match_info["pres_ex_id"] - referents = request.match_info.get("referent") - presentation_referents = referents.split(",") if referents else () + referents = request.query.get("referent") + presentation_referents = ( + (r.strip() for r in referents.split(",")) if referents else () + ) try: presentation_exchange_record = await V10PresentationExchange.retrieve_by_id( @@ -590,6 +637,7 @@ async def presentation_exchange_send_free_request(request: web.BaseRequest): tags=["present-proof"], summary="Sends a presentation request in reference to a proposal", ) +@match_info_schema(PresExIdMatchInfoSchema()) @request_schema(V10PresentationRequestRequestSchema()) @response_schema(V10PresentationExchangeSchema(), 200) async def presentation_exchange_send_bound_request(request: web.BaseRequest): @@ -652,6 +700,7 @@ async def presentation_exchange_send_bound_request(request: web.BaseRequest): @docs(tags=["present-proof"], summary="Sends a proof presentation") +@match_info_schema(PresExIdMatchInfoSchema()) @request_schema(V10PresentationRequestSchema()) @response_schema(V10PresentationExchangeSchema()) async def presentation_exchange_send_presentation(request: web.BaseRequest): @@ -723,6 +772,7 @@ async def presentation_exchange_send_presentation(request: web.BaseRequest): @docs(tags=["present-proof"], summary="Verify a received presentation") +@match_info_schema(PresExIdMatchInfoSchema()) @response_schema(V10PresentationExchangeSchema()) async def presentation_exchange_verify_presentation(request: web.BaseRequest): """ @@ -776,6 +826,7 @@ async def presentation_exchange_verify_presentation(request: web.BaseRequest): @docs(tags=["present-proof"], summary="Remove an existing presentation exchange record") +@match_info_schema(PresExIdMatchInfoSchema()) async def presentation_exchange_remove(request: web.BaseRequest): """ Request handler for removing a presentation exchange record. @@ -802,26 +853,32 @@ async def register(app: web.Application): app.add_routes( [ - web.get("/present-proof/records", presentation_exchange_list), web.get( - "/present-proof/records/{pres_ex_id}", presentation_exchange_retrieve + "/present-proof/records", presentation_exchange_list, allow_head=False ), web.get( - "/present-proof/records/{pres_ex_id}/credentials", - presentation_exchange_credentials_list, + "/present-proof/records/{pres_ex_id}", + presentation_exchange_retrieve, + allow_head=False, ), web.get( - "/present-proof/records/{pres_ex_id}/credentials/{referent}", + "/present-proof/records/{pres_ex_id}/credentials", presentation_exchange_credentials_list, + allow_head=False, ), + # web.get( + # "/present-proof/records/{pres_ex_id}/credentials/{referent}", + # presentation_exchange_credentials_list, + # allow_head=False + # ), web.post( - "/present-proof/send-proposal", presentation_exchange_send_proposal + "/present-proof/send-proposal", presentation_exchange_send_proposal, ), web.post( - "/present-proof/create-request", presentation_exchange_create_request + "/present-proof/create-request", presentation_exchange_create_request, ), web.post( - "/present-proof/send-request", presentation_exchange_send_free_request + "/present-proof/send-request", presentation_exchange_send_free_request, ), web.post( "/present-proof/records/{pres_ex_id}/send-request", diff --git a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py index 7ce8ab6639..366e4e190a 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py @@ -11,6 +11,14 @@ def setUp(self): self.mock_context = async_mock.MagicMock() self.test_instance = test_module.PresentationManager(self.mock_context) + async def test_validate_non_revoked(self): + non_revo = test_module.IndyProofReqNonRevokedSchema() + non_revo.validate({"from": 1234567890}) + non_revo.validate({"to": 1234567890}) + non_revo.validate({"from": 1234567890, "to": 1234567890}) + with self.assertRaises(test_module.ValidationError): + non_revo.validate_fields({}) + async def test_presentation_exchange_list(self): mock = async_mock.MagicMock() mock.query = { diff --git a/aries_cloudagent/protocols/trustping/v1_0/routes.py b/aries_cloudagent/protocols/trustping/v1_0/routes.py index 9608596229..3db197ffd9 100644 --- a/aries_cloudagent/protocols/trustping/v1_0/routes.py +++ b/aries_cloudagent/protocols/trustping/v1_0/routes.py @@ -1,13 +1,15 @@ """Trust ping admin routes.""" from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import docs, match_info_schema, request_schema, response_schema from marshmallow import fields, Schema from aries_cloudagent.connections.models.connection_record import ConnectionRecord +from aries_cloudagent.messaging.valid import UUIDFour from aries_cloudagent.storage.error import StorageNotFoundError + from .messages.ping import Ping @@ -23,7 +25,16 @@ class PingRequestResponseSchema(Schema): thread_id = fields.Str(required=False, description="Thread ID of the ping message") +class ConnIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking connection id.""" + + conn_id = fields.Str( + description="Connection identifier", required=True, example=UUIDFour.EXAMPLE + ) + + @docs(tags=["trustping"], summary="Send a trust ping to a connection") +@match_info_schema(ConnIdMatchInfoSchema()) @request_schema(PingRequestSchema()) @response_schema(PingRequestResponseSchema(), 200) async def connections_send_ping(request: web.BaseRequest): @@ -35,7 +46,7 @@ async def connections_send_ping(request: web.BaseRequest): """ context = request.app["request_context"] - connection_id = request.match_info["id"] + connection_id = request.match_info["conn_id"] outbound_handler = request.app["outbound_message_router"] body = await request.json() comment = body.get("comment") @@ -57,4 +68,6 @@ async def connections_send_ping(request: web.BaseRequest): async def register(app: web.Application): """Register routes.""" - app.add_routes([web.post("/connections/{id}/send-ping", connections_send_ping)]) + app.add_routes( + [web.post("/connections/{conn_id}/send-ping", connections_send_ping)] + ) diff --git a/aries_cloudagent/revocation/indy.py b/aries_cloudagent/revocation/indy.py index 147496dc43..6e900d4208 100644 --- a/aries_cloudagent/revocation/indy.py +++ b/aries_cloudagent/revocation/indy.py @@ -4,6 +4,7 @@ from ..config.injection_context import InjectionContext from ..ledger.base import BaseLedger +from ..storage.base import StorageNotFoundError from .error import RevocationNotSupportedError from .models.issuer_rev_reg_record import IssuerRevRegRecord @@ -64,7 +65,11 @@ async def get_active_issuer_rev_reg_record( current = await IssuerRevRegRecord.query_by_cred_def_id( self._context, cred_def_id, IssuerRevRegRecord.STATE_ACTIVE ) - return current[0] if current else None + if current: + return current[0] + raise StorageNotFoundError( + f"No active issuer revocation record found for cred def id {cred_def_id}" + ) async def get_issuer_rev_reg_record( self, revoc_reg_id: str diff --git a/aries_cloudagent/revocation/routes.py b/aries_cloudagent/revocation/routes.py index 19c4ec35ca..79ac7a9c41 100644 --- a/aries_cloudagent/revocation/routes.py +++ b/aries_cloudagent/revocation/routes.py @@ -1,21 +1,22 @@ """Revocation registry admin routes.""" +import logging + from asyncio import shield from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema - -import logging +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) -from marshmallow import fields, Schema +from marshmallow import fields, Schema, validate from ..messaging.credential_definitions.util import CRED_DEF_SENT_RECORD_TYPE -from ..messaging.valid import ( - INDY_CRED_DEF_ID, - IndyCredDefId, - INDY_REV_REG_ID, - IndyRevRegId, -) +from ..messaging.valid import INDY_CRED_DEF_ID, INDY_REV_REG_ID from ..storage.base import BaseStorage, StorageNotFoundError from .error import RevocationNotSupportedError @@ -64,12 +65,51 @@ class RevRegUpdateTailsFileUriSchema(Schema): description="Public URI to the tails file", example=( "http://192.168.56.133:5000/revocation/registry/" - f"{IndyRevRegId.EXAMPLE}/tails-file" + f"{INDY_REV_REG_ID['example']}/tails-file" ), required=True, ) +class RevRegsCreatedQueryStringSchema(Schema): + """Query string parameters and validators for rev regs created request.""" + + cred_def_id = fields.Str( + description="Credential definition identifier", + required=False, + **INDY_CRED_DEF_ID, + ) + state = fields.Str( + description="Revocation registry state", + required=False, + validate=validate.OneOf( + [ + getattr(IssuerRevRegRecord, m) + for m in vars(IssuerRevRegRecord) + if m.startswith("STATE_") + ] + ), + ) + + +class RevRegIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking rev reg id.""" + + rev_reg_id = fields.Str( + description="Revocation Registry identifier", required=True, **INDY_REV_REG_ID, + ) + + +class CredDefIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking cred def id.""" + + cred_def_id = fields.Str( + description="Credential definition identifier", + required=True, + **INDY_CRED_DEF_ID, + ) + + @docs(tags=["revocation"], summary="Creates a new revocation registry") @request_schema(RevRegCreateRequestSchema()) @response_schema(RevRegCreateResultSchema(), 200) @@ -122,31 +162,9 @@ async def revocation_create_registry(request: web.BaseRequest): @docs( tags=["revocation"], - parameters=[ - { - "name": "cred_def_id", - "in": "query", - "schema": {"type": "string", "pattern": IndyCredDefId.PATTERN}, - "required": False, - }, - { - "name": "state", - "in": "query", - "schema": { - "type": "string", - "pattern": ( - rf"^(?:{IssuerRevRegRecord.STATE_INIT}|" - rf"{IssuerRevRegRecord.STATE_GENERATED}|" - rf"{IssuerRevRegRecord.STATE_PUBLISHED}|" - rf"{IssuerRevRegRecord.STATE_ACTIVE}|" - rf"{IssuerRevRegRecord.STATE_FULL})$" - ), - }, - "required": False, - }, - ], summary="Search for matching revocation registries that current agent created", ) +@querystring_schema(RevRegsCreatedQueryStringSchema()) @response_schema(RevRegsCreatedSchema(), 200) async def revocation_registries_created(request: web.BaseRequest): """ @@ -161,10 +179,11 @@ async def revocation_registries_created(request: web.BaseRequest): """ context = request.app["request_context"] + search_tags = [ + tag for tag in vars(RevRegsCreatedQueryStringSchema)["_declared_fields"] + ] tag_filter = { - tag: request.query[tag] - for tag in ("cred_def_id", "state") - if tag in request.query + tag: request.query[tag] for tag in search_tags if tag in request.query } found = await IssuerRevRegRecord.query(context, tag_filter) @@ -172,17 +191,9 @@ async def revocation_registries_created(request: web.BaseRequest): @docs( - tags=["revocation"], - summary="Get revocation registry by revocation registry id", - parameters=[ - { - "in": "path", - "name": "id", - "schema": {"type": "string", "pattern": IndyRevRegId.PATTERN}, - "description": "revocation registry id", - } - ], + tags=["revocation"], summary="Get revocation registry by revocation registry id", ) +@match_info_schema(RevRegIdMatchInfoSchema()) @response_schema(RevRegCreateResultSchema(), 200) async def get_registry(request: web.BaseRequest): """ @@ -197,7 +208,7 @@ async def get_registry(request: web.BaseRequest): """ context = request.app["request_context"] - registry_id = request.match_info["id"] + registry_id = request.match_info["rev_reg_id"] try: revoc = IndyRevocation(context) @@ -211,10 +222,8 @@ async def get_registry(request: web.BaseRequest): @docs( tags=["revocation"], summary="Get an active revocation registry by credential definition id", - parameters=[ - {"in": "path", "name": "cred_def_id", "description": "credential definition id"} - ], ) +@match_info_schema(CredDefIdMatchInfoSchema()) @response_schema(RevRegCreateResultSchema(), 200) async def get_active_registry(request: web.BaseRequest): """ @@ -244,9 +253,9 @@ async def get_active_registry(request: web.BaseRequest): tags=["revocation"], summary="Download the tails file of revocation registry", produces="application/octet-stream", - parameters=[{"in": "path", "name": "id", "description": "revocation registry id"}], - responses={200: {"description": "tails file", "schema": {"type": "file"}}}, + responses={200: {"description": "tails file"}}, ) +@match_info_schema(RevRegIdMatchInfoSchema()) async def get_tails_file(request: web.BaseRequest) -> web.FileResponse: """ Request handler to download the tails file of the revocation registry. @@ -260,7 +269,7 @@ async def get_tails_file(request: web.BaseRequest) -> web.FileResponse: """ context = request.app["request_context"] - registry_id = request.match_info["id"] + registry_id = request.match_info["rev_reg_id"] try: revoc = IndyRevocation(context) @@ -272,10 +281,9 @@ async def get_tails_file(request: web.BaseRequest) -> web.FileResponse: @docs( - tags=["revocation"], - summary="Publish a given revocation registry", - parameters=[{"in": "path", "name": "id", "description": "revocation registry id"}], + tags=["revocation"], summary="Publish a given revocation registry", ) +@match_info_schema(RevRegIdMatchInfoSchema()) @response_schema(RevRegCreateResultSchema(), 200) async def publish_registry(request: web.BaseRequest): """ @@ -289,7 +297,7 @@ async def publish_registry(request: web.BaseRequest): """ context = request.app["request_context"] - registry_id = request.match_info["id"] + registry_id = request.match_info["rev_reg_id"] try: revoc = IndyRevocation(context) @@ -308,10 +316,8 @@ async def publish_registry(request: web.BaseRequest): @docs( tags=["revocation"], summary="Update revocation registry with new public URI to the tails file.", - parameters=[ - {"in": "path", "name": "id", "description": "revocation registry identifier"} - ], ) +@match_info_schema(RevRegIdMatchInfoSchema()) @request_schema(RevRegUpdateTailsFileUriSchema()) @response_schema(RevRegCreateResultSchema(), 200) async def update_registry(request: web.BaseRequest): @@ -330,7 +336,7 @@ async def update_registry(request: web.BaseRequest): body = await request.json() tails_public_uri = body.get("tails_public_uri") - registry_id = request.match_info["id"] + registry_id = request.match_info["rev_reg_id"] try: revoc = IndyRevocation(context) @@ -348,11 +354,25 @@ async def register(app: web.Application): app.add_routes( [ web.post("/revocation/create-registry", revocation_create_registry), - web.get("/revocation/registries/created", revocation_registries_created), - web.get("/revocation/registry/{id}", get_registry), - web.get("/revocation/active-registry/{cred_def_id}", get_active_registry), - web.get("/revocation/registry/{id}/tails-file", get_tails_file), - web.patch("/revocation/registry/{id}", update_registry), - web.post("/revocation/registry/{id}/publish", publish_registry), + web.get( + "/revocation/registries/created", + revocation_registries_created, + allow_head=False, + ), + web.get( + "/revocation/registry/{rev_reg_id}", get_registry, allow_head=False + ), + web.get( + "/revocation/active-registry/{cred_def_id}", + get_active_registry, + allow_head=False, + ), + web.get( + "/revocation/registry/{rev_reg_id}/tails-file", + get_tails_file, + allow_head=False, + ), + web.patch("/revocation/registry/{rev_reg_id}", update_registry), + web.post("/revocation/registry/{rev_reg_id}/publish", publish_registry), ] ) diff --git a/aries_cloudagent/revocation/tests/test_indy.py b/aries_cloudagent/revocation/tests/test_indy.py index 75b876dc4e..e61aa54da4 100644 --- a/aries_cloudagent/revocation/tests/test_indy.py +++ b/aries_cloudagent/revocation/tests/test_indy.py @@ -11,6 +11,7 @@ from ...ledger.base import BaseLedger from ...storage.base import BaseStorage from ...storage.basic import BasicStorage +from ...storage.error import StorageNotFoundError from ...wallet.base import BaseWallet from ...wallet.indy import IndyWallet @@ -91,8 +92,36 @@ async def test_get_active_issuer_rev_reg_record(self): async def test_get_active_issuer_rev_reg_record_none(self): CRED_DEF_ID = f"{self.test_did}:3:CL:1234:default" + with self.assertRaises(StorageNotFoundError) as x_init: + await self.revoc.get_active_issuer_rev_reg_record(CRED_DEF_ID) + + async def test_init_issuer_registry_no_revocation(self): + CRED_DEF_ID = f"{self.test_did}:3:CL:1234:default" + + self.context.injector.clear_binding(BaseLedger) + self.ledger.get_credential_definition = async_mock.CoroutineMock( + return_value={"value": {}} + ) + self.context.injector.bind_instance(BaseLedger, self.ledger) + + with self.assertRaises(RevocationNotSupportedError) as x_revo: + await self.revoc.init_issuer_registry(CRED_DEF_ID, self.test_did) + assert x_revo.message == "Credential definition does not support revocation" + + async def test_get_active_issuer_rev_reg_record(self): + CRED_DEF_ID = f"{self.test_did}:3:CL:1234:default" + rec = await self.revoc.init_issuer_registry(CRED_DEF_ID, self.test_did) + rec.revoc_reg_id = "dummy" + rec.state = IssuerRevRegRecord.STATE_ACTIVE + await rec.save(self.context) + result = await self.revoc.get_active_issuer_rev_reg_record(CRED_DEF_ID) - assert result is None + assert rec == result + + async def test_get_active_issuer_rev_reg_record_none(self): + CRED_DEF_ID = f"{self.test_did}:3:CL:1234:default" + with self.assertRaises(StorageNotFoundError): + result = await self.revoc.get_active_issuer_rev_reg_record(CRED_DEF_ID) async def test_get_issuer_rev_reg_record(self): CRED_DEF_ID = f"{self.test_did}:3:CL:1234:default" diff --git a/aries_cloudagent/revocation/tests/test_routes.py b/aries_cloudagent/revocation/tests/test_routes.py index 0376c6a713..de5e6d8dfe 100644 --- a/aries_cloudagent/revocation/tests/test_routes.py +++ b/aries_cloudagent/revocation/tests/test_routes.py @@ -141,7 +141,7 @@ async def test_get_registry(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -166,7 +166,7 @@ async def test_get_registry_not_found(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -233,7 +233,7 @@ async def test_get_tails_file(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -256,7 +256,7 @@ async def test_get_tails_file_not_found(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -279,7 +279,7 @@ async def test_publish_registry(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -306,7 +306,7 @@ async def test_publish_registry_not_found(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} with async_mock.patch.object( test_module, "IndyRevocation", autospec=True @@ -329,7 +329,7 @@ async def test_update_registry(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} request.json = async_mock.CoroutineMock( return_value={ "tails_public_uri": f"http://sample.ca:8181/tails/{REV_REG_ID}" @@ -361,7 +361,7 @@ async def test_update_registry_not_found(self): ) request = async_mock.MagicMock() request.app = self.app - request.match_info = {"id": REV_REG_ID} + request.match_info = {"rev_reg_id": REV_REG_ID} request.json = async_mock.CoroutineMock( return_value={ "tails_public_uri": f"http://sample.ca:8181/tails/{REV_REG_ID}" diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py index cbc0fe9e42..7d4d638213 100644 --- a/aries_cloudagent/wallet/routes.py +++ b/aries_cloudagent/wallet/routes.py @@ -1,12 +1,20 @@ """Wallet admin routes.""" +import json + from aiohttp import web -from aiohttp_apispec import docs, request_schema, response_schema +from aiohttp_apispec import ( + docs, + match_info_schema, + querystring_schema, + request_schema, + response_schema, +) from marshmallow import fields, Schema from ..ledger.base import BaseLedger -from ..messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY +from ..messaging.valid import INDY_CRED_DEF_ID, INDY_DID, INDY_RAW_PUBLIC_KEY from .base import DIDInfo, BaseWallet from .error import WalletError @@ -17,7 +25,7 @@ class DIDSchema(Schema): did = fields.Str(description="DID of interest", **INDY_DID) verkey = fields.Str(description="Public verification key", **INDY_RAW_PUBLIC_KEY) - public = fields.Bool(description="Whether DID is public", example=False) + public = fields.Boolean(description="Whether DID is public", example=False) class DIDResultSchema(Schema): @@ -52,38 +60,46 @@ class SetTagPolicyRequestSchema(Schema): ) +class DIDListQueryStringSchema(Schema): + """Parameters and validators for DID list request query string.""" + + did = fields.Str(description="DID of interest", required=False, **INDY_DID,) + verkey = fields.Str( + description="Verification key of interest", + required=False, + **INDY_RAW_PUBLIC_KEY, + ) + public = fields.Boolean(description="Whether DID is on the ledger", required=False) + + +class SetPublicDIDQueryStringSchema(Schema): + """Parameters and validators for set public DID request query string.""" + + did = fields.Str(description="DID of interest", required=True, **INDY_DID,) + + +class CredDefIdMatchInfoSchema(Schema): + """Path parameters and validators for request taking credential definition id.""" + + cred_def_id = fields.Str( + description="Credential identifier", required=True, **INDY_CRED_DEF_ID + ) + + def format_did_info(info: DIDInfo): """Serialize a DIDInfo object.""" if info: return { "did": info.did, "verkey": info.verkey, - "public": info.metadata - and info.metadata.get("public") - and "true" - or "false", + "public": json.dumps(bool(info.metadata.get("public"))), } @docs( - tags=["wallet"], - summary="List wallet DIDs", - parameters=[ - {"name": "did", "in": "query", "schema": {"type": "string"}, "required": False}, - { - "name": "verkey", - "in": "query", - "schema": {"type": "string"}, - "required": False, - }, - { - "name": "public", - "in": "query", - "schema": {"type": "boolean"}, - "required": False, - }, - ], + tags=["wallet"], summary="List wallet DIDs", ) +@querystring_schema(DIDListQueryStringSchema()) @response_schema(DIDListSchema, 200) async def wallet_did_list(request: web.BaseRequest): """ @@ -102,37 +118,42 @@ async def wallet_did_list(request: web.BaseRequest): raise web.HTTPForbidden() filter_did = request.query.get("did") filter_verkey = request.query.get("verkey") - filter_public = request.query.get("public") + filter_public = json.loads(request.query.get("public", json.dumps(None))) results = [] + public_did_info = await wallet.get_public_did() - if filter_public == "true": - info = await wallet.get_public_did() + if filter_public: # True (contrast False or None) if ( - info - and (not filter_verkey or info.verkey == filter_verkey) - and (not filter_did or info.did == filter_did) + public_did_info + and (not filter_verkey or public_did_info.verkey == filter_verkey) + and (not filter_did or public_did_info.did == filter_did) ): - results.append(format_did_info(info)) + results.append(format_did_info(public_did_info)) elif filter_did: try: info = await wallet.get_local_did(filter_did) except WalletError: # badly formatted DID or record not found info = None - if info and (not filter_verkey or info.verkey == filter_verkey): + if ( + info + and (not filter_verkey or info.verkey == filter_verkey) + and (filter_public is None or info != public_did_info) + ): results.append(format_did_info(info)) elif filter_verkey: try: info = await wallet.get_local_did_for_verkey(filter_verkey) except WalletError: info = None - if info: + if info and (filter_public is None or info != public_did_info): results.append(format_did_info(info)) else: dids = await wallet.get_local_dids() results = [] for info in dids: - results.append(format_did_info(info)) + if filter_public is None or info != public_did_info: + results.append(format_did_info(info)) results.sort(key=lambda info: info["did"]) return web.json_response({"results": results}) @@ -181,12 +202,9 @@ async def wallet_get_public_did(request: web.BaseRequest): @docs( - tags=["wallet"], - summary="Assign the current public DID", - parameters=[ - {"name": "did", "in": "query", "schema": {"type": "string"}, "required": True} - ], + tags=["wallet"], summary="Assign the current public DID", ) +@querystring_schema(SetPublicDIDQueryStringSchema()) @response_schema(DIDResultSchema, 200) async def wallet_set_public_did(request: web.BaseRequest): """ @@ -224,6 +242,7 @@ async def wallet_set_public_did(request: web.BaseRequest): @docs(tags=["wallet"], summary="Get the tagging policy for a credential definition") +@match_info_schema(CredDefIdMatchInfoSchema()) @response_schema(GetTagPolicyResultSchema()) async def wallet_get_tagging_policy(request: web.BaseRequest): """ @@ -238,7 +257,7 @@ async def wallet_get_tagging_policy(request: web.BaseRequest): """ context = request.app["request_context"] - credential_definition_id = request.match_info["id"] + credential_definition_id = request.match_info["cred_def_id"] wallet: BaseWallet = await context.inject(BaseWallet, required=False) if not wallet or wallet.WALLET_TYPE != "indy": @@ -248,6 +267,7 @@ async def wallet_get_tagging_policy(request: web.BaseRequest): @docs(tags=["wallet"], summary="Set the tagging policy for a credential definition") +@match_info_schema(CredDefIdMatchInfoSchema()) @request_schema(SetTagPolicyRequestSchema()) async def wallet_set_tagging_policy(request: web.BaseRequest): """ @@ -262,7 +282,7 @@ async def wallet_set_tagging_policy(request: web.BaseRequest): """ context = request.app["request_context"] - credential_definition_id = request.match_info["id"] + credential_definition_id = request.match_info["cred_def_id"] body = await request.json() taggables = body.get("taggables") # None for all attrs, [] for no attrs @@ -281,11 +301,15 @@ async def register(app: web.Application): app.add_routes( [ - web.get("/wallet/did", wallet_did_list), + web.get("/wallet/did", wallet_did_list, allow_head=False), web.post("/wallet/did/create", wallet_create_did), - web.get("/wallet/did/public", wallet_get_public_did), + web.get("/wallet/did/public", wallet_get_public_did, allow_head=False), web.post("/wallet/did/public", wallet_set_public_did), - web.get("/wallet/tag-policy/{id}", wallet_get_tagging_policy), - web.post("/wallet/tag-policy/{id}", wallet_set_tagging_policy), + web.get( + "/wallet/tag-policy/{cred_def_id}", + wallet_get_tagging_policy, + allow_head=False, + ), + web.post("/wallet/tag-policy/{cred_def_id}", wallet_set_tagging_policy), ] ) diff --git a/demo/runners/alice.py b/demo/runners/alice.py index 29e9cb1c15..36d4413810 100644 --- a/demo/runners/alice.py +++ b/demo/runners/alice.py @@ -148,8 +148,7 @@ async def handle_present_proof(self, message): predicates[referent] = { "cred_id": credentials_by_reft[referent]["cred_info"][ "referent" - ], - "revealed": True, + ] } log_status("#25 Generate the proof") diff --git a/demo/runners/performance.py b/demo/runners/performance.py index 4477b53325..36ea721994 100644 --- a/demo/runners/performance.py +++ b/demo/runners/performance.py @@ -44,16 +44,19 @@ def connection_id(self, conn_id: str): self._connection_id = conn_id self._connection_ready = asyncio.Future() - async def get_invite(self, accept: str = "auto"): + async def get_invite(self, auto_accept: bool = True): result = await self.admin_POST( - "/connections/create-invitation", params={"accept": accept} + "/connections/create-invitation", + params={"auto_accept": json.dumps(auto_accept)}, ) self.connection_id = result["connection_id"] return result["invitation"] - async def receive_invite(self, invite, accept: str = "auto"): + async def receive_invite(self, invite, auto_accept: bool = True): result = await self.admin_POST( - "/connections/receive-invitation", invite, params={"accept": accept} + "/connections/receive-invitation", + invite, + params={"auto_accept": json.dumps(auto_accept)}, ) self.connection_id = result["connection_id"] return self.connection_id @@ -292,7 +295,9 @@ async def main( invite = await faber.get_invite() if routing: - conn_id = await alice.receive_invite(invite, accept="manual") + conn_id = await alice.receive_invite( + invite, auto_accept=json_dumps(False) + ) await alice.establish_inbound(conn_id, alice_router_conn_id) await alice.accept_invite(conn_id) await asyncio.wait_for(alice.detect_connection(), 30) diff --git a/requirements.txt b/requirements.txt index 4a2eda93bc..c9b094fc12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -aiohttp==3.5.4 -aiohttp-apispec==1.1.2 +aiohttp~=3.6.2 +aiohttp-apispec==2.2.1 aiohttp-cors~=0.7.0 -apispec==2.0.2 +apispec~=3.3.0 async-timeout~=3.0.1 -base58~=1.0.3 +base58~=2.0.0 Markdown~=3.1.1 -marshmallow==3.0.0 +marshmallow==3.5.1 msgpack~=0.6.1 prompt_toolkit~=2.0.9 pynacl~=1.3.0