From 06a29ce91e950ba6c428274e6c4fbc3d398158fd Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 17 Nov 2025 21:40:44 +0100 Subject: [PATCH 1/4] Update error handling in `Protocol` for unknown message IDs --- .../modelcontextprotocol/kotlin/sdk/shared/Protocol.kt | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 64fe355c..c42a22c2 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -344,7 +344,11 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio if (handler != null) { messageId?.let { msg -> _progressHandlers.update { it.remove(msg) } } } else { - onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}")) + onError( + IllegalStateException( + "Received a response for an unknown message ID: ${McpJson.encodeToString(error ?: response)}", + ), + ) return } @@ -352,12 +356,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio handler(response, null) } else { check(error != null) - val error = McpException( + val mcpException = McpException( code = error.error.code, message = error.error.message, data = error.error.data, ) - handler(null, error) + handler(null, mcpException) } } From 78f431a7632d942496917133cb5b66d572648dc5 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 17 Nov 2025 22:18:48 +0100 Subject: [PATCH 2/4] Refactor request handling to enhance progress token management in `Protocol` --- .../kotlin/sdk/shared/Protocol.kt | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index c42a22c2..db8b6efd 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -407,18 +407,30 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio assertCapabilityForMethod(request.method) } - val message = request.toJSON() - val messageId = message.id + val jsonRpcRequest = request.toJSON().run { + options?.onProgress?.let { progressHandler -> + logger.trace { "Registering progress handler for request id: $id" } + _progressHandlers.update { current -> + current.put(id, progressHandler) + } - if (options?.onProgress != null) { - logger.trace { "Registering progress handler for request id: $messageId" } - _progressHandlers.update { current -> - current.put(messageId, options.onProgress) - } + val paramsObject = (this.params as? JsonObject) ?: JsonObject(emptyMap()) + val metaObject = request.params?.meta?.json ?: JsonObject(emptyMap()) + + val updatedMeta = JsonObject( + metaObject + ("progressToken" to McpJson.encodeToJsonElement(id)), + ) + val updatedParams = JsonObject( + paramsObject + ("_meta" to updatedMeta), + ) + + this.copy(params = updatedParams) + } ?: this } + val jsonRpcRequestId = jsonRpcRequest.id _responseHandlers.update { current -> - current.put(messageId) { response, error -> + current.put(jsonRpcRequestId) { response, error -> if (error != null) { result.completeExceptionally(error) return@put @@ -434,12 +446,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } val cancel: suspend (Throwable) -> Unit = { reason: Throwable -> - _responseHandlers.update { current -> current.remove(messageId) } - _progressHandlers.update { current -> current.remove(messageId) } + _responseHandlers.update { current -> current.remove(jsonRpcRequestId) } + _progressHandlers.update { current -> current.remove(jsonRpcRequestId) } val notification = CancelledNotification( params = CancelledNotificationParams( - requestId = messageId, + requestId = jsonRpcRequestId, reason = reason.message ?: "Unknown", ), ) @@ -456,8 +468,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT try { withTimeout(timeout) { - logger.trace { "Sending request message with id: $messageId" } - this@Protocol.transport?.send(message) + logger.trace { "Sending request message with id: $jsonRpcRequestId" } + this@Protocol.transport?.send(jsonRpcRequest) } return result.await() } catch (cause: TimeoutCancellationException) { From 205fbbc1375c5ebdc5611c0e51773ae06511bcd8 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Mon, 17 Nov 2025 23:16:42 +0100 Subject: [PATCH 3/4] Add unit tests for `Protocol` to validate progress token handling and meta behavior --- kotlin-sdk-core/build.gradle.kts | 1 + .../kotlin/sdk/shared/ProtocolTest.kt | 199 ++++++++++++++++++ 2 files changed, 200 insertions(+) create mode 100644 kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt diff --git a/kotlin-sdk-core/build.gradle.kts b/kotlin-sdk-core/build.gradle.kts index 4e2531b4..dd095e67 100644 --- a/kotlin-sdk-core/build.gradle.kts +++ b/kotlin-sdk-core/build.gradle.kts @@ -122,6 +122,7 @@ kotlin { commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt new file mode 100644 index 00000000..51451fe7 --- /dev/null +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt @@ -0,0 +1,199 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest +import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta +import kotlinx.coroutines.async +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonObjectBuilder +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.int +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class ProtocolTest { + private lateinit var protocol: TestProtocol + private lateinit var transport: RecordingTransport + + @BeforeTest + fun setUp() { + protocol = TestProtocol() + transport = RecordingTransport() + } + + @Test + fun `should preserve existing meta when adding progress token`() = runTest { + protocol.connect(transport) + val request = ReadResourceRequest( + ReadResourceRequestParams( + uri = "test://resource", + meta = metaOf { + put("customField", JsonPrimitive("customValue")) + put("anotherField", JsonPrimitive(123)) + }, + ), + ) + + val inFlight = async { + protocol.request( + request = request, + options = RequestOptions(onProgress = {}), + ) + } + + val sent = transport.awaitRequest() + val params = requireNotNull(sent.params).jsonObject + val meta = params["_meta"]!!.jsonObject + + assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) + assertEquals("customValue", meta["customField"]!!.jsonPrimitive.content) + assertEquals(123, meta["anotherField"]!!.jsonPrimitive.int) + assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + + transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) + inFlight.await() + } + + @Test + fun `should create meta with progress token when none exists`() = runTest { + protocol.connect(transport) + val request = ReadResourceRequest( + ReadResourceRequestParams( + uri = "test://resource", + meta = null, + ), + ) + + val inFlight = async { + protocol.request( + request = request, + options = RequestOptions(onProgress = {}), + ) + } + + val sent = transport.awaitRequest() + val params = requireNotNull(sent.params).jsonObject + val meta = params["_meta"]!!.jsonObject + + assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) + assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + + transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) + inFlight.await() + } + + @Test + fun `should not modify meta when onProgress is absent`() = runTest { + protocol.connect(transport) + val originalMeta = metaJson { + put("customField", JsonPrimitive("customValue")) + } + val request = ReadResourceRequest( + ReadResourceRequestParams( + uri = "test://resource", + meta = RequestMeta(originalMeta), + ), + ) + + val inFlight = async { + protocol.request(request) + } + + val sent = transport.awaitRequest() + val params = requireNotNull(sent.params).jsonObject + val meta = params["_meta"]!!.jsonObject + + assertEquals(originalMeta, meta) + assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) + + transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) + inFlight.await() + } + + @Test + fun `should create params object when request params are null`() = runTest { + protocol.connect(transport) + val request = CustomRequest( + method = Method.Custom("example"), + params = null, + ) + + val inFlight = async { + protocol.request( + request = request, + options = RequestOptions(onProgress = {}), + ) + } + + val sent = transport.awaitRequest() + val params = requireNotNull(sent.params).jsonObject + val meta = params["_meta"]!!.jsonObject + + assertEquals(setOf("_meta"), params.keys) + assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + + transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) + inFlight.await() + } +} + +private class TestProtocol : Protocol(null) { + override fun assertCapabilityForMethod(method: Method) {} + override fun assertNotificationCapability(method: Method) {} + override fun assertRequestHandlerCapability(method: Method) {} +} + +private class RecordingTransport : Transport { + private val sentMessages = Channel(Channel.UNLIMITED) + private var onMessageCallback: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseCallback: (() -> Unit)? = null + + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + sentMessages.send(message) + } + + override suspend fun close() { + onCloseCallback?.invoke() + } + + override fun onClose(block: () -> Unit) { + onCloseCallback = block + } + + override fun onError(block: (Throwable) -> Unit) {} + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageCallback = block + } + + suspend fun awaitRequest(): JSONRPCRequest { + val message = sentMessages.receive() + return message as? JSONRPCRequest + ?: error("Expected JSONRPCRequest but received ${message::class.simpleName}") + } + + suspend fun deliver(message: JSONRPCMessage) { + val callback = onMessageCallback ?: error("onMessage callback not registered") + callback(message) + } +} + +private fun metaOf(builderAction: JsonObjectBuilder.() -> Unit): RequestMeta = RequestMeta(metaJson(builderAction)) + +private fun metaJson(builderAction: JsonObjectBuilder.() -> Unit): JsonObject = buildJsonObject(builderAction) From a130902f11cea6fadbb00290a796e3105fa5cbd1 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Thu, 20 Nov 2025 18:41:43 +0100 Subject: [PATCH 4/4] Refactor `ProtocolTest` to use Kotest matchers for improved readability and null safety --- .../kotlin/sdk/shared/ProtocolTest.kt | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt index 51451fe7..f53366c3 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt @@ -1,5 +1,8 @@ package io.modelcontextprotocol.kotlin.sdk.shared +import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage @@ -23,7 +26,6 @@ import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import kotlin.test.BeforeTest import kotlin.test.Test -import kotlin.test.assertEquals class ProtocolTest { private lateinit var protocol: TestProtocol @@ -56,13 +58,13 @@ class ProtocolTest { } val sent = transport.awaitRequest() - val params = requireNotNull(sent.params).jsonObject - val meta = params["_meta"]!!.jsonObject + val params = sent.params?.jsonObject.shouldNotBeNull() + val meta = params["_meta"]?.jsonObject.shouldNotBeNull() - assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) - assertEquals("customValue", meta["customField"]!!.jsonPrimitive.content) - assertEquals(123, meta["anotherField"]!!.jsonPrimitive.int) - assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + params["uri"]?.jsonPrimitive?.content shouldBe "test://resource" + meta["customField"]?.jsonPrimitive?.content shouldBe "customValue" + meta["anotherField"]?.jsonPrimitive?.int shouldBe 123 + meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id) transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) inFlight.await() @@ -86,11 +88,11 @@ class ProtocolTest { } val sent = transport.awaitRequest() - val params = requireNotNull(sent.params).jsonObject - val meta = params["_meta"]!!.jsonObject + val params = sent.params?.jsonObject.shouldNotBeNull() + val meta = params["_meta"]?.jsonObject.shouldNotBeNull() - assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) - assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + params["uri"]?.jsonPrimitive?.content shouldBe "test://resource" + meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id) transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) inFlight.await() @@ -114,11 +116,11 @@ class ProtocolTest { } val sent = transport.awaitRequest() - val params = requireNotNull(sent.params).jsonObject - val meta = params["_meta"]!!.jsonObject + val params = sent.params?.jsonObject.shouldNotBeNull() + val meta = params["_meta"]?.jsonObject.shouldNotBeNull() - assertEquals(originalMeta, meta) - assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content) + meta shouldBe originalMeta + params["uri"]?.jsonPrimitive?.content shouldBe "test://resource" transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) inFlight.await() @@ -140,11 +142,11 @@ class ProtocolTest { } val sent = transport.awaitRequest() - val params = requireNotNull(sent.params).jsonObject - val meta = params["_meta"]!!.jsonObject + val params = sent.params?.jsonObject.shouldNotBeNull() + val meta = params["_meta"]?.jsonObject.shouldNotBeNull() - assertEquals(setOf("_meta"), params.keys) - assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"]) + params.keys shouldContainExactly setOf("_meta") + meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id) transport.deliver(JSONRPCResponse(sent.id, EmptyResult())) inFlight.await()