Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,19 @@ import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.update
import kotlinx.collections.immutable.PersistentMap
import kotlinx.collections.immutable.PersistentSet
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.persistentSetOf
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.yield
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.ClassDiscriminatorMode
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -68,23 +76,55 @@ public val McpJson: Json by lazy {

/**
* Additional initialization options.
*
* @property enforceStrictCapabilities whether to restrict emitted requests to only those that the remote side has indicated
* that they can handle, through their advertised capabilities.
*
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is
* considered a logic error to mis-specify those.
*
* Currently, this defaults to false, for backwards compatibility with SDK versions
* that did not advertise capabilities correctly.
* In the future, this will default to true.
*
* @property debouncedNotificationMethods an array of notification method names that should be automatically debounced.
* Any notifications with a method in this list will be coalesced if they occur in the same tick of the event loop.
* e.g., ['notifications/tools/list_changed']
*/
public open class ProtocolOptions(
/**
* Whether to restrict emitted requests to only those that the remote side has indicated
* that they can handle, through their advertised capabilities.
*
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is
* considered a logic error to mis-specify those.
*
* Currently, this defaults to false, for backwards compatibility with SDK versions
* that did not advertise capabilities correctly.
* In the future, this will default to true.
*/
public var enforceStrictCapabilities: Boolean = false,
public val debouncedNotificationMethods: List<Method> = emptyList(),
) {
public operator fun component1(): Boolean = enforceStrictCapabilities
public operator fun component2(): List<Method> = debouncedNotificationMethods

public open fun copy(
enforceStrictCapabilities: Boolean = this.enforceStrictCapabilities,
debouncedNotificationMethods: List<Method> = this.debouncedNotificationMethods,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should implement debouncing (dropping some events) from the start/ever. The specification states that the implementation should support rate limiting, not debouncing.
I would remove the word debounced from the name, because it's very implementation-specific

): ProtocolOptions = ProtocolOptions(enforceStrictCapabilities, debouncedNotificationMethods)

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as ProtocolOptions

return when {
enforceStrictCapabilities != other.enforceStrictCapabilities -> false
debouncedNotificationMethods != other.debouncedNotificationMethods -> false
else -> true
}
}

public var timeout: Duration = DEFAULT_REQUEST_TIMEOUT,
)
override fun hashCode(): Int {
var result = enforceStrictCapabilities.hashCode()
result = 31 * result + debouncedNotificationMethods.hashCode()
return result
}

override fun toString(): String =
"ProtocolOptions(enforceStrictCapabilities=$enforceStrictCapabilities, debouncedNotificationMethods=$debouncedNotificationMethods)"
}

/**
* The default request timeout.
Expand Down Expand Up @@ -153,6 +193,11 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
public val progressHandlers: Map<ProgressToken, ProgressCallback>
get() = _progressHandlers.value

@Suppress("ktlint:standard:backing-property-naming")
private val _pendingDebouncedNotifications: AtomicRef<PersistentSet<Method>> = atomic(persistentSetOf())
private val notificationScopeJob = SupervisorJob()
private val notificationScope = CoroutineScope(notificationScopeJob + Dispatchers.Default)

/**
* Callback for when the connection is closed for any reason.
*
Expand Down Expand Up @@ -224,6 +269,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
val handlersToNotify = _responseHandlers.value.values.toList()
_responseHandlers.getAndSet(persistentMapOf())
_progressHandlers.getAndSet(persistentMapOf())
_pendingDebouncedNotifications.update { it.clear() }
notificationScopeJob.cancelChildren()
transport = null
onClose()

Expand Down Expand Up @@ -489,13 +536,45 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
/**
* Emits a notification, which is a one-way message that does not expect a response.
*/
public suspend fun notification(notification: Notification) {
public suspend fun notification(notification: Notification, relatedRequestId: RequestId? = null) {
logger.trace { "Sending notification: ${notification.method}" }
val transport = this.transport ?: error("Not connected")
assertNotificationCapability(notification.method)
val jsonRpcNotification = notification.toJSON()

val isDebounced =
options?.debouncedNotificationMethods?.contains(notification.method) == true &&
notification.params == null &&
relatedRequestId == null

if (isDebounced) {
if (notification.method in _pendingDebouncedNotifications.value) {
logger.trace { "Skipping debounced notification: ${notification.method}" }
return
}

_pendingDebouncedNotifications.update { it.add(notification.method) }

notificationScope.launch {
try {
yield()
} finally {
_pendingDebouncedNotifications.update { it.remove(notification.method) }
}

val activeTransport = this@Protocol.transport ?: return@launch

try {
activeTransport.send(jsonRpcNotification)
} catch (cause: Throwable) {
logger.error(cause) { "Error sending debounced notification: ${notification.method}" }
onError(cause)
}
}
return
}

val message = notification.toJSON()
transport.send(message)
transport.send(jsonRpcNotification)
}

/**
Expand Down
Loading