diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index c86cad83..d861bd08 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -34,6 +34,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest @@ -54,6 +55,7 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus import kotlinx.collections.immutable.persistentMapOf import kotlinx.collections.immutable.toPersistentSet +import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonObject import kotlin.coroutines.cancellation.CancellationException @@ -196,11 +198,15 @@ public open class Client(private val clientInfo: Implementation, options: Client logger.error(error) { "Failed to initialize client: ${error.message}" } close() - if (error !is CancellationException) { - throw IllegalStateException("Error connecting to transport: ${error.message}", error) - } + when (error) { + is CancellationException, + is McpException, + is StreamableHttpError, + is SerializationException, + -> throw error - throw error + else -> throw IllegalStateException("Error connecting to transport: ${error.message}", error) + } } } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index d863e008..b4fbcb70 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -26,6 +26,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.Role import io.modelcontextprotocol.kotlin.sdk.types.Root @@ -232,6 +233,82 @@ class ClientTest { assertTrue(closed) } + @Test + fun `should rethrow McpException as is`() = runTest { + var closed = false + val failingTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw McpException( + code = -32600, + message = "Invalid Request", + ) + } + + override suspend fun close() { + closed = true + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions(), + ) + + val exception = assertFailsWith { + client.connect(failingTransport) + } + + assertEquals(-32600, exception.code) + assertEquals("MCP error -32600: Invalid Request", exception.message) + + assertTrue(closed) + } + + @Test + fun `should rethrow StreamableHttpError as is`() = runTest { + var closed = false + val failingTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw StreamableHttpError( + code = 500, + message = "Internal Server Error", + ) + } + + override suspend fun close() { + closed = true + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions(), + ) + + val exception = assertFailsWith { + client.connect(failingTransport) + } + + assertEquals(500, exception.code) + assertEquals("Streamable HTTP error: Internal Server Error", exception.message) + + assertTrue(closed) + } + @Test fun `should respect server capabilities`() = runTest { val serverOptions = ServerOptions( @@ -922,7 +999,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll()