Skip to content

Commit

Permalink
feat(prism-agent): expose connect/issue/presentation records 'thid' a…
Browse files Browse the repository at this point in the history
…nd add it to REST API queries (#583)

* chore(connect): make 'thid' field non-optional

* feat(prism-agent): expose 'thid' when getting connect/issue/presentation records from REST API

* feat(prism-agent): add supprot for filtering on 'thid' in connect/issue/presentation records retrieval (in-memory filtering for now)

* feat(connect): change 'thid' type from UUID to string and expose record retrieval by 'thid' in connect service

* feat(pollux): expose record retrieval by 'thid' in credential and presentation services

* feat(prism-agent): implement connect/issue/presentation records filtering by 'thid' at DB level

* test(prism-agent): declare 'thid' field in Kotlin Connection, Credential, and PresentationProof model classes
  • Loading branch information
bvoiturier committed Jul 6, 2023
1 parent c9e69f6 commit 9a97c7a
Show file tree
Hide file tree
Showing 27 changed files with 142 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class ConnectionRecord(
id: UUID,
createdAt: Instant,
updatedAt: Option[Instant],
thid: Option[UUID],
thid: String,
label: Option[String],
role: Role,
protocolState: ProtocolState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sealed trait ConnectionServiceError
object ConnectionServiceError {
final case class RepositoryError(cause: Throwable) extends ConnectionServiceError
final case class RecordIdNotFound(recordId: UUID) extends ConnectionServiceError
final case class ThreadIdNotFound(thid: UUID) extends ConnectionServiceError
final case class ThreadIdNotFound(thid: String) extends ConnectionServiceError
final case class InvitationParsingError(cause: Throwable) extends ConnectionServiceError
final case class UnexpectedError(msg: String) extends ConnectionServiceError
final case class InvalidFlowStateError(msg: String) extends ConnectionServiceError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ trait ConnectionRepository[F[_]] {

def deleteConnectionRecord(recordId: UUID): F[Int]

def getConnectionRecordByThreadId(thid: UUID): F[Option[ConnectionRecord]]
def getConnectionRecordByThreadId(thid: String): F[Option[ConnectionRecord]]

def updateWithConnectionRequest(
recordId: UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ class ConnectionRepositoryInMemory(storeRef: Ref[Map[UUID, ConnectionRecord]]) e
.getOrElse(ZIO.succeed(0))
} yield count

override def getConnectionRecordByThreadId(thid: UUID): Task[Option[ConnectionRecord]] = {
override def getConnectionRecordByThreadId(thid: String): Task[Option[ConnectionRecord]] = {
for {
store <- storeRef.get
} yield store.values.find(_.thid.contains(thid))
} yield store.values.find(_.thid.toString == thid)
}

override def getConnectionRecords: Task[Seq[ConnectionRecord]] = {
Expand All @@ -161,16 +161,13 @@ class ConnectionRepositoryInMemory(storeRef: Ref[Map[UUID, ConnectionRecord]]) e

override def createConnectionRecord(record: ConnectionRecord): Task[Int] = {
for {
_ <- record.thid match
case None => ZIO.unit
case Some(value) =>
for {
store <- storeRef.get
maybeRecord <- ZIO.succeed(store.values.find(_.thid == record.thid))
_ <- maybeRecord match
case None => ZIO.unit
case Some(value) => ZIO.fail(UniqueConstraintViolation("Unique Constraint Violation on 'thid'"))
} yield ()
_ <- for {
store <- storeRef.get
maybeRecord <- ZIO.succeed(store.values.find(_.thid == record.thid))
_ <- maybeRecord match
case None => ZIO.unit
case Some(value) => ZIO.fail(UniqueConstraintViolation("Unique Constraint Violation on 'thid'"))
} yield ()
_ <- storeRef.update(r => r + (record.id -> record))
} yield 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ trait ConnectionService {
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]]

/** Get the ConnectionRecord by the record id. If the record is id is not found the value None will be return */
def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def getConnectionRecordByThreadId(thid: String): IO[ConnectionServiceError, Option[ConnectionRecord]]

def deleteConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Int]

def reportProcessingFailure(recordId: UUID, failReason: Option[String]): IO[ConnectionServiceError, Unit]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private class ConnectionServiceImpl(
id = UUID.fromString(invitation.id),
createdAt = Instant.now,
updatedAt = None,
thid = Some(UUID.fromString(invitation.id)), // this is the default, can't with just use None?
thid = invitation.id,
label = label,
role = ConnectionRecord.Role.Inviter,
protocolState = ConnectionRecord.ProtocolState.InvitationGenerated,
Expand Down Expand Up @@ -80,6 +80,13 @@ private class ConnectionServiceImpl(
} yield record
}

override def getConnectionRecordByThreadId(thid: String): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- connectionRepository
.getConnectionRecordByThreadId(thid)
.mapError(RepositoryError.apply)
} yield record

override def deleteConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Int] = ???

override def receiveConnectionInvitation(invitation: String): IO[ConnectionServiceError, ConnectionRecord] =
Expand All @@ -88,7 +95,7 @@ private class ConnectionServiceImpl(
.fromEither(io.circe.parser.decode[Invitation](Base64Utils.decodeUrlToString(invitation)))
.mapError(err => InvitationParsingError(err))
_ <- connectionRepository
.getConnectionRecordByThreadId(UUID.fromString(invitation.id))
.getConnectionRecordByThreadId(invitation.id)
.mapError(RepositoryError.apply)
.flatMap {
case None => ZIO.unit
Expand All @@ -99,9 +106,8 @@ private class ConnectionServiceImpl(
id = UUID.randomUUID(),
createdAt = Instant.now,
updatedAt = None,
thid = Some(
UUID.fromString(invitation.id)
), // TODO: According to the standard, we should rather use 'pthid' and not 'thid'
// TODO: According to the standard, we should rather use 'pthid' and not 'thid'
thid = invitation.id,
label = None,
role = ConnectionRecord.Role.Invitee,
protocolState = ConnectionRecord.ProtocolState.InvitationReceived,
Expand Down Expand Up @@ -285,7 +291,6 @@ private class ConnectionServiceImpl(
thid <- ZIO
.fromOption(thid)
.mapError(_ => UnexpectedError("No `thid` found in credential request"))
.map(UUID.fromString)
maybeRecord <- connectionRepository
.getConnectionRecordByThreadId(thid)
.mapError(RepositoryError.apply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object ConnectionRepositorySpecSuite {
UUID.randomUUID,
Instant.ofEpochSecond(Instant.now.getEpochSecond),
None,
None,
UUID.randomUUID().toString,
None,
ConnectionRecord.Role.Inviter,
ConnectionRecord.ProtocolState.InvitationGenerated,
Expand Down Expand Up @@ -56,9 +56,9 @@ object ConnectionRepositorySpecSuite {
test("createConnectionRecord prevents creation of 2 records with the same thid") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
thid = UUID.randomUUID()
aRecord = connectionRecord.copy(thid = Some(thid))
bRecord = connectionRecord.copy(thid = Some(thid))
thid = UUID.randomUUID().toString
aRecord = connectionRecord.copy(thid = thid)
bRecord = connectionRecord.copy(thid = thid)
aCount <- repo.createConnectionRecord(aRecord)
bCount <- repo.createConnectionRecord(bRecord).exit
} yield {
Expand Down Expand Up @@ -208,8 +208,8 @@ object ConnectionRepositorySpecSuite {
test("getConnectionRecordByThreadId correctly returns an existing thid") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
thid = UUID.randomUUID()
aRecord = connectionRecord.copy(thid = Some(thid))
thid = UUID.randomUUID().toString
aRecord = connectionRecord.copy(thid = thid)
bRecord = connectionRecord
_ <- repo.createConnectionRecord(aRecord)
_ <- repo.createConnectionRecord(bRecord)
Expand All @@ -219,11 +219,11 @@ object ConnectionRepositorySpecSuite {
test("getConnectionRecordByThreadId returns nothing for an unknown thid") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
aRecord = connectionRecord.copy(thid = Some(UUID.randomUUID()))
bRecord = connectionRecord.copy(thid = Some(UUID.randomUUID()))
aRecord = connectionRecord.copy(thid = UUID.randomUUID().toString)
bRecord = connectionRecord.copy(thid = UUID.randomUUID().toString)
_ <- repo.createConnectionRecord(aRecord)
_ <- repo.createConnectionRecord(bRecord)
record <- repo.getConnectionRecordByThreadId(UUID.randomUUID())
record <- repo.getConnectionRecordByThreadId(UUID.randomUUID().toString)
} yield assertTrue(record.isEmpty)
},
test("updateConnectionProtocolState updates the record") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package io.iohk.atala.connect.core.service

import io.iohk.atala.connect.core.model.ConnectionRecord._
import io.iohk.atala.connect.core.repository.ConnectionRepositoryInMemory

import zio._
import zio.test._
import zio.test.Assertion._
import io.iohk.atala.mercury.model.DidId
import io.circe.syntax.*
import io.iohk.atala.connect.core.model.ConnectionRecord
import java.util.UUID
import io.iohk.atala.connect.core.model.ConnectionRecord.*
import io.iohk.atala.connect.core.model.error.ConnectionServiceError
import java.time.Instant
import io.circe.syntax._
import io.iohk.atala.mercury.model.Message
import io.iohk.atala.connect.core.repository.ConnectionRepositoryInMemory
import io.iohk.atala.mercury.model.{DidId, Message}
import io.iohk.atala.mercury.protocol.connection.ConnectionResponse
import zio.*
import zio.test.*
import zio.test.Assertion.*

import java.time.Instant
import java.util.UUID

object ConnectionServiceImplSpec extends ZIOSpecDefault {

Expand All @@ -32,7 +31,7 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
assertTrue(record.role == Role.Inviter) &&
assertTrue(record.connectionRequest.isEmpty) &&
assertTrue(record.connectionResponse.isEmpty) &&
assertTrue(record.thid.contains(record.id)) &&
assertTrue(record.thid == record.id.toString) &&
assertTrue(record.updatedAt.isEmpty) &&
assertTrue(record.invitation.from == did) &&
assertTrue(record.invitation.attachments.isEmpty) &&
Expand Down Expand Up @@ -113,7 +112,7 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
assertTrue(inviteeRecord.role == Role.Invitee) &&
assertTrue(inviteeRecord.connectionRequest.isEmpty) &&
assertTrue(inviteeRecord.connectionResponse.isEmpty) &&
assertTrue(inviteeRecord.thid.contains(UUID.fromString(inviterRecord.invitation.id))) &&
assertTrue(inviteeRecord.thid == inviterRecord.invitation.id) &&
assertTrue(inviteeRecord.updatedAt.isEmpty) &&
assertTrue(inviteeRecord.invitation == inviterRecord.invitation)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package io.iohk.atala.connect.sql.repository

import cats.data.NonEmptyList
import doobie.*
import doobie.implicits.*
import doobie.postgres.implicits._
import io.circe._
import io.circe.parser._
import io.circe.syntax._
import io.iohk.atala.connect.core.model.ConnectionRecord
import io.iohk.atala.connect.core.model.ConnectionRecord.ProtocolState
import io.iohk.atala.connect.core.model.ConnectionRecord.Role
import doobie.postgres.implicits.*
import io.circe.*
import io.circe.parser.*
import io.circe.syntax.*
import io.iohk.atala.connect.core.model.*
import io.iohk.atala.connect.core.model.error.ConnectionRepositoryError._
import io.iohk.atala.connect.core.model.ConnectionRecord.{ProtocolState, Role}
import io.iohk.atala.connect.core.model.error.ConnectionRepositoryError.*
import io.iohk.atala.connect.core.repository.ConnectionRepository
import io.iohk.atala.mercury.protocol.connection.*
import io.iohk.atala.mercury.protocol.invitation.v2.Invitation
Expand All @@ -20,7 +19,6 @@ import zio.interop.catz.*

import java.time.Instant
import java.util.UUID
import cats.data.NonEmptyList

class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepository[Task] {

Expand Down Expand Up @@ -184,7 +182,7 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
.transact(xa)
}

override def getConnectionRecordByThreadId(thid: UUID): Task[Option[ConnectionRecord]] = {
override def getConnectionRecordByThreadId(thid: String): Task[Option[ConnectionRecord]] = {
val cxnIO = sql"""
| SELECT
| id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ import java.util.UUID

trait CredentialService {

/** Copy pasted from Castor codebase for now TODO: replace with actual data from castor later
*
* @param method
* @param methodSpecificId
*/
final case class DID(
method: String,
methodSpecificId: String
Expand Down Expand Up @@ -63,11 +58,6 @@ trait CredentialService {
issuingDID: Option[CanonicalPrismDID]
): IO[CredentialServiceError, IssueCredentialRecord]

/** Return the full list of CredentialRecords.
*
* TODO this function API maybe change in the future to return a lazy sequence of records or something similar to a
* batabase cursor.
*/
def getIssueCredentialRecords: IO[CredentialServiceError, Seq[IssueCredentialRecord]]

def getIssueCredentialRecordsByStates(
Expand All @@ -76,11 +66,10 @@ trait CredentialService {
states: IssueCredentialRecord.ProtocolState*
): IO[CredentialServiceError, Seq[IssueCredentialRecord]]

/** Get the CredentialRecord by the record's id. If the record's id is not found the value None will be return
* instead.
*/
def getIssueCredentialRecord(recordId: DidCommID): IO[CredentialServiceError, Option[IssueCredentialRecord]]

def getIssueCredentialRecordByThreadId(thid: DidCommID): IO[CredentialServiceError, Option[IssueCredentialRecord]]

def receiveCredentialOffer(offer: OfferCredential): IO[CredentialServiceError, IssueCredentialRecord]

def acceptCredentialOffer(recordId: DidCommID, subjectId: String): IO[CredentialServiceError, IssueCredentialRecord]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ private class CredentialServiceImpl(
} yield records
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID
): IO[CredentialServiceError, Option[IssueCredentialRecord]] =
for {
record <- credentialRepository
.getIssueCredentialRecordByThreadId(thid)
.mapError(RepositoryError.apply)
} yield record

override def getIssueCredentialRecord(
recordId: DidCommID
): IO[CredentialServiceError, Option[IssueCredentialRecord]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ trait PresentationService {

def getPresentationRecord(recordId: DidCommID): IO[PresentationError, Option[PresentationRecord]]

def getPresentationRecordByThreadId(thid: DidCommID): IO[PresentationError, Option[PresentationRecord]]

def receiveRequestPresentation(
connectionId: Option[String],
request: RequestPresentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ private class PresentationServiceImpl(
} yield record
}

override def getPresentationRecordByThreadId(thid: DidCommID): IO[PresentationError, Option[PresentationRecord]] =
for {
record <- presentationRepository
.getPresentationRecordByThreadId(thid)
.mapError(RepositoryError.apply)
} yield record

override def rejectRequestPresentation(recordId: DidCommID): IO[PresentationError, Option[PresentationRecord]] = {
markRequestPresentationRejected(recordId)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.iohk.atala.connect.controller

import io.iohk.atala.api.http.model.Pagination
import io.iohk.atala.api.http.model.PaginationInput
import io.iohk.atala.api.http.{ErrorResponse, RequestContext}
import io.iohk.atala.connect.controller.http.{
AcceptConnectionInvitationRequest,
Expand All @@ -22,7 +22,7 @@ trait ConnectionController {
rc: RequestContext
): IO[ErrorResponse, Connection]

def getConnections(pagination: Pagination)(implicit
def getConnections(paginationInput: PaginationInput, thid: Option[String])(implicit
rc: RequestContext
): IO[ErrorResponse, ConnectionsPage]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.iohk.atala.connect.controller

import io.iohk.atala.agent.server.config.AppConfig
import io.iohk.atala.agent.walletapi.service.ManagedDIDService
import io.iohk.atala.api.http.model.Pagination
import io.iohk.atala.api.http.model.PaginationInput
import io.iohk.atala.api.http.{ErrorResponse, RequestContext}
import io.iohk.atala.connect.controller.ConnectionController.toHttpError
import io.iohk.atala.connect.controller.http.{
Expand Down Expand Up @@ -48,10 +48,13 @@ class ConnectionControllerImpl(
}

override def getConnections(
pagination: Pagination
paginationInput: PaginationInput,
thid: Option[String]
)(implicit rc: RequestContext): IO[ErrorResponse, ConnectionsPage] = {
val result = for {
connections <- service.getConnectionRecords()
connections <- thid match
case None => service.getConnectionRecords()
case Some(thid) => service.getConnectionRecordByThreadId(thid).map(_.toSeq)
} yield ConnectionsPage(contents = connections.map(Connection.fromDomain))

result.mapError(toHttpError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ object ConnectionEndpoints {
.description("Gets an existing connection record by its unique identifier")
.tag("Connections Management")

val getConnections: PublicEndpoint[(RequestContext, PaginationInput), ErrorResponse, ConnectionsPage, Any] =
val getConnections
: PublicEndpoint[(RequestContext, PaginationInput, Option[String]), ErrorResponse, ConnectionsPage, Any] =
endpoint.get
.in(extractFromRequest[RequestContext](RequestContext.apply))
.in("connections")
.in(paginationInput)
.in(query[Option[String]]("thid").description("The thid of a DIDComm communication."))
.out(jsonBody[ConnectionsPage].description("The list of connection records."))
.errorOut(basicFailures)
.name("getConnections")
Expand Down

0 comments on commit 9a97c7a

Please sign in to comment.