Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
Expand All @@ -54,6 +55,7 @@ import kotlinx.atomicfu.update
import kotlinx.collections.immutable.minus
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.toPersistentSet
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonObject
import kotlin.coroutines.cancellation.CancellationException

Expand Down Expand Up @@ -196,11 +198,15 @@ public open class Client(private val clientInfo: Implementation, options: Client
logger.error(error) { "Failed to initialize client: ${error.message}" }
close()

if (error !is CancellationException) {
throw IllegalStateException("Error connecting to transport: ${error.message}", error)
}
when (error) {
is CancellationException,
is McpException,
is StreamableHttpError,
is SerializationException,
-> throw error

throw error
else -> throw IllegalStateException("Error connecting to transport: ${error.message}", error)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.Role
import io.modelcontextprotocol.kotlin.sdk.types.Root
Expand Down Expand Up @@ -232,6 +233,82 @@ class ClientTest {
assertTrue(closed)
}

@Test
fun `should rethrow McpException as is`() = runTest {
var closed = false
val failingTransport = object : AbstractTransport() {
override suspend fun start() {}

override suspend fun send(message: JSONRPCMessage) {
if (message !is JSONRPCRequest) return
check(message.method == Method.Defined.Initialize.value)
throw McpException(
code = -32600,
message = "Invalid Request",
)
}

override suspend fun close() {
closed = true
}
}

val client = Client(
clientInfo = Implementation(
name = "test client",
version = "1.0",
),
options = ClientOptions(),
)

val exception = assertFailsWith<McpException> {
client.connect(failingTransport)
}

assertEquals(-32600, exception.code)
assertEquals("MCP error -32600: Invalid Request", exception.message)

assertTrue(closed)
}

@Test
fun `should rethrow StreamableHttpError as is`() = runTest {
var closed = false
val failingTransport = object : AbstractTransport() {
override suspend fun start() {}

override suspend fun send(message: JSONRPCMessage) {
if (message !is JSONRPCRequest) return
check(message.method == Method.Defined.Initialize.value)
throw StreamableHttpError(
code = 500,
message = "Internal Server Error",
)
}

override suspend fun close() {
closed = true
}
}

val client = Client(
clientInfo = Implementation(
name = "test client",
version = "1.0",
),
options = ClientOptions(),
)

val exception = assertFailsWith<StreamableHttpError> {
client.connect(failingTransport)
}

assertEquals(500, exception.code)
assertEquals("Streamable HTTP error: Internal Server Error", exception.message)

assertTrue(closed)
}

@Test
fun `should respect server capabilities`() = runTest {
val serverOptions = ServerOptions(
Expand Down Expand Up @@ -922,7 +999,7 @@ class ClientTest {
println("Client connected")
},
launch {
serverSessionResult.complete(server.connect(serverTransport))
serverSessionResult.complete(server.createSession(serverTransport))
println("Server connected")
},
).joinAll()
Expand Down
Loading