Skip to content

Commit

Permalink
feat(pollux): validate the current record state on each protocol acti…
Browse files Browse the repository at this point in the history
…on received (#193)

* chore(pollux): switch back to shared 0.1.0

* fix(pollux): ensure record has the right state on each action

* chore(pollux): bump 'shared' dep version to 0.2.0

* chore(connect): update 'update_at' attribute in connect record appropriately

* fix(connect): do not accept the same invitation twice + check protocol state is valid on each update

* fix(prism-agent): fail on DIDComm sender side when receiving non-success HTTP response status from peer

* fix(pollux): check DID string provided in 'subjectId' is valid and supported

* chore(connect): bump version to 0.3.0-SNAPSHOT

* chore(connect): bump mercury version to 0.7.0-SNAPSHOT

* chore(connect): bump mercury version to 0.8.0

* chore(pollux): bump mercury version to 0.8.0

* chore(pollux): simplify CredentialService error structure

* chore(pollux): rename IssueCredentialError to CredentialServiceError

* chore(connect): rename ConnectError to ConnectionServiceError

* chore(prism-agent): undo change in MercuryUtils to allow pollux/connect merge & release

* chore(pollux): undo changes in version.sbt (connect & pollux)
  • Loading branch information
bvoiturier committed Dec 2, 2022
1 parent 9d65380 commit 6fffde2
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package io.iohk.atala.connect.core.model.error

import java.util.UUID

sealed trait ConnectionError
sealed trait ConnectionServiceError

object ConnectionError {
final case class RepositoryError(cause: Throwable) extends ConnectionError
final case class RecordIdNotFound(recordId: UUID) extends ConnectionError
final case class ThreadIdNotFound(thid: UUID) extends ConnectionError
final case class InvitationParsingError(cause: Throwable) extends ConnectionError
final case class UnexpectedError(msg: String) extends ConnectionError
object ConnectionServiceError {
final case class RepositoryError(cause: Throwable) extends ConnectionServiceError
final case class RecordIdNotFound(recordId: UUID) extends ConnectionServiceError
final case class ThreadIdNotFound(thid: UUID) extends ConnectionServiceError
final case class InvitationParsingError(cause: Throwable) extends ConnectionServiceError
final case class UnexpectedError(msg: String) extends ConnectionServiceError
final case class InvalidFlowStateError(msg: String) extends ConnectionServiceError
final case class InvitationAlreadyReceived(msg: String) extends ConnectionServiceError
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.iohk.atala.connect.core.service

import io.iohk.atala.connect.core.model.ConnectionRecord
import io.iohk.atala.connect.core.model.error.ConnectionError
import io.iohk.atala.connect.core.model.error.ConnectionServiceError
import zio._
import java.util.UUID
import io.iohk.atala.mercury.model.DidId
Expand All @@ -11,26 +11,26 @@ import io.iohk.atala.mercury.protocol.connection.ConnectionResponse

trait ConnectionService {

def createConnectionInvitation(label: Option[String]): IO[ConnectionError, ConnectionRecord]
def createConnectionInvitation(label: Option[String]): IO[ConnectionServiceError, ConnectionRecord]

def receiveConnectionInvitation(invitation: String): IO[ConnectionError, ConnectionRecord]
def receiveConnectionInvitation(invitation: String): IO[ConnectionServiceError, ConnectionRecord]

def acceptConnectionInvitation(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]]
def acceptConnectionInvitation(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def markConnectionRequestSent(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]]
def markConnectionRequestSent(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionError, Option[ConnectionRecord]]
def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, Option[ConnectionRecord]]

def acceptConnectionRequest(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]]
def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def markConnectionResponseSent(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]]
def markConnectionResponseSent(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def receiveConnectionResponse(response: ConnectionResponse): IO[ConnectionError, Option[ConnectionRecord]]
def receiveConnectionResponse(response: ConnectionResponse): IO[ConnectionServiceError, Option[ConnectionRecord]]

def getConnectionRecords(): IO[ConnectionError, Seq[ConnectionRecord]]
def getConnectionRecords(): IO[ConnectionServiceError, Seq[ConnectionRecord]]

def getConnectionRecord(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]]
def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def deleteConnectionRecord(recordId: UUID): IO[ConnectionError, Int]
def deleteConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Int]

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package io.iohk.atala.connect.core.service
import io.iohk.atala.connect.core.repository.ConnectionRepository
import io.iohk.atala.mercury.DidComm
import zio._
import io.iohk.atala.connect.core.model.error.ConnectionError
import io.iohk.atala.connect.core.model.error.ConnectionError._
import io.iohk.atala.connect.core.model.error.ConnectionServiceError
import io.iohk.atala.connect.core.model.error.ConnectionServiceError._
import io.iohk.atala.connect.core.model.ConnectionRecord
import io.iohk.atala.connect.core.model.ConnectionRecord._
import io.iohk.atala.mercury.protocol.connection.ConnectionRequest
Expand All @@ -22,7 +22,7 @@ private class ConnectionServiceImpl(
didComm: DidComm
) extends ConnectionService {

override def createConnectionInvitation(label: Option[String]): IO[ConnectionError, ConnectionRecord] =
override def createConnectionInvitation(label: Option[String]): IO[ConnectionServiceError, ConnectionRecord] =
for {
recordId <- ZIO.succeed(UUID.randomUUID)
invitation <- ZIO.succeed(createDidCommInvitation(recordId, didComm.myDid))
Expand All @@ -49,29 +49,36 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

override def getConnectionRecords(): IO[ConnectionError, Seq[ConnectionRecord]] = {
override def getConnectionRecords(): IO[ConnectionServiceError, Seq[ConnectionRecord]] = {
for {
records <- connectionRepository
.getConnectionRecords()
.mapError(RepositoryError.apply)
} yield records
}

override def getConnectionRecord(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]] = {
override def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] = {
for {
record <- connectionRepository
.getConnectionRecord(recordId)
.mapError(RepositoryError.apply)
} yield record
}

override def deleteConnectionRecord(recordId: UUID): IO[ConnectionError, Int] = ???
override def deleteConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Int] = ???

override def receiveConnectionInvitation(invitation: String): IO[ConnectionError, ConnectionRecord] =
override def receiveConnectionInvitation(invitation: String): IO[ConnectionServiceError, ConnectionRecord] =
for {
invitation <- ZIO
.fromEither(io.circe.parser.decode[Invitation](Base64Utils.decodeUrlToString(invitation)))
.mapError(err => InvitationParsingError(err))
_ <- connectionRepository
.getConnectionRecordByThreadId(UUID.fromString(invitation.id))
.mapError(RepositoryError.apply)
.flatMap {
case None => ZIO.unit
case Some(_) => ZIO.fail(InvitationAlreadyReceived(invitation.id))
}
record <- ZIO.succeed(
ConnectionRecord(
id = UUID.randomUUID(),
Expand All @@ -97,14 +104,9 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

override def acceptConnectionInvitation(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]] =
override def acceptConnectionInvitation(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
maybeRecord <- connectionRepository
.getConnectionRecord(recordId)
.mapError(RepositoryError.apply)
record <- ZIO
.fromOption(maybeRecord)
.mapError(_ => RecordIdNotFound(recordId))
record <- getRecordWithState(recordId, ProtocolState.InvitationReceived)
request = createDidCommConnectionRequest(record)
count <- connectionRepository
.updateWithConnectionRequest(recordId, request, ProtocolState.ConnectionRequestPending)
Expand All @@ -117,16 +119,16 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

override def markConnectionRequestSent(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]] =
override def markConnectionRequestSent(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] =
updateConnectionProtocolState(
recordId,
ProtocolState.ConnectionRequestPending,
ProtocolState.ConnectionRequestSent
)

override def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionError, Option[ConnectionRecord]] =
override def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordFromThreadId(request.thid)
record <- getRecordFromThreadIdAndState(request.thid, ProtocolState.InvitationGenerated)
_ <- connectionRepository
.updateWithConnectionRequest(record.id, request, ProtocolState.ConnectionRequestReceived)
.flatMap {
Expand All @@ -139,14 +141,9 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

override def acceptConnectionRequest(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]] =
override def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
maybeRecord <- connectionRepository
.getConnectionRecord(recordId)
.mapError(RepositoryError.apply)
record <- ZIO
.fromOption(maybeRecord)
.mapError(_ => RecordIdNotFound(recordId))
record <- getRecordWithState(recordId, ProtocolState.ConnectionRequestReceived)
response = createDidCommConnectionResponse(record)
count <- connectionRepository
.updateWithConnectionResponse(recordId, response, ProtocolState.ConnectionResponsePending)
Expand All @@ -159,16 +156,16 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

override def markConnectionResponseSent(recordId: UUID): IO[ConnectionError, Option[ConnectionRecord]] =
override def markConnectionResponseSent(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] =
updateConnectionProtocolState(
recordId,
ProtocolState.ConnectionResponsePending,
ProtocolState.ConnectionResponseSent
)

override def receiveConnectionResponse(response: ConnectionResponse): IO[ConnectionError, Option[ConnectionRecord]] =
override def receiveConnectionResponse(response: ConnectionResponse): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordFromThreadId(response.thid)
record <- getRecordFromThreadIdAndState(response.thid, ProtocolState.ConnectionRequestSent)
_ <- connectionRepository
.updateWithConnectionResponse(record.id, response, ProtocolState.ConnectionResponseReceived)
.flatMap {
Expand All @@ -181,6 +178,24 @@ private class ConnectionServiceImpl(
.mapError(RepositoryError.apply)
} yield record

private[this] def getRecordWithState(
recordId: UUID,
state: ProtocolState
): IO[ConnectionServiceError, ConnectionRecord] = {
for {
maybeRecord <- connectionRepository
.getConnectionRecord(recordId)
.mapError(RepositoryError.apply)
record <- ZIO
.fromOption(maybeRecord)
.mapError(_ => RecordIdNotFound(recordId))
_ <- record.protocolState match {
case s if s == state => ZIO.unit
case state => ZIO.fail(InvalidFlowStateError(s"Invalid protocol state for operation: $state"))
}
} yield record
}

private[this] def createDidCommInvitation(thid: UUID, from: DidId): Invitation = {
Invitation(
id = thid.toString,
Expand All @@ -205,7 +220,7 @@ private class ConnectionServiceImpl(
recordId: UUID,
from: ProtocolState,
to: ProtocolState
): IO[ConnectionError, Option[ConnectionRecord]] = {
): IO[ConnectionServiceError, Option[ConnectionRecord]] = {
for {
_ <- connectionRepository
.updateConnectionProtocolState(recordId, from, to)
Expand All @@ -220,9 +235,10 @@ private class ConnectionServiceImpl(
} yield record
}

private[this] def getRecordFromThreadId(
thid: Option[String]
): IO[ConnectionError, ConnectionRecord] = {
private[this] def getRecordFromThreadIdAndState(
thid: Option[String],
state: ProtocolState
): IO[ConnectionServiceError, ConnectionRecord] = {
for {
thid <- ZIO
.fromOption(thid)
Expand All @@ -234,6 +250,10 @@ private class ConnectionServiceImpl(
record <- ZIO
.fromOption(maybeRecord)
.mapError(_ => ThreadIdNotFound(thid))
_ <- record.protocolState match {
case s if s == state => ZIO.unit
case state => ZIO.fail(InvalidFlowStateError(s"Invalid protocol state for operation: $state"))
}
} yield record
}

Expand Down
2 changes: 1 addition & 1 deletion connect/lib/project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object Dependencies {
val doobie = "1.0.0-RC2"
val zioCatsInterop = "3.3.0"
val iris = "0.1.0"
val mercury = "0.7.0"
val mercury = "0.8.0"
val flyway = "9.7.0"
val shared = "0.2.0"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
val cxnIO = sql"""
| UPDATE public.connection_records
| SET
| protocol_state = $to
| protocol_state = $to,
| updated_at = ${Instant.now}
| WHERE
| id = $id
| AND protocol_state = $from
Expand All @@ -175,7 +176,8 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
| UPDATE public.connection_records
| SET
| connection_request = $request,
| protocol_state = $state
| protocol_state = $state,
| updated_at = ${Instant.now}
| WHERE
| id = $recordId
""".stripMargin.update
Expand All @@ -193,7 +195,8 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
| UPDATE public.connection_records
| SET
| connection_response = $response,
| protocol_state = $state
| protocol_state = $state,
| updated_at = ${Instant.now}
| WHERE
| id = $recordId
""".stripMargin.update
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.iohk.atala.pollux.core.model.error

import java.util.UUID
import io.iohk.atala.pollux.vc.jwt.W3cCredentialPayload

sealed trait CredentialServiceError

object CredentialServiceError {
final case class RepositoryError(cause: Throwable) extends CredentialServiceError
final case class RecordIdNotFound(recordId: UUID) extends CredentialServiceError
final case class ThreadIdNotFound(thid: UUID) extends CredentialServiceError
final case class InvalidFlowStateError(msg: String) extends CredentialServiceError
final case class UnexpectedError(msg: String) extends CredentialServiceError
final case class UnsupportedDidFormat(did: String) extends CredentialServiceError
final case class CreateCredentialPayloadFromRecordError(cause: Throwable) extends CredentialServiceError
final case class CredentialIdNotDefined(credential: W3cCredentialPayload) extends CredentialServiceError
final case class IrisError(cause: Throwable) extends CredentialServiceError
}

This file was deleted.

This file was deleted.

This file was deleted.

0 comments on commit 6fffde2

Please sign in to comment.