Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kotlin-sdk-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ kotlin {
commonTest {
dependencies {
implementation(kotlin("test"))
implementation(libs.kotlinx.coroutines.test)
implementation(libs.kotest.assertions.core)
implementation(libs.kotest.assertions.json)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,20 +344,24 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
if (handler != null) {
messageId?.let { msg -> _progressHandlers.update { it.remove(msg) } }
} else {
onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}"))
onError(
IllegalStateException(
"Received a response for an unknown message ID: ${McpJson.encodeToString(error ?: response)}",
),
)
return
}

if (response != null) {
handler(response, null)
} else {
check(error != null)
val error = McpException(
val mcpException = McpException(
code = error.error.code,
message = error.error.message,
data = error.error.data,
)
handler(null, error)
handler(null, mcpException)
}
}

Expand Down Expand Up @@ -403,18 +407,30 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
assertCapabilityForMethod(request.method)
}

val message = request.toJSON()
val messageId = message.id
val jsonRpcRequest = request.toJSON().run {
options?.onProgress?.let { progressHandler ->
logger.trace { "Registering progress handler for request id: $id" }
_progressHandlers.update { current ->
current.put(id, progressHandler)
}

if (options?.onProgress != null) {
logger.trace { "Registering progress handler for request id: $messageId" }
_progressHandlers.update { current ->
current.put(messageId, options.onProgress)
}
val paramsObject = (this.params as? JsonObject) ?: JsonObject(emptyMap())
val metaObject = request.params?.meta?.json ?: JsonObject(emptyMap())

val updatedMeta = JsonObject(
metaObject + ("progressToken" to McpJson.encodeToJsonElement(id)),
)
val updatedParams = JsonObject(
paramsObject + ("_meta" to updatedMeta),
)

this.copy(params = updatedParams)
} ?: this
}
val jsonRpcRequestId = jsonRpcRequest.id

_responseHandlers.update { current ->
current.put(messageId) { response, error ->
current.put(jsonRpcRequestId) { response, error ->
if (error != null) {
result.completeExceptionally(error)
return@put
Expand All @@ -430,12 +446,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
}

val cancel: suspend (Throwable) -> Unit = { reason: Throwable ->
_responseHandlers.update { current -> current.remove(messageId) }
_progressHandlers.update { current -> current.remove(messageId) }
_responseHandlers.update { current -> current.remove(jsonRpcRequestId) }
_progressHandlers.update { current -> current.remove(jsonRpcRequestId) }

val notification = CancelledNotification(
params = CancelledNotificationParams(
requestId = messageId,
requestId = jsonRpcRequestId,
reason = reason.message ?: "Unknown",
),
)
Expand All @@ -452,8 +468,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
try {
withTimeout(timeout) {
logger.trace { "Sending request message with id: $messageId" }
this@Protocol.transport?.send(message)
logger.trace { "Sending request message with id: $jsonRpcRequestId" }
this@Protocol.transport?.send(jsonRpcRequest)
}
return result.await()
} catch (cause: TimeoutCancellationException) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest
import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.encodeToJsonElement
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.test.BeforeTest
import kotlin.test.Test

class ProtocolTest {
private lateinit var protocol: TestProtocol
private lateinit var transport: RecordingTransport

@BeforeTest
fun setUp() {
protocol = TestProtocol()
transport = RecordingTransport()
}

@Test
fun `should preserve existing meta when adding progress token`() = runTest {
protocol.connect(transport)
val request = ReadResourceRequest(
ReadResourceRequestParams(
uri = "test://resource",
meta = metaOf {
put("customField", JsonPrimitive("customValue"))
put("anotherField", JsonPrimitive(123))
},
),
)

val inFlight = async {
protocol.request<EmptyResult>(
request = request,
options = RequestOptions(onProgress = {}),
)
}

val sent = transport.awaitRequest()
val params = sent.params?.jsonObject.shouldNotBeNull()
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()

params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"
meta["customField"]?.jsonPrimitive?.content shouldBe "customValue"
meta["anotherField"]?.jsonPrimitive?.int shouldBe 123
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)

transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
inFlight.await()
}

@Test
fun `should create meta with progress token when none exists`() = runTest {
protocol.connect(transport)
val request = ReadResourceRequest(
ReadResourceRequestParams(
uri = "test://resource",
meta = null,
),
)

val inFlight = async {
protocol.request<EmptyResult>(
request = request,
options = RequestOptions(onProgress = {}),
)
}

val sent = transport.awaitRequest()
val params = sent.params?.jsonObject.shouldNotBeNull()
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()

params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)

transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
inFlight.await()
}

@Test
fun `should not modify meta when onProgress is absent`() = runTest {
protocol.connect(transport)
val originalMeta = metaJson {
put("customField", JsonPrimitive("customValue"))
}
val request = ReadResourceRequest(
ReadResourceRequestParams(
uri = "test://resource",
meta = RequestMeta(originalMeta),
),
)

val inFlight = async {
protocol.request<EmptyResult>(request)
}

val sent = transport.awaitRequest()
val params = sent.params?.jsonObject.shouldNotBeNull()
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()

meta shouldBe originalMeta
params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"

transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
inFlight.await()
}

@Test
fun `should create params object when request params are null`() = runTest {
protocol.connect(transport)
val request = CustomRequest(
method = Method.Custom("example"),
params = null,
)

val inFlight = async {
protocol.request<EmptyResult>(
request = request,
options = RequestOptions(onProgress = {}),
)
}

val sent = transport.awaitRequest()
val params = sent.params?.jsonObject.shouldNotBeNull()
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()

params.keys shouldContainExactly setOf("_meta")
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)

transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
inFlight.await()
}
}

private class TestProtocol : Protocol(null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good stuff, let's extract it

override fun assertCapabilityForMethod(method: Method) {}
override fun assertNotificationCapability(method: Method) {}
override fun assertRequestHandlerCapability(method: Method) {}
}

private class RecordingTransport : Transport {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good stuff, let's extract it

private val sentMessages = Channel<JSONRPCMessage>(Channel.UNLIMITED)
private var onMessageCallback: (suspend (JSONRPCMessage) -> Unit)? = null
private var onCloseCallback: (() -> Unit)? = null

override suspend fun start() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it make sense to add debug logging here


override suspend fun send(message: JSONRPCMessage) {
sentMessages.send(message)
}

override suspend fun close() {
onCloseCallback?.invoke()
}

override fun onClose(block: () -> Unit) {
onCloseCallback = block
}

override fun onError(block: (Throwable) -> Unit) {}

override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
onMessageCallback = block
}

suspend fun awaitRequest(): JSONRPCRequest {
val message = sentMessages.receive()
return message as? JSONRPCRequest
?: error("Expected JSONRPCRequest but received ${message::class.simpleName}")
}

suspend fun deliver(message: JSONRPCMessage) {
val callback = onMessageCallback ?: error("onMessage callback not registered")
callback(message)
}
}

private fun metaOf(builderAction: JsonObjectBuilder.() -> Unit): RequestMeta = RequestMeta(metaJson(builderAction))

private fun metaJson(builderAction: JsonObjectBuilder.() -> Unit): JsonObject = buildJsonObject(builderAction)
Loading