diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index db8b6efd..40ca1132 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -27,11 +27,19 @@ import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.update import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.PersistentSet import kotlinx.collections.immutable.persistentMapOf +import kotlinx.collections.immutable.persistentSetOf import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.cancelChildren +import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.yield import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.ClassDiscriminatorMode import kotlinx.serialization.json.Json @@ -68,23 +76,55 @@ public val McpJson: Json by lazy { /** * Additional initialization options. + * + * @property enforceStrictCapabilities whether to restrict emitted requests to only those that the remote side has indicated + * that they can handle, through their advertised capabilities. + * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is + * considered a logic error to mis-specify those. + * + * Currently, this defaults to false, for backwards compatibility with SDK versions + * that did not advertise capabilities correctly. + * In the future, this will default to true. + * + * @property debouncedNotificationMethods an array of notification method names that should be automatically debounced. + * Any notifications with a method in this list will be coalesced if they occur in the same tick of the event loop. + * e.g., ['notifications/tools/list_changed'] */ public open class ProtocolOptions( - /** - * Whether to restrict emitted requests to only those that the remote side has indicated - * that they can handle, through their advertised capabilities. - * - * Note that this DOES NOT affect checking of _local_ side capabilities, as it is - * considered a logic error to mis-specify those. - * - * Currently, this defaults to false, for backwards compatibility with SDK versions - * that did not advertise capabilities correctly. - * In the future, this will default to true. - */ public var enforceStrictCapabilities: Boolean = false, + public val debouncedNotificationMethods: List = emptyList(), +) { + public operator fun component1(): Boolean = enforceStrictCapabilities + public operator fun component2(): List = debouncedNotificationMethods + + public open fun copy( + enforceStrictCapabilities: Boolean = this.enforceStrictCapabilities, + debouncedNotificationMethods: List = this.debouncedNotificationMethods, + ): ProtocolOptions = ProtocolOptions(enforceStrictCapabilities, debouncedNotificationMethods) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as ProtocolOptions + + return when { + enforceStrictCapabilities != other.enforceStrictCapabilities -> false + debouncedNotificationMethods != other.debouncedNotificationMethods -> false + else -> true + } + } - public var timeout: Duration = DEFAULT_REQUEST_TIMEOUT, -) + override fun hashCode(): Int { + var result = enforceStrictCapabilities.hashCode() + result = 31 * result + debouncedNotificationMethods.hashCode() + return result + } + + override fun toString(): String = + "ProtocolOptions(enforceStrictCapabilities=$enforceStrictCapabilities, debouncedNotificationMethods=$debouncedNotificationMethods)" +} /** * The default request timeout. @@ -153,6 +193,11 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio public val progressHandlers: Map get() = _progressHandlers.value + @Suppress("ktlint:standard:backing-property-naming") + private val _pendingDebouncedNotifications: AtomicRef> = atomic(persistentSetOf()) + private val notificationScopeJob = SupervisorJob() + private val notificationScope = CoroutineScope(notificationScopeJob + Dispatchers.Default) + /** * Callback for when the connection is closed for any reason. * @@ -224,6 +269,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val handlersToNotify = _responseHandlers.value.values.toList() _responseHandlers.getAndSet(persistentMapOf()) _progressHandlers.getAndSet(persistentMapOf()) + _pendingDebouncedNotifications.update { it.clear() } + notificationScopeJob.cancelChildren() transport = null onClose() @@ -489,13 +536,45 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio /** * Emits a notification, which is a one-way message that does not expect a response. */ - public suspend fun notification(notification: Notification) { + public suspend fun notification(notification: Notification, relatedRequestId: RequestId? = null) { logger.trace { "Sending notification: ${notification.method}" } val transport = this.transport ?: error("Not connected") assertNotificationCapability(notification.method) + val jsonRpcNotification = notification.toJSON() + + val isDebounced = + options?.debouncedNotificationMethods?.contains(notification.method) == true && + notification.params == null && + relatedRequestId == null + + if (isDebounced) { + if (notification.method in _pendingDebouncedNotifications.value) { + logger.trace { "Skipping debounced notification: ${notification.method}" } + return + } + + _pendingDebouncedNotifications.update { it.add(notification.method) } + + notificationScope.launch { + try { + yield() + } finally { + _pendingDebouncedNotifications.update { it.remove(notification.method) } + } + + val activeTransport = this@Protocol.transport ?: return@launch + + try { + activeTransport.send(jsonRpcNotification) + } catch (cause: Throwable) { + logger.error(cause) { "Error sending debounced notification: ${notification.method}" } + onError(cause) + } + } + return + } - val message = notification.toJSON() - transport.send(message) + transport.send(jsonRpcNotification) } /**