Skip to content

Commit

Permalink
feat: revocation credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianIOHK committed Apr 17, 2024
1 parent 2710997 commit 9971eba
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 33 deletions.
Expand Up @@ -50,7 +50,7 @@ class AnoncredsTests {
polluxMock = PolluxMock()
mediationHandlerMock = MediationHandlerMock()
// Pairing will be removed in the future
connectionManager = ConnectionManager(mercuryMock, castorMock, plutoMock, mediationHandlerMock, mutableListOf())
connectionManager = ConnectionManager(mercuryMock, castorMock, plutoMock, mediationHandlerMock, mutableListOf(), polluxMock)
json = Json {
ignoreUnknownKeys = true
prettyPrint = true
Expand Down
Expand Up @@ -327,4 +327,6 @@ interface Pluto {
* or null if no metadata is found.
*/
fun getCredentialMetadata(linkSecretName: String): Flow<CredentialRequestMeta?>

fun revokeCredential(credentialId: String)
}
Expand Up @@ -1020,4 +1020,8 @@ class PlutoImpl(private val connection: DbConnection) : Pluto {
)
}
}

override fun revokeCredential(credentialId: String) {
getInstance().storableCredentialQueries.revokeCredentialById(credentialId)
}
}
Expand Up @@ -2,15 +2,23 @@

package io.iohk.atala.prism.walletsdk.prismagent

import io.iohk.atala.prism.apollo.base64.base64UrlDecoded
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Castor
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Mercury
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Pluto
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Pollux
import io.iohk.atala.prism.walletsdk.domain.models.AttachmentBase64
import io.iohk.atala.prism.walletsdk.domain.models.CredentialType
import io.iohk.atala.prism.walletsdk.domain.models.DID
import io.iohk.atala.prism.walletsdk.domain.models.DIDPair
import io.iohk.atala.prism.walletsdk.domain.models.Message
import io.iohk.atala.prism.walletsdk.pollux.models.JWTCredential
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.ConnectionsManager
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.DIDCommConnection
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
import io.iohk.atala.prism.walletsdk.prismagent.protocols.issueCredential.IssueCredential
import io.iohk.atala.prism.walletsdk.prismagent.protocols.revocation.RevocationNotification
import java.time.Duration
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
Expand All @@ -36,6 +44,7 @@ class ConnectionManager(
private val pluto: Pluto,
internal val mediationHandler: MediationHandler,
private var pairings: MutableList<DIDPair>,
private val pollux: Pollux,
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
) : ConnectionsManager, DIDCommConnection {

Expand Down Expand Up @@ -73,22 +82,23 @@ class ConnectionManager(
mediationHandler.listenUnreadMessages(
serviceEndpointUrl
) { arrayMessages ->
// Process the received messages
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
arrayMessages.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}
// If there are any messages, mark them as read and store them
scope.launch {
if (messagesIds.isNotEmpty()) {
mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
}
processMessages(arrayMessages)
// // Process the received messages
// val messagesIds = mutableListOf<String>()
// val messages = mutableListOf<Message>()
// arrayMessages.map { pair ->
// messagesIds.add(pair.first)
// messages.add(pair.second)
// }
// // If there are any messages, mark them as read and store them
// scope.launch {
// if (messagesIds.isNotEmpty()) {
// mediationHandler.registerMessagesAsRead(
// messagesIds.toTypedArray()
// )
// pluto.storeMessages(messages)
// }
// }
}
}

Expand All @@ -97,18 +107,19 @@ class ConnectionManager(
while (true) {
// Continuously await and process new messages
awaitMessages().collect { array ->
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
array.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}
if (messagesIds.isNotEmpty()) {
mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
processMessages(array)
// val messagesIds = mutableListOf<String>()
// val messages = mutableListOf<Message>()
// array.map { pair ->
// messagesIds.add(pair.first)
// messages.add(pair.second)
// }
// if (messagesIds.isNotEmpty()) {
// mediationHandler.registerMessagesAsRead(
// messagesIds.toTypedArray()
// )
// pluto.storeMessages(messages)
// }
}
// Wait for the specified request interval before fetching new messages
delay(Duration.ofSeconds(requestInterval.toLong()).toMillis())
Expand Down Expand Up @@ -198,6 +209,47 @@ class ConnectionManager(
return null
}

internal fun processMessages(arrayMessages: Array<Pair<String, Message>>) {
scope.launch {
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
arrayMessages.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}

val allMessages = pluto.getAllMessages().first()

val revokedMessages = messages.filter { it.piuri == ProtocolType.PrismRevocation.value }
revokedMessages.forEach { msg ->
val revokedMessage = RevocationNotification.fromMessage(msg)
val threadId = revokedMessage.body.threadId
val matchingMessages =
allMessages.filter { it.piuri == ProtocolType.DidcommIssueCredential.value && it.thid == threadId }
if (matchingMessages.isNotEmpty()) {
matchingMessages.forEach { message ->
val issueMessage = IssueCredential.fromMessage(message)
if (pollux.extractCredentialFormatFromMessage(issueMessage.attachments) == CredentialType.JWT) {
val attachment = issueMessage.attachments.firstOrNull()?.data as? AttachmentBase64
attachment?.let {
val credentialId = it.base64.base64UrlDecoded
pluto.revokeCredential(credentialId)
}
}
}
}
}

// If there are any messages, mark them as read and store them
if (messagesIds.isNotEmpty()) {
mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
}
}

/**
* Awaits a response to a specified message ID from the connection.
*
Expand Down
Expand Up @@ -222,7 +222,7 @@ class PrismAgent {
this.logger = logger
// Pairing will be removed in the future
this.connectionManager =
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf())
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
}

init {
Expand Down Expand Up @@ -455,7 +455,7 @@ class PrismAgent {
fun setupMediatorHandler(mediatorHandler: MediationHandler) {
stop()
this.connectionManager =
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf())
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
}

/**
Expand Down
Expand Up @@ -35,6 +35,7 @@ enum class ProtocolType(val value: String) {
PickupStatus("https://didcomm.org/messagepickup/3.0/status"),
PickupReceived("https://didcomm.org/messagepickup/3.0/messages-received"),
LiveDeliveryChange("https://didcomm.org/messagepickup/3.0/live-delivery-change"),
PrismRevocation("https://atalaprism.io/revocation_notification/1.0/revoke"),
None("");

companion object {
Expand Down
@@ -0,0 +1,57 @@
package io.iohk.atala.prism.walletsdk.prismagent.protocols.revocation

import io.iohk.atala.prism.walletsdk.domain.models.DID
import io.iohk.atala.prism.walletsdk.domain.models.Message
import io.iohk.atala.prism.walletsdk.prismagent.PrismAgentError
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
import kotlinx.serialization.SerialName
import java.util.UUID
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json

class RevocationNotification(
val id: String = UUID.randomUUID().toString(),
val body: Body,
val from: DID,
val to: DID
) {
val type = ProtocolType.PrismRevocation

fun makeMessage(): Message {
return Message(
id = id,
piuri = type.value,
from = from,
to = to,
body = Json.encodeToString(body)
)
}

@Serializable
data class Body @JvmOverloads constructor(
@SerialName("issueCredentialProtocolThreadId")
val threadId: String,
val comment: String?
)

companion object {
fun fromMessage(message: Message): RevocationNotification {
require(
message.piuri == ProtocolType.PrismRevocation.value &&
message.from != null &&
message.to != null
) {
throw PrismAgentError.InvalidMessageType(
type = message.piuri,
shouldBe = ProtocolType.PrismRevocation.value
)
}
return RevocationNotification(
body = Json.decodeFromString(message.body),
from = message.from,
to = message.to
)
}
}
}
Expand Up @@ -23,3 +23,8 @@ SELECT StorableCredential.*, AvailableClaims.claim AS claims
FROM StorableCredential
LEFT JOIN AvailableClaims ON StorableCredential.id = AvailableClaims.credentialId
GROUP BY StorableCredential.id;

revokeCredentialById:
UPDATE StorableCredential
SET revoked = 1
WHERE id = :id;
Expand Up @@ -180,4 +180,8 @@ class PlutoMock : Pluto {
override fun getCredentialMetadata(linkSecretName: String): Flow<CredentialRequestMeta?> {
TODO("Not yet implemented")
}

override fun revokeCredential(credentialId: String) {
TODO("Not yet implemented")
}
}
Expand Up @@ -2,15 +2,21 @@

package io.iohk.atala.prism.walletsdk.prismagent

import io.iohk.atala.prism.apollo.base64.base64UrlEncoded
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Castor
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Mercury
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Pluto
import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Pollux
import io.iohk.atala.prism.walletsdk.domain.models.AttachmentBase64
import io.iohk.atala.prism.walletsdk.domain.models.AttachmentDescriptor
import io.iohk.atala.prism.walletsdk.domain.models.CredentialType
import io.iohk.atala.prism.walletsdk.domain.models.Curve
import io.iohk.atala.prism.walletsdk.domain.models.DID
import io.iohk.atala.prism.walletsdk.domain.models.DIDDocument
import io.iohk.atala.prism.walletsdk.domain.models.DIDUrl
import io.iohk.atala.prism.walletsdk.domain.models.Message
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.TestCoroutineDispatcher
Expand All @@ -22,9 +28,16 @@ import org.mockito.Mockito.`when`
import org.mockito.MockitoAnnotations
import org.mockito.kotlin.any
import java.util.UUID
import kotlinx.coroutines.flow.Flow
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.anyList
import org.mockito.kotlin.anyArray
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.mock
import kotlin.test.assertNotNull
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals

class ConnectionManagerTest {

Expand All @@ -37,6 +50,9 @@ class ConnectionManagerTest {
@Mock
lateinit var plutoMock: Pluto

@Mock
lateinit var polluxMock: Pollux

@Mock
lateinit var basicMediatorHandlerMock: MediationHandler

Expand All @@ -53,6 +69,7 @@ class ConnectionManagerTest {
pluto = plutoMock,
mediationHandler = basicMediatorHandlerMock,
pairings = mutableListOf(),
pollux = polluxMock,
scope = CoroutineScope(testDispatcher)
)
}
Expand Down Expand Up @@ -168,4 +185,55 @@ class ConnectionManagerTest {
verify(basicMediatorHandlerMock).pickupUnreadMessages(10)
verify(basicMediatorHandlerMock).registerMessagesAsRead(arrayOf("1234"))
}

@Test
fun testConnectionManager_whenProcessMessageRevoke_thenAllCorrect() = runTest {
val threadId = UUID.randomUUID().toString()
val attachments: Array<AttachmentDescriptor> =
arrayOf(
AttachmentDescriptor(
mediaType = "application/json",
format = CredentialType.JWT.type,
data = AttachmentBase64(base64 = "asdfasdfasdfasdfasdfasdfasdfasdfasdf".base64UrlEncoded)
)
)
val listMessages = listOf(
Message(
piuri = ProtocolType.DidcommconnectionRequest.value,
body = ""
),
Message(
piuri = ProtocolType.DidcommIssueCredential.value,
thid = threadId,
from = DID("did:peer:asdf897a6sdf"),
to = DID("did:peer:f706sg678ha"),
attachments = attachments,
body = """{}"""
)
)
val messageList: Flow<List<Message>> = flow {
emit(listMessages)
}
`when`(plutoMock.getAllMessages()).thenReturn(messageList)
`when`(polluxMock.extractCredentialFormatFromMessage(any())).thenReturn(CredentialType.JWT)

val messages = arrayOf(
Pair(
threadId, Message(
piuri = ProtocolType.PrismRevocation.value,
from = DID("did:peer:0978aszdf7890asg"),
to = DID("did:peer:asdf9068asdf"),
body = """{"threadId":"$threadId","comment":null}"""
)
)
)


connectionManager.processMessages(messages)
val argumentCaptor = argumentCaptor<String>()
verify(plutoMock).revokeCredential(argumentCaptor.capture())
assertEquals("asdfasdfasdfasdfasdfasdfasdfasdfasdf", argumentCaptor.firstValue)
verify(basicMediatorHandlerMock).registerMessagesAsRead(anyArray())
verify(plutoMock).storeMessages(anyList())
}
}

0 comments on commit 9971eba

Please sign in to comment.