Skip to content

Commit

Permalink
OOB v1.1 support w/ backward comp
Browse files Browse the repository at this point in the history
Signed-off-by: Shaanjot Gill <gill.shaanjots@gmail.com>
  • Loading branch information
shaangill025 committed Oct 1, 2022
1 parent e62d5ba commit 590f31b
Show file tree
Hide file tree
Showing 25 changed files with 295 additions and 115 deletions.
8 changes: 5 additions & 3 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from typing import List, Sequence, Tuple
from typing import Optional, List, Sequence, Tuple, Text

from pydid import (
BaseDIDDocument as ResolvedDocument,
Expand Down Expand Up @@ -227,7 +227,9 @@ async def remove_keys_for_did(self, did: str):
storage: BaseStorage = session.inject(BaseStorage)
await storage.delete_all_records(self.RECORD_TYPE_DID_KEY, {"did": did})

async def resolve_invitation(self, did: str):
async def resolve_invitation(
self, did: str, accept: Optional[Sequence[Text]] = None
):
"""
Resolve invitation with the DID Resolver.
Expand All @@ -241,7 +243,7 @@ async def resolve_invitation(self, did: str):

resolver = self._profile.inject(DIDResolver)
try:
doc_dict: dict = await resolver.resolve(self._profile, did)
doc_dict: dict = await resolver.resolve(self._profile, did, accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
Expand Down
15 changes: 12 additions & 3 deletions aries_cloudagent/core/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,19 @@ def test_get_version_from_message_type(self):
)

def test_get_version_from_message(self):
assert test_module.get_version_from_message(HandshakeReuse()) == "1.0"
assert test_module.get_version_from_message(HandshakeReuse()) == "1.1"

async def test_get_proto_default_version(self):
async def test_get_proto_default_version_from_msg_class(self):
profile = make_profile()
assert (
await test_module.get_proto_default_version(profile, HandshakeReuse)
await test_module.get_proto_default_version_from_msg_class(
profile, HandshakeReuse
)
) == "1.1"

def test_get_proto_default_version(self):
assert (
test_module.get_proto_default_version(
"aries_cloudagent.protocols.out_of_band.definition"
)
) == "1.1"
36 changes: 26 additions & 10 deletions aries_cloudagent/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,20 @@ def get_version_from_message(msg: AgentMessage) -> str:
return get_version_from_message_type(msg_type)


async def get_proto_default_version(
async def get_proto_default_version_from_msg_class(
profile: Profile, msg_class: type, major_version: int = 1
) -> str:
"""Return default protocol version from version_definition."""
version_definition = await get_version_def_from_msg_class(
profile, msg_class, major_version
)
default_major_version = version_definition["major_version"]
default_minor_version = version_definition["current_minor_version"]
return f"{default_major_version}.{default_minor_version}"
return _get_default_version_from_version_def(version_definition)


def get_proto_default_version(def_path: str, major_version: int = 1) -> str:
"""Return default protocol version from version_definition."""
version_definition = _get_version_def_from_path(def_path, major_version)
return _get_default_version_from_version_def(version_definition)


def _get_path_from_msg_class(msg_class: type) -> str:
Expand All @@ -116,10 +120,26 @@ def _get_path_from_msg_class(msg_class: type) -> str:
return (path.replace("/", ".")) + "definition"


def _get_version_def_from_path(definition_path: str, major_version: int = 1):
version_definition = None
definition = ClassLoader.load_module(definition_path)
for protocol_version in definition.versions:
if major_version == protocol_version["major_version"]:
version_definition = protocol_version
break
return version_definition


def _get_default_version_from_version_def(version_definition) -> str:
default_major_version = version_definition["major_version"]
default_minor_version = version_definition["current_minor_version"]
return f"{default_major_version}.{default_minor_version}"


async def get_version_def_from_msg_class(
profile: Profile, msg_class: type, major_version: int = 1
):
"""Return version_definition of a protocol."""
"""Return version_definition of a protocol from msg_class."""
cache = profile.inject_or(BaseCache)
version_definition = None
if cache:
Expand All @@ -129,11 +149,7 @@ async def get_version_def_from_msg_class(
if version_definition:
return version_definition
definition_path = _get_path_from_msg_class(msg_class)
definition = ClassLoader.load_module(definition_path)
for protocol_version in definition.versions:
if major_version == protocol_version["major_version"]:
version_definition = protocol_version
break
version_definition = _get_version_def_from_path(definition_path, major_version)
if not version_definition:
raise ProtocolDefinitionValidationError(
f"Unable to load protocol version_definition for {str(msg_class)}"
Expand Down
17 changes: 16 additions & 1 deletion aries_cloudagent/messaging/agent_message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Agent message base class and schema."""

import re
import uuid

from collections import OrderedDict
from typing import Mapping, Union
import uuid
from string import Template

from marshmallow import (
EXCLUDE,
Expand Down Expand Up @@ -98,6 +101,18 @@ def _get_handler_class(cls):
"""
return resolve_class(cls.Meta.handler_class, cls)

@classmethod
def assign_version_to_message_type(cls, version: str):
"""Assign version to Meta.message_type."""
if "$version" in cls.Meta.message_type:
cls.Meta.message_type = Template(cls.Meta.message_type).substitute(
version=version
)
else:
cls.Meta.message_type = re.sub(
r"(\d+\.)?(\*|\d+)", version, cls.Meta.message_type
)

@property
def Handler(self) -> type:
"""
Expand Down
15 changes: 14 additions & 1 deletion aries_cloudagent/messaging/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for Models and Schemas."""

import re
import logging
import json

Expand Down Expand Up @@ -319,7 +320,19 @@ def make_model(self, data: dict, **kwargs):
A model instance
"""
return self.Model(**data)
try:
cls_inst = self.Model(**data)
except TypeError as err:
msg_type_version = None
if "_type" in str(err) and "_type" in data:
match = re.search(r"(\d+\.)?(\*|\d+)", data["_type"])
if match:
msg_type_version = match.group()
del data["_type"]
cls_inst = self.Model(**data)
if msg_type_version:
cls_inst.assign_version_to_message_type(msg_type_version)
return cls_inst

@post_dump
def remove_skipped_values(self, data, **kwargs):
Expand Down
34 changes: 25 additions & 9 deletions aries_cloudagent/protocols/out_of_band/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
import re
from typing import Mapping, Optional, Sequence, Union
from typing import Mapping, Optional, Sequence, Union, Text
import uuid


Expand Down Expand Up @@ -87,6 +87,7 @@ async def create_invitation(
attachments: Sequence[Mapping] = None,
metadata: dict = None,
mediation_id: str = None,
accept: Optional[Sequence[Text]] = None,
) -> InvitationRecord:
"""
Generate new connection invitation.
Expand All @@ -105,6 +106,8 @@ async def create_invitation(
multi_use: set to True to create an invitation for multiple-use connection
alias: optional alias to apply to connection for later use
attachments: list of dicts in form of {"id": ..., "type": ...}
accept: Optional list of mime types in the order of preference of the sender
that the receiver can use in responding to the message
Returns:
Invitation record
Expand All @@ -122,7 +125,7 @@ async def create_invitation(
"request attachments, or both"
)

accept = bool(
auto_accept = bool(
auto_accept
or (
auto_accept is None
Expand Down Expand Up @@ -227,6 +230,7 @@ async def create_invitation(
handshake_protocols=handshake_protocols,
requests_attach=message_attachments,
services=[f"did:sov:{public_did.did}"],
accept=accept,
)

our_recipient_key = public_did.verkey
Expand All @@ -242,7 +246,7 @@ async def create_invitation(
their_role=ConnRecord.Role.REQUESTER.rfc23,
state=ConnRecord.State.INVITATION.rfc23,
accept=ConnRecord.ACCEPT_AUTO
if accept
if auto_accept
else ConnRecord.ACCEPT_MANUAL,
alias=alias,
connection_protocol=connection_protocol,
Expand Down Expand Up @@ -286,7 +290,7 @@ async def create_invitation(
their_role=ConnRecord.Role.REQUESTER.rfc23,
state=ConnRecord.State.INVITATION.rfc23,
accept=ConnRecord.ACCEPT_AUTO
if accept
if auto_accept
else ConnRecord.ACCEPT_MANUAL,
invitation_mode=invitation_mode,
alias=alias,
Expand Down Expand Up @@ -322,6 +326,7 @@ async def create_invitation(
invi_msg.label = my_label or self.profile.settings.get("default_label")
invi_msg.handshake_protocols = handshake_protocols
invi_msg.requests_attach = message_attachments
invi_msg.accept = accept
invi_msg.services = [
ServiceMessage(
_id="#inline",
Expand Down Expand Up @@ -415,6 +420,9 @@ async def receive_invitation(
# Get the single service item
oob_service_item = invitation.services[0]

# accept
accept = invitation.accept

# Get the DID public did, if any
public_did = None
if isinstance(oob_service_item, str):
Expand Down Expand Up @@ -446,7 +454,9 @@ async def receive_invitation(

# Try to reuse the connection. If not accepted sets the conn_rec to None
if conn_rec and not invitation.requests_attach:
oob_record = await self._handle_hanshake_reuse(oob_record, conn_rec)
oob_record = await self._handle_hanshake_reuse(
oob_record, conn_rec, get_version_from_message(invitation)
)

LOGGER.warning(
f"Connection reuse request finished with state {oob_record.state}"
Expand All @@ -467,6 +477,7 @@ async def receive_invitation(
alias=alias,
auto_accept=auto_accept,
mediation_id=mediation_id,
accept=accept,
)
LOGGER.debug(
f"Performed handshake with connection {oob_record.connection_id}"
Expand Down Expand Up @@ -674,10 +685,12 @@ async def _wait_for_state() -> ConnRecord:
return None

async def _handle_hanshake_reuse(
self, oob_record: OobRecord, conn_record: ConnRecord
self, oob_record: OobRecord, conn_record: ConnRecord, version: str
) -> OobRecord:
# Send handshake reuse
oob_record = await self._create_handshake_reuse_message(oob_record, conn_record)
oob_record = await self._create_handshake_reuse_message(
oob_record, conn_record, version
)

# Wait for the reuse accepted message
oob_record = await self._wait_for_reuse_response(oob_record.oob_id)
Expand Down Expand Up @@ -719,6 +732,7 @@ async def _perform_handshake(
alias: Optional[str] = None,
auto_accept: Optional[bool] = None,
mediation_id: Optional[str] = None,
accept: Optional[Sequence[Text]] = None,
) -> OobRecord:
invitation = oob_record.invitation

Expand Down Expand Up @@ -746,7 +760,8 @@ async def _perform_handshake(
# or something else that includes the key type. We now assume
# ED25519 keys
endpoint, recipient_keys, routing_keys = await self.resolve_invitation(
service
service,
accept=accept,
)
service = ServiceMessage.deserialize(
{
Expand Down Expand Up @@ -824,6 +839,7 @@ async def _create_handshake_reuse_message(
self,
oob_record: OobRecord,
conn_record: ConnRecord,
version: str,
) -> OobRecord:
"""
Create and Send a Handshake Reuse message under RFC 0434.
Expand All @@ -840,7 +856,7 @@ async def _create_handshake_reuse_message(
"""
try:
reuse_msg = HandshakeReuse()
reuse_msg = HandshakeReuse(version=version)
reuse_msg.assign_thread_id(thid=reuse_msg._id, pthid=oob_record.invi_msg_id)

connection_targets = await self.fetch_connection_targets(
Expand Down
7 changes: 7 additions & 0 deletions aries_cloudagent/protocols/out_of_band/v1_0/message_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Message and inner object type identifiers for Out of Band messages."""

from ....core.util import get_proto_default_version

from ...didcomm_prefix import DIDCommPrefix

SPEC_URI = (
Expand All @@ -13,6 +15,11 @@
MESSAGE_REUSE_ACCEPT = "out-of-band/$version/handshake-reuse-accepted"
PROBLEM_REPORT = "out-of-band/$version/problem_report"

# Default Version
DEFAULT_VERSION = get_proto_default_version(
"aries_cloudagent.protocols.out_of_band.definition", 1
)

PROTOCOL_PACKAGE = "aries_cloudagent.protocols.out_of_band.v1_0"

MESSAGE_TYPES = DIDCommPrefix.qualify_all(
Expand Down

0 comments on commit 590f31b

Please sign in to comment.