From 3d01a16d3113d732746c1f3a6e64462ef12c29c2 Mon Sep 17 00:00:00 2001 From: Abdulbois Date: Mon, 25 Jul 2022 15:59:32 +0500 Subject: [PATCH] feat: Connection protocol (RFC-0160) DIDCommV1 implementation (partially) Add implementation only for protocol's package Signed-off-by: Abdulbois --- pkg/didcomm/protocol/didconnection/event.go | 55 + .../protocol/didconnection/event_test.go | 30 + pkg/didcomm/protocol/didconnection/keys.go | 91 + .../protocol/didconnection/keys_test.go | 132 + pkg/didcomm/protocol/didconnection/models.go | 78 + pkg/didcomm/protocol/didconnection/service.go | 845 +++++++ .../protocol/didconnection/service_test.go | 2164 +++++++++++++++++ pkg/didcomm/protocol/didconnection/states.go | 808 ++++++ .../protocol/didconnection/states_test.go | 1565 ++++++++++++ 9 files changed, 5768 insertions(+) create mode 100644 pkg/didcomm/protocol/didconnection/event.go create mode 100644 pkg/didcomm/protocol/didconnection/event_test.go create mode 100644 pkg/didcomm/protocol/didconnection/keys.go create mode 100644 pkg/didcomm/protocol/didconnection/keys_test.go create mode 100644 pkg/didcomm/protocol/didconnection/models.go create mode 100644 pkg/didcomm/protocol/didconnection/service.go create mode 100644 pkg/didcomm/protocol/didconnection/service_test.go create mode 100644 pkg/didcomm/protocol/didconnection/states.go create mode 100644 pkg/didcomm/protocol/didconnection/states_test.go diff --git a/pkg/didcomm/protocol/didconnection/event.go b/pkg/didcomm/protocol/didconnection/event.go new file mode 100644 index 0000000000..92a9e092a7 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/event.go @@ -0,0 +1,55 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +// connectionEvent implements connection.Event interface. +type connectionEvent struct { + connectionID string + invitationID string +} + +// ConnectionID returns Connection connectionID. +func (ex *connectionEvent) ConnectionID() string { + return ex.connectionID +} + +// InvitationID returns Connection invitationID. +func (ex *connectionEvent) InvitationID() string { + return ex.invitationID +} + +// connectionEventError for sending events with processing error. +type connectionEventError struct { + connectionEvent + err error +} + +// Error implements error interface. +func (ex *connectionEventError) Error() string { + if ex.err != nil { + return ex.err.Error() + } + + return "" +} + +// All implements EventProperties interface. +func (ex *connectionEvent) All() map[string]interface{} { + return map[string]interface{}{ + "connectionID": ex.ConnectionID(), + "invitationID": ex.InvitationID(), + } +} + +// All implements EventProperties interface. +func (ex *connectionEventError) All() map[string]interface{} { + return map[string]interface{}{ + "connectionID": ex.ConnectionID(), + "invitationID": ex.InvitationID(), + "error": ex.Error(), + } +} diff --git a/pkg/didcomm/protocol/didconnection/event_test.go b/pkg/didcomm/protocol/didconnection/event_test.go new file mode 100644 index 0000000000..238b178532 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/event_test.go @@ -0,0 +1,30 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnectionEvent(t *testing.T) { + ev := connectionEvent{connectionID: "abc", invitationID: "xyz"} + require.Equal(t, ev.ConnectionID(), "abc") + require.Equal(t, ev.InvitationID(), "xyz") + require.Equal(t, ev.All()["connectionID"], ev.ConnectionID()) + require.Equal(t, ev.All()["invitationID"], ev.InvitationID()) + + err := errors.New("processing error") + evErr := connectionEventError{err: err} + require.Equal(t, err.Error(), evErr.Error()) + require.Equal(t, evErr.All()["error"], evErr.Error()) + + evErr = connectionEventError{} + require.Equal(t, "", evErr.Error()) +} diff --git a/pkg/didcomm/protocol/didconnection/keys.go b/pkg/didcomm/protocol/didconnection/keys.go new file mode 100644 index 0000000000..0936bf2fb8 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/keys.go @@ -0,0 +1,91 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/json" + "fmt" + + "github.com/hyperledger/aries-framework-go/pkg/crypto" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/kms" +) + +func (ctx *context) createNewKeyAndVM(didDoc *did.Doc) error { + vm, err := ctx.createSigningVM() + if err != nil { + return err + } + + kaVM, err := ctx.createEncryptionVM() + if err != nil { + return err + } + + didDoc.VerificationMethod = append(didDoc.VerificationMethod, *vm) + // TODO is Authentication needed? + didDoc.Authentication = append(didDoc.Authentication, *did.NewReferencedVerification(vm, did.Authentication)) + // TODO is KeyAgreement needed? + didDoc.KeyAgreement = append(didDoc.KeyAgreement, *did.NewReferencedVerification(kaVM, did.KeyAgreement)) + + return nil +} + +func (ctx *context) createSigningVM() (*did.VerificationMethod, error) { + vmType := getVerMethodType(ctx.keyType) + + _, pubKeyBytes, err := ctx.kms.CreateAndExportPubKeyBytes(ctx.keyType) + if err != nil { + return nil, fmt.Errorf("createSigningVM: %w", err) + } + + vmID := "#key-1" + + switch vmType { + case ed25519VerificationKey2018: + return did.NewVerificationMethodFromBytes(vmID, vmType, "", pubKeyBytes), nil + default: + return nil, fmt.Errorf("createSigningVM: unsupported verification method: '%s'", vmType) + } +} + +func (ctx *context) createEncryptionVM() (*did.VerificationMethod, error) { + encKeyType := ctx.keyAgreementType + + vmType := getVerMethodType(encKeyType) + + _, kaPubKeyBytes, err := ctx.kms.CreateAndExportPubKeyBytes(encKeyType) + if err != nil { + return nil, fmt.Errorf("createEncryptionVM: %w", err) + } + + vmID := "#key-2" + + switch vmType { + case x25519KeyAgreementKey2019: + key := &crypto.PublicKey{} + + err = json.Unmarshal(kaPubKeyBytes, key) + if err != nil { + return nil, fmt.Errorf("createEncryptionVM: unable to unmarshal X25519 key: %w", err) + } + + return did.NewVerificationMethodFromBytes(vmID, vmType, "", key.X), nil + default: + return nil, fmt.Errorf("unsupported verification method for KeyAgreement: '%s'", vmType) + } +} + +// nolint:gochecknoglobals +var vmType = map[kms.KeyType]string{ + kms.ED25519Type: ed25519VerificationKey2018, + kms.X25519ECDHKWType: x25519KeyAgreementKey2019, +} + +func getVerMethodType(kt kms.KeyType) string { + return vmType[kt] +} diff --git a/pkg/didcomm/protocol/didconnection/keys_test.go b/pkg/didcomm/protocol/didconnection/keys_test.go new file mode 100644 index 0000000000..1fcd5f3b37 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/keys_test.go @@ -0,0 +1,132 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/kms" + "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol" + mockroute "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/mediator" + mockstorage "github.com/hyperledger/aries-framework-go/pkg/mock/storage" +) + +func TestCreateNewKeyAndVM(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + + p, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + }) + require.NoError(t, err) + + t.Run("createNewKeyAndVM success", func(t *testing.T) { + didDoc := &did.Doc{} + + p.ctx.keyType = kms.ED25519 + p.ctx.keyAgreementType = kms.X25519ECDHKWType + + err = p.ctx.createNewKeyAndVM(didDoc) + require.NoError(t, err) + require.Equal(t, ed25519VerificationKey2018, didDoc.VerificationMethod[0].Type) + require.Equal(t, x25519KeyAgreementKey2019, didDoc.KeyAgreement[0].VerificationMethod.Type) + }) + + t.Run("createNewKeyAndVM invalid keyType export signing key", func(t *testing.T) { + didDoc := &did.Doc{} + + p.ctx.keyType = kms.HMACSHA256Tag256Type // invalid signing key + p.ctx.keyAgreementType = kms.X25519ECDHKWType + + err = p.ctx.createNewKeyAndVM(didDoc) + require.EqualError(t, err, "createSigningVM: createAndExportPubKeyBytes: failed to export new public key bytes: "+ + "exportPubKeyBytes: failed to export marshalled key: exportPubKeyBytes: failed to get public keyset "+ + "handle: keyset.Handle: keyset.Handle: keyset contains a non-private key") + require.Empty(t, didDoc.VerificationMethod) + require.Empty(t, didDoc.KeyAgreement) + }) +} + +func TestCreateSigningVM(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + + p, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + }) + require.NoError(t, err) + + t.Run("createSigningVM success", func(t *testing.T) { + p.ctx.keyType = kms.ED25519 + + svm, err := p.ctx.createSigningVM() + require.NoError(t, err) + require.NotEmpty(t, svm) + }) + + t.Run("createSigningVM with empty vmType", func(t *testing.T) { + p.ctx.keyType = "" + + svm, err := p.ctx.createSigningVM() + require.EqualError(t, err, "createSigningVM: createAndExportPubKeyBytes: failed to create new key: "+ + "failed to create new key, missing key type") + require.Empty(t, svm) + }) + + t.Run("createSigningVM with unsupported keyType", func(t *testing.T) { + p.ctx.keyType = kms.X25519ECDHKW + + svm, err := p.ctx.createSigningVM() + require.EqualError(t, err, "createSigningVM: unsupported verification method: 'X25519KeyAgreementKey2019'") + require.Empty(t, svm) + }) +} + +func TestCreateEncryptionVM(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + + p, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + }) + require.NoError(t, err) + + t.Run("createEncryptionVM success", func(t *testing.T) { + p.ctx.keyAgreementType = kms.X25519ECDHKW + + evm, err := p.ctx.createEncryptionVM() + require.NoError(t, err) + require.NotEmpty(t, evm) + }) + + t.Run("createEncryptionVM with empty keyAgreementType", func(t *testing.T) { + p.ctx.keyAgreementType = "" + + evm, err := p.ctx.createEncryptionVM() + require.EqualError(t, err, "createEncryptionVM: createAndExportPubKeyBytes: failed to create new key: "+ + "failed to create new key, missing key type") + require.Empty(t, evm) + }) + + t.Run("createEncryptionVM with unsupported keyType", func(t *testing.T) { + p.ctx.keyAgreementType = kms.ED25519Type + + evm, err := p.ctx.createEncryptionVM() + require.EqualError(t, err, "unsupported verification method for KeyAgreement: 'Ed25519VerificationKey2018'") + require.Empty(t, evm) + }) +} diff --git a/pkg/didcomm/protocol/didconnection/models.go b/pkg/didcomm/protocol/didconnection/models.go new file mode 100644 index 0000000000..a38d5bc391 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/models.go @@ -0,0 +1,78 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" +) + +// Invitation model +// +// Invitation defines Connection protocol invitation message +// https://github.com/hyperledger/aries-rfcs/tree/main/features/0160-connection-protocol#0-invitation-to-connect +type Invitation struct { + // the Type of the connection invitation + Type string `json:"@type,omitempty"` + + // the ID of the connection invitation + ID string `json:"@id,omitempty"` + + // the Label of the connection invitation + Label string `json:"label,omitempty"` + + // the RecipientKeys for the connection invitation + RecipientKeys []string `json:"recipientKeys,omitempty"` + + // the Service endpoint of the connection invitation + ServiceEndpoint string `json:"serviceEndpoint,omitempty"` + + // the RoutingKeys of the connection invitation + RoutingKeys []string `json:"routingKeys,omitempty"` + + // the DID of the connection invitation + DID string `json:"did,omitempty"` +} + +// Request defines a2a Connection request +// https://github.com/hyperledger/aries-rfcs/tree/main/features/0160-connection-protocol#1-connection-request +type Request struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + Label string `json:"label"` + Thread *decorator.Thread `json:"~thread,omitempty"` + Connection *Connection `json:"connection,omitempty"` +} + +// Response defines a2a Connection response +// https://github.com/hyperledger/aries-rfcs/tree/main/features/0160-connection-protocol#2-connection-response +type Response struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + ConnectionSignature *ConnectionSignature `json:"connection~sig,omitempty"` + Thread *decorator.Thread `json:"~thread,omitempty"` + PleaseAck *PleaseAck `json:"~please_ack,omitempty"` +} + +// ConnectionSignature connection signature. +type ConnectionSignature struct { + Type string `json:"@type,omitempty"` + Signature string `json:"signature,omitempty"` + SignedData string `json:"sig_data,omitempty"` + SignVerKey string `json:"signers,omitempty"` +} + +// PleaseAck connection response accepted acknowledgement. +type PleaseAck struct { + On []string `json:"on,omitempty"` +} + +// Connection defines connection body of connection request. +type Connection struct { + DID string `json:"DID,omitempty"` + DIDDoc *did.Doc `json:"DIDDoc,omitempty"` +} diff --git a/pkg/didcomm/protocol/didconnection/service.go b/pkg/didcomm/protocol/didconnection/service.go new file mode 100644 index 0000000000..46b3b42204 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/service.go @@ -0,0 +1,845 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + + "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/common/model" + "github.com/hyperledger/aries-framework-go/pkg/crypto" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" + "github.com/hyperledger/aries-framework-go/pkg/internal/logutil" + "github.com/hyperledger/aries-framework-go/pkg/kms" + "github.com/hyperledger/aries-framework-go/pkg/store/connection" + didstore "github.com/hyperledger/aries-framework-go/pkg/store/did" + "github.com/hyperledger/aries-framework-go/pkg/vdr" + "github.com/hyperledger/aries-framework-go/spi/storage" +) + +var logger = log.New("aries-framework/connection/service") + +const ( + // DIDConnection connection protocol. + DIDConnection = "DIDConnection" + // PIURI is the connection protocol identifier URI. + PIURI = "https://didcomm.org/connection/1.0" + // InvitationMsgType defines the did-connection invite message type. + InvitationMsgType = PIURI + "/invitation" + // RequestMsgType defines the did-connection request message type. + RequestMsgType = PIURI + "/request" + // ResponseMsgType defines the did-connection response message type. + ResponseMsgType = PIURI + "/response" + // AckMsgType defines the did-connection ack message type. + AckMsgType = PIURI + "/ack" + routerConnsMetadataKey = "routerConnections" +) + +const ( + myNSPrefix = "my" + // TODO: https://github.com/hyperledger/aries-framework-go/issues/556 It will not be constant, this namespace + // will need to be figured with verification key + theirNSPrefix = "their" +) + +// message type to store data for eventing. This is retrieved during callback. +type message struct { + Msg service.DIDCommMsgMap + ThreadID string + Options *options + NextStateName string + ConnRecord *connection.Record + // err is used to determine whether callback was stopped + // e.g the user received an action event and executes Stop(err) function + // in that case `err` is equal to `err` which was passing to Stop function + err error +} + +// provider contains dependencies for the Connection protocol and is typically created by using aries.Context(). +type provider interface { + OutboundDispatcher() dispatcher.Outbound + StorageProvider() storage.Provider + ProtocolStateStorageProvider() storage.Provider + DIDConnectionStore() didstore.ConnectionStore + Crypto() crypto.Crypto + KMS() kms.KeyManager + VDRegistry() vdrapi.Registry + Service(id string) (interface{}, error) + KeyType() kms.KeyType + KeyAgreementType() kms.KeyType + MediaTypeProfiles() []string +} + +// stateMachineMsg is an internal struct used to pass data to state machine. +type stateMachineMsg struct { + service.DIDCommMsg + connRecord *connection.Record + options *options +} + +type options struct { + publicDID string + routerConnections []string + label string +} + +// Service for Connection protocol. +type Service struct { + service.Action + service.Message + ctx *context + callbackChannel chan *message + connectionRecorder *connection.Recorder + connectionStore didstore.ConnectionStore + initialized bool +} + +type context struct { + outboundDispatcher dispatcher.Outbound + crypto crypto.Crypto + kms kms.KeyManager + connectionRecorder *connection.Recorder + connectionStore didstore.ConnectionStore + vdRegistry vdrapi.Registry + routeSvc mediator.ProtocolService + doACAPyInterop bool + keyType kms.KeyType + keyAgreementType kms.KeyType + mediaTypeProfiles []string +} + +// opts are used to provide client properties to Connection service. +type opts interface { + // PublicDID allows for setting public DID + PublicDID() string + + // Label allows for setting label + Label() string + + // RouterConnections allows for setting router connections + RouterConnections() []string +} + +// New return connection service. +func New(prov provider) (*Service, error) { + svc := Service{} + + err := svc.Initialize(prov) + if err != nil { + return nil, err + } + + return &svc, nil +} + +// Initialize initializes the Service. If Initialize succeeds, any further call is a no-op. +func (s *Service) Initialize(p interface{}) error { + if s.initialized { + return nil + } + + prov, ok := p.(provider) + if !ok { + return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p) + } + + connRecorder, err := connection.NewRecorder(prov) + if err != nil { + return fmt.Errorf("failed to initialize connection recorder: %w", err) + } + + routeSvcBase, err := prov.Service(mediator.Coordination) + if err != nil { + return err + } + + routeSvc, ok := routeSvcBase.(mediator.ProtocolService) + if !ok { + return errors.New("cast service to Route Service failed") + } + + const callbackChannelSize = 10 + + keyType := kms.ED25519Type + + keyAgreementType := kms.X25519ECDHKWType + + mediaTypeProfiles := []string{transport.MediaTypeProfileDIDCommAIP1} + + s.ctx = &context{ + outboundDispatcher: prov.OutboundDispatcher(), + crypto: prov.Crypto(), + kms: prov.KMS(), + vdRegistry: prov.VDRegistry(), + connectionRecorder: connRecorder, + connectionStore: prov.DIDConnectionStore(), + routeSvc: routeSvc, + keyType: keyType, + keyAgreementType: keyAgreementType, + mediaTypeProfiles: mediaTypeProfiles, + } + + // TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246 + s.callbackChannel = make(chan *message, callbackChannelSize) + s.connectionRecorder = connRecorder + s.connectionStore = prov.DIDConnectionStore() + + // start the listener + go s.startInternalListener() + + s.initialized = true + + return nil +} + +func retrievingRouterConnections(msg service.DIDCommMsg) []string { + raw, found := msg.Metadata()[routerConnsMetadataKey] + if !found { + return nil + } + + connections, ok := raw.([]string) + if !ok { + return nil + } + + return connections +} + +// HandleInbound handles inbound connection messages. +func (s *Service) HandleInbound(msg service.DIDCommMsg, _ service.DIDCommContext) (string, error) { + logger.Debugf("receive inbound message : %s", msg) + + // fetch the thread id + thID, err := msg.ThreadID() + if err != nil { + return "", err + } + + // valid state transition and get the next state + next, err := s.nextState(msg.Type(), thID) + if err != nil { + return "", fmt.Errorf("handle inbound - next state : %w", err) + } + + // connection record + connRecord, err := s.connectionRecord(msg) + if err != nil { + return "", fmt.Errorf("failed to fetch connection record : %w", err) + } + + logger.Debugf("connection record: %+v", connRecord) + + internalMsg := &message{ + Options: &options{routerConnections: retrievingRouterConnections(msg)}, + Msg: msg.Clone(), + ThreadID: thID, + NextStateName: next.Name(), + ConnRecord: connRecord, + } + + go func(msg *message, aEvent chan<- service.DIDCommAction) { + if err = s.handle(msg, aEvent); err != nil { + logutil.LogError(logger, DIDConnection, "processMessage", err.Error(), + logutil.CreateKeyValueString("msgType", msg.Msg.Type()), + logutil.CreateKeyValueString("msgID", msg.Msg.ID()), + logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID)) + } + + logutil.LogDebug(logger, DIDConnection, "processMessage", "success", + logutil.CreateKeyValueString("msgType", msg.Msg.Type()), + logutil.CreateKeyValueString("msgID", msg.Msg.ID()), + logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID)) + }(internalMsg, s.ActionEvent()) + + logutil.LogDebug(logger, DIDConnection, "handleInbound", "success", + logutil.CreateKeyValueString("msgType", msg.Type()), + logutil.CreateKeyValueString("msgID", msg.ID()), + logutil.CreateKeyValueString("connectionID", internalMsg.ConnRecord.ConnectionID)) + + return connRecord.ConnectionID, nil +} + +// Name return service name. +func (s *Service) Name() string { + return DIDConnection +} + +func findNamespace(msgType string) string { + namespace := theirNSPrefix + if msgType == InvitationMsgType || msgType == ResponseMsgType { + namespace = myNSPrefix + } + + return namespace +} + +// Accept msg checks the msg type. +func (s *Service) Accept(msgType string) bool { + return msgType == InvitationMsgType || + msgType == RequestMsgType || + msgType == ResponseMsgType || + msgType == AckMsgType +} + +// HandleOutbound handles outbound connection messages. +func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) { + return "", errors.New("not implemented") +} + +func (s *Service) nextState(msgType, thID string) (state, error) { + logger.Debugf("msgType=%s thID=%s", msgType, thID) + + nsThID, err := connection.CreateNamespaceKey(findNamespace(msgType), thID) + if err != nil { + return nil, err + } + + current, err := s.currentState(nsThID) + if err != nil { + return nil, err + } + + logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID) + + next, err := stateFromMsgType(msgType) + if err != nil { + return nil, err + } + + logger.Debugf("check if current state [%s] can transition to [%s]", current.Name(), next.Name()) + + if !current.CanTransitionTo(next) { + return nil, fmt.Errorf("invalid state transition: %s -> %s", current.Name(), next.Name()) + } + + return next, nil +} + +func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) error { //nolint:funlen,gocyclo + logger.Debugf("handling msg: %+v", msg) + + next, err := stateFromName(msg.NextStateName) + if err != nil { + return fmt.Errorf("invalid state name: %w", err) + } + + for !isNoOp(next) { + s.sendMsgEvents(&service.StateMsg{ + ProtocolName: DIDConnection, + Type: service.PreState, + Msg: msg.Msg.Clone(), + StateID: next.Name(), + Properties: createEventProperties(msg.ConnRecord.ConnectionID, msg.ConnRecord.InvitationID), + }) + logger.Debugf("sent pre event for state %s", next.Name()) + + var ( + action stateAction + followup state + connectionRecord *connection.Record + ) + + connectionRecord, followup, action, err = next.ExecuteInbound( + &stateMachineMsg{ + DIDCommMsg: msg.Msg, + connRecord: msg.ConnRecord, + options: msg.Options, + }, + msg.ThreadID, + s.ctx) + + if err != nil { + return fmt.Errorf("failed to execute state '%s': %w", next.Name(), err) + } + + connectionRecord.State = next.Name() + logger.Debugf("finished execute state: %s", next.Name()) + + if err = s.update(msg.Msg.Type(), connectionRecord); err != nil { + return fmt.Errorf("failed to persist state '%s': %w", next.Name(), err) + } + + if connectionRecord.State == StateIDCompleted { + err = s.connectionStore.SaveDIDByResolving(connectionRecord.TheirDID, connectionRecord.RecipientKeys...) + if err != nil { + return fmt.Errorf("save theirDID: %w", err) + } + } + + if err = action(); err != nil { + return fmt.Errorf("failed to execute state action '%s': %w", next.Name(), err) + } + + logger.Debugf("finish execute state action: '%s'", next.Name()) + + prev := next + next = followup + haltExecution := false + + // trigger action event based on message type for inbound messages + if canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) { + logger.Debugf("action event triggered for msg type: %s", msg.Msg.Type()) + + msg.NextStateName = next.Name() + if err = s.sendActionEvent(msg, aEvent); err != nil { + return fmt.Errorf("handle inbound: %w", err) + } + + haltExecution = true + } + + s.sendMsgEvents(&service.StateMsg{ + ProtocolName: DIDConnection, + Type: service.PostState, + Msg: msg.Msg.Clone(), + StateID: prev.Name(), + Properties: createEventProperties(connectionRecord.ConnectionID, connectionRecord.InvitationID), + }) + logger.Debugf("sent post event for state %s", prev.Name()) + + if haltExecution { + logger.Debugf("halted execution before state=%s", msg.NextStateName) + + break + } + } + + return nil +} + +func (s *Service) handleWithoutAction(msg *message) error { + return s.handle(msg, nil) +} + +func createEventProperties(connectionID, invitationID string) *connectionEvent { + return &connectionEvent{ + connectionID: connectionID, + invitationID: invitationID, + } +} + +// sendActionEvent triggers the action event. This function stores the state of current processing and passes a callback +// function in the event message. +func (s *Service) sendActionEvent(internalMsg *message, aEvent chan<- service.DIDCommAction) error { + // save data to support AcceptConnectionRequest APIs (when client will not be able to invoke the callback function) + err := s.storeEventProtocolStateData(internalMsg) + if err != nil { + return fmt.Errorf("send action event : %w", err) + } + + if aEvent != nil { + // trigger action event + aEvent <- service.DIDCommAction{ + ProtocolName: DIDConnection, + Message: internalMsg.Msg.Clone(), + Continue: func(args interface{}) { + switch v := args.(type) { + case opts: + internalMsg.Options = &options{ + publicDID: v.PublicDID(), + label: v.Label(), + routerConnections: v.RouterConnections(), + } + default: + // nothing to do + } + + s.processCallback(internalMsg) + }, + Stop: func(err error) { + // sets an error to the message + internalMsg.err = err + s.processCallback(internalMsg) + }, + Properties: createEventProperties(internalMsg.ConnRecord.ConnectionID, internalMsg.ConnRecord.InvitationID), + } + + logger.Debugf("dispatched action for msg: %+v", internalMsg.Msg) + } + + return nil +} + +// sendEvent triggers the message events. +func (s *Service) sendMsgEvents(msg *service.StateMsg) { + // trigger the message events + for _, handler := range s.MsgEvents() { + handler <- *msg + + logger.Debugf("sent msg event to handler: %+v", msg) + } +} + +// startInternalListener listens to messages in gochannel for callback messages from clients. +func (s *Service) startInternalListener() { + for msg := range s.callbackChannel { + // TODO https://github.com/hyperledger/aries-framework-go/issues/242 - retry logic + // if no error - do handle + if msg.err == nil { + msg.err = s.handleWithoutAction(msg) + } + + // no error - continue + if msg.err == nil { + continue + } + } +} + +// AcceptInvitation accepts/approves connection invitation. +func (s *Service) AcceptInvitation(connectionID, publicDID, label string, routerConnections []string) error { + return s.accept(connectionID, publicDID, label, StateIDInvited, + "accept connection invitation", routerConnections) +} + +// AcceptConnectionRequest accepts/approves connection request. +func (s *Service) AcceptConnectionRequest(connectionID, publicDID, label string, routerConnections []string) error { + return s.accept(connectionID, publicDID, label, StateIDRequested, + "accept exchange request", routerConnections) +} + +func (s *Service) accept(connectionID, publicDID, label, stateID, errMsg string, routerConnections []string) error { + msg, err := s.getEventProtocolStateData(connectionID) + if err != nil { + return fmt.Errorf("failed to accept invitation for connectionID=%s : %s : %w", connectionID, errMsg, err) + } + + connRecord, err := s.connectionRecorder.GetConnectionRecord(connectionID) + if err != nil { + return fmt.Errorf("%s : %w", errMsg, err) + } + + if connRecord.State != stateID { + return fmt.Errorf("current state (%s) is different from "+ + "expected state (%s)", connRecord.State, stateID) + } + + msg.Options = &options{publicDID: publicDID, label: label, routerConnections: routerConnections} + + return s.handleWithoutAction(msg) +} + +func (s *Service) storeEventProtocolStateData(msg *message) error { + bytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("store protocol state data : %w", err) + } + + return s.connectionRecorder.SaveEvent(msg.ConnRecord.ConnectionID, bytes) +} + +func (s *Service) getEventProtocolStateData(connectionID string) (*message, error) { + val, err := s.connectionRecorder.GetEvent(connectionID) + if err != nil { + return nil, fmt.Errorf("get protocol state data : %w", err) + } + + msg := &message{} + + err = json.Unmarshal(val, msg) + if err != nil { + return nil, fmt.Errorf("get protocol state data : %w", err) + } + + return msg, nil +} + +func (s *Service) processCallback(msg *message) { + // pass the callback data to internal channel. This is created to unblock consumer go routine and wrap the callback + // channel internally. + s.callbackChannel <- msg +} + +func isNoOp(s state) bool { + _, ok := s.(*noOp) + return ok +} + +func (s *Service) currentState(nsThID string) (state, error) { + connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID) + if err != nil { + if errors.Is(err, storage.ErrDataNotFound) { + return &null{}, nil + } + + return nil, fmt.Errorf("cannot fetch state from store: thID=%s err=%w", nsThID, err) + } + + return stateFromName(connRec.State) +} + +func (s *Service) update(msgType string, record *connection.Record) error { + if (msgType == RequestMsgType && record.State == StateIDRequested) || + (msgType == InvitationMsgType && record.State == StateIDInvited) { + return s.connectionRecorder.SaveConnectionRecordWithMappings(record) + } + + return s.connectionRecorder.SaveConnectionRecord(record) +} + +// CreateConnection saves the record to the connection store and maps TheirDID to their recipient keys in +// the did connection store. +func (s *Service) CreateConnection(record *connection.Record, theirDID *did.Doc) error { + logger.Debugf("creating connection using record [%+v] and theirDID [%+v]", record, theirDID) + + didMethod, err := vdr.GetDidMethod(theirDID.ID) + if err != nil { + return err + } + + _, err = s.ctx.vdRegistry.Create(didMethod, theirDID, vdrapi.WithOption("store", true)) + if err != nil { + return fmt.Errorf("vdr failed to store theirDID : %w", err) + } + + err = s.connectionStore.SaveDIDFromDoc(theirDID) + if err != nil { + return fmt.Errorf("failed to save theirDID to the did.ConnectionStore: %w", err) + } + + err = s.connectionStore.SaveDIDByResolving(record.MyDID) + if err != nil { + return fmt.Errorf("failed to save myDID to the did.ConnectionStore: %w", err) + } + + if isDIDCommV2(record.MediaTypeProfiles) { + record.DIDCommVersion = service.V2 + } else { + record.DIDCommVersion = service.V1 + } + + return s.connectionRecorder.SaveConnectionRecord(record) +} + +func (s *Service) connectionRecord(msg service.DIDCommMsg) (*connection.Record, error) { + switch msg.Type() { + case InvitationMsgType: + return s.invitationMsgRecord(msg) + case RequestMsgType: + return s.requestMsgRecord(msg) + case ResponseMsgType: + return s.responseMsgRecord(msg) + case AckMsgType: + return s.fetchConnectionRecord(theirNSPrefix, msg) + } + + return nil, errors.New("invalid message type") +} + +func (s *Service) invitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) { + thID, msgErr := msg.ThreadID() + if msgErr != nil { + return nil, msgErr + } + + invitation := &Invitation{} + + err := msg.Decode(invitation) + if err != nil { + return nil, err + } + + recKey, err := s.ctx.getInvitationRecipientKey(invitation) + if err != nil { + return nil, err + } + + connRecord := &connection.Record{ + ConnectionID: generateRandomID(), + ThreadID: thID, + State: stateNameNull, + InvitationID: invitation.ID, + InvitationDID: invitation.DID, + ServiceEndPoint: model.NewDIDCommV1Endpoint(invitation.ServiceEndpoint), + RecipientKeys: []string{recKey}, + TheirLabel: invitation.Label, + Namespace: findNamespace(msg.Type()), + DIDCommVersion: service.V1, + } + + if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil { + return nil, err + } + + return connRecord, nil +} + +func (s *Service) requestMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) { + request := Request{} + + err := msg.Decode(&request) + if err != nil { + return nil, fmt.Errorf("unmarshalling failed: %w", err) + } + + invitationID := msg.ParentThreadID() + if invitationID == "" { + return nil, fmt.Errorf("missing parent thread ID on connection request with @id=%s", request.ID) + } + + if request.Connection == nil { + return nil, fmt.Errorf("missing connection field on connection request with @id=%s", request.ID) + } + + connRecord := &connection.Record{ + TheirLabel: request.Label, + ConnectionID: generateRandomID(), + ThreadID: request.ID, + State: stateNameNull, + TheirDID: request.Connection.DID, + InvitationID: invitationID, + Namespace: theirNSPrefix, + DIDCommVersion: service.V1, + } + + if !strings.HasPrefix(connRecord.TheirDID, "did") { + connRecord.TheirDID = "did:peer:" + connRecord.TheirDID + } + + if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil { + return nil, err + } + + return connRecord, nil +} + +func (s *Service) responseMsgRecord(payload service.DIDCommMsg) (*connection.Record, error) { + return s.fetchConnectionRecord(myNSPrefix, payload) +} + +func (s *Service) fetchConnectionRecord(nsPrefix string, payload service.DIDCommMsg) (*connection.Record, error) { + msg := &struct { + Thread decorator.Thread `json:"~thread,omitempty"` + }{} + + err := payload.Decode(msg) + if err != nil { + return nil, err + } + + key, err := connection.CreateNamespaceKey(nsPrefix, msg.Thread.ID) + if err != nil { + return nil, err + } + + return s.connectionRecorder.GetConnectionRecordByNSThreadID(key) +} + +func generateRandomID() string { + return uuid.New().String() +} + +// canTriggerActionEvents true based on role and state. +// 1. Role is invitee and state is invited. +// 2. Role is inviter and state is requested. +func canTriggerActionEvents(stateID, ns string) bool { + return (stateID == StateIDInvited && ns == myNSPrefix) || (stateID == StateIDRequested && ns == theirNSPrefix) +} + +// CreateImplicitInvitation creates implicit invitation. Inviter DID is required, invitee DID is optional. +// If invitee DID is not provided new peer DID will be created for implicit invitation connection request. +//nolint:funlen +func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID, + inviteeLabel, inviteeDID string, routerConnections []string) (string, error) { + logger.Debugf("implicit invitation requested inviterDID[%s] inviteeDID[%s]", inviterDID, inviteeDID) + + docResolution, err := s.ctx.vdRegistry.Resolve(inviterDID) + if err != nil { + return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err) + } + + dest, err := service.CreateDestination(docResolution.DIDDocument) + if err != nil { + return "", err + } + + thID := generateRandomID() + + var connRecord *connection.Record + + if accept, e := dest.ServiceEndpoint.Accept(); e == nil && isDIDCommV2(accept) { + connRecord = &connection.Record{ + ConnectionID: generateRandomID(), + ThreadID: thID, + State: stateNameNull, + InvitationDID: inviterDID, + Implicit: true, + ServiceEndPoint: dest.ServiceEndpoint, + RecipientKeys: dest.RecipientKeys, + TheirLabel: inviterLabel, + Namespace: findNamespace(InvitationMsgType), + } + } else { + connRecord = &connection.Record{ + ConnectionID: generateRandomID(), + ThreadID: thID, + State: stateNameNull, + InvitationDID: inviterDID, + Implicit: true, + ServiceEndPoint: dest.ServiceEndpoint, + RecipientKeys: dest.RecipientKeys, + RoutingKeys: dest.RoutingKeys, + MediaTypeProfiles: dest.MediaTypeProfiles, + TheirLabel: inviterLabel, + Namespace: findNamespace(InvitationMsgType), + } + } + + if e := s.connectionRecorder.SaveConnectionRecordWithMappings(connRecord); e != nil { + return "", fmt.Errorf("failed to save new connection record for implicit invitation: %w", e) + } + + invitation := &Invitation{ + ID: uuid.New().String(), + Label: inviterLabel, + DID: inviterDID, + Type: InvitationMsgType, + } + + msg, err := createDIDCommMsg(invitation) + if err != nil { + return "", fmt.Errorf("failed to create DIDCommMsg for implicit invitation: %w", err) + } + + next := &requested{} + internalMsg := &message{ + Msg: msg.Clone(), + ThreadID: thID, + NextStateName: next.Name(), + ConnRecord: connRecord, + } + internalMsg.Options = &options{publicDID: inviteeDID, label: inviteeLabel, routerConnections: routerConnections} + + go func(msg *message, aEvent chan<- service.DIDCommAction) { + if err = s.handle(msg, aEvent); err != nil { + logger.Errorf("error from handle for implicit invitation: %s", err) + } + }(internalMsg, s.ActionEvent()) + + return connRecord.ConnectionID, nil +} + +func createDIDCommMsg(invitation *Invitation) (service.DIDCommMsg, error) { + payload, err := json.Marshal(invitation) + if err != nil { + return nil, fmt.Errorf("marshal invitation: %w", err) + } + + return service.ParseDIDCommMsgMap(payload) +} diff --git a/pkg/didcomm/protocol/didconnection/service_test.go b/pkg/didcomm/protocol/didconnection/service_test.go new file mode 100644 index 0000000000..2a546857fd --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/service_test.go @@ -0,0 +1,2164 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/btcsuite/btcutil/base58" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + commonmodel "github.com/hyperledger/aries-framework-go/pkg/common/model" + "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" + "github.com/hyperledger/aries-framework-go/pkg/kms" + "github.com/hyperledger/aries-framework-go/pkg/kms/localkms" + "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol" + mockroute "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/mediator" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/mock/diddoc" + mockkms "github.com/hyperledger/aries-framework-go/pkg/mock/kms" + mockprovider "github.com/hyperledger/aries-framework-go/pkg/mock/provider" + mockstorage "github.com/hyperledger/aries-framework-go/pkg/mock/storage" + mockvdr "github.com/hyperledger/aries-framework-go/pkg/mock/vdr" + "github.com/hyperledger/aries-framework-go/pkg/secretlock/noop" + "github.com/hyperledger/aries-framework-go/pkg/store/connection" + didstore "github.com/hyperledger/aries-framework-go/pkg/store/did" + "github.com/hyperledger/aries-framework-go/pkg/vdr/peer" + "github.com/hyperledger/aries-framework-go/spi/storage" +) + +const ( + testMethod = "peer" + threadIDValue = "xyz" +) + +type event interface { + // connection ID + ConnectionID() string + + // invitation ID + InvitationID() string +} + +func TestService_Name(t *testing.T) { + t.Run("test success", func(t *testing.T) { + prov, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + require.Equal(t, DIDConnection, prov.Name()) + }) +} + +func TestServiceNew(t *testing.T) { + t.Run("test error from open store", func(t *testing.T) { + _, err := New( + &protocol.MockProvider{StoreProvider: &mockstorage.MockStoreProvider{ + ErrOpenStoreHandle: fmt.Errorf("failed to open store"), + }}) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to open store") + }) + + t.Run("test error from open protocol state store", func(t *testing.T) { + _, err := New( + &protocol.MockProvider{ProtocolStateStoreProvider: &mockstorage.MockStoreProvider{ + ErrOpenStoreHandle: fmt.Errorf("failed to open protocol state store"), + }}) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to open protocol state store") + }) + + t.Run("test service new error - no route service found", func(t *testing.T) { + _, err := New(&protocol.MockProvider{ServiceErr: errors.New("service not found")}) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found") + }) + + t.Run("test service new error - casting to route service failed", func(t *testing.T) { + _, err := New(&protocol.MockProvider{}) + require.Error(t, err) + require.Contains(t, err.Error(), "cast service to Route Service failed") + }) +} + +func TestService_Initialize(t *testing.T) { + t.Run("success: already initialized", func(t *testing.T) { + prov := &protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + } + + svc, err := New(prov) + require.NoError(t, err) + + require.NoError(t, svc.Initialize(prov)) + }) + + t.Run("fail: provider of wrong type", func(t *testing.T) { + prov := "this is not a provider" + + svc := Service{} + + err := svc.Initialize(prov) + + require.Error(t, err) + require.Contains(t, err.Error(), "expected provider of type") + }) +} + +// connection flow with role Inviter. +func TestService_Handle_Inviter(t *testing.T) { + mockStore := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)} + storeProv := mockstorage.NewCustomMockStoreProvider(mockStore) + k := newKMS(t, storeProv) + prov := &protocol.MockProvider{ + StoreProvider: storeProv, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: kms.ED25519Type, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + } + + ctx := &context{ + outboundDispatcher: prov.OutboundDispatcher(), + crypto: &tinkcrypto.Crypto{}, + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDocWithKey(pubKey)} + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + ctx.connectionRecorder = connRec + + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + + s, err := New(prov) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = s.RegisterActionEvent(actionCh) + require.NoError(t, err) + + statusCh := make(chan service.StateMsg, 10) + err = s.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + completedFlag := make(chan struct{}) + respondedFlag := make(chan struct{}) + + go msgEventListener(t, statusCh, respondedFlag, completedFlag) + + go func() { service.AutoExecuteActionEvent(actionCh) }() + + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(pubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + thid := randomString() + + // Invitation was previously sent by Alice to Bob. + // Bob now sends a connection Request + payloadBytes, err := json.Marshal( + &Request{ + Type: RequestMsgType, + ID: thid, + Label: "Bob", + Thread: &decorator.Thread{ + PID: invitation.ID, + }, + Connection: &Connection{ + DID: doc.DIDDocument.ID, + DIDDoc: doc.DIDDocument, + }, + }) + require.NoError(t, err) + msg, err := service.ParseDIDCommMsgMap(payloadBytes) + require.NoError(t, err) + _, err = s.HandleInbound(msg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + select { + case <-respondedFlag: + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive post event responded") + } + // Alice automatically sends connection Response to Bob + // Bob replies with an ACK + payloadBytes, err = json.Marshal( + &model.Ack{ + Type: AckMsgType, + ID: randomString(), + Status: "OK", + Thread: &decorator.Thread{ID: thid}, + }) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(payloadBytes) + require.NoError(t, err) + + _, err = s.HandleInbound(didMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + select { + case <-completedFlag: + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive post event complete") + } + + validateState(t, s, thid, findNamespace(AckMsgType), (&completed{}).Name()) +} + +func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) { + for e := range statusCh { + require.Equal(t, DIDConnection, e.ProtocolName) + + prop, ok := e.Properties.(event) + if !ok { + require.Fail(t, "Failed to cast the event properties to service.Event") + } + // Get the connectionID when it's created + if e.Type == service.PreState { + if e.StateID == "requested" { + require.NotNil(t, prop.ConnectionID()) + require.NotNil(t, prop.InvitationID()) + } + } + + if e.Type == service.PostState { + // receive the events + if e.StateID == "completed" { + // validate connectionID received during state transition with original connectionID + require.NotNil(t, prop.ConnectionID()) + require.NotNil(t, prop.InvitationID()) + close(completedFlag) + } + + if e.StateID == "responded" { + // validate connectionID received during state transition with original connectionID + require.NotNil(t, prop.ConnectionID()) + require.NotNil(t, prop.InvitationID()) + close(respondedFlag) + } + } + } +} + +func newKMS(t *testing.T, store storage.Provider) kms.KeyManager { + t.Helper() + + kmsProv := &protocol.MockProvider{ + StoreProvider: store, + CustomLock: &noop.NoLock{}, + } + + customKMS, err := localkms.New("local-lock://primary/test/", kmsProv) + require.NoError(t, err) + + return customKMS +} + +// connection flow with role Invitee. +func TestService_Handle_Invitee(t *testing.T) { + protocolStateStore := mockstorage.NewMockStoreProvider() + store := mockstorage.NewMockStoreProvider() + k := newKMS(t, store) + prov := &protocol.MockProvider{ + StoreProvider: store, + ProtocolStateStoreProvider: protocolStateStore, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: kms.ED25519Type, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + } + + mtp := transport.MediaTypeRFC0019EncryptedEnvelope + + ctx := &context{ + outboundDispatcher: prov.OutboundDispatcher(), + crypto: &tinkcrypto.Crypto{}, + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{mtp}, + } + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDocWithKey(verPubKey)} + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + ctx.connectionRecorder = connRec + + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + + s, err := New(prov) + require.NoError(t, err) + + s.ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: doc.DIDDocument} + actionCh := make(chan service.DIDCommAction, 10) + err = s.RegisterActionEvent(actionCh) + require.NoError(t, err) + + statusCh := make(chan service.StateMsg, 10) + err = s.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + requestedCh := make(chan string) + completedCh := make(chan struct{}) + + go handleMessagesInvitee(statusCh, requestedCh, completedCh) + + go func() { service.AutoExecuteActionEvent(actionCh) }() + + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(verPubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + // Alice receives an invitation from Bob + payloadBytes, err := json.Marshal(invitation) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(payloadBytes) + require.NoError(t, err) + + _, err = s.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + + var connID string + select { + case connID = <-requestedCh: + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive post event requested") + } + + // Alice automatically sends a Request to Bob and is now in REQUESTED state. + connRecord, err := s.connectionRecorder.GetConnectionRecord(connID) + require.NoError(t, err) + require.Equal(t, (&requested{}).Name(), connRecord.State) + require.Equal(t, invitation.ID, connRecord.InvitationID) + require.Equal(t, invitation.RecipientKeys, connRecord.RecipientKeys) + uri, err := connRecord.ServiceEndPoint.URI() + require.NoError(t, err) + require.Equal(t, invitation.ServiceEndpoint, uri) + + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, invitation.ID) + require.NoError(t, err) + + // Bob replies with a Response + payloadBytes, err = json.Marshal( + &Response{ + Type: ResponseMsgType, + ID: randomString(), + Thread: &decorator.Thread{ + ID: connRecord.ThreadID, + }, + ConnectionSignature: connectionSignature, + PleaseAck: &PleaseAck{ + On: []string{PlsAckOnReceipt}, + }, + }, + ) + require.NoError(t, err) + + didMsg, err = service.ParseDIDCommMsgMap(payloadBytes) + require.NoError(t, err) + + _, err = s.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + + // Alice automatically sends an ACK to Bob + // Alice must now be in COMPLETED state + select { + case <-completedCh: + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive post event complete") + } + + validateState(t, s, connRecord.ThreadID, findNamespace(ResponseMsgType), (&completed{}).Name()) +} + +func handleMessagesInvitee(statusCh chan service.StateMsg, requestedCh chan string, completedCh chan struct{}) { + for e := range statusCh { + if e.Type == service.PostState { + // receive the events + if e.StateID == StateIDCompleted { + close(completedCh) + } else if e.StateID == StateIDRequested { + prop, ok := e.Properties.(event) + if !ok { + panic("Failed to cast the event properties to service.Event") + } + + requestedCh <- prop.ConnectionID() + } + } + } +} + +func TestService_Handle_EdgeCases(t *testing.T) { + t.Run("handleInbound - must not transition to same state", func(t *testing.T) { + s, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = s.RegisterActionEvent(make(chan service.DIDCommAction)) + require.NoError(t, err) + + response, err := json.Marshal( + &Response{ + Type: ResponseMsgType, + ID: randomString(), + Thread: &decorator.Thread{ID: randomString()}, + }, + ) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(response) + require.NoError(t, err) + + _, err = s.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.Error(t, err) + require.Contains(t, err.Error(), "handle inbound - next state : invalid state transition: "+ + "null -> responded") + }) + + t.Run("handleInbound - connection record error", func(t *testing.T) { + protocolStateStore := &mockstorage.MockStore{ + Store: make(map[string]mockstorage.DBEntry), + ErrPut: errors.New("db error"), + } + prov := &protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider(protocolStateStore), + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + } + svc, err := New(prov) + require.NoError(t, err) + + err = svc.RegisterActionEvent(make(chan service.DIDCommAction)) + require.NoError(t, err) + + svc.connectionRecorder, err = connection.NewRecorder(prov) + require.NotNil(t, svc.connectionRecorder) + require.NoError(t, err) + + _, err = svc.HandleInbound( + generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), randomString()), + service.EmptyDIDCommContext()) + require.Error(t, err) + require.Contains(t, err.Error(), "save connection record") + }) + + t.Run("handleInbound - no error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = svc.RegisterActionEvent(make(chan service.DIDCommAction)) + require.NoError(t, err) + + protocolStateStore := &mockStore{ + get: func(s string) (bytes []byte, e error) { + return nil, storage.ErrDataNotFound + }, + put: func(s string, bytes []byte, tags ...storage.Tag) error { + if strings.Contains(s, "didex-event-") { + return errors.New("db error") + } + + return nil + }, + } + + svc.connectionRecorder, err = connection.NewRecorder(&protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider(protocolStateStore), + }) + require.NotNil(t, svc.connectionRecorder) + require.NoError(t, err) + + requestBytes, err := json.Marshal(&Request{ + Type: RequestMsgType, + ID: generateRandomID(), + Connection: &Connection{ + DID: "xyz", + }, + Thread: &decorator.Thread{ + PID: randomString(), + }, + }) + require.NoError(t, err) + + // send invite + didMsg, err := service.ParseDIDCommMsgMap(requestBytes) + require.NoError(t, err) + + _, err = svc.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + }) +} + +func TestService_Accept(t *testing.T) { + s := &Service{} + + require.Equal(t, true, s.Accept("https://didcomm.org/connection/1.0/invitation")) + require.Equal(t, true, s.Accept("https://didcomm.org/connection/1.0/request")) + require.Equal(t, true, s.Accept("https://didcomm.org/connection/1.0/response")) + require.Equal(t, true, s.Accept("https://didcomm.org/connection/1.0/ack")) + require.Equal(t, false, s.Accept("unsupported msg type")) +} + +func TestService_CurrentState(t *testing.T) { + t.Run("null state if not found in store", func(t *testing.T) { + connRec, err := connection.NewRecorder(&protocol.MockProvider{ + StoreProvider: mockstorage.NewCustomMockStoreProvider(&mockStore{ + get: func(string) ([]byte, error) { return nil, storage.ErrDataNotFound }, + }), + }) + require.NotNil(t, connRec) + require.NoError(t, err) + + svc := &Service{ + connectionRecorder: connRec, + } + thid, err := connection.CreateNamespaceKey(theirNSPrefix, "ignored") + require.NoError(t, err) + s, err := svc.currentState(thid) + require.NoError(t, err) + require.Equal(t, (&null{}).Name(), s.Name()) + }) + + t.Run("returns state from store", func(t *testing.T) { + expected := &requested{} + connRecord, err := json.Marshal(&connection.Record{State: expected.Name()}) + require.NoError(t, err) + + connRec, err := connection.NewRecorder(&protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider(&mockStore{ + get: func(string) ([]byte, error) { return connRecord, nil }, + }), + }) + require.NotNil(t, connRec) + require.NoError(t, err) + + svc := &Service{ + connectionRecorder: connRec, + } + thid, err := connection.CreateNamespaceKey(theirNSPrefix, "ignored") + require.NoError(t, err) + actual, err := svc.currentState(thid) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + + t.Run("forwards generic error from store", func(t *testing.T) { + connRec, err := connection.NewRecorder(&protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider(&mockStore{ + get: func(string) ([]byte, error) { + return nil, errors.New("test") + }, + }), + }) + require.NotNil(t, connRec) + require.NoError(t, err) + + svc := &Service{connectionRecorder: connRec} + thid, err := connection.CreateNamespaceKey(theirNSPrefix, "ignored") + require.NoError(t, err) + _, err = svc.currentState(thid) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot fetch state from store") + }) +} + +func TestService_Update(t *testing.T) { + s := &requested{} + data := make(map[string][]byte) + connRecord := &connection.Record{ + ThreadID: "123", ConnectionID: "123456", State: s.Name(), + Namespace: findNamespace(RequestMsgType), + } + bytes, err := json.Marshal(connRecord) + require.NoError(t, err) + + connRec, err := connection.NewRecorder(&protocol.MockProvider{ + StoreProvider: mockstorage.NewCustomMockStoreProvider(&mockStore{ + put: func(k string, v []byte, tags ...storage.Tag) error { + data[k] = bytes + return nil + }, + get: func(k string) ([]byte, error) { + return bytes, nil + }, + }), + }) + require.NotNil(t, connRec) + require.NoError(t, err) + + svc := &Service{connectionRecorder: connRec} + + require.NoError(t, svc.update(RequestMsgType, connRecord)) + + cr := &connection.Record{} + err = json.Unmarshal(bytes, cr) + require.NoError(t, err) + require.Equal(t, cr, connRecord) +} + +func TestCreateConnection(t *testing.T) { + store := mockstorage.NewMockStoreProvider() + k := newKMS(t, store) + + t.Run("create connection", func(t *testing.T) { + theirDID := newPeerDID(t, k) + record := &connection.Record{ + ConnectionID: uuid.New().String(), + State: StateIDCompleted, + ThreadID: uuid.New().String(), + ParentThreadID: uuid.New().String(), + TheirLabel: uuid.New().String(), + TheirDID: theirDID.ID, + MyDID: newPeerDID(t, k).ID, + ServiceEndPoint: commonmodel.NewDIDCommV1Endpoint("http://example.com"), + RecipientKeys: []string{"testkeys"}, + InvitationID: uuid.New().String(), + Namespace: myNSPrefix, + } + storedInVDR := false + storageProvider := &mockprovider.Provider{ + StorageProviderValue: mockstorage.NewMockStoreProvider(), + ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), + } + provider := &mockprovider.Provider{ + KMSValue: &mockkms.KeyManager{}, + StorageProviderValue: storageProvider.StorageProvider(), + ProtocolStateStorageProviderValue: storageProvider.ProtocolStateStorageProvider(), + VDRegistryValue: &mockvdr.MockVDRegistry{ + CreateFunc: func(method string, result *did.Doc, _ ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { + storedInVDR = true + require.Equal(t, theirDID, result) + + return nil, nil + }, + }, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + DIDConnectionStoreValue: &mockConnectionStore{}, + } + s, err := New(provider) + require.NoError(t, err) + + err = s.CreateConnection(record, theirDID) + require.True(t, storedInVDR) + require.NoError(t, err) + + connRec, err := connection.NewRecorder(provider) + require.NoError(t, err) + result, err := connRec.GetConnectionRecord(record.ConnectionID) + require.NoError(t, err) + require.Equal(t, record, result) + }) + + t.Run("wraps vdr registry error", func(t *testing.T) { + expected := errors.New("test") + s, err := New(&mockprovider.Provider{ + KMSValue: &mockkms.KeyManager{}, + StorageProviderValue: mockstorage.NewMockStoreProvider(), + ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), + VDRegistryValue: &mockvdr.MockVDRegistry{ + CreateFunc: func(s string, doc *did.Doc, option ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { + return nil, expected + }, + }, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = s.CreateConnection(&connection.Record{}, newPeerDID(t, k)) + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) + + t.Run("wraps connection store error", func(t *testing.T) { + expected := errors.New("test") + s, err := New(&mockprovider.Provider{ + KMSValue: &mockkms.KeyManager{}, + StorageProviderValue: &mockstorage.MockStoreProvider{ + Store: &mockstorage.MockStore{ErrPut: expected}, + }, + ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), + VDRegistryValue: &mockvdr.MockVDRegistry{}, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + DIDConnectionStoreValue: &mockConnectionStore{}, + }) + require.NoError(t, err) + + err = s.CreateConnection(&connection.Record{ + State: StateIDCompleted, + }, newPeerDID(t, k)) + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) + + t.Run("wraps did.ConnectionStore.SaveDIDFromDoc error", func(t *testing.T) { + expected := errors.New("test") + s, err := New(&mockprovider.Provider{ + KMSValue: &mockkms.KeyManager{}, + KeyTypeValue: kms.ECDSAP384TypeIEEEP1363, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + StorageProviderValue: mockstorage.NewMockStoreProvider(), + ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), + VDRegistryValue: &mockvdr.MockVDRegistry{}, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + DIDConnectionStoreValue: &mockConnectionStore{ + saveDIDFromDocErr: expected, + }, + }) + require.NoError(t, err) + + err = s.CreateConnection(&connection.Record{ + State: StateIDCompleted, + }, newPeerDID(t, k)) + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) + + t.Run("wraps did.ConnectionStore.SaveDIDByResolving error", func(t *testing.T) { + expected := errors.New("test") + s, err := New(&mockprovider.Provider{ + KMSValue: &mockkms.KeyManager{}, + KeyTypeValue: kms.ED25519Type, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + StorageProviderValue: mockstorage.NewMockStoreProvider(), + ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), + VDRegistryValue: &mockvdr.MockVDRegistry{}, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + DIDConnectionStoreValue: &mockConnectionStore{ + saveDIDByResolvingErr: expected, + }, + }) + require.NoError(t, err) + + err = s.CreateConnection(&connection.Record{ + State: StateIDCompleted, + }, newPeerDID(t, k)) + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) +} + +type mockStore struct { + put func(string, []byte, ...storage.Tag) error + get func(string) ([]byte, error) + delete func(string) error +} + +// Put stores the key and the record. +func (m *mockStore) Put(k string, v []byte, tags ...storage.Tag) error { + return m.put(k, v, tags...) +} + +// Get fetches the record based on key. +func (m *mockStore) Get(k string) ([]byte, error) { + return m.get(k) +} + +func (m *mockStore) GetTags(key string) ([]storage.Tag, error) { + panic("implement me") +} + +func (m *mockStore) GetBulk(keys ...string) ([][]byte, error) { + panic("implement me") +} + +func (m *mockStore) Query(expression string, options ...storage.QueryOption) (storage.Iterator, error) { + panic("implement me") +} + +// Delete the record based on key. +func (m *mockStore) Delete(k string) error { + return m.delete(k) +} + +func (m *mockStore) Batch(operations []storage.Operation) error { + panic("implement me") +} + +func (m *mockStore) Flush() error { + panic("implement me") +} + +func (m *mockStore) Close() error { + panic("implement me") +} + +func TestEventsSuccess(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + go func() { service.AutoExecuteActionEvent(actionCh) }() + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for e := range statusCh { + if e.Type == service.PostState && e.StateID == StateIDRequested { + done <- struct{}{} + } + } + }() + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + id := randomString() + invite, err := json.Marshal( + &Invitation{ + Type: InvitationMsgType, + ID: id, + Label: "test", + RecipientKeys: []string{base58.Encode(pubKey)}, + }, + ) + require.NoError(t, err) + + // send invite + didMsg, err := service.ParseDIDCommMsgMap(invite) + require.NoError(t, err) + + _, err = svc.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } +} + +func TestContinueWithPublicDID(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc := mockdiddoc.GetMockDIDDoc(t, false) + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + go func() { continueWithPublicDID(actionCh, didDoc.ID) }() + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + id := randomString() + invite, err := json.Marshal( + &Invitation{ + Type: InvitationMsgType, + ID: id, + Label: "test", + RecipientKeys: []string{base58.Encode(pubKey)}, + }, + ) + require.NoError(t, err) + + // send invite + didMsg, err := service.ParseDIDCommMsgMap(invite) + require.NoError(t, err) + + _, err = svc.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) +} + +func continueWithPublicDID(ch chan service.DIDCommAction, pubDID string) { + for msg := range ch { + msg.Continue(&testOptions{publicDID: pubDID}) + } +} + +type testOptions struct { + publicDID string + label string + routerConnections []string +} + +func (to *testOptions) PublicDID() string { + return to.publicDID +} + +func (to *testOptions) Label() string { + return to.label +} + +func (to *testOptions) RouterConnections() []string { + return to.routerConnections +} + +func TestEventsUserError(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for { + select { + case e := <-actionCh: + e.Stop(errors.New("invalid id")) + case e := <-statusCh: + if e.Type == service.PostState { + done <- struct{}{} + } + } + } + }() + + id := randomString() + connRec := &connection.Record{ + ConnectionID: randomString(), ThreadID: id, + Namespace: findNamespace(RequestMsgType), State: (&null{}).Name(), + } + + err = svc.connectionRecorder.SaveConnectionRecordWithMappings(connRec) + require.NoError(t, err) + + _, err = svc.HandleInbound( + generateRequestMsgPayload(t, &protocol.MockProvider{}, id, randomString()), + service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } +} + +func TestEventStoreError(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + go func() { + for e := range actionCh { + e.Continue = func(args interface{}) { + svc.processCallback(&message{Msg: service.NewDIDCommMsgMap(struct{}{})}) + } + e.Continue(&service.Empty{}) + } + }() + + _, err = svc.HandleInbound( + generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), randomString()), + service.EmptyDIDCommContext()) + require.NoError(t, err) +} + +func TestEventProcessCallback(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + msg := &message{ + ThreadID: threadIDValue, + Msg: service.NewDIDCommMsgMap(model.Ack{Type: AckMsgType}), + } + + err = svc.handleWithoutAction(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid state name: invalid state name ") +} + +func validateState(t *testing.T, svc *Service, id, namespace, expected string) { + nsThid, err := connection.CreateNamespaceKey(namespace, id) + require.NoError(t, err) + s, err := svc.currentState(nsThid) + require.NoError(t, err) + require.Equal(t, expected, s.Name()) +} + +func TestServiceErrors(t *testing.T) { + requestBytes, err := json.Marshal( + &Request{ + Type: ResponseMsgType, + ID: randomString(), + Label: "test", + }, + ) + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap(requestBytes) + require.NoError(t, err) + + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + // fetch current state error + mockStore := &mockStore{get: func(s string) (bytes []byte, e error) { + return nil, errors.New("error") + }} + + prov := &protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider( + mockStore, + ), + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + } + svc, err = New(prov) + require.NoError(t, err) + + payload := generateRequestMsgPayload(t, prov, randomString(), "") + _, err = svc.HandleInbound(payload, service.EmptyDIDCommContext()) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot fetch state from store") + + svc, err = New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + // invalid message type + msg["@type"] = "invalid" + svc.connectionRecorder, err = connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + + _, err = svc.HandleInbound(msg, service.EmptyDIDCommContext()) + require.Error(t, err) + require.Contains(t, err.Error(), "unrecognized msgType: invalid") + + // test handle - invalid state name + msg["@type"] = ResponseMsgType + m := &message{Msg: msg, ThreadID: randomString()} + err = svc.handleWithoutAction(m) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid state name:") + + // invalid state name + m.NextStateName = StateIDInvited + m.ConnRecord = &connection.Record{ConnectionID: "abc"} + err = svc.handleWithoutAction(m) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to execute state 'invited':") +} + +func TestHandleOutbound(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + _, err = svc.HandleOutbound(service.DIDCommMsgMap{}, "", "") + require.Error(t, err) + require.Contains(t, err.Error(), "not implemented") +} + +func TestConnectionRecord(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + conn, err := svc.connectionRecord(generateRequestMsgPayload(t, &protocol.MockProvider{}, + randomString(), randomString())) + require.NoError(t, err) + require.NotNil(t, conn) + + // invalid type + requestBytes, err := json.Marshal(&Request{ + Type: "invalid-type", + }) + require.NoError(t, err) + msg, err := service.ParseDIDCommMsgMap(requestBytes) + require.NoError(t, err) + + _, err = svc.connectionRecord(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid message type") +} + +func TestInvitationRecord(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + invitationBytes, err := json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: "id", + RecipientKeys: []string{base58.Encode(verPubKey)}, + }) + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap(invitationBytes) + require.NoError(t, err) + + conn, err := svc.invitationMsgRecord(msg) + require.NoError(t, err) + require.NotNil(t, conn) + + // invalid thread id + invitationBytes, err = json.Marshal(&Invitation{ + Type: "invalid-type", + }) + require.NoError(t, err) + msg, err = service.ParseDIDCommMsgMap(invitationBytes) + require.NoError(t, err) + + _, err = svc.invitationMsgRecord(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "threadID not found") + + // db error + svc, err = New(&protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider(&mockstorage.MockStore{ + Store: make(map[string]mockstorage.DBEntry), ErrPut: errors.New("db error"), + }), + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NotNil(t, svc.connectionRecorder) + require.NoError(t, err) + + invitationBytes, err = json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: "id", + RecipientKeys: []string{base58.Encode(verPubKey)}, + }) + require.NoError(t, err) + + msg, err = service.ParseDIDCommMsgMap(invitationBytes) + require.NoError(t, err) + + _, err = svc.invitationMsgRecord(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "save connection record") +} + +func TestRequestRecord(t *testing.T) { + t.Run("returns connection record", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + didcommMsg := generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), uuid.New().String()) + require.NotEmpty(t, didcommMsg.ParentThreadID()) + conn, err := svc.requestMsgRecord(didcommMsg) + require.NoError(t, err) + require.NotNil(t, conn) + require.Equal(t, didcommMsg.ParentThreadID(), conn.InvitationID) + }) + + t.Run("returns connection record from request without connection DIDDoc field", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + didcommMsg := generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), uuid.New().String()) + require.NotEmpty(t, didcommMsg.ParentThreadID()) + delete(didcommMsg, "connection") + + _, err = svc.requestMsgRecord(didcommMsg) + require.Error(t, err) + require.Contains(t, err.Error(), "missing connection field") + }) + + t.Run("fails on db error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ProtocolStateStoreProvider: mockstorage.NewCustomMockStoreProvider( + &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry), ErrPut: errors.New("db error")}, + ), + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NotNil(t, svc.connectionRecorder) + require.NoError(t, err) + + _, err = svc.requestMsgRecord(generateRequestMsgPayload(t, &protocol.MockProvider{}, + randomString(), uuid.New().String())) + require.Error(t, err) + require.Contains(t, err.Error(), "save connection record") + }) + + t.Run("fails if parent thread ID is missing", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + parentThreadID := "" + didcommMsg := generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), parentThreadID) + require.Empty(t, didcommMsg.ParentThreadID()) + _, err = svc.requestMsgRecord(didcommMsg) + require.Error(t, err) + }) +} + +func TestAcceptExchangeRequest(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + + svc, err := New(&protocol.MockProvider{ + StoreProvider: sp, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(verPubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err = svc.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + go func() { + for e := range actionCh { + prop, ok := e.Properties.(event) + require.True(t, ok, "Failed to cast the event properties to service.Event") + require.NoError(t, svc.AcceptConnectionRequest(prop.ConnectionID(), "", "", nil)) + } + }() + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for e := range statusCh { + if e.Type == service.PostState && e.StateID == StateIDResponded { + done <- struct{}{} + } + } + }() + + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{ + StoreProvider: mockstorage.NewMockStoreProvider(), + }, randomString(), invitation.ID), service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } +} + +func TestAcceptExchangeRequestWithPublicDID(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + svc, err := New(&protocol.MockProvider{ + StoreProvider: sp, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + const publicDIDMethod = "sidetree" + publicDID := fmt.Sprintf("did:%s:123456", publicDIDMethod) + doc, err := svc.ctx.vdRegistry.Create(publicDIDMethod, nil) + require.NoError(t, err) + + svc.ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: doc.DIDDocument} + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(verPubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err = svc.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + go func() { + for e := range actionCh { + prop, ok := e.Properties.(event) + require.True(t, ok, "Failed to cast the event properties to service.Event") + require.NoError(t, svc.AcceptConnectionRequest(prop.ConnectionID(), publicDID, "sample-label", nil)) + } + }() + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for e := range statusCh { + if e.Type == service.PostState && e.StateID == StateIDResponded { + done <- struct{}{} + } + } + }() + + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{ + StoreProvider: mockstorage.NewMockStoreProvider(), + }, randomString(), invitation.ID), service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } +} + +func TestAcceptInvitation(t *testing.T) { + t.Run("accept invitation - success", func(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + svc, err := New(&protocol.MockProvider{ + StoreProvider: sp, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + go func() { + for e := range actionCh { + _, ok := e.Properties.(event) + require.True(t, ok, "Failed to cast the event properties to service.Event") + + // ignore action event + } + }() + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for e := range statusCh { + prop, ok := e.Properties.(event) + if !ok { + require.Fail(t, "Failed to cast the event properties to service.Event") + } + + if e.Type == service.PostState && e.StateID == StateIDInvited { + require.NoError(t, svc.AcceptInvitation(prop.ConnectionID(), "", "", nil)) + } + + if e.Type == service.PostState && e.StateID == StateIDRequested { + done <- struct{}{} + } + } + }() + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitationBytes, err := json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: generateRandomID(), + RecipientKeys: []string{base58.Encode(verPubKey)}, + }) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(invitationBytes) + require.NoError(t, err) + + _, err = svc.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } + }) + + t.Run("accept invitation - error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = svc.AcceptInvitation(generateRandomID(), "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept connection invitation : get protocol state data : data not found") + }) + + t.Run("accept invitation - state error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + id := generateRandomID() + connRecord := &connection.Record{ + ConnectionID: id, + State: StateIDRequested, + } + err = svc.connectionRecorder.SaveConnectionRecord(connRecord) + require.NoError(t, err) + + err = svc.storeEventProtocolStateData(&message{ConnRecord: connRecord}) + require.NoError(t, err) + + err = svc.AcceptInvitation(id, "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "current state (requested) is different from expected state (invited)") + }) + + t.Run("accept invitation - no connection record error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + id := generateRandomID() + connRecord := &connection.Record{ + ConnectionID: id, + State: StateIDRequested, + } + + err = svc.storeEventProtocolStateData(&message{ConnRecord: connRecord}) + require.NoError(t, err) + + err = svc.AcceptInvitation(id, "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept connection invitation : data not found") + }) +} + +func TestAcceptInvitationWithPublicDID(t *testing.T) { + t.Run("accept invitation with public DID - success", func(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + svc, err := New(&protocol.MockProvider{ + StoreProvider: sp, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + }) + require.NoError(t, err) + + const publicDIDMethod = "sidetree" + publicDID := fmt.Sprintf("did:%s:123456", publicDIDMethod) + doc, err := svc.ctx.vdRegistry.Create(publicDIDMethod, nil) + require.NoError(t, err) + svc.ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: doc.DIDDocument} + + actionCh := make(chan service.DIDCommAction, 10) + err = svc.RegisterActionEvent(actionCh) + require.NoError(t, err) + + go func() { + for e := range actionCh { + _, ok := e.Properties.(event) + require.True(t, ok, "Failed to cast the event properties to service.Event") + + // ignore action event + } + }() + + statusCh := make(chan service.StateMsg, 10) + err = svc.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + done := make(chan struct{}) + + go func() { + for e := range statusCh { + prop, ok := e.Properties.(event) + if !ok { + require.Fail(t, "Failed to cast the event properties to service.Event") + } + + if e.Type == service.PostState && e.StateID == StateIDInvited { + require.NoError(t, svc.AcceptInvitation(prop.ConnectionID(), publicDID, "sample-label", nil)) + } + + if e.Type == service.PostState && e.StateID == StateIDRequested { + done <- struct{}{} + } + } + }() + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitationBytes, err := json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: generateRandomID(), + RecipientKeys: []string{base58.Encode(verPubKey)}, + }) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(invitationBytes) + require.NoError(t, err) + + _, err = svc.HandleInbound(didMsg, service.EmptyDIDCommContext()) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(5 * time.Second): + require.Fail(t, "tests are not validated") + } + }) + + t.Run("accept invitation - error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = svc.AcceptInvitation(generateRandomID(), "sample-public-did", "sample-label", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept connection invitation : get protocol state data : data not found") + }) + + t.Run("accept invitation - state error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + id := generateRandomID() + connRecord := &connection.Record{ + ConnectionID: id, + State: StateIDRequested, + } + err = svc.connectionRecorder.SaveConnectionRecord(connRecord) + require.NoError(t, err) + + err = svc.storeEventProtocolStateData(&message{ConnRecord: connRecord}) + require.NoError(t, err) + + err = svc.AcceptInvitation(id, "sample-public-did", "sample-label", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "current state (requested) is different from expected state (invited)") + }) + + t.Run("accept invitation - no connection record error", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + id := generateRandomID() + connRecord := &connection.Record{ + ConnectionID: id, + State: StateIDRequested, + } + + err = svc.storeEventProtocolStateData(&message{ConnRecord: connRecord}) + require.NoError(t, err) + + err = svc.AcceptInvitation(id, "sample-public-did", "sample-label", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept connection invitation : data not found") + }) +} + +func TestEventProtocolStateData(t *testing.T) { + t.Run("event protocol state data - success", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + connID := generateRandomID() + + msg := &message{ + ConnRecord: &connection.Record{ConnectionID: connID}, + } + err = svc.storeEventProtocolStateData(msg) + require.NoError(t, err) + + retrievedMsg, err := svc.getEventProtocolStateData(connID) + require.NoError(t, err) + require.Equal(t, msg, retrievedMsg) + }) + + t.Run("event protocol state data - data not found", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + err = svc.AcceptConnectionRequest(generateRandomID(), "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept exchange request : get protocol state data : data not found") + + err = svc.AcceptConnectionRequest(generateRandomID(), "sample-public-did", "sample-label", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "accept exchange request : get protocol state data : data not found") + }) + + t.Run("event protocol state data - invalid data", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + connID := generateRandomID() + + err = svc.connectionRecorder.SaveEvent(connID, []byte("invalid data")) + require.NoError(t, err) + + _, err = svc.getEventProtocolStateData(connID) + require.Error(t, err) + require.Contains(t, err.Error(), "get protocol state data : invalid character") + }) +} + +func TestNextState(t *testing.T) { + t.Run("empty thread ID", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + _, err = svc.nextState(RequestMsgType, "") + require.EqualError(t, err, "unable to compute hash, empty bytes") + }) + + t.Run("valid inputs", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + s, errState := svc.nextState(RequestMsgType, generateRandomID()) + require.NoError(t, errState) + require.Equal(t, StateIDRequested, s.Name()) + }) +} + +func TestFetchConnectionRecord(t *testing.T) { + t.Run("fetch connection record - invalid payload", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + _, err = svc.fetchConnectionRecord("", service.DIDCommMsgMap{"~thread": map[int]int{1: 1}}) + require.Contains(t, fmt.Sprintf("%v", err), `'~thread' needs a map with string keys`) + }) + + t.Run("fetch connection record - no thread id", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + _, err = svc.fetchConnectionRecord(theirNSPrefix, toDIDCommMsg(t, &Request{ + Type: ResponseMsgType, + ID: generateRandomID(), + })) + require.Error(t, err) + require.Contains(t, err.Error(), "unable to compute hash, empty bytes") + }) + + t.Run("fetch connection record - valid input", func(t *testing.T) { + svc, err := New(&protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + }) + require.NoError(t, err) + + _, err = svc.fetchConnectionRecord(theirNSPrefix, toDIDCommMsg(t, &Response{ + Type: ResponseMsgType, + ID: generateRandomID(), + Thread: &decorator.Thread{ID: generateRandomID()}, + })) + require.Error(t, err) + require.Contains(t, err.Error(), "get connectionID by namespaced threadID: data not found") + }) +} + +func generateRequestMsgPayload(t *testing.T, prov provider, id, invitationID string) service.DIDCommMsgMap { + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + ctx := context{ + outboundDispatcher: prov.OutboundDispatcher(), + vdRegistry: &mockvdr.MockVDRegistry{CreateValue: mockdiddoc.GetMockDIDDoc(t, false)}, + connectionRecorder: connRec, + } + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + + requestBytes, err := json.Marshal(&Request{ + Type: RequestMsgType, + ID: id, + Thread: &decorator.Thread{ + PID: invitationID, + }, + Connection: &Connection{ + DID: doc.DIDDocument.ID, + DIDDoc: doc.DIDDocument, + }, + }) + require.NoError(t, err) + + didMsg, err := service.ParseDIDCommMsgMap(requestBytes) + require.NoError(t, err) + + return didMsg +} + +func TestService_CreateImplicitInvitation(t *testing.T) { + t.Run("success", func(t *testing.T) { + routeSvc := &mockroute.MockMediatorSvc{} + prov := &protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: routeSvc, + }, + } + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + outboundDispatcher: prov.OutboundDispatcher(), + routeSvc: routeSvc, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + newDIDDoc := createDIDDocWithKey(verPubKey) + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + didConnStore, err := didstore.NewConnectionStore(prov) + require.NoError(t, err) + require.NotNil(t, didConnStore) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: newDIDDoc} + ctx.connectionRecorder = connRec + ctx.connectionStore = didConnStore + + s, err := New(prov) + require.NoError(t, err) + + s.ctx = ctx + connID, err := s.CreateImplicitInvitation("label", newDIDDoc.ID, "", "", nil) + require.NoError(t, err) + require.NotEmpty(t, connID) + }) + + t.Run("error during did resolution", func(t *testing.T) { + routeSvc := &mockroute.MockMediatorSvc{} + prov := &protocol.MockProvider{ + ServiceMap: map[string]interface{}{ + mediator.Coordination: routeSvc, + }, + } + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + outboundDispatcher: prov.OutboundDispatcher(), + routeSvc: routeSvc, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + newDIDDoc := createDIDDocWithKey(verPubKey) + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + didConnStore, err := didstore.NewConnectionStore(prov) + require.NoError(t, err) + require.NotNil(t, didConnStore) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveErr: errors.New("resolve error")} + ctx.connectionRecorder = connRec + ctx.connectionStore = didConnStore + + s, err := New(prov) + require.NoError(t, err) + s.ctx = ctx + + connID, err := s.CreateImplicitInvitation("label", newDIDDoc.ID, "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "resolve error") + require.Empty(t, connID) + }) + + t.Run("error during saving connection", func(t *testing.T) { + sp := mockstorage.NewMockStoreProvider() + k := newKMS(t, sp) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + routeSvc := &mockroute.MockMediatorSvc{} + protocolStateStore := mockstorage.NewMockStoreProvider() + protocolStateStore.Store.ErrPut = errors.New("store put error") + prov := &protocol.MockProvider{ + ProtocolStateStoreProvider: protocolStateStore, + ServiceMap: map[string]interface{}{ + mediator.Coordination: routeSvc, + }, + KeyTypeValue: ctx.keyType, + KeyAgreementTypeValue: ctx.keyAgreementType, + } + + ctx.outboundDispatcher = prov.OutboundDispatcher() + ctx.routeSvc = routeSvc + + _, verPubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + newDIDDoc := createDIDDocWithKey(verPubKey) + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + didConnStore, err := didstore.NewConnectionStore(prov) + require.NoError(t, err) + require.NotNil(t, didConnStore) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: newDIDDoc} + ctx.connectionRecorder = connRec + ctx.connectionStore = didConnStore + + s, err := New(prov) + require.NoError(t, err) + s.ctx = ctx + + connID, err := s.CreateImplicitInvitation("label", newDIDDoc.ID, "", "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "store put error") + require.Empty(t, connID) + }) +} + +func testProvider() *protocol.MockProvider { + return &protocol.MockProvider{ + StoreProvider: mockstorage.NewMockStoreProvider(), + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + KeyTypeValue: kms.ED25519Type, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + } +} + +func newPeerDID(t *testing.T, k kms.KeyManager) *did.Doc { + kid, pubKey, err := k.CreateAndExportPubKeyBytes(kms.ED25519) + require.NoError(t, err) + + key := did.VerificationMethod{ + ID: kid, + Type: "Ed25519VerificationKey2018", + Controller: "", + Value: pubKey, + } + doc, err := peer.NewDoc( + []did.VerificationMethod{key}, + did.WithAuthentication([]did.Verification{{ + VerificationMethod: key, + Relationship: 0, + Embedded: true, + }}), + did.WithService([]did.Service{{ + ID: "didcomm", + Type: "did-communication", + Priority: 0, + RecipientKeys: []string{base58.Encode(pubKey)}, + ServiceEndpoint: commonmodel.NewDIDCommV1Endpoint("http://example.com"), + }}), + ) + require.NoError(t, err) + + return doc +} + +type mockConnectionStore struct { + saveDIDByResolvingErr error + saveDIDFromDocErr error +} + +// GetDID returns DID associated with key. +func (m *mockConnectionStore) GetDID(string) (string, error) { + return "", nil +} + +// SaveDID saves DID to the underlying storage. +func (m *mockConnectionStore) SaveDID(string, ...string) error { + return nil +} + +// SaveDIDFromDoc saves DID from did.Doc to the underlying storage. +func (m *mockConnectionStore) SaveDIDFromDoc(*did.Doc) error { + return m.saveDIDFromDocErr +} + +// SaveDIDByResolving saves DID resolved by VDR to the underlying storage. +func (m *mockConnectionStore) SaveDIDByResolving(string, ...string) error { + return m.saveDIDByResolvingErr +} + +func randomString() string { + u := uuid.New() + return u.String() +} diff --git a/pkg/didcomm/protocol/didconnection/states.go b/pkg/didcomm/protocol/didconnection/states.go new file mode 100644 index 0000000000..edde13ced2 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/states.go @@ -0,0 +1,808 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/btcsuite/btcutil/base58" + "github.com/google/uuid" + + "github.com/hyperledger/aries-framework-go/pkg/common/model" + model2 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" + "github.com/hyperledger/aries-framework-go/pkg/kms" + "github.com/hyperledger/aries-framework-go/pkg/kms/localkms" + connectionstore "github.com/hyperledger/aries-framework-go/pkg/store/connection" +) + +const ( + stateNameNoop = "noop" + stateNameNull = "null" + // StateIDInvited marks the invited phase of the connection protocol. + StateIDInvited = "invited" + // StateIDRequested marks the requested phase of the connection protocol. + StateIDRequested = "requested" + // StateIDResponded marks the responded phase of the connection protocol. + StateIDResponded = "responded" + // StateIDCompleted marks the completed phase of the connection protocol. + StateIDCompleted = "completed" + didCommServiceType = "did-communication" + ackStatusOK = "ok" + // legacyDIDCommServiceType for aca-py interop. + legacyDIDCommServiceType = "IndyAgent" + ed25519VerificationKey2018 = "Ed25519VerificationKey2018" + didMethod = "peer" + x25519KeyAgreementKey2019 = "X25519KeyAgreementKey2019" + signatureType = "https://didcomm.org/signature/1.0/ed25519Sha512_single" + // PlsAckOnReceipt ack type that says, "Please send me an ack as soon as you receive this message.". + PlsAckOnReceipt = "RECEIPT" + timestampLength = 8 +) + +// state action for network call. +type stateAction func() error + +// The connection protocol's state. +type state interface { + // Name of this state. + Name() string + + // CanTransitionTo Whether this state allows transitioning into the next state. + CanTransitionTo(next state) bool + + // ExecuteInbound this state, returning a followup state to be immediately executed as well. + // The 'noOp' state should be returned if the state has no followup. + ExecuteInbound(msg *stateMachineMsg, thid string, ctx *context) (connRecord *connectionstore.Record, + state state, action stateAction, err error) +} + +// Returns the state towards which the protocol will transition to if the msgType is processed. +func stateFromMsgType(msgType string) (state, error) { + switch msgType { + case InvitationMsgType: + return &invited{}, nil + case RequestMsgType: + return &requested{}, nil + case ResponseMsgType: + return &responded{}, nil + case AckMsgType: + return &completed{}, nil + default: + return nil, fmt.Errorf("unrecognized msgType: %s", msgType) + } +} + +// Returns the state representing the name. +func stateFromName(name string) (state, error) { + switch name { + case stateNameNoop: + return &noOp{}, nil + case stateNameNull: + return &null{}, nil + case StateIDInvited: + return &invited{}, nil + case StateIDRequested: + return &requested{}, nil + case StateIDResponded: + return &responded{}, nil + case StateIDCompleted: + return &completed{}, nil + default: + return nil, fmt.Errorf("invalid state name %s", name) + } +} + +type noOp struct{} + +func (s *noOp) Name() string { + return stateNameNoop +} + +func (s *noOp) CanTransitionTo(_ state) bool { + return false +} + +func (s *noOp) ExecuteInbound(_ *stateMachineMsg, _ string, _ *context) (*connectionstore.Record, + state, stateAction, error) { + return nil, nil, nil, errors.New("cannot execute no-op") +} + +// null state. +type null struct{} + +func (s *null) Name() string { + return stateNameNull +} + +func (s *null) CanTransitionTo(next state) bool { + return StateIDInvited == next.Name() || StateIDRequested == next.Name() +} + +func (s *null) ExecuteInbound(_ *stateMachineMsg, _ string, _ *context) (*connectionstore.Record, + state, stateAction, error) { + return &connectionstore.Record{}, &noOp{}, nil, nil +} + +// invited state. +type invited struct{} + +func (s *invited) Name() string { + return StateIDInvited +} + +func (s *invited) CanTransitionTo(next state) bool { + return StateIDRequested == next.Name() +} + +func (s *invited) ExecuteInbound(msg *stateMachineMsg, _ string, _ *context) (*connectionstore.Record, + state, stateAction, error) { + if msg.Type() != InvitationMsgType { + return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name()) + } + + return msg.connRecord, &requested{}, func() error { return nil }, nil +} + +// requested state. +type requested struct{} + +func (s *requested) Name() string { + return StateIDRequested +} + +func (s *requested) CanTransitionTo(next state) bool { + return StateIDResponded == next.Name() +} + +func (s *requested) ExecuteInbound(msg *stateMachineMsg, thid string, ctx *context) (*connectionstore.Record, + state, stateAction, error) { + switch msg.Type() { + case InvitationMsgType: + invitation := &Invitation{} + + err := msg.Decode(invitation) + if err != nil { + return nil, nil, nil, fmt.Errorf("JSON unmarshalling of invitation: %w", err) + } + + action, connRecord, err := ctx.handleInboundInvitation(invitation, thid, msg.options, msg.connRecord) + if err != nil { + return nil, nil, nil, fmt.Errorf("handle inbound invitation: %w", err) + } + + return connRecord, &noOp{}, action, nil + case RequestMsgType: + return msg.connRecord, &responded{}, func() error { return nil }, nil + default: + return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name()) + } +} + +// responded state. +type responded struct{} + +func (s *responded) Name() string { + return StateIDResponded +} + +func (s *responded) CanTransitionTo(next state) bool { + return StateIDCompleted == next.Name() +} + +func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) (*connectionstore.Record, + state, stateAction, error) { + switch msg.Type() { + case RequestMsgType: + request := &Request{} + + err := msg.Decode(request) + if err != nil { + return nil, nil, nil, fmt.Errorf("JSON unmarshalling of request: %w", err) + } + + action, connRecord, err := ctx.handleInboundRequest(request, msg.options, msg.connRecord) + if err != nil { + return nil, nil, nil, fmt.Errorf("handle inbound request: %w", err) + } + + return connRecord, &noOp{}, action, nil + case ResponseMsgType: + return msg.connRecord, &completed{}, func() error { return nil }, nil + default: + return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name()) + } +} + +// completed state. +type completed struct{} + +func (s *completed) Name() string { + return StateIDCompleted +} + +func (s *completed) CanTransitionTo(_ state) bool { + return false +} + +func (s *completed) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) (*connectionstore.Record, + state, stateAction, error) { + switch msg.Type() { + case ResponseMsgType: + response := &Response{} + + err := msg.Decode(response) + if err != nil { + return nil, nil, nil, fmt.Errorf("JSON unmarshalling of response: %w", err) + } + + action, connRecord, err := ctx.handleInboundResponse(response) + if err != nil { + return nil, nil, nil, fmt.Errorf("handle inbound response: %w", err) + } + + return connRecord, &noOp{}, action, nil + case AckMsgType: + action := func() error { return nil } + return msg.connRecord, &noOp{}, action, nil + default: + return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name()) + } +} + +func (ctx *context) handleInboundInvitation(invitation *Invitation, thid string, options *options, + connRec *connectionstore.Record) (stateAction, *connectionstore.Record, error) { + // create a destination from invitation + destination, err := ctx.getDestination(invitation) + if err != nil { + return nil, nil, err + } + + pid := invitation.ID + if connRec.Implicit { + pid = invitation.DID + } + + return ctx.createConnectionRequest(destination, getLabel(options), thid, pid, options, connRec) +} + +func (ctx *context) createConnectionRequest(destination *service.Destination, label, thid, pthid string, + options *options, connRec *connectionstore.Record) (stateAction, *connectionstore.Record, error) { + request := &Request{ + Type: RequestMsgType, + ID: thid, + Label: label, + Thread: &decorator.Thread{ + PID: pthid, + }, + } + // get did document to use in connection request + myDIDDoc, err := ctx.getMyDIDDoc(getPublicDID(options), getRouterConnections(options), didCommServiceType) + if err != nil { + return nil, nil, err + } + + connRec.MyDID = myDIDDoc.ID + + senderKey, err := recipientKey(myDIDDoc) + if err != nil { + return nil, nil, fmt.Errorf("getting recipient key: %w", err) + } + + request.Connection = &Connection{ + DID: myDIDDoc.ID, + DIDDoc: myDIDDoc, + } + + return func() error { + return ctx.outboundDispatcher.Send(request, senderKey, destination) + }, connRec, nil +} + +func (ctx *context) handleInboundRequest(request *Request, options *options, + connRec *connectionstore.Record) (stateAction, *connectionstore.Record, error) { + logger.Debugf("handling request: %#v", request) + + requestDidDoc, err := ctx.resolveDidDocFromConnection(request.Connection) + if err != nil { + return nil, nil, fmt.Errorf("resolve did doc from connection request: %w", err) + } + + // get did document that will be used in connection response + // (my did doc) + myDID := getPublicDID(options) + + destination, err := service.CreateDestination(requestDidDoc) + if err != nil { + return nil, nil, err + } + + var serviceType string + if len(requestDidDoc.Service) > 0 { + serviceType = requestDidDoc.Service[0].Type + } else { + serviceType = didCommServiceType + } + + responseDidDoc, err := ctx.getMyDIDDoc(myDID, getRouterConnections(options), serviceType) + if err != nil { + return nil, nil, fmt.Errorf("get response did doc and connection: %w", err) + } + + // prepare connection signature + connectionSignature, err := ctx.prepareConnectionSignature(responseDidDoc, request.Thread.PID) + if err != nil { + return nil, nil, err + } + + response := ctx.prepareResponse(request, connectionSignature) + + var senderVerKey string + + senderVerKey, err = recipientKey(responseDidDoc) + if err != nil { + return nil, nil, fmt.Errorf("get recipient key: %w", err) + } + + connRec.MyDID = responseDidDoc.ID + connRec.TheirDID = request.Connection.DID + connRec.TheirLabel = request.Label + + accept, err := destination.ServiceEndpoint.Accept() + if err != nil { + accept = []string{} + } + + if len(accept) > 0 { + connRec.MediaTypeProfiles = accept + } + // send connection response + return func() error { + return ctx.outboundDispatcher.Send(response, senderVerKey, destination) + }, connRec, nil +} + +func (ctx *context) prepareConnectionSignature(didDoc *did.Doc, + invitationID string) (*ConnectionSignature, error) { + connection := &Connection{ + DID: didDoc.ID, + DIDDoc: didDoc, + } + logger.Debugf("connection=%+v invitationID=%s", connection, invitationID) + + connAttributeBytes, err := json.Marshal(connection) + if err != nil { + return nil, fmt.Errorf("failed to marshal connection : %w", err) + } + + now := time.Now().Unix() + timestampBuf := make([]byte, timestampLength) + binary.BigEndian.PutUint64(timestampBuf, uint64(now)) + + concatenateSignData := append(timestampBuf, connAttributeBytes...) + + pubKey, err := ctx.getVerKey(invitationID) + if err != nil { + return nil, fmt.Errorf("failed to get verkey : %w", err) + } + + signingKID, err := localkms.CreateKID(base58.Decode(pubKey), kms.ED25519Type) + if err != nil { + return nil, fmt.Errorf("failed to generate KID from public key: %w", err) + } + + kh, err := ctx.kms.Get(signingKID) + if err != nil { + return nil, fmt.Errorf("failed to get key handle: %w", err) + } + + var signature []byte + + signature, err = ctx.crypto.Sign(concatenateSignData, kh) + if err != nil { + return nil, fmt.Errorf("signing data: %w", err) + } + + return &ConnectionSignature{ + Type: signatureType, + SignedData: base64.URLEncoding.EncodeToString(concatenateSignData), + SignVerKey: base64.URLEncoding.EncodeToString(base58.Decode(pubKey)), + Signature: base64.URLEncoding.EncodeToString(signature), + }, nil +} + +func (ctx *context) prepareResponse(request *Request, signature *ConnectionSignature) *Response { + // prepare the response + response := &Response{ + Type: ResponseMsgType, + ID: uuid.New().String(), + Thread: &decorator.Thread{ + ID: request.ID, + }, + ConnectionSignature: signature, + PleaseAck: &PleaseAck{ + []string{PlsAckOnReceipt}, + }, + } + + if request.Thread != nil { + response.Thread.PID = request.Thread.PID + } + + return response +} + +func getPublicDID(options *options) string { + if options == nil { + return "" + } + + return options.publicDID +} + +func getRouterConnections(options *options) []string { + if options == nil { + return nil + } + + return options.routerConnections +} + +// returns the label given in the options, otherwise an empty string. +func getLabel(options *options) string { + if options == nil { + return "" + } + + return options.label +} + +func (ctx *context) getDestination(invitation *Invitation) (*service.Destination, error) { + if invitation.DID != "" { + return service.GetDestination(invitation.DID, ctx.vdRegistry) + } + + accept := ctx.mediaTypeProfiles + if isDIDCommV2(accept) { + return nil, fmt.Errorf("DIDComm V2 profile type(s): %v - are not supported", accept) + } + + return &service.Destination{ + RecipientKeys: invitation.RecipientKeys, + ServiceEndpoint: model.NewDIDCommV1Endpoint(invitation.ServiceEndpoint), + MediaTypeProfiles: accept, + RoutingKeys: invitation.RoutingKeys, + }, nil +} + +// nolint:gocyclo,funlen +func (ctx *context) getMyDIDDoc(pubDID string, routerConnections []string, serviceType string) (*did.Doc, error) { + if pubDID != "" { + logger.Debugf("using public did[%s] for connection", pubDID) + + docResolution, err := ctx.vdRegistry.Resolve(pubDID) + if err != nil { + return nil, fmt.Errorf("resolve public did[%s]: %w", pubDID, err) + } + + err = ctx.connectionStore.SaveDIDFromDoc(docResolution.DIDDocument) + if err != nil { + return nil, err + } + + return docResolution.DIDDocument, nil + } + + logger.Debugf("creating new '%s' did for connection", didMethod) + + var ( + services []did.Service + newService bool + ) + + for _, connID := range routerConnections { + // get the route configs (pass empty service endpoint, as default service endpoint added in VDR) + serviceEndpoint, routingKeys, err := mediator.GetRouterConfig(ctx.routeSvc, connID, "") + if err != nil { + return nil, fmt.Errorf("did doc - fetch router config: %w", err) + } + + var svc did.Service + + switch serviceType { + case didCommServiceType, legacyDIDCommServiceType: + svc = did.Service{ + Type: didCommServiceType, + ServiceEndpoint: model.NewDIDCommV1Endpoint(serviceEndpoint), + RoutingKeys: routingKeys, + } + default: + return nil, fmt.Errorf("service type %s is not supported", serviceType) + } + + services = append(services, svc) + } + + if len(services) == 0 { + newService = true + + services = append(services, did.Service{Type: serviceType}) + } + + newDID := &did.Doc{Service: services} + + err := ctx.createNewKeyAndVM(newDID) + if err != nil { + return nil, fmt.Errorf("failed to create and export public key: %w", err) + } + + if newService { + switch newDID.Service[0].Type { + case didCommServiceType, legacyDIDCommServiceType: + newDID.Service[0].RecipientKeys = []string{base58.Encode(newDID.VerificationMethod[0].Value)} + default: + return nil, fmt.Errorf("service type %s is not supported", newDID.Service[0].Type) + } + } + // by default use peer did + docResolution, err := ctx.vdRegistry.Create(didMethod, newDID) + if err != nil { + return nil, fmt.Errorf("create %s did: %w", didMethod, err) + } + + if len(routerConnections) != 0 { + err = ctx.addRouterKeys(docResolution.DIDDocument, routerConnections) + if err != nil { + return nil, err + } + } + + err = ctx.connectionStore.SaveDIDFromDoc(docResolution.DIDDocument) + if err != nil { + return nil, err + } + + return docResolution.DIDDocument, nil +} + +func (ctx *context) addRouterKeys(doc *did.Doc, routerConnections []string) error { + svc, ok := did.LookupService(doc, didCommServiceType) + if ok { + for _, recKey := range svc.RecipientKeys { + for _, connID := range routerConnections { + // TODO https://github.com/hyperledger/aries-framework-go/issues/1105 Support to Add multiple + // recKeys to the Router + if err := mediator.AddKeyToRouter(ctx.routeSvc, connID, recKey); err != nil { + return fmt.Errorf("did doc - add key to the router: %w", err) + } + } + } + } + + return nil +} + +func (ctx *context) isPrivateDIDMethod(method string) bool { + // todo: find better solution to forcing test dids to be treated as private dids + if method == "local" || method == "test" { + return true + } + + return method == "peer" || method == "sov" +} + +func (ctx *context) resolveDidDocFromConnection(con *Connection) (*did.Doc, error) { + parsedDID, err := did.Parse(con.DID) + if err != nil { + return nil, fmt.Errorf("failed to parse did: %w", err) + } + + if err == nil && !ctx.isPrivateDIDMethod(parsedDID.Method) { + docResolution, e := ctx.vdRegistry.Resolve(con.DID) + if e != nil { + return nil, fmt.Errorf("failed to resolve public did %s: %w", con.DID, e) + } + + return docResolution.DIDDocument, nil + } + + if con.DIDDoc == nil { + return nil, fmt.Errorf("missing DIDDoc") + } + + var method string + + if parsedDID != nil && parsedDID.Method != "sov" { + method = parsedDID.Method + } else { + method = "peer" + } + // store provided did document + _, err = ctx.vdRegistry.Create(method, con.DIDDoc, vdrapi.WithOption("store", true)) + if err != nil { + return nil, fmt.Errorf("failed to store provided did document: %w", err) + } + + return con.DIDDoc, nil +} + +func (ctx *context) handleInboundResponse(response *Response) (stateAction, *connectionstore.Record, error) { + ack := model2.Ack{ + Type: AckMsgType, + ID: uuid.New().String(), + Status: ackStatusOK, + Thread: &decorator.Thread{ + ID: response.Thread.ID, + }, + } + + nsThID, err := connectionstore.CreateNamespaceKey(myNSPrefix, ack.Thread.ID) + if err != nil { + return nil, nil, err + } + + connRecord, err := ctx.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID) + if err != nil { + return nil, nil, fmt.Errorf("get connection record: %w", err) + } + + conn, err := ctx.verifySignature(response.ConnectionSignature, connRecord.RecipientKeys[0]) + if err != nil { + return nil, nil, err + } + + connRecord.TheirDID = conn.DID + + responseDidDoc, err := ctx.resolveDidDocFromConnection(conn) + if err != nil { + return nil, nil, fmt.Errorf("resolve response did doc: %w", err) + } + + destination, err := service.CreateDestination(responseDidDoc) + if err != nil { + return nil, nil, fmt.Errorf("prepare destination from response did doc: %w", err) + } + + docResolution, err := ctx.vdRegistry.Resolve(connRecord.MyDID) + if err != nil { + return nil, nil, fmt.Errorf("fetching did document: %w", err) + } + + recKey, err := recipientKey(docResolution.DIDDocument) + if err != nil { + return nil, nil, fmt.Errorf("handle inbound response: %w", err) + } + + return func() error { + return ctx.outboundDispatcher.Send(ack, recKey, destination) + }, connRecord, nil +} + +// verifySignature verifies connection signature and returns connection. +func (ctx *context) verifySignature(connSignature *ConnectionSignature, recipientKeys string) (*Connection, error) { + sigData, err := base64.URLEncoding.DecodeString(connSignature.SignedData) + if err != nil { + return nil, fmt.Errorf("decode signature data: %w", err) + } + + if len(sigData) == 0 { + return nil, fmt.Errorf("missing or invalid signature data") + } + + signature, err := base64.URLEncoding.DecodeString(connSignature.Signature) + if err != nil { + return nil, fmt.Errorf("decode signature: %w", err) + } + + // The signature data must be used to verify against the invitation's recipientKeys for continuity. + pubKey := base58.Decode(recipientKeys) + + kh, err := ctx.kms.PubKeyBytesToHandle(pubKey, kms.ED25519Type) + if err != nil { + return nil, fmt.Errorf("failed to get key handle: %w", err) + } + + err = ctx.crypto.Verify(signature, sigData, kh) + if err != nil { + return nil, fmt.Errorf("verify signature: %w", err) + } + + // trimming the timestamp and delimiter - only taking out connection attribute bytes + if len(sigData) <= timestampLength { + return nil, fmt.Errorf("missing connection attribute bytes") + } + + connBytes := sigData[timestampLength:] + conn := &Connection{} + + err = json.Unmarshal(connBytes, conn) + if err != nil { + return nil, fmt.Errorf("JSON unmarshalling of connection: %w", err) + } + + return conn, nil +} + +func (ctx *context) getVerKey(invitationID string) (string, error) { + var invitation Invitation + + if isDID(invitationID) { + invitation = Invitation{ID: invitationID, DID: invitationID} + } else { + err := ctx.connectionRecorder.GetInvitation(invitationID, &invitation) + if err != nil { + return "", fmt.Errorf("get invitation for [invitationID=%s]: %w", invitationID, err) + } + } + + invPubKey, err := ctx.getInvitationRecipientKey(&invitation) + if err != nil { + return "", fmt.Errorf("get invitation recipient key: %w", err) + } + + return invPubKey, nil +} + +func (ctx *context) getInvitationRecipientKey(invitation *Invitation) (string, error) { + if invitation.DID != "" { + docResolution, err := ctx.vdRegistry.Resolve(invitation.DID) + if err != nil { + return "", fmt.Errorf("get invitation recipient key: %w", err) + } + + recKey, err := recipientKey(docResolution.DIDDocument) + if err != nil { + return "", fmt.Errorf("getInvitationRecipientKey: %w", err) + } + + return recKey, nil + } + + return invitation.RecipientKeys[0], nil +} + +func isDID(str string) bool { + const didPrefix = "did:" + return strings.HasPrefix(str, didPrefix) +} + +func isDIDCommV2(mediaTypeProfiles []string) bool { + for _, mtp := range mediaTypeProfiles { + switch mtp { + case transport.MediaTypeDIDCommV2Profile, transport.MediaTypeAIP2RFC0587Profile: + return true + } + } + + return false +} + +// returns the did:key ID of the first element in the doc's destination RecipientKeys. +func recipientKey(doc *did.Doc) (string, error) { + switch doc.Service[0].Type { + case vdrapi.DIDCommServiceType, legacyDIDCommServiceType: + dest, err := service.CreateDestination(doc) + if err != nil { + return "", fmt.Errorf("failed to create destination: %w", err) + } + + return dest.RecipientKeys[0], nil + default: + return "", fmt.Errorf("recipientKeyAsDIDKey: invalid DID Doc service type: '%v'", doc.Service[0].Type) + } +} diff --git a/pkg/didcomm/protocol/didconnection/states_test.go b/pkg/didcomm/protocol/didconnection/states_test.go new file mode 100644 index 0000000000..e8121ea1a2 --- /dev/null +++ b/pkg/didcomm/protocol/didconnection/states_test.go @@ -0,0 +1,1565 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcutil/base58" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + commonmodel "github.com/hyperledger/aries-framework-go/pkg/common/model" + "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" + "github.com/hyperledger/aries-framework-go/pkg/kms" + mockcrypto "github.com/hyperledger/aries-framework-go/pkg/mock/crypto" + "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol" + mockroute "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/mediator" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/mock/diddoc" + mockstorage "github.com/hyperledger/aries-framework-go/pkg/mock/storage" + mockvdr "github.com/hyperledger/aries-framework-go/pkg/mock/vdr" + "github.com/hyperledger/aries-framework-go/pkg/store/connection" + didstore "github.com/hyperledger/aries-framework-go/pkg/store/did" +) + +func TestNoopState(t *testing.T) { + noop := &noOp{} + require.Equal(t, "noop", noop.Name()) + + t.Run("must not transition to any state", func(t *testing.T) { + all := []state{&null{}, &invited{}, &requested{}, &responded{}, &completed{}} + for _, s := range all { + require.False(t, noop.CanTransitionTo(s)) + } + }) +} + +// null state can transition to invited state or requested state. +func TestNullState(t *testing.T) { + nul := &null{} + require.Equal(t, "null", nul.Name()) + require.False(t, nul.CanTransitionTo(nul)) + require.True(t, nul.CanTransitionTo(&invited{})) + require.True(t, nul.CanTransitionTo(&requested{})) + require.False(t, nul.CanTransitionTo(&responded{})) + require.False(t, nul.CanTransitionTo(&completed{})) +} + +// invited can only transition to requested state. +func TestInvitedState(t *testing.T) { + inv := &invited{} + require.Equal(t, "invited", inv.Name()) + require.False(t, inv.CanTransitionTo(&null{})) + require.False(t, inv.CanTransitionTo(inv)) + require.True(t, inv.CanTransitionTo(&requested{})) + require.False(t, inv.CanTransitionTo(&responded{})) + require.False(t, inv.CanTransitionTo(&completed{})) +} + +// requested can only transition to responded state. +func TestRequestedState(t *testing.T) { + req := &requested{} + require.Equal(t, "requested", req.Name()) + require.False(t, req.CanTransitionTo(&null{})) + require.False(t, req.CanTransitionTo(&invited{})) + require.False(t, req.CanTransitionTo(req)) + require.True(t, req.CanTransitionTo(&responded{})) + require.False(t, req.CanTransitionTo(&completed{})) +} + +// responded can only transition to completed state. +func TestRespondedState(t *testing.T) { + res := &responded{} + require.Equal(t, "responded", res.Name()) + require.False(t, res.CanTransitionTo(&null{})) + require.False(t, res.CanTransitionTo(&invited{})) + require.False(t, res.CanTransitionTo(&requested{})) + require.False(t, res.CanTransitionTo(res)) + require.True(t, res.CanTransitionTo(&completed{})) +} + +// completed is an end state. +func TestCompletedState(t *testing.T) { + comp := &completed{} + require.Equal(t, "completed", comp.Name()) + require.False(t, comp.CanTransitionTo(&null{})) + require.False(t, comp.CanTransitionTo(&invited{})) + require.False(t, comp.CanTransitionTo(&requested{})) + require.False(t, comp.CanTransitionTo(&responded{})) + require.False(t, comp.CanTransitionTo(comp)) +} + +func TestStateFromMsgType(t *testing.T) { + t.Run("invited", func(t *testing.T) { + expected := &invited{} + actual, err := stateFromMsgType(InvitationMsgType) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("requested", func(t *testing.T) { + expected := &requested{} + actual, err := stateFromMsgType(RequestMsgType) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("responded", func(t *testing.T) { + expected := &responded{} + actual, err := stateFromMsgType(ResponseMsgType) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("completed", func(t *testing.T) { + expected := &completed{} + actual, err := stateFromMsgType(AckMsgType) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("invalid", func(t *testing.T) { + actual, err := stateFromMsgType("invalid") + require.Nil(t, actual) + require.Error(t, err) + require.Contains(t, err.Error(), "unrecognized msgType: invalid") + }) +} + +func TestStateFromName(t *testing.T) { + t.Run("noop", func(t *testing.T) { + expected := &noOp{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("null", func(t *testing.T) { + expected := &null{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("invited", func(t *testing.T) { + expected := &invited{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("requested", func(t *testing.T) { + expected := &requested{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("responded", func(t *testing.T) { + expected := &responded{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("completed", func(t *testing.T) { + expected := &completed{} + actual, err := stateFromName(expected.Name()) + require.NoError(t, err) + require.Equal(t, expected.Name(), actual.Name()) + }) + t.Run("undefined", func(t *testing.T) { + actual, err := stateFromName("undefined") + require.Nil(t, actual) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid state name") + }) +} + +// noOp.ExecuteInbound() returns nil, error. +func TestNoOpState_Execute(t *testing.T) { + _, followup, _, err := (&noOp{}).ExecuteInbound(&stateMachineMsg{}, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot execute no-op") + require.Nil(t, followup) +} + +// null.ExecuteInbound() is a no-op. +func TestNullState_Execute(t *testing.T) { + _, followup, _, err := (&null{}).ExecuteInbound(&stateMachineMsg{}, "", &context{}) + require.NoError(t, err) + require.IsType(t, &noOp{}, followup) +} + +func TestInvitedState_Execute(t *testing.T) { + t.Run("rejects msgs other than invitations", func(t *testing.T) { + others := []service.DIDCommMsg{ + service.NewDIDCommMsgMap(Request{Type: RequestMsgType}), + service.NewDIDCommMsgMap(Response{Type: ResponseMsgType}), + service.NewDIDCommMsgMap(model.Ack{Type: AckMsgType}), + } + for _, msg := range others { + _, _, _, err := (&invited{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + }, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "illegal msg type") + } + }) + t.Run("followup to 'requested' on inbound invitations", func(t *testing.T) { + invitationPayloadBytes, err := json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + ServiceEndpoint: "https://localhost:8090", + RoutingKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + }) + require.NoError(t, err) + connRec, followup, _, err := (&invited{}).ExecuteInbound( + &stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, invitationPayloadBytes), + connRecord: &connection.Record{}, + }, + "", + &context{}) + require.NoError(t, err) + require.Equal(t, &requested{}, followup) + require.NotNil(t, connRec) + }) +} + +func TestRequestedState_Execute(t *testing.T) { + prov := getProvider(t) + // Alice receives an invitation from Bob + invitationPayloadBytes, err := json.Marshal(&Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + ServiceEndpoint: "https://localhost:8090", + RoutingKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + }) + require.NoError(t, err) + + mtps := []string{ + transport.MediaTypeRFC0019EncryptedEnvelope, + transport.MediaTypeProfileDIDCommAIP1, + } + + for _, mtp := range mtps { + t.Run("rejects messages other than invitations or requests", func(t *testing.T) { + others := []service.DIDCommMsg{ + service.NewDIDCommMsgMap(Response{Type: ResponseMsgType}), + service.NewDIDCommMsgMap(model.Ack{Type: AckMsgType}), + } + for _, msg := range others { + _, _, _, e := (&requested{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + }, "", &context{}) + require.Error(t, e) + require.Contains(t, e.Error(), "illegal msg type") + } + }) + t.Run("handle inbound invitations", func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + msg, e := service.ParseDIDCommMsgMap(invitationPayloadBytes) + require.NoError(t, e) + thid, e := msg.ThreadID() + require.NoError(t, e) + connRec, _, _, e := (&requested{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + connRecord: &connection.Record{}, + }, thid, ctx) + require.NoError(t, e) + require.NotNil(t, connRec.MyDID) + }) + t.Run("handling invitations fails if my diddoc does not have a valid didcomm service", func(t *testing.T) { + msg, e := service.ParseDIDCommMsgMap(invitationPayloadBytes) + require.NoError(t, e) + + ctx := getContext(t, &prov, mtp) + + myDoc := createDIDDoc(t, ctx) + myDoc.Service = []diddoc.Service{{ + ID: uuid.New().String(), + Type: "invalid", + Priority: 0, + RecipientKeys: nil, + ServiceEndpoint: commonmodel.NewDIDCommV1Endpoint("https://localhost:8090"), + }} + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: myDoc} + _, _, _, err = (&requested{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + connRecord: &connection.Record{}, + }, "", ctx) + require.Error(t, err) + }) + t.Run("inbound request unmarshalling error", func(t *testing.T) { + _, followup, _, err := (&requested{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: service.DIDCommMsgMap{ + "@type": InvitationMsgType, + "@id": map[int]int{}, + }, + }, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "JSON unmarshalling of invitation") + require.Nil(t, followup) + }) + t.Run("create DID error", func(t *testing.T) { + ctx2 := &context{ + outboundDispatcher: prov.OutboundDispatcher(), + vdRegistry: &mockvdr.MockVDRegistry{CreateErr: fmt.Errorf("create DID error")}, + } + didDoc, err := ctx2.vdRegistry.Create(testMethod, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "create DID error") + require.Nil(t, didDoc) + }) + } +} + +func TestRespondedState_Execute(t *testing.T) { + mtps := []string{transport.MediaTypeProfileDIDCommAIP1, transport.MediaTypeRFC0019EncryptedEnvelope} + + for _, mtp := range mtps { + prov := getProvider(t) + ctx := getContext(t, &prov, mtp) + + request, err := createRequest(t, ctx) + require.NoError(t, err) + + requestPayloadBytes, err := json.Marshal(request) + require.NoError(t, err) + + response, err := createResponse(request, ctx) + require.NoError(t, err) + + responsePayloadBytes, err := json.Marshal(response) + require.NoError(t, err) + + t.Run("rejects messages other than requests and responses", func(t *testing.T) { + others := []service.DIDCommMsg{ + service.NewDIDCommMsgMap(Invitation{Type: InvitationMsgType}), + service.NewDIDCommMsgMap(model.Ack{Type: AckMsgType}), + } + for _, msg := range others { + _, _, _, e := (&responded{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + }, "", &context{}) + require.Error(t, e) + require.Contains(t, e.Error(), "illegal msg type") + } + }) + t.Run("no followup for inbound requests", func(t *testing.T) { + connRec, followup, _, e := (&responded{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, requestPayloadBytes), + connRecord: &connection.Record{}, + }, "", ctx) + require.NoError(t, e) + require.NotNil(t, connRec) + require.IsType(t, &noOp{}, followup) + }) + t.Run("followup to 'completed' on inbound responses", func(t *testing.T) { + connRec := &connection.Record{ + State: (&responded{}).Name(), + ThreadID: request.ID, + ConnectionID: "123", + Namespace: findNamespace(ResponseMsgType), + } + err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(connRec) + require.NoError(t, err) + connRec, followup, _, e := (&responded{}).ExecuteInbound( + &stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, responsePayloadBytes), + connRecord: connRec, + }, "", ctx) + require.NoError(t, e) + require.NotNil(t, connRec) + require.Equal(t, (&completed{}).Name(), followup.Name()) + }) + + t.Run("handle inbound request unmarshalling error", func(t *testing.T) { + _, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: service.DIDCommMsgMap{"@id": map[int]int{}, "@type": RequestMsgType}, + }, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "JSON unmarshalling of request") + require.Nil(t, followup) + }) + + t.Run("fails if my did has an invalid didcomm service entry", func(t *testing.T) { + myDoc := createDIDDoc(t, ctx) + myDoc.Service = []diddoc.Service{{ + ID: uuid.New().String(), + Type: "invalid", + Priority: 0, + RecipientKeys: nil, + ServiceEndpoint: commonmodel.NewDIDCommV1Endpoint("http://localhost:58416"), + }} + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: myDoc} + _, _, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, requestPayloadBytes), + connRecord: &connection.Record{}, + }, "", ctx) + require.Error(t, err) + }) + } +} + +// completed is an end state. +func TestCompletedState_Execute(t *testing.T) { + prov := getProvider(t) + customKMS := newKMS(t, prov.StoreProvider) + ctx := &context{ + crypto: &tinkcrypto.Crypto{}, + kms: customKMS, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + connRec, err := connection.NewRecorder(&prov) + + require.NoError(t, err) + require.NotNil(t, connRec) + + ctx.connectionRecorder = connRec + + newDIDDoc := createDIDDocWithKey(pubKey) + + invitation, err := createMockInvitation(pubKey, ctx) + require.NoError(t, err) + + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + + response := &Response{ + Type: ResponseMsgType, + ID: randomString(), + ConnectionSignature: connectionSignature, + Thread: &decorator.Thread{ + ID: "test", + }, + PleaseAck: &PleaseAck{On: []string{PlsAckOnReceipt}}, + } + + t.Run("no followup for inbound responses", func(t *testing.T) { + var responsePayloadBytes []byte + + responsePayloadBytes, err = json.Marshal(response) + require.NoError(t, err) + + newConnRec := &connection.Record{ + State: (&responded{}).Name(), + ThreadID: response.Thread.ID, + ConnectionID: "123", + MyDID: "did:peer:123456789abcdefghi#inbox", + Namespace: myNSPrefix, + InvitationID: invitation.ID, + RecipientKeys: []string{base58.Encode(pubKey)}, + } + err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(newConnRec) + require.NoError(t, err) + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: mockdiddoc.GetMockDIDDoc(t, false)} + require.NoError(t, err) + _, followup, _, e := (&completed{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, responsePayloadBytes), + connRecord: newConnRec, + }, "", ctx) + require.NoError(t, e) + require.IsType(t, &noOp{}, followup) + }) + t.Run("no followup for inbound acks", func(t *testing.T) { + newConnRec := &connection.Record{ + State: (&responded{}).Name(), + ThreadID: response.Thread.ID, + ConnectionID: "123", + Namespace: findNamespace(AckMsgType), + RecipientKeys: []string{base58.Encode(pubKey)}, + } + err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(newConnRec) + require.NoError(t, err) + ack := &model.Ack{ + Type: AckMsgType, + ID: randomString(), + Status: ackStatusOK, + Thread: &decorator.Thread{ + ID: response.Thread.ID, + }, + } + ackPayloadBytes, e := json.Marshal(ack) + require.NoError(t, e) + _, followup, _, e := (&completed{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, ackPayloadBytes), + }, "", ctx) + require.NoError(t, e) + require.IsType(t, &noOp{}, followup) + }) + t.Run("rejects messages other than responses, acks, and completes", func(t *testing.T) { + others := []service.DIDCommMsg{ + service.NewDIDCommMsgMap(Invitation{Type: InvitationMsgType}), + service.NewDIDCommMsgMap(Request{Type: RequestMsgType}), + } + + for _, msg := range others { + _, _, _, err = (&completed{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: msg, + }, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "illegal msg type") + } + }) + t.Run("no followup for inbound responses unmarshalling error", func(t *testing.T) { + _, followup, _, err := (&completed{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: service.DIDCommMsgMap{"@id": map[int]int{}, "@type": ResponseMsgType}, + }, "", &context{}) + require.Error(t, err) + require.Contains(t, err.Error(), "JSON unmarshalling of response") + require.Nil(t, followup) + }) + + t.Run("execute inbound handle inbound response error", func(t *testing.T) { + response.ConnectionSignature = &ConnectionSignature{} + responsePayloadBytes, err := json.Marshal(response) + require.NoError(t, err) + + _, followup, _, err := (&completed{}).ExecuteInbound(&stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, responsePayloadBytes), + }, "", ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "handle inbound response") + require.Nil(t, followup) + }) +} + +func TestNewRequestFromInvitation(t *testing.T) { + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + ServiceEndpoint: "https://localhost:8090", + RoutingKeys: []string{"8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K"}, + } + + t.Run("successful new request from invitation", func(t *testing.T) { + prov := getProvider(t) + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + _, connRec, err := ctx.handleInboundInvitation(invitation, invitation.ID, &options{}, &connection.Record{}) + require.NoError(t, err) + require.NotNil(t, connRec.MyDID) + }) + t.Run("successful response to invitation with public did", func(t *testing.T) { + prov := getProvider(t) + ctx := &context{ + kms: prov.CustomKMS, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + doc := createDIDDoc(t, ctx) + connRec, err := connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + didConnStore, err := didstore.NewConnectionStore(&protocol.MockProvider{}) + require.NoError(t, err) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: doc} + ctx.connectionRecorder = connRec + ctx.connectionStore = didConnStore + + _, connRecord, err := ctx.handleInboundInvitation(invitation, invitation.ID, &options{publicDID: doc.ID}, + &connection.Record{}) + require.NoError(t, err) + require.NotNil(t, connRecord.MyDID) + require.Equal(t, connRecord.MyDID, doc.ID) + }) + t.Run("unsuccessful new request from invitation ", func(t *testing.T) { + prov := protocol.MockProvider{} + customKMS := newKMS(t, prov.StoreProvider) + + ctx := &context{ + kms: customKMS, + outboundDispatcher: prov.OutboundDispatcher(), + routeSvc: &mockroute.MockMediatorSvc{}, + vdRegistry: &mockvdr.MockVDRegistry{CreateErr: fmt.Errorf("create DID error")}, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + _, connRec, err := ctx.handleInboundInvitation(invitation, invitation.ID, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create DID error") + require.Nil(t, connRec) + }) +} + +func TestNewResponseFromRequest(t *testing.T) { + prov := getProvider(t) + store := mockstorage.NewMockStoreProvider() + k := newKMS(t, store) + + t.Run("successful new response from request", func(t *testing.T) { + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + request, err := createRequest(t, ctx) + require.NoError(t, err) + _, connRec, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.NoError(t, err) + require.NotNil(t, connRec.MyDID) + require.NotNil(t, connRec.TheirDID) + }) + + t.Run("unsuccessful new response from request due to resolve DID error", func(t *testing.T) { + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + request, err := createRequest(t, ctx) + require.NoError(t, err) + + request.Connection.DID = "invalid" + _, connRec, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "resolve did doc from connection request") + require.Nil(t, connRec) + }) + + t.Run("unsuccessful new response from request due to create did error", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc(t, false) + ctx := &context{ + vdRegistry: &mockvdr.MockVDRegistry{ + CreateErr: fmt.Errorf("create DID error"), + ResolveValue: mockdiddoc.GetMockDIDDoc(t, false), + }, + routeSvc: &mockroute.MockMediatorSvc{}, + } + request := &Request{ + Connection: &Connection{ + DID: didDoc.ID, + DIDDoc: didDoc, + }, + } + _, connRec, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create DID error") + require.Nil(t, connRec) + }) + + t.Run("unsuccessful new response from request due to get did doc error", func(t *testing.T) { + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + ctx.connectionStore = &mockConnectionStore{saveDIDFromDocErr: fmt.Errorf("save did error")} + + request, err := createRequest(t, ctx) + require.NoError(t, err) + _, connRec, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "get response did doc and connection") + require.Nil(t, connRec) + }) + + t.Run("unsuccessful new response from request due to sign error", func(t *testing.T) { + connRec, err := connection.NewRecorder(&prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + didConnStore, err := didstore.NewConnectionStore(&prov) + require.NoError(t, err) + require.NotNil(t, didConnStore) + + ctx := &context{ + vdRegistry: &mockvdr.MockVDRegistry{CreateValue: mockdiddoc.GetMockDIDDoc(t, false)}, + crypto: &mockcrypto.Crypto{SignErr: errors.New("sign error")}, + connectionRecorder: connRec, + connectionStore: didConnStore, + routeSvc: &mockroute.MockMediatorSvc{}, + kms: prov.CustomKMS, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + doACAPyInterop: true, + } + + request, err := createRequest(t, ctx) + require.NoError(t, err) + + _, connRecord, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + + require.Error(t, err) + require.Contains(t, err.Error(), "sign error") + require.Nil(t, connRecord) + }) + + t.Run("unsuccessful new response from request due to resolve public did from request error", func(t *testing.T) { + ctx := &context{vdRegistry: &mockvdr.MockVDRegistry{ResolveErr: errors.New("resolver error")}} + request := &Request{Connection: &Connection{DID: "did:sidetree:abc"}} + _, _, err := ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "resolver error") + }) + + t.Run("unsuccessful new response from request due to invalid did for creating destination", func(t *testing.T) { + mockDoc := newPeerDID(t, k) + mockDoc.Service = nil + + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + + request, err := createRequest(t, ctx) + require.NoError(t, err) + + request.Connection.DID = mockDoc.ID + request.Connection.DIDDoc = mockDoc + + _, _, err = ctx.handleInboundRequest(request, &options{}, &connection.Record{}) + require.Error(t, err) + require.Contains(t, err.Error(), "missing DID doc service") + }) +} + +func TestPrepareConnectionSignature(t *testing.T) { + prov := getProvider(t) + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitation, err := createMockInvitation(pubKey, ctx) + require.NoError(t, err) + + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + + t.Run("prepare connection signature", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, invitation.ID) + require.NoError(t, err) + require.NotNil(t, connectionSignature) + sigData, err := base64.URLEncoding.DecodeString(connectionSignature.SignedData) + require.NoError(t, err) + connBytes := sigData[timestampLength:] + sigDataConnection := &Connection{} + err = json.Unmarshal(connBytes, sigDataConnection) + require.NoError(t, err) + require.Equal(t, doc.DIDDocument.ID, sigDataConnection.DID) + }) + t.Run("implicit invitation with DID - success", func(t *testing.T) { + connRec, err := connection.NewRecorder(&prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + didConnStore, err := didstore.NewConnectionStore(&prov) + require.NoError(t, err) + require.NotNil(t, didConnStore) + + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, invitation.ID) + require.NoError(t, err) + require.NotNil(t, connectionSignature) + sigData, err := base64.URLEncoding.DecodeString(connectionSignature.SignedData) + require.NoError(t, err) + connBytes := sigData[timestampLength:] + sigDataConnection := &Connection{} + err = json.Unmarshal(connBytes, sigDataConnection) + require.NoError(t, err) + require.Equal(t, doc.DIDDocument.ID, sigDataConnection.DID) + }) + t.Run("prepare connection signature get invitation", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, "test") + require.Error(t, err) + require.Contains(t, err.Error(), "get invitation for [invitationID=test]: data not found") + require.Nil(t, connectionSignature) + }) + t.Run("prepare connection signature get invitation", func(t *testing.T) { + invID := randomString() + inv := &Invitation{ + Type: InvitationMsgType, + ID: invID, + DID: "test", + } + err := ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, inv.ID) + require.Error(t, err) + require.Contains(t, err.Error(), + fmt.Sprintf("get invitation for [invitationID=%s]: data not found", invID)) + require.Nil(t, connectionSignature) + }) + t.Run("prepare connection signature error", func(t *testing.T) { + ctx2 := ctx + ctx2.crypto = &mockcrypto.Crypto{SignErr: errors.New("sign error")} + newDIDDoc := mockdiddoc.GetMockDIDDoc(t, false) + + connectionSignature, err := ctx2.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.Error(t, err) + require.Contains(t, err.Error(), "sign error") + require.Nil(t, connectionSignature) + }) +} + +func TestVerifySignature(t *testing.T) { + prov := getProvider(t) + + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + + keyID, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + newDIDDoc := createDIDDocWithKey(pubKey) + + invitation, err := createMockInvitation(pubKey, ctx) + require.NoError(t, err) + + t.Run("signature verified", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + con, err := ctx.verifySignature(connectionSignature, invitation.RecipientKeys[0]) + require.NoError(t, err) + require.NotNil(t, con) + require.Equal(t, newDIDDoc.ID, con.DID) + }) + t.Run("missing/invalid signature data", func(t *testing.T) { + con, err := ctx.verifySignature(&ConnectionSignature{}, invitation.RecipientKeys[0]) + require.Error(t, err) + require.Contains(t, err.Error(), "missing or invalid signature data") + require.Nil(t, con) + }) + t.Run("decode signature data error", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + + connectionSignature.SignedData = "invalid-signed-data" + con, err := ctx.verifySignature(connectionSignature, "") + require.Error(t, err) + require.Contains(t, err.Error(), "decode signature data: illegal base64 data") + require.Nil(t, con) + }) + t.Run("decode signature error", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + + connectionSignature.Signature = "invalid-signature" + con, err := ctx.verifySignature(connectionSignature, "") + require.Error(t, err) + require.Contains(t, err.Error(), "decode signature: illegal base64 data") + require.Nil(t, con) + }) + t.Run("decode verification key error", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + + con, err := ctx.verifySignature(connectionSignature, "invalid-key") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to get key handle: pubKey is empty") + require.Nil(t, con) + }) + t.Run("verify signature error", func(t *testing.T) { + connectionSignature, err := ctx.prepareConnectionSignature(newDIDDoc, invitation.ID) + require.NoError(t, err) + + // generate different key and assign it to signature verification key + pubKey2, _ := generateKeyPair() + con, err := ctx.verifySignature(connectionSignature, pubKey2) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid signature") + require.Nil(t, con) + }) + t.Run("connection unmarshal error", func(t *testing.T) { + connAttributeBytes := []byte("{hello world}") + + now := getEpochTime() + timestampBuf := make([]byte, timestampLength) + binary.BigEndian.PutUint64(timestampBuf, uint64(now)) + concatenateSignData := append(timestampBuf, connAttributeBytes...) + + kh, err := ctx.kms.Get(keyID) + require.NoError(t, err) + + signature, err := ctx.crypto.Sign(concatenateSignData, kh) + require.NoError(t, err) + + cs := &ConnectionSignature{ + Type: "https://didcomm.org/signature/1.0/ed25519Sha512_single", + SignedData: base64.URLEncoding.EncodeToString(concatenateSignData), + SignVerKey: base64.URLEncoding.EncodeToString(pubKey), + Signature: base64.URLEncoding.EncodeToString(signature), + } + + con, err := ctx.verifySignature(cs, invitation.RecipientKeys[0]) + require.Error(t, err) + require.Contains(t, err.Error(), "JSON unmarshalling of connection") + require.Nil(t, con) + }) + t.Run("missing connection attribute bytes", func(t *testing.T) { + now := getEpochTime() + timestampBuf := make([]byte, timestampLength) + binary.BigEndian.PutUint64(timestampBuf, uint64(now)) + + kh, err := ctx.kms.Get(keyID) + require.NoError(t, err) + + signature, err := ctx.crypto.Sign(timestampBuf, kh) + require.NoError(t, err) + + cs := &ConnectionSignature{ + Type: "https://didcomm.org/signature/1.0/ed25519Sha512_single", + SignedData: base64.URLEncoding.EncodeToString(timestampBuf), + SignVerKey: base64.URLEncoding.EncodeToString(pubKey), + Signature: base64.URLEncoding.EncodeToString(signature), + } + + con, err := ctx.verifySignature(cs, invitation.RecipientKeys[0]) + require.Error(t, err) + require.Contains(t, err.Error(), "missing connection attribute bytes") + require.Nil(t, con) + }) +} + +func TestResolveDIDDocFromConnection(t *testing.T) { + prov := getProvider(t) + mtps := []string{transport.MediaTypeProfileDIDCommAIP1, transport.MediaTypeRFC0019EncryptedEnvelope} + + for _, mtp := range mtps { + t.Run(fmt.Sprintf("success with media type profile: %s", mtp), func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + docIn := mockdiddoc.GetMockDIDDoc(t, false) + con := &Connection{ + DID: docIn.ID, + DIDDoc: docIn, + } + doc, err := ctx.resolveDidDocFromConnection(con) + require.NoError(t, err) + + require.Equal(t, docIn.ID, doc.ID) + }) + + t.Run(fmt.Sprintf("success - public resolution with media type profile: %s", mtp), func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + docIn := mockdiddoc.GetMockDIDDoc(t, false) + docIn.ID = "did:remote:abc" + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveValue: docIn} + + con := &Connection{ + DID: docIn.ID, + DIDDoc: docIn, + } + doc, err := ctx.resolveDidDocFromConnection(con) + require.NoError(t, err) + + require.Equal(t, docIn.ID, doc.ID) + }) + + t.Run(fmt.Sprintf("failure - can't do public resolution with media type profile: %s", mtp), + func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + docIn := mockdiddoc.GetMockDIDDoc(t, false) + docIn.ID = "did:remote:abc" + + ctx.vdRegistry = &mockvdr.MockVDRegistry{ResolveErr: fmt.Errorf("resolve error")} + + con := &Connection{ + DID: docIn.ID, + DIDDoc: docIn, + } + _, err := ctx.resolveDidDocFromConnection(con) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to resolve public did") + }) + + t.Run(fmt.Sprintf("failure - can't parse did with media type profile: %s", mtp), func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + + _, err := ctx.resolveDidDocFromConnection(&Connection{DID: "blah blah"}) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to parse did") + }) + + t.Run(fmt.Sprintf("failure - missing attachment for private did with media type profile: %s", mtp), + func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + _, err := ctx.resolveDidDocFromConnection(&Connection{DID: "did:peer:abcdefg"}) + require.Error(t, err) + require.Contains(t, err.Error(), "missing DIDDoc") + }) + + t.Run(fmt.Sprintf("failure - can't store document locally with media type profile: %s", mtp), + func(t *testing.T) { + ctx := getContext(t, &prov, mtp) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateErr: fmt.Errorf("create error")} + + docIn := mockdiddoc.GetMockDIDDoc(t, false) + + con := &Connection{ + DID: docIn.ID, + DIDDoc: docIn, + } + _, err := ctx.resolveDidDocFromConnection(con) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to store provided did document") + }) + } +} + +func TestHandleInboundResponse(t *testing.T) { + prov := getProvider(t) + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + _, err = createMockInvitation(pubKey, ctx) + require.NoError(t, err) + + t.Run("handle inbound responses get connection record error", func(t *testing.T) { + response := &Response{Thread: &decorator.Thread{ID: "test"}} + _, connRec, e := ctx.handleInboundResponse(response) + require.Error(t, e) + require.Contains(t, e.Error(), "get connection record") + require.Nil(t, connRec) + }) + t.Run("handle inbound responses get connection record error", func(t *testing.T) { + response := &Response{Thread: &decorator.Thread{ID: ""}} + _, connRec, e := ctx.handleInboundResponse(response) + require.Error(t, e) + require.Contains(t, e.Error(), "empty bytes") + require.Nil(t, connRec) + }) +} + +func TestGetInvitationRecipientKey(t *testing.T) { + prov := getProvider(t) + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + + t.Run("successfully getting invitation recipient key", func(t *testing.T) { + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{"test"}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + recKey, err := ctx.getInvitationRecipientKey(invitation) + require.NoError(t, err) + require.Equal(t, invitation.RecipientKeys[0], recKey) + }) + t.Run("failed to get invitation recipient key", func(t *testing.T) { + doc := mockdiddoc.GetMockDIDDoc(t, false) + _, ok := diddoc.LookupService(doc, "did-communication") + require.True(t, ok) + + ctx := context{vdRegistry: &mockvdr.MockVDRegistry{ResolveValue: doc}} + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + DID: doc.ID, + } + + recKey, err := ctx.getInvitationRecipientKey(invitation) + require.NoError(t, err) + require.Equal(t, doc.Service[0].RecipientKeys[0], recKey) + }) + t.Run("failed to get invitation recipient key", func(t *testing.T) { + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + DID: "test", + } + _, err := ctx.getInvitationRecipientKey(invitation) + require.Error(t, err) + require.Contains(t, err.Error(), "get invitation recipient key: DID does not exist") + }) +} + +func TestGetPublicKey(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + t.Run("successfully getting public key by id", func(t *testing.T) { + prov := protocol.MockProvider{CustomKMS: k} + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + pubkey, ok := diddoc.LookupPublicKey(doc.DIDDocument.VerificationMethod[0].ID, doc.DIDDocument) + require.True(t, ok) + require.NotNil(t, pubkey) + }) + t.Run("failed to get public key", func(t *testing.T) { + prov := protocol.MockProvider{CustomKMS: k} + ctx := getContext(t, &prov, transport.MediaTypeRFC0019EncryptedEnvelope) + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + pubkey, ok := diddoc.LookupPublicKey("invalid-key", doc.DIDDocument) + require.False(t, ok) + require.Nil(t, pubkey) + }) +} + +func TestGetDIDDocAndConnection(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + + t.Run("successfully getting did doc and connection for public did", func(t *testing.T) { + doc := createDIDDoc(t, ctx) + connRec, err := connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + didConnStore, err := didstore.NewConnectionStore(&protocol.MockProvider{}) + require.NoError(t, err) + ctx := context{ + vdRegistry: &mockvdr.MockVDRegistry{ResolveValue: doc}, + connectionRecorder: connRec, + connectionStore: didConnStore, + } + didDoc, err := ctx.getMyDIDDoc(doc.ID, nil, "") + require.NoError(t, err) + require.NotNil(t, didDoc) + }) + t.Run("error getting public did doc from resolver", func(t *testing.T) { + ctx := context{ + vdRegistry: &mockvdr.MockVDRegistry{ResolveErr: errors.New("resolver error")}, + } + didDoc, err := ctx.getMyDIDDoc("did-id", nil, "") + require.Error(t, err) + require.Contains(t, err.Error(), "resolver error") + require.Nil(t, didDoc) + }) + t.Run("error creating peer did", func(t *testing.T) { + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateErr: errors.New("creator error")}, + routeSvc: &mockroute.MockMediatorSvc{}, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc, err := ctx.getMyDIDDoc("", nil, didCommServiceType) + require.Error(t, err) + require.Contains(t, err.Error(), "creator error") + require.Nil(t, didDoc) + }) + t.Run("error creating peer did with DIDCommV2 service type", func(t *testing.T) { + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateErr: errors.New("DIDCommMessaging is not supported")}, + routeSvc: &mockroute.MockMediatorSvc{}, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc, err := ctx.getMyDIDDoc("", nil, vdrapi.DIDCommV2ServiceType) + require.Error(t, err) + require.Contains(t, err.Error(), "DIDCommMessaging is not supported") + require.Nil(t, didDoc) + }) + t.Run("error creating peer did with empty service type", func(t *testing.T) { + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateErr: errors.New("is not supported")}, + routeSvc: &mockroute.MockMediatorSvc{}, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc, err := ctx.getMyDIDDoc("", nil, "") + require.Error(t, err) + require.Contains(t, err.Error(), "is not supported") + require.Nil(t, didDoc) + }) + + t.Run("successfully created peer did", func(t *testing.T) { + connRec, err := connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + didConnStore, err := didstore.NewConnectionStore(&protocol.MockProvider{}) + require.NoError(t, err) + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateValue: mockdiddoc.GetMockDIDDoc(t, false)}, + connectionRecorder: connRec, + connectionStore: didConnStore, + routeSvc: &mockroute.MockMediatorSvc{}, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc, err := ctx.getMyDIDDoc("", nil, didCommServiceType) + require.NoError(t, err) + require.NotNil(t, didDoc) + }) + t.Run("test create did doc - router service config error", func(t *testing.T) { + connRec, err := connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateValue: mockdiddoc.GetMockDIDDoc(t, false)}, + connectionRecorder: connRec, + routeSvc: &mockroute.MockMediatorSvc{ + Connections: []string{"xyz"}, + ConfigErr: errors.New("router config error"), + }, + } + didDoc, err := ctx.getMyDIDDoc("", []string{"xyz"}, "") + require.Error(t, err) + require.Contains(t, err.Error(), "did doc - fetch router config") + require.Nil(t, didDoc) + }) + + t.Run("test create did doc - router service config error", func(t *testing.T) { + connRec, err := connection.NewRecorder(&protocol.MockProvider{}) + require.NoError(t, err) + customKMS := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := context{ + kms: customKMS, + vdRegistry: &mockvdr.MockVDRegistry{CreateValue: mockdiddoc.GetMockDIDDoc(t, false)}, + connectionRecorder: connRec, + routeSvc: &mockroute.MockMediatorSvc{ + Connections: []string{"xyz"}, + AddKeyErr: errors.New("router add key error"), + }, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + didDoc, err := ctx.getMyDIDDoc("", []string{"xyz"}, didCommServiceType) + require.Error(t, err) + require.Contains(t, err.Error(), "did doc - add key to the router") + require.Nil(t, didDoc) + }) +} + +func TestGetVerKey(t *testing.T) { + k := newKMS(t, mockstorage.NewMockStoreProvider()) + ctx := &context{ + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{transport.MediaTypeRFC0019EncryptedEnvelope}, + } + + t.Run("returns verkey from explicit connection invitation", func(t *testing.T) { + expected := newServiceBlock() + invitation := newConnectionInvite(t, "", expected) + ctx.connectionRecorder = connRecorder(t, testProvider()) + + err := ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + result, err := ctx.getVerKey(invitation.ID) + require.NoError(t, err) + require.Equal(t, expected.RecipientKeys[0], result) + + expected = newServiceBlock() + invitation = newConnectionInvite(t, "", expected) + ctx.connectionRecorder = connRecorder(t, testProvider()) + + err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + result, err = ctx.getVerKey(invitation.ID) + require.NoError(t, err) + require.Equal(t, expected.RecipientKeys[0], result) + }) + + t.Run("returns verkey from implicit connection invitation", func(t *testing.T) { + publicDID := createDIDDoc(t, ctx) + ctx.connectionRecorder = connRecorder(t, testProvider()) + ctx.vdRegistry = &mockvdr.MockVDRegistry{ + ResolveValue: publicDID, + } + + svc, found := diddoc.LookupService(publicDID, "did-communication") + require.True(t, found) + + result, err := ctx.getVerKey(publicDID.ID) + require.NoError(t, err) + require.Equal(t, svc.RecipientKeys[0], result) + }) + + t.Run("wraps error from store", func(t *testing.T) { + expected := errors.New("test") + pr := testProvider() + pr.StoreProvider = &mockstorage.MockStoreProvider{ + Store: &mockstorage.MockStore{ + Store: make(map[string]mockstorage.DBEntry), + ErrGet: expected, + }, + } + ctx.connectionRecorder = connRecorder(t, pr) + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitation, err := createMockInvitation(pubKey, ctx) + require.NoError(t, err) + + err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + _, err = ctx.getVerKey(invitation.ID) + require.Error(t, err) + }) + + t.Run("wraps error from vdr resolution", func(t *testing.T) { + expected := errors.New("test") + ctx.connectionRecorder = connRecorder(t, testProvider()) + ctx.vdRegistry = &mockvdr.MockVDRegistry{ + ResolveErr: expected, + } + + _, err := ctx.getVerKey("did:example:123") + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) +} + +func createDIDDoc(t *testing.T, ctx *context) *diddoc.Doc { + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + return createDIDDocWithKey(pubKey) +} + +func createDIDDocWithKey(pubKey []byte) *diddoc.Doc { + const ( + didFormat = "did:%s:%s" + didPKID = "%s#keys-%d" + didServiceID = "%s#endpoint-%d" + method = "test" + ) + + pub := base58.Encode(pubKey) + id := fmt.Sprintf(didFormat, method, pub[:16]) + pubKeyID := fmt.Sprintf(didPKID, id, 1) + verPubKeyVM := diddoc.VerificationMethod{ + ID: pubKeyID, + Type: "Ed25519VerificationKey2018", + Controller: id, + Value: pubKey, + } + services := []diddoc.Service{ + { + ID: fmt.Sprintf(didServiceID, id, 1), + Type: vdrapi.DIDCommServiceType, + ServiceEndpoint: commonmodel.NewDIDCommV1Endpoint("http://localhost:58416"), + Priority: 0, + RecipientKeys: []string{pubKeyID}, + }, + } + + services[0].Accept = []string{transport.MediaTypeRFC0019EncryptedEnvelope} + + createdTime := time.Now() + didDoc := &diddoc.Doc{ + Context: []string{diddoc.ContextV1}, + ID: id, + VerificationMethod: []diddoc.VerificationMethod{verPubKeyVM}, + Service: services, + Created: &createdTime, + Updated: &createdTime, + } + + return didDoc +} + +func getProvider(t *testing.T) protocol.MockProvider { + t.Helper() + + store := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)} + sProvider := mockstorage.NewCustomMockStoreProvider(store) + customKMS := newKMS(t, sProvider) + + return protocol.MockProvider{ + StoreProvider: sProvider, + CustomKMS: customKMS, + } +} + +func getContext(t *testing.T, prov *protocol.MockProvider, mediaTypeProfile string) *context { + t.Helper() + + ctx := &context{ + outboundDispatcher: prov.OutboundDispatcher(), + crypto: &tinkcrypto.Crypto{}, + routeSvc: &mockroute.MockMediatorSvc{}, + kms: prov.KMS(), + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + mediaTypeProfiles: []string{mediaTypeProfile}, + } + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + + didConnStore, err := didstore.NewConnectionStore(prov) + require.NoError(t, err) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDoc(t, ctx)} + ctx.connectionRecorder = connRec + ctx.connectionStore = didConnStore + + return ctx +} + +func createRequest(t *testing.T, ctx *context) (*Request, error) { + t.Helper() + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + invitation, err := createMockInvitation(pubKey, ctx) + if err != nil { + return nil, err + } + + newDidDoc := createDIDDocWithKey(pubKey) + + // Prepare connection inbound request + request := &Request{ + Type: RequestMsgType, + ID: randomString(), + Label: "Bob", + Thread: &decorator.Thread{ + PID: invitation.ID, + }, + Connection: &Connection{ + DID: newDidDoc.ID, + DIDDoc: newDidDoc, + }, + } + + return request, nil +} + +func generateKeyPair() (string, []byte) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + return base58.Encode(pubKey[:]), privKey +} + +func createResponse(request *Request, ctx *context) (*Response, error) { + doc, err := ctx.vdRegistry.Create(testMethod, nil) + if err != nil { + return nil, err + } + + connectionSignature, err := ctx.prepareConnectionSignature(doc.DIDDocument, request.Thread.PID) + if err != nil { + return nil, err + } + + response := &Response{ + Type: ResponseMsgType, + ID: randomString(), + Thread: &decorator.Thread{ + ID: request.ID, + }, + ConnectionSignature: connectionSignature, + PleaseAck: &PleaseAck{ + On: []string{PlsAckOnReceipt}, + }, + } + + return response, nil +} + +func createMockInvitation(pubKey []byte, ctx *context) (*Invitation, error) { + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(pubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err := ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + if err != nil { + return nil, err + } + + return invitation, nil +} + +func toDIDCommMsg(t *testing.T, v interface{}) service.DIDCommMsgMap { + msg, err := service.ParseDIDCommMsgMap(toBytes(t, v)) + require.NoError(t, err) + + return msg +} + +func bytesToDIDCommMsg(t *testing.T, v []byte) service.DIDCommMsg { + msg, err := service.ParseDIDCommMsgMap(v) + require.NoError(t, err) + + return msg +} + +func toBytes(t *testing.T, data interface{}) []byte { + t.Helper() + + src, err := json.Marshal(data) + require.NoError(t, err) + + return src +} + +func newConnectionInvite(t *testing.T, publicDID string, svc *diddoc.Service) *Invitation { + t.Helper() + + i := &Invitation{ + ID: uuid.New().String(), + Type: InvitationMsgType, + DID: publicDID, + } + + if svc != nil { + var err error + + i.RecipientKeys = svc.RecipientKeys + i.RoutingKeys = svc.RoutingKeys + + i.ServiceEndpoint, err = svc.ServiceEndpoint.URI() + require.NoError(t, err) + } + + return i +} + +func newServiceBlock() *diddoc.Service { + var ( + sp commonmodel.Endpoint + didCommV1RoutingKeys []string + ) + + sp = commonmodel.NewDIDCommV1Endpoint("http://test.com") + didCommV1RoutingKeys = []string{uuid.New().String()} + + svc := &diddoc.Service{ + ID: uuid.New().String(), + Type: didCommServiceType, + RecipientKeys: []string{uuid.New().String()}, + ServiceEndpoint: sp, + } + + svc.Accept = []string{transport.MediaTypeRFC0019EncryptedEnvelope} + svc.RoutingKeys = didCommV1RoutingKeys + + return svc +} + +func connRecorder(t *testing.T, p provider) *connection.Recorder { + s, err := connection.NewRecorder(p) + require.NoError(t, err) + + return s +} + +func getEpochTime() int64 { + return time.Now().Unix() +}