From 954e17ab959ff91b66afe5ec2b3e15b6c0b102f1 Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Thu, 23 Oct 2025 18:06:50 +0300 Subject: [PATCH] Refactor MockTransport This replaces bespoke test transport with a reusable, configurable MockTransport in a shared testing package. The new MockTransport supports registering handlers for specific JSON-RPC methods (success and error), records sent/received messages, and provides an awaitMessage helper with polling and timeouts. Tests were added to validate handler registration, auto-responses, error responses, concurrency, and message awaiting behavior. Existing client tests were updated to configure MockTransport via lambdas instead of hardcoded logic. Additionally, coroutine test utilities were added to the core test dependencies. --- .../sdk/client/ClientMetaParameterTest.kt | 24 +- .../kotlin/sdk/client/MockTransport.kt | 94 --- kotlin-sdk-core/api/kotlin-sdk-core.api | 21 + kotlin-sdk-core/build.gradle.kts | 1 + .../kotlin/sdk/testing/MockTransport.kt | 213 +++++++ .../kotlin/sdk/testing/MockTransportTest.kt | 575 ++++++++++++++++++ 6 files changed, 833 insertions(+), 95 deletions(-) delete mode 100644 kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt create mode 100644 kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt create mode 100644 kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt index e7061073..6ae6ea04 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt @@ -1,7 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.testing.MockTransport import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.boolean @@ -31,7 +36,24 @@ class ClientMetaParameterTest { @BeforeTest fun setup() = runTest { - mockTransport = MockTransport() + mockTransport = MockTransport { + // configure mock transport behavior + onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ) + } + onMessageReplyResult(Method.Defined.ToolsCall) { + CallToolResult( + content = listOf(), + isError = false, + ) + } + } client = Client(clientInfo = clientInfo) mockTransport.setupInitializationResponse() client.connect(mockTransport) diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt deleted file mode 100644 index c987619d..00000000 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt +++ /dev/null @@ -1,94 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.client - -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest -import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.shared.Transport -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock - -class MockTransport : Transport { - private val _sentMessages = mutableListOf() - private val _receivedMessages = mutableListOf() - private val mutex = Mutex() - - suspend fun getSentMessages() = mutex.withLock { _sentMessages.toList() } - suspend fun getReceivedMessages() = mutex.withLock { _receivedMessages.toList() } - - private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null - private var onCloseBlock: (() -> Unit)? = null - private var onErrorBlock: ((Throwable) -> Unit)? = null - - override suspend fun start() = Unit - - override suspend fun send(message: JSONRPCMessage) { - mutex.withLock { - _sentMessages += message - } - - // Auto-respond to initialization and tool calls - when (message) { - is JSONRPCRequest -> { - when (message.method) { - "initialize" -> { - val initResponse = JSONRPCResponse( - id = message.id, - result = InitializeResult( - protocolVersion = "2024-11-05", - capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(listChanged = null), - ), - serverInfo = Implementation("mock-server", "1.0.0"), - ), - ) - onMessageBlock?.invoke(initResponse) - } - - "tools/call" -> { - val toolResponse = JSONRPCResponse( - id = message.id, - result = CallToolResult( - content = listOf(), - isError = false, - ), - ) - onMessageBlock?.invoke(toolResponse) - } - } - } - - else -> { - // Handle other message types if needed - } - } - } - - override suspend fun close() { - onCloseBlock?.invoke() - } - - override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - onMessageBlock = { message -> - mutex.withLock { - _receivedMessages += message - } - block(message) - } - } - - override fun onClose(block: () -> Unit) { - onCloseBlock = block - } - - override fun onError(block: (Throwable) -> Unit) { - onErrorBlock = block - } - - fun setupInitializationResponse() { - // This method helps set up the mock for proper initialization - } -} diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index d4d4b9d5..98d60937 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -3364,3 +3364,24 @@ public final class io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTranspo public static final field MCP_SUBPROTOCOL Ljava/lang/String; } +public class io/modelcontextprotocol/kotlin/sdk/testing/MockTransport : io/modelcontextprotocol/kotlin/sdk/shared/Transport { + public fun ()V + public fun (Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun awaitMessage-ePrTys8 (JJLjava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun awaitMessage-ePrTys8$default (Lio/modelcontextprotocol/kotlin/sdk/testing/MockTransport;JJLjava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getReceivedMessages (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSentMessages (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun onClose (Lkotlin/jvm/functions/Function0;)V + public fun onError (Lkotlin/jvm/functions/Function1;)V + public fun onMessage (Lkotlin/jvm/functions/Function2;)V + public final fun onMessageReply (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;)V + public final fun onMessageReplyError (Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun onMessageReplyError$default (Lio/modelcontextprotocol/kotlin/sdk/testing/MockTransport;Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public final fun onMessageReplyResult (Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;)V + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setupInitializationResponse ()V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + diff --git a/kotlin-sdk-core/build.gradle.kts b/kotlin-sdk-core/build.gradle.kts index 68e56185..3e81d3d0 100644 --- a/kotlin-sdk-core/build.gradle.kts +++ b/kotlin-sdk-core/build.gradle.kts @@ -124,6 +124,7 @@ kotlin { implementation(kotlin("test")) implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) + implementation(libs.kotlinx.coroutines.test) } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt new file mode 100644 index 00000000..06f243c7 --- /dev/null +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt @@ -0,0 +1,213 @@ +package io.modelcontextprotocol.kotlin.sdk.testing + +import io.ktor.util.collections.ConcurrentSet +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestResult +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.time.Clock +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds +import kotlin.time.ExperimentalTime + +private typealias RequestPredicate = (JSONRPCRequest) -> Boolean +private typealias RequestHandler = suspend (JSONRPCRequest) -> JSONRPCResponse + +/** + * A mock transport implementation for testing JSON-RPC communication. + * + * This class simulates transport that can be used to test server and client interactions by + * allowing the registration of handlers for incoming requests and the ability to record + * messages sent and received. + * + * The mock transport supports: + * - Recording all sent and received messages (via `getSentMessages` and `getReceivedMessages`) + * - Registering request handlers that respond to specific message predicates (e.g., by method) + * - Setting up responses that can be either successful or with errors + * - Waiting for specific messages to be received + * + * Note: This class is designed to be used as a test helper and should not be used in production. + */ +@Suppress("TooManyFunctions") +public open class MockTransport(configurer: MockTransport.() -> Unit = {}) : Transport { + private val _sentMessages = mutableListOf() + private val _receivedMessages = mutableListOf() + + private val requestHandlers = ConcurrentSet>() + private val mutex = Mutex() + + public suspend fun getSentMessages(): List = mutex.withLock { _sentMessages.toList() } + + public suspend fun getReceivedMessages(): List = mutex.withLock { _receivedMessages.toList() } + + private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseBlock: (() -> Unit)? = null + private var onErrorBlock: ((Throwable) -> Unit)? = null + + init { + configurer.invoke(this) + } + + override suspend fun start(): Unit = Unit + + override suspend fun send(message: JSONRPCMessage) { + mutex.withLock { + _sentMessages += message + } + + // Auto-respond to using preconfigured request handlers + when (message) { + is JSONRPCRequest -> { + val response = requestHandlers.firstOrNull { + it.first.invoke(message) + }?.second?.invoke(message) + + checkNotNull(response) { + "No request handler found for $message." + } + onMessageBlock?.invoke(response) + } + + else -> { + // TODO("Not implemented yet") + } + } + } + + override suspend fun close() { + onCloseBlock?.invoke() + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageBlock = { message -> + mutex.withLock { + _receivedMessages += message + } + block(message) + } + } + + override fun onClose(block: () -> Unit) { + onCloseBlock = block + } + + override fun onError(block: (Throwable) -> Unit) { + onErrorBlock = block + } + + public fun setupInitializationResponse() { + // This method helps set up the mock for proper initialization + } + + /** + * Registers a handler that will be called when a message matching the given predicate is received. + * + * The handler is expected to return a `RequestResult` which will be used as the response to the request. + * + * @param predicate A predicate that matches the incoming `JSONRPCMessage` + * for which the handler should be triggered. + * @param block A function that processes the incoming `JSONRPCMessage` and returns a `RequestResult` + * to be used as the response. + */ + public fun onMessageReply(predicate: RequestPredicate, block: RequestHandler) { + requestHandlers.add(Pair(predicate, block)) + } + + /** + * Registers a handler for responses to a specific method. + * + * This method allows registering a handler that will be called when a message with the specified method + * is received. The handler is expected to return a `RequestResult` which is the response to the request. + * + * @param method The method (from the `Method` enum) that the handler should respond to. + * @param block A function that processes the incoming `JSONRPCRequest` and returns a `RequestResult`. + * The returned `RequestResult` will be used as the result of the response. + */ + public fun onMessageReplyResult(method: Method, block: (JSONRPCRequest) -> T) { + onMessageReply( + predicate = { + it.method == method.value + }, + block = { + JSONRPCResponse( + id = it.id, + result = block.invoke(it), + ) + }, + ) + } + + /** + * Registers a handler that will be called when a request with the specified method is received + * and an error response is to be generated. + * + * This handler is used to respond to requests with a specific method by returning an error response. + * The handler is triggered when a request message with the given `method` is received. + * + * @param method The method (from the `Method` enum) that the handler should respond to with an error. + * @param block A function that processes the incoming `JSONRPCRequest` and returns a `JSONRPCError` + * to be used as the error response. + * The default block returns an internal error with the message "Expected error". + */ + public fun onMessageReplyError( + method: Method, + block: (JSONRPCRequest) -> JSONRPCError = { + JSONRPCError( + code = ErrorCode.Defined.InternalError, + message = "Expected error", + ) + }, + ) { + onMessageReply( + predicate = { + it.method == method.value + }, + block = { + JSONRPCResponse( + id = it.id, + error = block.invoke(it), + ) + }, + ) + } + + /** + * Waits for a JSON-RPC message that matches the given predicate in the received messages. + * + * @param poolInterval The interval at which the function polls the received messages. Default is 50 milliseconds. + * @param timeout The maximum time to wait for a matching message. Default is 3 seconds. + * @param timeoutMessage The error message to throw when the timeout is reached. + * Default is "No message received matching predicate". + * @param predicate A predicate function that returns true if the message matches the criteria. + * @return The first JSON-RPC message that matches the predicate. + */ + @OptIn(ExperimentalTime::class) + public suspend fun awaitMessage( + poolInterval: Duration = 50.milliseconds, + timeout: Duration = 3.seconds, + timeoutMessage: String = "No message received matching predicate", + predicate: (JSONRPCMessage) -> Boolean, + ): JSONRPCMessage { + val clock = Clock.System + val startTime = clock.now() + val finishTime = startTime + timeout + while (clock.now() < finishTime) { + val found = mutex.withLock { + _receivedMessages.firstOrNull { predicate(it) } + } + if (found != null) { + return found + } + delay(poolInterval) + } + error(timeoutMessage) + } +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt new file mode 100644 index 00000000..5cdc0b62 --- /dev/null +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt @@ -0,0 +1,575 @@ +package io.modelcontextprotocol.kotlin.sdk.testing + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import kotlinx.coroutines.async +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class MockTransportTest { + + private lateinit var transport: MockTransport + + @BeforeTest + fun beforeTest() { + transport = MockTransport { + // configure mock transport behavior + onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ) + } + } + + // Set up onMessage callback to add messages + transport.onMessage { } + } + + @Test + fun `awaitMessage should return message when predicate matches`() = runTest { + // Trigger the onMessage callback directly via send + launch { + transport.send( + JSONRPCRequest( + id = RequestId.StringId("some-id"), + method = "initialize", + ), + ) + delay(200) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "initialize", + params = buildJsonObject { + put("foo", JsonPrimitive("bar")) + }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("other-id"), + method = "initialize", + ), + ) + } + + // Wait for the auto-response + val message = transport.awaitMessage { + it is JSONRPCResponse && + it.id == RequestId.StringId("test-id") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("test-id"), message.id) + } + + @Test + fun `awaitMessage should timeout when no matching message arrives`() = runTest { + val exception = assertFailsWith { + transport.awaitMessage( + timeout = 100.milliseconds, + timeoutMessage = "Custom timeout message", + ) { false } // Predicate that never matches + } + + assertEquals("Custom timeout message", exception.message) + } + + @Test + fun `awaitMessage should filter messages by predicate`() = runTest { + transport.onMessageReply(predicate = { true }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + // Send multiple messages + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-1"), + method = "test1", + params = buildJsonObject { }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "test2", + params = buildJsonObject { }, + ), + ) + + // Wait for response with specific id - note: no auto-response for non-initialize/tools methods + // So this test will timeout unless we manually trigger a response + // Let's send an initialize to get a response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + val message = transport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-2") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("req-2"), message.id) + } + + @Test + fun `awaitMessage should return first matching message`() = runTest { + // Send initialize request to get auto-response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("init-1"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Wait for any response + val message = transport.awaitMessage { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("init-1"), message.id) + } + + @Test + fun `awaitMessage should handle concurrent access safely`() = runTest { + // Send a message that will trigger auto-response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("concurrent-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Launch multiple concurrent awaitMessage calls + val deferred1 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + val deferred2 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + val deferred3 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + // All should successfully find the message + val message1 = deferred1.await() + val message2 = deferred2.await() + val message3 = deferred3.await() + + assertNotNull(message1) + assertNotNull(message2) + assertNotNull(message3) + + // All should be the same message + assertTrue(message1 is JSONRPCResponse) + assertTrue(message2 is JSONRPCResponse) + assertTrue(message3 is JSONRPCResponse) + assertEquals(RequestId.StringId("concurrent-test"), message1.id) + assertEquals(RequestId.StringId("concurrent-test"), message2.id) + assertEquals(RequestId.StringId("concurrent-test"), message3.id) + } + + @Test + fun `awaitMessage should wait for message to arrive`() = runTest { + // Launch awaitMessage before message arrives + val deferred = async { + transport.awaitMessage(timeout = 2.seconds) { it is JSONRPCResponse } + } + + // Wait a bit before sending message + delay(100.milliseconds) + + // Now send the message + transport.send( + JSONRPCRequest( + id = RequestId.StringId("delayed"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Should successfully receive it + val message = deferred.await() + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("delayed"), message.id) + } + + @Test + fun `awaitMessage should use custom pool interval`() = runTest { + // Send message + transport.send( + JSONRPCRequest( + id = RequestId.StringId("pool-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Should work with custom pool interval + val message = transport.awaitMessage( + poolInterval = 10.milliseconds, + timeout = 1.seconds, + ) { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + } + + @Test + fun `awaitMessage should handle tools call auto-response`() = runTest { + transport.onMessageReplyResult(Method.Defined.ToolsCall) { + CallToolResult(content = listOf()) + } + + // Send tools/call request + transport.send( + JSONRPCRequest( + id = RequestId.StringId("tool-1"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Should receive auto-response + val message = transport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("tool-1") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("tool-1"), message.id) + } + + @Test + fun `awaitMessage should return existing message immediately`() = runTest { + // Send message first + transport.send( + JSONRPCRequest( + id = RequestId.StringId("existing"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Give it time to be received + delay(50.milliseconds) + + // Now await should return immediately without waiting + val message = transport.awaitMessage( + timeout = 100.milliseconds, + ) { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("existing"), message.id) + } + + @Test + fun `awaitMessage with complex predicate`() = runTest { + transport.onMessageReply(predicate = { true }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + // Send multiple requests + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-1"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Wait for response with specific criteria + val message = transport.awaitMessage { msg -> + msg is JSONRPCResponse && msg.id == RequestId.StringId("req-2") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("req-2"), message.id) + } + + @Test + fun `onMessageReply should register handler with custom predicate`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler that only responds to requests with "custom" method + customTransport.onMessageReply( + predicate = { request -> request.method == "custom-method" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = EmptyRequestResult(), + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("test-1"), + method = "custom-method", + params = buildJsonObject { }, + ), + ) + + // Verify response was received + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("test-1") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("test-1"), message.id) + assertNotNull(message.result) + } + + @Test + fun `onMessageReply should support multiple handlers with different predicates`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register first handler for "method-a" + customTransport.onMessageReply( + predicate = { it.method == "method-a" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = CallToolResult(content = listOf()), + ) + } + + // Register second handler for "method-b" + customTransport.onMessageReply( + predicate = { it.method == "method-b" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = EmptyRequestResult(), + ) + } + + // Test first handler + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("req-a"), + method = "method-a", + params = buildJsonObject { }, + ), + ) + + val messageA = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-a") + } + + assertTrue(messageA is JSONRPCResponse) + assertTrue(messageA.result is CallToolResult) + + // Test second handler + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("req-b"), + method = "method-b", + params = buildJsonObject { }, + ), + ) + + val messageB = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-b") + } + + assertTrue(messageB is JSONRPCResponse) + assertTrue(messageB.result is EmptyRequestResult) + } + + @Test + fun `onMessageReplyResult should create response with result for matching method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler using onMessageReplyResult + customTransport.onMessageReplyResult(Method.Defined.Initialize) { _ -> + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("test-server", "1.0.0"), + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("init-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Verify response with result + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("init-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("init-test"), message.id) + assertNotNull(message.result) + assertTrue(message.result is InitializeResult) + val result = message.result + assertEquals("2024-11-05", result.protocolVersion) + assertEquals("test-server", result.serverInfo.name) + } + + @Test + fun `onMessageReplyResult should only respond to specified method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler only for Initialize + customTransport.onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("test", "1.0"), + ) + } + + // Also register a catch-all handler for other methods + customTransport.onMessageReply(predicate = { it.method != "initialize" }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + + // Send non-matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("other-method"), + method = "other-method", + params = buildJsonObject { }, + ), + ) + + // Should get response from catch-all handler + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("other-method") + } + + assertTrue(message is JSONRPCResponse) + assertTrue(message.result is EmptyRequestResult) + } + + @Test + fun `onMessageReplyError should create response with error for matching method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register error handler with custom error + customTransport.onMessageReplyError(Method.Defined.ToolsCall) { _ -> + io.modelcontextprotocol.kotlin.sdk.JSONRPCError( + code = io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InvalidParams, + message = "Custom error message", + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("error-test"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Verify response with error + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("error-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("error-test"), message.id) + assertNotNull(message.error) + assertEquals(io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InvalidParams, message.error?.code) + assertEquals("Custom error message", message.error.message) + } + + @Test + fun `onMessageReplyError should use default error when block not provided`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register error handler without custom block (using default) + customTransport.onMessageReplyError(Method.Defined.Initialize) + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("default-error-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Verify response with default error + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("default-error-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("default-error-test"), message.id) + assertNotNull(message.error) + assertEquals(io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InternalError, message.error?.code) + assertEquals("Expected error", message.error.message) + } +}