Skip to content

Commit

Permalink
feat: support for mediator live mode (websocket) (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianIOHK committed Apr 17, 2024
1 parent a46db03 commit 2710997
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 35 deletions.
5 changes: 5 additions & 0 deletions atala-prism-sdk/build.gradle.kts
Expand Up @@ -106,6 +106,7 @@ kotlin {
implementation("io.ktor:ktor-client-content-negotiation:2.3.4")
implementation("io.ktor:ktor-serialization-kotlinx-json:2.3.4")
implementation("io.ktor:ktor-client-logging:2.3.4")
implementation("io.ktor:ktor-websockets:2.3.4")

implementation("io.iohk.atala.prism.didcomm:didpeer:$didpeerVersion")

Expand Down Expand Up @@ -135,12 +136,15 @@ kotlin {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.8.0")
implementation("io.ktor:ktor-client-mock:2.3.4")
implementation("junit:junit:4.13.2")
implementation("org.mockito:mockito-core:4.4.0")
implementation("org.mockito.kotlin:mockito-kotlin:4.0.0")
}
}
val jvmMain by getting {
dependencies {
implementation("io.ktor:ktor-client-okhttp:2.3.4")
implementation("app.cash.sqldelight:sqlite-driver:2.0.1")
implementation("io.ktor:ktor-client-java:2.3.4")
}
}
val jvmTest by getting
Expand All @@ -149,6 +153,7 @@ kotlin {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.8.0")
implementation("io.ktor:ktor-client-okhttp:2.3.4")
implementation("app.cash.sqldelight:android-driver:2.0.1")
implementation("io.ktor:ktor-client-android:2.3.4")
}
}
val androidInstrumentedTest by getting {
Expand Down
@@ -1,3 +1,5 @@
@file:Suppress("ktlint:standard:import-ordering")

package io.iohk.atala.prism.walletsdk.prismagent

import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Castor
Expand All @@ -9,8 +11,14 @@ import io.iohk.atala.prism.walletsdk.domain.models.Message
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.ConnectionsManager
import io.iohk.atala.prism.walletsdk.prismagent.connectionsmanager.DIDCommConnection
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
import java.time.Duration
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlin.jvm.Throws

/**
Expand All @@ -27,9 +35,99 @@ class ConnectionManager(
private val castor: Castor,
private val pluto: Pluto,
internal val mediationHandler: MediationHandler,
private var pairings: MutableList<DIDPair>
private var pairings: MutableList<DIDPair>,
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
) : ConnectionsManager, DIDCommConnection {

var fetchingMessagesJob: Job? = null

/**
* Starts the process of fetching messages at a regular interval.
*
* @param requestInterval The time interval (in seconds) between message fetch requests.
* Defaults to 5 seconds if not specified.
*/
@JvmOverloads
fun startFetchingMessages(requestInterval: Int = 5) {
// Check if the job for fetching messages is already running
if (fetchingMessagesJob == null) {
// Launch a coroutine in the provided scope
fetchingMessagesJob = scope.launch {
// Retrieve the current mediator DID
val currentMediatorDID = mediationHandler.mediatorDID
// Resolve the DID document for the mediator
val mediatorDidDoc = castor.resolveDID(currentMediatorDID.toString())
var serviceEndpoint: String? = null

// Loop through the services in the DID document to find a WebSocket endpoint
mediatorDidDoc.services.forEach {
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
serviceEndpoint = it.serviceEndpoint.uri
return@forEach // Exit loop once the WebSocket endpoint is found
}
}

// If a WebSocket service endpoint is found
serviceEndpoint?.let { serviceEndpointUrl ->
// Listen for unread messages on the WebSocket endpoint
mediationHandler.listenUnreadMessages(
serviceEndpointUrl
) { arrayMessages ->
// Process the received messages
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
arrayMessages.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}
// If there are any messages, mark them as read and store them
scope.launch {
if (messagesIds.isNotEmpty()) {
mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
}
}
}

// Fallback mechanism if no WebSocket service endpoint is available
if (serviceEndpoint == null) {
while (true) {
// Continuously await and process new messages
awaitMessages().collect { array ->
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
array.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}
if (messagesIds.isNotEmpty()) {
mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
}
// Wait for the specified request interval before fetching new messages
delay(Duration.ofSeconds(requestInterval.toLong()).toMillis())
}
}
}

// Start the coroutine if it's not already active
fetchingMessagesJob?.let {
if (it.isActive) return
it.start()
}
}
}

fun stopConnection() {
fetchingMessagesJob?.cancel()
}

/**
* Suspends the current coroutine and boots the registered mediator associated with the mediator handler.
* If no mediator is available, a [PrismAgentError.NoMediatorAvailableError] is thrown.
Expand Down
Expand Up @@ -67,11 +67,8 @@ import io.ktor.http.HttpMethod
import io.ktor.http.Url
import io.ktor.serialization.kotlinx.json.json
import java.net.UnknownHostException
import java.time.Duration
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
Expand Down Expand Up @@ -116,7 +113,6 @@ class PrismAgent {
val pluto: Pluto
val mercury: Mercury
val pollux: Pollux
var fetchingMessagesJob: Job? = null
val flowState = MutableSharedFlow<State>()

private val prismAgentScope: CoroutineScope = CoroutineScope(Dispatchers.Default)
Expand Down Expand Up @@ -298,7 +294,6 @@ class PrismAgent {
}
logger.info(message = "Stoping agent")
state = State.STOPPING
fetchingMessagesJob?.cancel()
state = State.STOPPED
logger.info(message = "Agent not running")
}
Expand Down Expand Up @@ -724,40 +719,15 @@ class PrismAgent {
*/
@JvmOverloads
fun startFetchingMessages(requestInterval: Int = 5) {
if (fetchingMessagesJob == null) {
logger.info(message = "Start streaming new unread messages")
fetchingMessagesJob = prismAgentScope.launch {
while (true) {
connectionManager.awaitMessages().collect { array ->
val messagesIds = mutableListOf<String>()
val messages = mutableListOf<Message>()
array.map { pair ->
messagesIds.add(pair.first)
messages.add(pair.second)
}
if (messagesIds.isNotEmpty()) {
connectionManager.mediationHandler.registerMessagesAsRead(
messagesIds.toTypedArray()
)
pluto.storeMessages(messages)
}
}
delay(Duration.ofSeconds(requestInterval.toLong()).toMillis())
}
}
}
fetchingMessagesJob?.let {
if (it.isActive) return
it.start()
}
connectionManager.startFetchingMessages(requestInterval)
}

/**
* Stop fetching messages
*/
fun stopFetchingMessages() {
logger.info(message = "Stop streaming new unread messages")
fetchingMessagesJob?.cancel()
connectionManager.stopConnection()
}

/**
Expand Down
@@ -1,3 +1,5 @@
@file:Suppress("ktlint:standard:import-ordering")

package io.iohk.atala.prism.walletsdk.prismagent.mediation

import io.iohk.atala.prism.walletsdk.domain.buildingblocks.Mercury
Expand All @@ -7,16 +9,24 @@ import io.iohk.atala.prism.walletsdk.domain.models.Mediator
import io.iohk.atala.prism.walletsdk.domain.models.Message
import io.iohk.atala.prism.walletsdk.domain.models.UnknownError
import io.iohk.atala.prism.walletsdk.prismagent.PrismAgentError
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationGrant
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationKeysUpdateList
import io.iohk.atala.prism.walletsdk.prismagent.protocols.mediation.MediationRequest
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupReceived
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupRequest
import io.iohk.atala.prism.walletsdk.prismagent.protocols.pickup.PickupRunner
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.websocket.WebSockets
import io.ktor.client.plugins.websocket.webSocket
import io.ktor.websocket.Frame
import io.ktor.websocket.readText
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import java.util.UUID
import kotlinx.coroutines.isActive

/**
* A class that provides an implementation of [MediationHandler] using a Pluto instance and a Mercury instance. It can
Expand Down Expand Up @@ -84,7 +94,8 @@ class BasicMediatorHandler(
val registeredMediator = bootRegisteredMediator()
if (registeredMediator == null) {
try {
val requestMessage = MediationRequest(from = host, to = mediatorDID).makeMessage()
val requestMessage =
MediationRequest(from = host, to = mediatorDID).makeMessage()
val message = mercury.sendMessageParseResponse(message = requestMessage)
?: throw UnknownError.SomethingWentWrongError(
message = "BasicMediatorHandler => mercury.sendMessageParseResponse returned null"
Expand Down Expand Up @@ -167,4 +178,75 @@ class BasicMediatorHandler(
} ?: throw PrismAgentError.NoMediatorAvailableError()
mercury.sendMessage(requestMessage)
}

/**
* Listens for unread messages from a specified WebSocket service endpoint.
*
* This function creates a WebSocket connection to the provided service endpoint URI
* and listens for incoming messages. Upon receiving messages, it processes and
* dispatches them to the specified callback function.
*
* @param serviceEndpointUri The URI of the service endpoint. It should be a valid WebSocket URI.
* @param onMessageCallback A callback function that is invoked when a message is received.
* This function is responsible for handling the incoming message.
*/
override suspend fun listenUnreadMessages(
serviceEndpointUri: String,
onMessageCallback: OnMessageCallback
) {
val client = HttpClient {
install(WebSockets)
install(HttpTimeout) {
requestTimeoutMillis = WEBSOCKET_TIMEOUT
connectTimeoutMillis = WEBSOCKET_TIMEOUT
socketTimeoutMillis = WEBSOCKET_TIMEOUT
}
}
if (serviceEndpointUri.contains("wss://") || serviceEndpointUri.contains("ws://")) {
client.webSocket(serviceEndpointUri) {
if (isActive) {
val liveDeliveryMessage = Message(
body = "{\"live_delivery\":true}",
piuri = ProtocolType.LiveDeliveryChange.value,
id = UUID.randomUUID().toString(),
from = mediator?.hostDID,
to = mediatorDID
)
val packedMessage = mercury.packMessage(liveDeliveryMessage)
send(Frame.Text(packedMessage))
}
while (isActive) {
try {
for (frame in incoming) {
if (frame is Frame.Text) {
val messages =
handleReceivedMessagesFromSockets(frame.readText())
onMessageCallback.onMessage(messages)
}
}
} catch (e: Exception) {
e.printStackTrace()
continue
}
}
}
}
}

private suspend fun handleReceivedMessagesFromSockets(text: String): Array<Pair<String, Message>> {
val decryptedMessage = mercury.unpackMessage(text)
if (decryptedMessage.piuri == ProtocolType.PickupStatus.value ||
decryptedMessage.piuri == ProtocolType.PickupDelivery.value
) {
return PickupRunner(decryptedMessage, mercury).run()
} else {
return emptyArray()
}
}
}

fun interface OnMessageCallback {
fun onMessage(messages: Array<Pair<String, Message>>)
}

const val WEBSOCKET_TIMEOUT: Long = 15_000
Expand Up @@ -57,4 +57,16 @@ interface MediationHandler {
* @param ids An array of message IDs to register as read.
*/
suspend fun registerMessagesAsRead(ids: Array<String>)

/**
* Listens for unread messages from a specified WebSocket service endpoint.
*
* @param serviceEndpointUri The URI of the service endpoint. It should be a valid WebSocket URI.
* @param onMessageCallback A callback function that is invoked when a message is received.
* This function is responsible for handling the incoming message.
*/
suspend fun listenUnreadMessages(
serviceEndpointUri: String,
onMessageCallback: OnMessageCallback
)
}
Expand Up @@ -34,6 +34,7 @@ enum class ProtocolType(val value: String) {
PickupDelivery("https://didcomm.org/messagepickup/3.0/delivery"),
PickupStatus("https://didcomm.org/messagepickup/3.0/status"),
PickupReceived("https://didcomm.org/messagepickup/3.0/messages-received"),
LiveDeliveryChange("https://didcomm.org/messagepickup/3.0/live-delivery-change"),
None("");

companion object {
Expand Down

0 comments on commit 2710997

Please sign in to comment.