diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt index e7061073..600859fb 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt @@ -1,7 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.client -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.boolean diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt index c987619d..860ed147 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt @@ -1,13 +1,13 @@ package io.modelcontextprotocol.kotlin.sdk.client -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest -import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaClientMetaParameterTest.kt new file mode 100644 index 00000000..b5411767 --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaClientMetaParameterTest.kt @@ -0,0 +1,275 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.boolean +import kotlinx.serialization.json.int +import kotlinx.serialization.json.jsonPrimitive +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Comprehensive test suite for MCP Client meta parameter functionality + * + * Tests cover: + * - Meta key validation according to MCP specification + * - JSON type conversion for various data types + * - Error handling for invalid meta keys + * - Integration with callTool method + */ +class OldSchemaClientMetaParameterTest { + + private lateinit var client: Client + private lateinit var mockTransport: OldSchemaMockTransport + private val clientInfo = Implementation("test-client", "1.0.0") + + @BeforeTest + fun setup() = runTest { + mockTransport = OldSchemaMockTransport() + client = Client(clientInfo = clientInfo) + mockTransport.setupInitializationResponse() + client.connect(mockTransport) + } + + @Test + fun `should accept valid meta keys without throwing exception`() = runTest { + val validMeta = buildMap { + put("simple-key", "value1") + put("api.example.com/version", "1.0") + put("com.company.app/setting", "enabled") + put("retry_count", 3) + put("user.preference", true) + put("valid123", "alphanumeric") + put("multi.dot.name", "multiple-dots") + put("under_score", "underscore") + put("hyphen-dash", "hyphen") + put("org.apache.kafka/consumer-config", "complex-valid-prefix") + } + + val result = runCatching { + client.callTool("test-tool", mapOf("arg" to "value"), validMeta) + } + + assertTrue(result.isSuccess, "Valid meta keys should not cause exceptions") + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + assertTrue(params.containsKey("_meta"), "Request should contain _meta field") + val metaField = params["_meta"] as JsonObject + + // Verify all meta keys are present + assertEquals(validMeta.size, metaField.size, "All meta keys should be included") + + // Verify specific key-value pairs + assertEquals("value1", metaField["simple-key"]?.jsonPrimitive?.content) + assertEquals("1.0", metaField["api.example.com/version"]?.jsonPrimitive?.content) + assertEquals("enabled", metaField["com.company.app/setting"]?.jsonPrimitive?.content) + assertEquals(3, metaField["retry_count"]?.jsonPrimitive?.int) + assertEquals(true, metaField["user.preference"]?.jsonPrimitive?.boolean) + assertEquals("alphanumeric", metaField["valid123"]?.jsonPrimitive?.content) + assertEquals("multiple-dots", metaField["multi.dot.name"]?.jsonPrimitive?.content) + assertEquals("underscore", metaField["under_score"]?.jsonPrimitive?.content) + assertEquals("hyphen", metaField["hyphen-dash"]?.jsonPrimitive?.content) + assertEquals("complex-valid-prefix", metaField["org.apache.kafka/consumer-config"]?.jsonPrimitive?.content) + } + } + + @Test + fun `should accept edge case valid prefixes and names`() = runTest { + val edgeCaseValidMeta = buildMap { + put("a/", "single-char-prefix-empty-name") // empty name is allowed + put("a1-b2/test", "alphanumeric-hyphen-prefix") + put("long.domain.name.here/config", "long-prefix") + put("x/a", "minimal-valid-key") + put("test123", "alphanumeric-name-only") + } + + val result = runCatching { + client.callTool("test-tool", emptyMap(), edgeCaseValidMeta) + } + + assertTrue(result.isSuccess, "Edge case valid meta keys should be accepted") + } + + @Test + fun `should reject mcp reserved prefix`() = runTest { + val invalidMeta = mapOf("mcp/internal" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject modelcontextprotocol reserved prefix`() = runTest { + val invalidMeta = mapOf("modelcontextprotocol/config" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject nested reserved prefixes`() = runTest { + val invalidKeys = listOf( + "api.mcp.io/setting", + "com.modelcontextprotocol.test/value", + "example.mcp/data", + "subdomain.mcp.com/config", + "app.modelcontextprotocol.dev/setting", + "test.mcp/value", + "service.modelcontextprotocol/data", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject nested reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject case-insensitive reserved prefixes`() = runTest { + val invalidKeys = listOf( + "MCP/internal", + "Mcp/config", + "mCp/setting", + "MODELCONTEXTPROTOCOL/data", + "ModelContextProtocol/value", + "modelContextProtocol/test", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject case-insensitive reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject invalid key formats`() = runTest { + val invalidKeys = listOf( + "", // empty key - not allowed at key level + "/invalid", // starts with slash + "-invalid", // starts with hyphen + ".invalid", // starts with dot + "in valid", // contains space + "api../test", // consecutive dots + "api./test", // label ends with dot + ) + + invalidKeys.forEach { key -> + assertFailsWith( + message = "Should reject invalid key format: '$key'", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + } + } + + @Test + fun `should convert various data types to JSON correctly`() = runTest { + val complexMeta = createComplexMetaData() + + val result = runCatching { + client.callTool( + "test-tool", + emptyMap(), + complexMeta, + ) + } + + assertTrue(result.isSuccess, "Complex data type conversion should not throw exceptions") + + mockTransport.lastJsonRpcRequest()?.let { request -> + assertEquals("tools/call", request.method) + val params = request.params as JsonObject + assertTrue(params.containsKey("_meta"), "Request should contain _meta field") + } + } + + @Test + fun `should handle nested map structures correctly`() = runTest { + val nestedMeta = buildNestedConfiguration() + + val result = runCatching { + client.callTool("test-tool", emptyMap(), nestedMeta) + } + + assertTrue(result.isSuccess) + + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.containsKey("config")) + } + } + + @Test + fun `should include empty meta object when meta parameter not provided`() = runTest { + client.callTool("test-tool", mapOf("arg" to "value")) + + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.isEmpty(), "Meta field should be empty when not provided") + } + } + + private fun createComplexMetaData(): Map = buildMap { + put("string", "text") + put("number", 42) + put("boolean", true) + put("null_value", null) + put("list", listOf(1, 2, 3)) + put("map", mapOf("nested" to "value")) + put("enum", "STRING") + put("int_array", intArrayOf(1, 2, 3)) + } + + private fun buildNestedConfiguration(): Map = buildMap { + put( + "config", + buildMap { + put( + "database", + buildMap { + put("host", "localhost") + put("port", 5432) + }, + ) + put("features", listOf("feature1", "feature2")) + }, + ) + } +} + +suspend fun OldSchemaMockTransport.lastJsonRpcRequest(): JSONRPCRequest? = + getSentMessages().lastOrNull() as? JSONRPCRequest diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockTransport.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockTransport.kt new file mode 100644 index 00000000..da813c7b --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockTransport.kt @@ -0,0 +1,94 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +class OldSchemaMockTransport : Transport { + private val _sentMessages = mutableListOf() + private val _receivedMessages = mutableListOf() + private val mutex = Mutex() + + suspend fun getSentMessages() = mutex.withLock { _sentMessages.toList() } + suspend fun getReceivedMessages() = mutex.withLock { _receivedMessages.toList() } + + private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseBlock: (() -> Unit)? = null + private var onErrorBlock: ((Throwable) -> Unit)? = null + + override suspend fun start() = Unit + + override suspend fun send(message: JSONRPCMessage) { + mutex.withLock { + _sentMessages += message + } + + // Auto-respond to initialization and tool calls + when (message) { + is JSONRPCRequest -> { + when (message.method) { + "initialize" -> { + val initResponse = JSONRPCResponse( + id = message.id, + result = InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ), + ) + onMessageBlock?.invoke(initResponse) + } + + "tools/call" -> { + val toolResponse = JSONRPCResponse( + id = message.id, + result = CallToolResult( + content = listOf(), + isError = false, + ), + ) + onMessageBlock?.invoke(toolResponse) + } + } + } + + else -> { + // Handle other message types if needed + } + } + } + + override suspend fun close() { + onCloseBlock?.invoke() + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageBlock = { message -> + mutex.withLock { + _receivedMessages += message + } + block(message) + } + } + + override fun onClose(block: () -> Unit) { + onCloseBlock = block + } + + override fun onError(block: (Throwable) -> Unit) { + onErrorBlock = block + } + + fun setupInitializationResponse() { + // This method helps set up the mock for proper initialization + } +} diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaStreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaStreamableHttpClientTransportTest.kt new file mode 100644 index 00000000..c5be21e8 --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaStreamableHttpClientTransportTest.kt @@ -0,0 +1,422 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.MockRequestHandler +import io.ktor.client.engine.mock.respond +import io.ktor.client.plugins.sse.SSE +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteReadChannel +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlin.test.Ignore +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class OldSchemaStreamableHttpClientTransportTest { + + private fun createTransport(handler: MockRequestHandler): StreamableHttpClientTransport { + val mockEngine = MockEngine(handler) + val httpClient = HttpClient(mockEngine) { + install(SSE) { + reconnectionTime = 1.seconds + } + } + + return StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp") + } + + @Test + fun testSendJsonRpcMessage() = runTest { + val message = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "test", + params = buildJsonObject { }, + ) + + val transport = createTransport { request -> + assertEquals(HttpMethod.Post, request.method) + assertEquals("http://localhost:8080/mcp", request.url.toString()) + assertEquals(ContentType.Application.Json, request.body.contentType) + + val body = (request.body as TextContent).text + val decodedMessage = McpJson.decodeFromString(body) + assertEquals(message, decodedMessage) + + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + + transport.start() + transport.send(message) + transport.close() + } + + @Test + fun testStoreSessionId() = runTest { + val initMessage = JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "initialize", + params = buildJsonObject { + put( + "clientInfo", + buildJsonObject { + put("name", JsonPrimitive("test-client")) + put("version", JsonPrimitive("1.0")) + }, + ) + put("protocolVersion", JsonPrimitive("2025-06-18")) + }, + ) + + val transport = createTransport { request -> + when (val msg = McpJson.decodeFromString((request.body as TextContent).text)) { + is JSONRPCRequest if msg.method == "initialize" -> respond( + content = "", + status = HttpStatusCode.OK, + headers = headersOf("mcp-session-id", "test-session-id"), + ) + + is JSONRPCNotification if msg.method == "test" -> { + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + + else -> error("Unexpected message: $msg") + } + } + + transport.start() + transport.send(initMessage) + + assertEquals("test-session-id", transport.sessionId) + + transport.send(JSONRPCNotification(method = "test")) + + transport.close() + } + + @Test + fun testTerminateSession() = runTest { +// transport.sessionId = "test-session-id" + + val transport = createTransport { request -> + assertEquals(HttpMethod.Delete, request.method) + assertEquals("test-session-id", request.headers["mcp-session-id"]) + respond( + content = "", + status = HttpStatusCode.OK, + ) + } + + transport.start() + transport.terminateSession() + + assertNull(transport.sessionId) + transport.close() + } + + @Test + fun testTerminateSessionHandle405() = runTest { +// transport.sessionId = "test-session-id" + + val transport = createTransport { request -> + assertEquals(HttpMethod.Delete, request.method) + respond( + content = "", + status = HttpStatusCode.MethodNotAllowed, + ) + } + + transport.start() + // Should not throw for 405 + transport.terminateSession() + + // Session ID should still be cleared + assertNull(transport.sessionId) + transport.close() + } + + @Test + fun testProtocolVersionHeader() = runTest { + val transport = createTransport { request -> + assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + transport.protocolVersion = "2025-06-18" + + transport.start() + transport.send(JSONRPCNotification(method = "test")) + transport.close() + } + + // Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + @Ignore + @Test + fun testNotificationSchemaE2E() = runTest { + val receivedMessages = mutableListOf() + var sseStarted = false + + val transport = createTransport { request -> + when (request.method) { + HttpMethod.Post if request.body.toString().contains("notifications/initialized") -> { + respond( + content = "", + status = HttpStatusCode.Accepted, + headers = headersOf("mcp-session-id", "notification-test-session"), + ) + } + + // Handle SSE connection + HttpMethod.Get -> { + sseStarted = true + val sseContent = buildString { + // Server sends various notifications + appendLine("event: message") + appendLine("id: 1") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""", + ) + appendLine() + + appendLine("event: message") + appendLine("id: 2") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resources/list_changed"}""") + appendLine() + + appendLine("event: message") + appendLine("id: 3") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/tools/list_changed"}""") + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } + + // Handle regular notifications + HttpMethod.Post -> { + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + + else -> respond("", HttpStatusCode.OK) + } + } + + transport.onMessage { message -> + receivedMessages.add(message) + } + + transport.start() + + // Test 1: Send initialized notification to trigger SSE + val initializedNotification = JSONRPCNotification( + method = "notifications/initialized", + params = buildJsonObject { + put("protocolVersion", JsonPrimitive("1.0")) + put( + "capabilities", + buildJsonObject { + put("tools", JsonPrimitive(true)) + put("resources", JsonPrimitive(true)) + }, + ) + }, + ) + + transport.send(initializedNotification) + + // Verify SSE was triggered + assertTrue(sseStarted, "SSE should start after initialized notification") + + // Test 2: Verify received notifications + assertEquals(3, receivedMessages.size) + assertTrue(receivedMessages.all { it is JSONRPCNotification }) + + val notifications = receivedMessages.filterIsInstance() + + // Verify progress notification + val progressNotif = notifications[0] + assertEquals("notifications/progress", progressNotif.method) + val progressParams = progressNotif.params as JsonObject + assertEquals("upload-123", (progressParams["progressToken"] as JsonPrimitive).content) + assertEquals(50, (progressParams["progress"] as JsonPrimitive).content.toInt()) + + // Verify list changed notifications + assertEquals("notifications/resources/list_changed", notifications[1].method) + assertEquals("notifications/tools/list_changed", notifications[2].method) + + // Test 3: Send various client notifications + val clientNotifications = listOf( + JSONRPCNotification( + method = "notifications/progress", + params = buildJsonObject { + put("progressToken", JsonPrimitive("download-456")) + put("progress", JsonPrimitive(75)) + }, + ), + JSONRPCNotification( + method = "notifications/cancelled", + params = buildJsonObject { + put("requestId", JsonPrimitive("req-789")) + put("reason", JsonPrimitive("user_cancelled")) + }, + ), + JSONRPCNotification( + method = "notifications/message", + params = buildJsonObject { + put("level", JsonPrimitive("info")) + put("message", JsonPrimitive("Operation completed")) + put( + "data", + buildJsonObject { + put("duration", JsonPrimitive(1234)) + }, + ) + }, + ), + ) + + // Send all client notifications + clientNotifications.forEach { notification -> + transport.send(notification) + } + + // Verify session ID is maintained + assertEquals("notification-test-session", transport.sessionId) + transport.close() + } + + // Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + @Ignore + @Test + fun testNotificationWithResumptionToken() = runTest { + var resumptionTokenReceived: String? = null + var lastEventIdSent: String? = null + + val transport = createTransport { request -> + // Capture Last-Event-ID header + lastEventIdSent = request.headers["Last-Event-ID"] + + when (request.method) { + HttpMethod.Get -> { + val sseContent = buildString { + appendLine("event: message") + appendLine("id: resume-100") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"$lastEventIdSent"}}""", + ) + appendLine() + } + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } + + else -> respond("", HttpStatusCode.Accepted) + } + } + + transport.start() + + // Send notification with resumption token + transport.send( + message = JSONRPCNotification( + method = "notifications/test", + params = buildJsonObject { + put("data", JsonPrimitive("test-data")) + }, + ), + resumptionToken = "previous-token-99", + onResumptionToken = { token -> + resumptionTokenReceived = token + }, + ) + + // Wait for response + delay(1.seconds) + + // Verify resumption token was sent in header + assertEquals("previous-token-99", lastEventIdSent) + + // Verify new resumption token was received + assertEquals("resume-100", resumptionTokenReceived) + transport.close() + } + + @Test + fun testClientConnectWithInvalidJson() = runTest { + // Transport under test: respond with invalid JSON for the initialize request + val transport = createTransport { _ -> + respond( + "this is not valid json", + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val client = Client( + clientInfo = Implementation( + name = "test-client", + version = "1.0", + ), + ) + + try { + // Real time-keeping is needed; otherwise Protocol will always throw TimeoutCancellationException in tests + assertFailsWith( + message = "Expected client.connect to fail on invalid JSON response", + ) { + withContext(Dispatchers.Default.limitedParallelism(1)) { + withTimeout(5.seconds) { + client.connect(transport) + } + } + } + } finally { + transport.close() + } + } +} diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt index a72bd853..fea1de19 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt @@ -12,12 +12,12 @@ import io.ktor.http.HttpStatusCode import io.ktor.http.content.TextContent import io.ktor.http.headersOf import io.ktor.utils.io.ByteReadChannel -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification -import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest -import io.modelcontextprotocol.kotlin.sdk.RequestId -import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.test.runTest diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt index 77ba553e..6f1e30fc 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt @@ -12,7 +12,7 @@ import org.junit.jupiter.api.TestInstance internal abstract class AbstractStreamableHttpClientTest { // start mokksy on random port - protected val mockMcp: MockMcp = MockMcp(verbose = true) + protected val mockMcp: OldSchemaMockMcp = OldSchemaMockMcp(verbose = true) @AfterEach fun afterEach() { diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt new file mode 100644 index 00000000..af2110a0 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt @@ -0,0 +1,229 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import dev.mokksy.mokksy.BuildingStep +import dev.mokksy.mokksy.Mokksy +import dev.mokksy.mokksy.StubConfiguration +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.sse.ServerSentEvent +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.putJsonObject + +const val MCP_SESSION_ID_HEADER = "Mcp-Session-Id" + +internal class MockMcp(verbose: Boolean = false) { + + private val mokksy: Mokksy = Mokksy(verbose = verbose) + + fun checkForUnmatchedRequests() { + mokksy.checkForUnmatchedRequests() + } + + val url = "${mokksy.baseUrl()}/mcp" + + @Suppress("LongParameterList") + fun onInitialize( + clientName: String? = null, + sessionId: String, + protocolVersion: String = "2025-03-26", + serverName: String = "Mock MCP Server", + serverVersion: String = "1.0.0", + capabilities: JsonObject = buildJsonObject { + putJsonObject("tools") { + put("listChanged", JsonPrimitive(false)) + } + }, + ) { + val predicates = if (clientName != null) { + arrayOf<(JSONRPCRequest?) -> Boolean>({ + it?.params?.jsonObject + ?.get("clientInfo")?.jsonObject + ?.get("name")?.jsonPrimitive + ?.contentOrNull == clientName + }) + } else { + emptyArray() + } + + handleWithResult( + jsonRpcMethod = "initialize", + sessionId = sessionId, + bodyPredicates = predicates, + // language=json + result = """ + { + "capabilities": $capabilities, + "protocolVersion": "$protocolVersion", + "serverInfo": { + "name": "$serverName", + "version": "$serverVersion" + }, + "_meta": { + "foo": "bar" + } + } + """.trimIndent(), + ) + } + + fun onJSONRPCRequest( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + ): BuildingStep = mokksy.method( + configuration = StubConfiguration(removeAfterMatch = true), + httpMethod = httpMethod, + requestType = JSONRPCRequest::class, + ) { + path("/mcp") + expectedSessionId?.let { + containsHeader(MCP_SESSION_ID_HEADER, it) + } + bodyMatchesPredicate( + description = "JSON-RPC version is '2.0'", + predicate = + { + it!!.jsonrpc == "2.0" + }, + ) + bodyMatchesPredicate( + description = "JSON-RPC Method should be '$jsonRpcMethod'", + predicate = + { + it!!.method == jsonRpcMethod + }, + ) + bodyPredicates.forEach { predicate -> + bodyMatchesPredicate(predicate = { predicate.invoke(it!!) }) + } + } + + @Suppress("LongParameterList") + fun handleWithResult( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + sessionId: String, + contentType: ContentType = ContentType.Application.Json, + statusCode: HttpStatusCode = HttpStatusCode.OK, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + result: () -> JsonObject, + ) { + onJSONRPCRequest( + httpMethod = httpMethod, + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + bodyPredicates = bodyPredicates, + ) respondsWith { + val requestId = when (request.body.id) { + is RequestId.NumberId -> (request.body.id as RequestId.NumberId).value.toString() + is RequestId.StringId -> "\"${(request.body.id as RequestId.StringId).value}\"" + } + val resultObject = result.invoke() + // language=json + body = """ + { + "jsonrpc": "2.0", + "id": $requestId, + "result": $resultObject + } + """.trimIndent() + this.contentType = contentType + headers += MCP_SESSION_ID_HEADER to sessionId + httpStatus = statusCode + } + } + + @Suppress("LongParameterList") + fun handleWithResult( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + sessionId: String, + contentType: ContentType = ContentType.Application.Json, + statusCode: HttpStatusCode = HttpStatusCode.OK, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + result: String, + ) { + handleWithResult( + httpMethod = httpMethod, + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + sessionId = sessionId, + contentType = contentType, + statusCode = statusCode, + bodyPredicates = bodyPredicates, + result = { + Json.parseToJsonElement(result).jsonObject + }, + ) + } + + @Suppress("LongParameterList") + fun handleJSONRPCRequest( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + sessionId: String, + contentType: ContentType = ContentType.Application.Json, + statusCode: HttpStatusCode = HttpStatusCode.OK, + vararg bodyPredicates: (JSONRPCRequest?) -> Boolean, + bodyBuilder: () -> String = { "" }, + ) { + onJSONRPCRequest( + httpMethod = httpMethod, + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + bodyPredicates = bodyPredicates, + ) respondsWith { + body = bodyBuilder.invoke() + this.contentType = contentType + headers += MCP_SESSION_ID_HEADER to sessionId + httpStatus = statusCode + } + } + + fun onSubscribe(httpMethod: HttpMethod = HttpMethod.Post, sessionId: String): BuildingStep = mokksy.method( + httpMethod = httpMethod, + name = "MCP GETs", + requestType = Any::class, + ) { + path("/mcp") + containsHeader(MCP_SESSION_ID_HEADER, sessionId) + containsHeader("Accept", "application/json,text/event-stream") + containsHeader("Cache-Control", "no-store") + } + + fun handleSubscribeWithGet(sessionId: String, block: () -> Flow) { + onSubscribe( + httpMethod = HttpMethod.Get, + sessionId = sessionId, + ) respondsWithSseStream { + headers += MCP_SESSION_ID_HEADER to sessionId + this.flow = block.invoke() + } + } + + fun mockUnsubscribeRequest(sessionId: String) { + mokksy.delete( + configuration = StubConfiguration(removeAfterMatch = true), + requestType = JSONRPCRequest::class, + ) { + path("/mcp") + containsHeader(MCP_SESSION_ID_HEADER, sessionId) + } respondsWith { + body = null + } + } +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockMcp.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockMcp.kt index eb2760b3..5d1a1441 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockMcp.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaMockMcp.kt @@ -19,8 +19,6 @@ import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.putJsonObject -const val MCP_SESSION_ID_HEADER = "Mcp-Session-Id" - /** * High-level helper for simulating an MCP server over Streaming HTTP transport with Server-Sent Events (SSE), * built on top of an HTTP server using the [Mokksy](https://mokksy.dev) library. @@ -30,7 +28,7 @@ const val MCP_SESSION_ID_HEADER = "Mcp-Session-Id" * @param verbose Whether to print detailed logs. Defaults to `false`. * @author Konstantin Pavlov */ -internal class MockMcp(verbose: Boolean = false) { +internal class OldSchemaMockMcp(verbose: Boolean = false) { private val mokksy: Mokksy = Mokksy(verbose = verbose) diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt new file mode 100644 index 00000000..37f5a307 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt @@ -0,0 +1,214 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.kotest.matchers.collections.shouldContain +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.sse.ServerSentEvent +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.Tool +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.junit.jupiter.api.TestInstance +import kotlin.test.Test +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalUuidApi::class) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Suppress("LongMethod") +internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { + + @Test + fun `test streamableHttpClient`() = runBlocking { + val client = Client( + clientInfo = Implementation( + name = "client1", + version = "1.0.0", + ), + options = ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + val sessionId = Uuid.random().toString() + + mockMcp.onInitialize( + clientName = "client1", + sessionId = sessionId, + ) + + mockMcp.handleJSONRPCRequest( + jsonRpcMethod = "notifications/initialized", + expectedSessionId = sessionId, + sessionId = sessionId, + statusCode = HttpStatusCode.Accepted, + ) + + mockMcp.handleSubscribeWithGet(sessionId) { + flow { + delay(500.milliseconds) + emit( + ServerSentEvent( + event = "message", + id = "1", + data = @Suppress("MaxLineLength") + //language=json + """{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""", + ), + ) + delay(200.milliseconds) + emit( + ServerSentEvent( + data = @Suppress("MaxLineLength") + //language=json + """{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""", + ), + ) + } + } + + // TODO: how to get notifications via Client API? + + mockMcp.handleWithResult( + jsonRpcMethod = "tools/list", + sessionId = sessionId, + // language=json + result = """ + { + "tools": [ + { + "name": "get_weather", + "title": "Weather Information Provider", + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + }, + "outputSchema": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "description": "Temperature, Celsius" + } + }, + "required": ["temperature"] + }, + "_meta": {} + } + ] + } + """.trimIndent(), + ) + + connect(client) + + val listToolsResult = client.listTools() + + listToolsResult.tools shouldContain Tool( + name = "get_weather", + title = "Weather Information Provider", + description = "Get current weather information for a location", + inputSchema = ToolSchema( + properties = buildJsonObject { + putJsonObject("location") { + put("type", "string") + put("description", "City name or zip code") + } + }, + required = listOf("location"), + ), + outputSchema = ToolSchema( + properties = buildJsonObject { + putJsonObject("temperature") { + put("type", "number") + put("description", "Temperature, Celsius") + } + }, + required = listOf("temperature"), + ), + annotations = null, + meta = EmptyJsonObject, + ) + + mockMcp.mockUnsubscribeRequest(sessionId = sessionId) + + client.close() + } + + @Test + fun `handle MethodNotAllowed`() = runBlocking { + checkSupportNonStreamingResponse( + ContentType.Text.EventStream, + HttpStatusCode.MethodNotAllowed, + ) + } + + @Test + fun `handle non-streaming response`() = runBlocking { + checkSupportNonStreamingResponse( + ContentType.Application.Json, + HttpStatusCode.OK, + ) + } + + private suspend fun checkSupportNonStreamingResponse(contentType: ContentType, statusCode: HttpStatusCode) { + val sessionId = "SID_${Uuid.random().toHexString()}" + val clientName = "client-${Uuid.random().toHexString()}" + val client = Client( + clientInfo = Implementation(name = clientName, version = "1.0.0"), + options = ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + mockMcp.onInitialize(clientName = clientName, sessionId = sessionId) + + mockMcp.handleJSONRPCRequest( + jsonRpcMethod = "notifications/initialized", + expectedSessionId = sessionId, + sessionId = sessionId, + statusCode = HttpStatusCode.Accepted, + ) + + mockMcp.onSubscribe( + httpMethod = HttpMethod.Get, + sessionId = sessionId, + ) respondsWith { + headers += MCP_SESSION_ID_HEADER to sessionId + body = null + httpStatus = statusCode + this.contentType = contentType + } + + mockMcp.handleWithResult(jsonRpcMethod = "ping", sessionId = sessionId) { + buildJsonObject {} + } + + mockMcp.mockUnsubscribeRequest(sessionId = sessionId) + + connect(client) + + delay(1.seconds) + + client.ping() // connection is still alive + + client.close() + } +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/OldSchemaAudioContentSerializationTest.kt similarity index 94% rename from kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt rename to kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/OldSchemaAudioContentSerializationTest.kt index 247388f7..738539f8 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/OldSchemaAudioContentSerializationTest.kt @@ -5,7 +5,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.McpJson import kotlin.test.Test import kotlin.test.assertEquals -class AudioContentSerializationTest { +class OldSchemaAudioContentSerializationTest { private val audioContentJson = """ { diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/models/ProgressNotificationsTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/models/ProgressNotificationsTest.kt new file mode 100644 index 00000000..9b52ef9a --- /dev/null +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/models/ProgressNotificationsTest.kt @@ -0,0 +1,72 @@ +package io.modelcontextprotocol.kotlin.sdk.models + +import io.kotest.matchers.shouldBe +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotification +import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import kotlin.test.Test + +class ProgressNotificationsTest { + + /** + * https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress#progress-flow + */ + @Test + fun `Read ProgressNotifications with string token`() { + //language=json + val json = """ + { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progressToken": "abc123", + "progress": 50, + "total": 100, + "message": "Reticulating splines..." + } + } + """.trimIndent() + + val result = McpJson.decodeFromString(json) + + result shouldBe ProgressNotification( + params = ProgressNotificationParams( + progressToken = RequestId.StringId("abc123"), + progress = 50.0, + message = "Reticulating splines...", + total = 100.0, + ), + ) + } + + /** + * https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress#progress-flow + */ + @Test + fun `Read ProgressNotifications with integer token`() { + //language=json + val json = """ + { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progressToken": 100500, + "progress": 50, + "total": 100, + "message": "Reticulating splines..." + } + } + """.trimIndent() + + val result = McpJson.decodeFromString(json) + result shouldBe ProgressNotification( + params = ProgressNotificationParams( + progressToken = RequestId.NumberId(100500), + progress = 50.0, + message = "Reticulating splines...", + total = 100.0, + ), + ) + } +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaReadBufferTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaReadBufferTest.kt new file mode 100644 index 00000000..0311c139 --- /dev/null +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaReadBufferTest.kt @@ -0,0 +1,62 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.ktor.utils.io.charsets.Charsets +import io.ktor.utils.io.core.toByteArray +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class OldSchemaReadBufferTest { + private val testMessage: JSONRPCMessage = JSONRPCNotification(method = "foobar") + + private val json = Json { + ignoreUnknownKeys = true + encodeDefaults = true + } + + @Test + fun `should have no messages after initialization`() { + val readBuffer = ReadBuffer() + assertNull(readBuffer.readMessage()) + } + + @Test + fun `should only yield a message after a newline`() { + val readBuffer = ReadBuffer() + + // Append message without a newline + val messageBytes = json.encodeToString(testMessage).encodeToByteArray() + readBuffer.append(messageBytes) + assertNull(readBuffer.readMessage()) + + // Append a newline and verify message is now available + readBuffer.append("\n".encodeToByteArray()) + assertEquals(testMessage, readBuffer.readMessage()) + assertNull(readBuffer.readMessage()) + } + + @Test + fun `skip empty line`() { + val readBuffer = ReadBuffer() + readBuffer.append("\n".toByteArray()) + assertNull(readBuffer.readMessage()) + } + + @Test + fun `should be reusable after clearing`() { + val readBuffer = ReadBuffer() + + readBuffer.append("foobar".toByteArray(Charsets.UTF_8)) + readBuffer.clear() + assertNull(readBuffer.readMessage()) + + val messageJson = serializeMessage(testMessage) + readBuffer.append(messageJson.toByteArray(Charsets.UTF_8)) + readBuffer.append("\n".toByteArray(Charsets.UTF_8)) + val message = readBuffer.readMessage() + assertEquals(testMessage, message) + } +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt index 8e6f4f65..a49ff3df 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.ktor.utils.io.charsets.Charsets import io.ktor.utils.io.core.toByteArray -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import kotlinx.serialization.json.Json import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaStdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaStdioServerTransportTest.kt new file mode 100644 index 00000000..32e25baf --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaStdioServerTransportTest.kt @@ -0,0 +1,140 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer +import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage +import io.modelcontextprotocol.kotlin.sdk.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import java.io.ByteArrayOutputStream +import java.io.PipedInputStream +import java.io.PipedOutputStream +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class OldSchemaStdioServerTransportTest { + private lateinit var input: PipedInputStream + private lateinit var inputWriter: PipedOutputStream + private lateinit var outputBuffer: ReadBuffer + private lateinit var output: ByteArrayOutputStream + + // We'll store the wrapped streams that meet the constructor requirements + private lateinit var bufferedInput: Source + private lateinit var printOutput: Sink + + @BeforeTest + fun setUp() { + // Simulate an input stream that we can push data into using inputWriter. + input = PipedInputStream() + inputWriter = PipedOutputStream(input) + + outputBuffer = ReadBuffer() + + // A custom ByteArrayOutputStream that appends all written data into outputBuffer + output = object : ByteArrayOutputStream() { + override fun write(b: ByteArray, off: Int, len: Int) { + super.write(b, off, len) + outputBuffer.append(b.copyOfRange(off, off + len)) + } + } + + bufferedInput = input.asSource().buffered() + + printOutput = output.asSink().buffered() + } + + @Test + fun `should start then close cleanly`() { + runBlocking { + val server = StdioServerTransport(bufferedInput, printOutput) + server.onError { error -> + throw error + } + + var didClose = false + server.onClose { + didClose = true + } + + server.start() + assertFalse(didClose, "Should not have closed yet") + + server.close() + assertTrue(didClose, "Should have closed after calling close()") + } + } + + @Test + fun `should not read until started`() = runTest { + val server = StdioServerTransport(bufferedInput, printOutput) + server.onError { error -> + throw error + } + + var didRead = false + val readMessage = CompletableDeferred() + + server.onMessage { message -> + didRead = true + readMessage.complete(message) + } + + val message = PingRequest().toJSON() + + // Push a message before the server started + val serialized = serializeMessage(message) + inputWriter.write(serialized) + inputWriter.flush() + + assertFalse(didRead, "Should not have read message before start") + + server.start() + val received = readMessage.await() + assertEquals(message, received) + } + + @Test + fun `should read multiple messages`() = runTest { + val server = StdioServerTransport(bufferedInput, printOutput) + server.onError { error -> + throw error + } + + val messages = listOf( + PingRequest().toJSON(), + InitializedNotification().toJSON(), + ) + + val readMessages = mutableListOf() + val finished = CompletableDeferred() + + server.onMessage { message -> + readMessages.add(message) + if (message == messages[1]) { + finished.complete(Unit) + } + } + + // Push both messages before starting the server + for (m in messages) { + inputWriter.write(serializeMessage(m)) + } + inputWriter.flush() + + server.start() + finished.await() + + assertEquals(messages, readMessages) + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index be1e64e8..ff46263d 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -1,11 +1,11 @@ package io.modelcontextprotocol.kotlin.sdk.server -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.PingRequest import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage -import io.modelcontextprotocol.kotlin.sdk.toJSON +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.toJSON import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt new file mode 100644 index 00000000..d863e008 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -0,0 +1,1039 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.ServerSession +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult +import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.types.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.Root +import io.modelcontextprotocol.kotlin.sdk.types.RootsListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.Tool +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.cancel +import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import kotlin.coroutines.cancellation.CancellationException +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail + +class ClientTest { + @Test + fun `should initialize with matching protocol version`() = runTest { + var initialised = false + val clientTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + initialised = true + val result = InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + + _onMessage.invoke(response) + } + + override suspend fun close() { + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions( + capabilities = ClientCapabilities( + sampling = EmptyJsonObject, + ), + ), + ) + + client.connect(clientTransport) + assertTrue(initialised) + } + + @Test + fun `should initialize with supported older protocol version`() = runTest { + val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1] + val clientTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + + val result = InitializeResult( + protocolVersion = oldVersion, + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + _onMessage.invoke(response) + } + + override suspend fun close() { + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions( + capabilities = ClientCapabilities( + sampling = EmptyJsonObject, + ), + ), + ) + + client.connect(clientTransport) + assertEquals( + Implementation("test", "1.0"), + client.serverVersion, + ) + } + + @Test + fun `should reject unsupported protocol version`() = runTest { + var closed = false + val clientTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + + val result = InitializeResult( + protocolVersion = "invalid-version", + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + + _onMessage.invoke(response) + } + + override suspend fun close() { + closed = true + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions(), + ) + + assertFailsWith("Server's protocol version is not supported: invalid-version") { + client.connect(clientTransport) + } + + assertTrue(closed) + } + + @Test + fun `should reject due to non cancellation exception`() = runTest { + var closed = false + val failingTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw IllegalStateException("Test error") + } + + override suspend fun close() { + closed = true + } + } + + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions(), + ) + + val exception = assertFailsWith { + client.connect(failingTransport) + } + + assertEquals("Error connecting to transport: Test error", exception.message) + + assertTrue(closed) + } + + @Test + fun `should respect server capabilities`() = runTest { + val serverOptions = ServerOptions( + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null), + ), + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions, + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null), + ), + serverInfo = Implementation(name = "test", version = "1.0"), + ) + } + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + ListResourcesResult(resources = emptyList(), nextCursor = null) + } + + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + ListToolsResult(tools = emptyList(), nextCursor = null) + } + // Server supports resources and tools, but not prompts + val caps = client.serverCapabilities + assertEquals(ServerCapabilities.Resources(null, null), caps?.resources) + assertEquals(ServerCapabilities.Tools(null), caps?.tools) + assertTrue(caps?.prompts == null) // or check that prompts are absent + + // These should not throw + client.listResources() + client.listTools() + + // This should fail because prompts are not supported + val ex = assertFailsWith { + client.listPrompts() + } + assertTrue(ex.message?.contains("Server does not support prompts") == true) + } + + @Test + fun `should respect client notification capabilities`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions(capabilities = ServerCapabilities()), + ) + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + server.createSession(serverTransport) + println("Server connected") + }, + ).joinAll() + + // This should not throw because the client supports roots.listChanged + client.sendRootsListChanged() + + // Create a new client without the roots.listChanged capability + val clientWithoutCapability = Client( + clientInfo = Implementation(name = "test client without capability", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(), + // enforceStrictCapabilities = true // TODO() + ), + ) + + clientWithoutCapability.connect(clientTransport) + // Using the same transport pair might not be realistic - in a real scenario you'd create another pair. + // Adjust if necessary. + + // This should fail + val ex = assertFailsWith { + clientWithoutCapability.sendRootsListChanged() + } + assertTrue(ex.message?.startsWith("Client does not support") == true) + } + + @Test + fun `should respect server notification capabilities`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions( + capabilities = ServerCapabilities( + logging = EmptyJsonObject, + resources = ServerCapabilities.Resources(listChanged = true, subscribe = null), + ), + ), + ) + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + // These should not throw + val jsonObject = buildJsonObject { + put("name", "John") + put("age", 30) + put("isStudent", false) + } + serverSession.sendLoggingMessage( + LoggingMessageNotification( + params = LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = jsonObject, + ), + ), + ) + serverSession.sendResourceListChanged() + + // This should fail because the server doesn't have the tools capability + val ex = assertFailsWith { + serverSession.sendToolListChanged() + } + assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true) + } + + @Test + fun `should handle client cancelling a request`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions( + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + listChanged = null, + subscribe = null, + ), + ), + ), + ) + + val def = CompletableDeferred() + val defTimeOut = CompletableDeferred() + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate delay + def.complete(Unit) + try { + delay(1000) + } catch (e: CancellationException) { + defTimeOut.complete(Unit) + throw e + } + ListResourcesResult(resources = emptyList()) + fail("Shouldn't have been called") + } + + val defCancel = CompletableDeferred() + val job = launch { + try { + client.listResources() + } catch (e: CancellationException) { + defCancel.complete(Unit) + assertEquals("Cancelled by test", e.message) + } + } + def.await() + runCatching { job.cancel("Cancelled by test") } + defCancel.await() + defTimeOut.await() + } + + @Test + fun `should handle request timeout`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions( + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + listChanged = null, + subscribe = null, + ), + ), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate a delayed response + // Wait ~100ms unless canceled + try { + withTimeout(100L) { + // Just delay here, if timeout is 0 on the client side, this won't return in time + delay(100) + } + } catch (_: Exception) { + // If aborted, just rethrow or return early + } + ListResourcesResult(resources = emptyList()) + } + + // Request with 1 msec timeout should fail immediately + val ex = assertFailsWith { + withTimeout(1) { + client.listResources() + } + } + assertTrue(ex is TimeoutCancellationException) + } + + @Test + fun `should only allow setRequestHandler for declared capabilities`() = runTest { + val client = Client( + clientInfo = Implementation( + name = "test client", + version = "1.0", + ), + options = ClientOptions( + capabilities = ClientCapabilities( + sampling = EmptyJsonObject, + ), + ), + ) + + client.setRequestHandler(Method.Defined.SamplingCreateMessage) { _, _ -> + CreateMessageResult( + model = "test-model", + role = Role.Assistant, + content = TextContent( + text = "Test response", + ), + ) + } + + assertFailsWith("Client does not support roots capability (required for RootsList)") { + client.setRequestHandler(Method.Defined.RootsList) { _, _ -> null } + } + } + + @Test + fun `JSONRPCRequest with ToolsList method and default params returns list of tools`() = runTest { + val serverOptions = ServerOptions( + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(null), + ), + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions, + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ), + ) + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null), + ), + serverInfo = Implementation(name = "test", version = "1.0"), + ) + } + + val serverListToolsResult = ListToolsResult( + tools = listOf( + Tool( + name = "testTool", + title = "testTool title", + description = "testTool description", + annotations = null, + inputSchema = ToolSchema(), + outputSchema = null, + ), + ), + nextCursor = null, + ) + + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + serverListToolsResult + } + + val serverCapabilities = client.serverCapabilities + assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) + + val request = JSONRPCRequest( + method = Method.Defined.ToolsList.value, + ) + clientTransport.send(request) + + assertIs(receivedMessage) + val receivedAsResponse = receivedMessage as JSONRPCResponse + assertEquals(request.id, receivedAsResponse.id) + assertEquals(request.jsonrpc, receivedAsResponse.jsonrpc) + assertEquals(serverListToolsResult, receivedAsResponse.result) + } + + @Test + fun `listRoots returns list of roots`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null), + ), + ), + ) + + val clientRoots = listOf( + Root(uri = "file:///test-root", name = "testRoot"), + ) + + client.addRoots(clientRoots) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities(), + ), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + val clientCapabilities = serverSession.clientCapabilities + assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots) + + val listRootsResult = serverSession.listRoots() + + assertEquals(listRootsResult.roots, clientRoots) + } + + @Test + fun `addRoot should throw when roots capability is not supported`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + // Verify that adding a root throws an exception + val exception = assertFailsWith { + client.addRoot(uri = "file:///test-root1", name = "testRoot1") + } + assertEquals("Client does not support roots capability.", exception.message) + } + + @Test + fun `removeRoot should throw when roots capability is not supported`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + // Verify that removing a root throws an exception + val exception = assertFailsWith { + client.removeRoot(uri = "file:///test-root1") + } + assertEquals("Client does not support roots capability.", exception.message) + } + + @Test + fun `removeRoot should remove a root`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null), + ), + ), + ) + + // Add some roots + client.addRoots( + listOf( + Root(uri = "file:///test-root1", name = "testRoot1"), + Root(uri = "file:///test-root2", name = "testRoot2"), + ), + ) + + // Remove a root + val result = client.removeRoot("file:///test-root1") + + // Verify the root was removed + assertTrue(result, "Root should be removed successfully") + } + + @Test + fun `removeRoots should remove multiple roots`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null), + ), + ), + ) + + // Add some roots + client.addRoots( + listOf( + Root(uri = "file:///test-root1", name = "testRoot1"), + Root(uri = "file:///test-root2", name = "testRoot2"), + ), + ) + + // Remove multiple roots + val result = client.removeRoots( + listOf("file:///test-root1", "file:///test-root2"), + ) + + // Verify the root was removed + assertEquals(2, result, "Both roots should be removed") + } + + @Test + fun `sendRootsListChanged should notify server`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities(), + ), + ) + + // Track notifications + var rootListChangedNotificationReceived = false + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + serverSession.setNotificationHandler( + Method.Defined.NotificationsRootsListChanged, + ) { + rootListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + + client.sendRootsListChanged() + + assertTrue( + rootListChangedNotificationReceived, + "Notification should be sent when sendRootsListChanged is called", + ) + } + + @Test + fun `should reject server elicitation when elicitation capability is not supported`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities(), + ), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + // Verify that creating an elicitation throws an exception + val exception = assertFailsWith { + serverSession.createElicitation( + message = "Please provide your GitHub username", + requestedSchema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + putJsonObject("name") { + put("type", "string") + } + }, + required = listOf("name"), + ), + ) + } + assertEquals( + "Client does not support elicitation (required for elicitation/create)", + exception.message, + ) + } + + @Test + fun `should handle logging setLevel request`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions( + capabilities = ServerCapabilities( + logging = EmptyJsonObject, + ), + ), + ) + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val receivedMessages = mutableListOf() + client.setNotificationHandler(Method.Defined.NotificationsMessage) { notification -> + receivedMessages.add(notification) + CompletableDeferred(Unit) + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + // Set logging level to warning + val minLevel = LoggingLevel.Warning + val result = client.setLoggingLevel(minLevel) + assertNull(result.meta) + + // Send messages of different levels + val testMessages = listOf( + LoggingLevel.Debug to "Debug - should be filtered", + LoggingLevel.Info to "Info - should be filtered", + LoggingLevel.Warning to "Warning - should pass", + LoggingLevel.Error to "Error - should pass", + ) + + testMessages.forEach { (level, message) -> + serverSession.sendLoggingMessage( + LoggingMessageNotification( + params = LoggingMessageNotificationParams( + level = level, + data = buildJsonObject { put("message", message) }, + ), + ), + ) + } + + delay(100) + + // Only warning and error should be received + assertEquals(2, receivedMessages.size, "Should receive only 2 messages (warning and error)") + + // Verify all received messages have severity >= minLevel + receivedMessages.forEach { message -> + val messageSeverity = message.params.level.ordinal + assertTrue( + messageSeverity >= minLevel.ordinal, + "Received message with level ${message.params.level} should have severity >= $minLevel", + ) + } + } + + @Test + fun `should handle server elicitation`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + elicitation = EmptyJsonObject, + ), + ), + ) + + val elicitationMessage = "Please provide your GitHub username" + val requestedSchema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + putJsonObject("name") { + put("type", "string") + } + }, + required = listOf("name"), + ) + + val elicitationResultAction = ElicitResult.Action.Accept + val elicitationResultContent = buildJsonObject { + put("name", "octocat") + } + + client.setElicitationHandler { request -> + assertEquals(elicitationMessage, request.params.message) + assertEquals(requestedSchema, request.params.requestedSchema) + + ElicitResult( + action = elicitationResultAction, + content = elicitationResultContent, + ) + } + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities(), + ), + ) + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.createSession(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + val result = serverSession.createElicitation( + message = elicitationMessage, + requestedSchema = requestedSchema, + ) + + assertEquals(elicitationResultAction, result.action) + assertEquals(elicitationResultContent, result.content) + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaSseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaSseTransportTest.kt new file mode 100644 index 00000000..3c2d75a7 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/OldSchemaSseTransportTest.kt @@ -0,0 +1,123 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.ktor.client.HttpClient +import io.ktor.server.application.install +import io.ktor.server.cio.CIO +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.modelcontextprotocol.kotlin.sdk.shared.BaseTransportTest +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Ignore +import kotlin.test.Test +import io.ktor.client.plugins.sse.SSE as ClientSSE +import io.ktor.server.sse.SSE as ServerSSE + +class OldSchemaSseTransportTest : BaseTransportTest() { + + private suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().first().port + + private lateinit var mcpServer: Server + + @BeforeTest + fun setUp() { + mcpServer = Server( + serverInfo = Implementation( + name = "test-server", + version = "1.0", + ), + options = ServerOptions(ServerCapabilities()), + ) + } + + @Test + @Ignore // Ignored because it doesn’t work with wasm/js in Ktor 3.2.3 + fun `should start then close cleanly`() = runTest { + val server = embeddedServer(CIO, port = 0) { + install(ServerSSE) + routing { + mcp { mcpServer } + } + }.startSuspend(wait = false) + + val actualPort = server.actualPort() + + val transport = HttpClient { + install(ClientSSE) + }.mcpSseTransport { + url { + host = "localhost" + this.port = actualPort + } + } + + try { + testTransportRead(transport) + } finally { + server.stopSuspend() + } + } + + @Ignore + @Test + fun `should read messages`() = runTest { + val server = embeddedServer(CIO, port = 0) { + install(ServerSSE) + routing { + mcp { mcpServer } + } + }.startSuspend(wait = false) + + val actualPort = server.actualPort() + + val transport = HttpClient { + install(ClientSSE) + }.mcpSseTransport { + url { + host = "localhost" + this.port = actualPort + } + } + + try { + testTransportRead(transport) + } finally { + server.stopSuspend() + } + } + + @Ignore + @Test + fun `test sse path not root path`() = runTest { + val server = embeddedServer(CIO, port = 0) { + install(ServerSSE) + routing { + mcp { mcpServer } + } + }.startSuspend(wait = false) + + val actualPort = server.actualPort() + + val transport = HttpClient { + install(ClientSSE) + }.mcpSseTransport { + url { + host = "localhost" + this.port = actualPort + pathSegments = listOf("sse") + } + } + + try { + testTransportRead(transport) + } finally { + server.stopSuspend() + } + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index e81332e6..d82ff47a 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -6,12 +6,12 @@ import io.ktor.server.cio.CIO import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.routing -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp import io.modelcontextprotocol.kotlin.sdk.shared.BaseTransportTest +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import kotlinx.coroutines.test.runTest import kotlin.test.BeforeTest import kotlin.test.Ignore diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt index 3e7cdb2a..d13f5848 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt @@ -1,9 +1,9 @@ package io.modelcontextprotocol.kotlin.sdk.integration -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport -import io.modelcontextprotocol.kotlin.sdk.toJSON +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.toJSON import kotlinx.coroutines.test.runTest import kotlin.test.BeforeTest import kotlin.test.Test diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/OldSchemaInMemoryTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/OldSchemaInMemoryTransportTest.kt new file mode 100644 index 00000000..5f31bc12 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/OldSchemaInMemoryTransportTest.kt @@ -0,0 +1,110 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.toJSON +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class OldSchemaInMemoryTransportTest { + private lateinit var clientTransport: InMemoryTransport + private lateinit var serverTransport: InMemoryTransport + + @BeforeTest + fun setUp() { + val (client, server) = InMemoryTransport.createLinkedPair() + clientTransport = client + serverTransport = server + } + + @Test + fun `should create linked pair`() { + assertNotNull(clientTransport) + assertNotNull(serverTransport) + } + + @Test + fun `should start without error`() = runTest { + clientTransport.start() + serverTransport.start() + // If no exception is thrown, the test passes + } + + @Test + fun `should send message from client to server`() = runTest { + val message = InitializedNotification() + + var receivedMessage: JSONRPCMessage? = null + serverTransport.onMessage { msg -> + receivedMessage = msg + } + + val rpcNotification = message.toJSON() + clientTransport.send(rpcNotification) + assertEquals(rpcNotification, receivedMessage) + } + + @Test + fun `should send message from server to client`() = runTest { + val message = InitializedNotification() + .toJSON() + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + serverTransport.send(message) + assertEquals(message, receivedMessage) + } + + @Test + fun `should handle close`() = runTest { + var clientClosed = false + var serverClosed = false + + clientTransport.onClose { + clientClosed = true + } + + serverTransport.onClose { + serverClosed = true + } + + clientTransport.close() + assertTrue(clientClosed) + assertTrue(serverClosed) + } + + @Test + fun `should throw error when sending after close`() = runTest { + clientTransport.close() + + assertFailsWith { + clientTransport.send( + InitializedNotification().toJSON(), + ) + } + } + + @Test + fun `should queue messages sent before start`() = runTest { + val message = InitializedNotification() + .toJSON() + + var receivedMessage: JSONRPCMessage? = null + serverTransport.onMessage { msg -> + receivedMessage = msg + } + + clientTransport.send(message) + serverTransport.start() + assertEquals(message, receivedMessage) + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 26404b62..1929facf 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -1,9 +1,9 @@ package io.modelcontextprotocol.kotlin.sdk.shared -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.PingRequest -import io.modelcontextprotocol.kotlin.sdk.toJSON +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.toJSON import kotlinx.coroutines.CompletableDeferred import kotlin.test.assertEquals import kotlin.test.assertFalse diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt index fd567c75..5c5fdd7c 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt @@ -1,6 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage /** * In-memory transport for creating clients and servers that talk to each other within the same process. diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaBaseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaBaseTransportTest.kt new file mode 100644 index 00000000..d2d0402e --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaBaseTransportTest.kt @@ -0,0 +1,63 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.fail + +abstract class OldSchemaBaseTransportTest { + + protected suspend fun testTransportOpenClose(transport: Transport) { + transport.onError { error -> + fail("Unexpected error: $error") + } + + var didClose = false + transport.onClose { didClose = true } + + transport.start() + assertFalse(didClose, "Transport should not be closed immediately after start") + + transport.close() + assertTrue(didClose, "Transport should be closed after close() call") + } + + protected suspend fun testTransportRead(transport: Transport) { + transport.onError { error -> + error.printStackTrace() + fail("Unexpected error: $error") + } + + val messages = listOf( + PingRequest().toJSON(), + InitializedNotification().toJSON(), + ) + + val readMessages = mutableListOf() + val finished = CompletableDeferred() + + transport.onMessage { message -> + readMessages.add(message) + if (message == messages.last()) { + finished.complete(Unit) + } + } + + transport.start() + + for (message in messages) { + transport.send(message) + } + + finished.await() + + assertEquals(messages, readMessages, "Assert messages received") + + transport.close() + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaInMemoryTransport.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaInMemoryTransport.kt new file mode 100644 index 00000000..e7889184 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/OldSchemaInMemoryTransport.kt @@ -0,0 +1,47 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage + +/** + * In-memory transport for creating clients and servers that talk to each other within the same process. + */ +class OldSchemaInMemoryTransport : AbstractTransport() { + private var otherTransport: OldSchemaInMemoryTransport? = null + private val messageQueue: MutableList = mutableListOf() + + /** + * Creates a pair of linked in-memory transports that can communicate with each other. + * One should be passed to a Client and one to a Server. + */ + companion object { + fun createLinkedPair(): Pair { + val clientTransport = OldSchemaInMemoryTransport() + val serverTransport = OldSchemaInMemoryTransport() + clientTransport.otherTransport = serverTransport + serverTransport.otherTransport = clientTransport + return Pair(clientTransport, serverTransport) + } + } + + override suspend fun start() { + // Process any messages that were queued before start was called + while (messageQueue.isNotEmpty()) { + messageQueue.removeFirstOrNull()?.let { message -> + _onMessage.invoke(message) // todo? + } + } + } + + override suspend fun close() { + val other = otherTransport + otherTransport = null + other?.close() + _onClose.invoke() + } + + override suspend fun send(message: JSONRPCMessage) { + val other = otherTransport ?: throw IllegalStateException("Not connected") + + other._onMessage.invoke(message) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt new file mode 100644 index 00000000..b66ac7bd --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -0,0 +1,694 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractPromptIntegrationTest : KotlinTestBase() { + + private val basicPromptName = "basic-prompt" + private val basicPromptDescription = "A basic prompt for testing" + + private val complexPromptName = "multimodal-prompt" + private val complexPromptDescription = "A prompt with multiple content types" + private val conversationPromptName = "conversation" + private val conversationPromptDescription = "A prompt with multiple messages and roles" + private val strictPromptName = "strict-prompt" + private val strictPromptDescription = "A prompt with required arguments" + + private val largePromptName = "large-prompt" + private val largePromptDescription = "A very large prompt for testing" + private val largePromptContent = "X".repeat(100_000) // 100KB of data + + private val specialCharsPromptName = "special-chars-prompt" + private val specialCharsPromptDescription = "A prompt with special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + // basic prompt with a name parameter + server.addPrompt( + name = basicPromptName, + description = basicPromptDescription, + arguments = listOf( + PromptArgument( + name = "name", + description = "The name to greet", + required = true, + ), + ), + ) { request -> + val name = request.params.arguments?.get("name") ?: "World" + + GetPromptResult( + description = basicPromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Hello, $name!"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent(text = "Greetings, $name! How can I assist you today?"), + ), + ), + ) + } + + // special chars prompt + server.addPrompt( + name = specialCharsPromptName, + description = specialCharsPromptDescription, + arguments = listOf( + PromptArgument( + name = "special", + description = "Special characters to include", + required = false, + ), + ), + ) { request -> + val special = request.params.arguments?.get("special") ?: specialCharsContent + + GetPromptResult( + description = specialCharsPromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Special characters: $special"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent(text = "Received special characters: $special"), + ), + ), + ) + } + + // very large prompt + server.addPrompt( + name = largePromptName, + description = largePromptDescription, + arguments = listOf( + PromptArgument( + name = "size", + description = "Size multiplier", + required = false, + ), + ), + ) { request -> + val size = request.params.arguments?.get("size")?.toIntOrNull() ?: 1 + val content = largePromptContent.repeat(size) + + GetPromptResult( + description = largePromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Generate a large response"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent(text = content), + ), + ), + ) + } + + // complext prompt + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument(name = "arg1", description = "Argument 1", required = true), + PromptArgument(name = "arg2", description = "Argument 2", required = true), + PromptArgument(name = "arg3", description = "Argument 3", required = true), + PromptArgument(name = "arg4", description = "Argument 4", required = false), + PromptArgument(name = "arg5", description = "Argument 5", required = false), + PromptArgument(name = "arg6", description = "Argument 6", required = false), + PromptArgument(name = "arg7", description = "Argument 7", required = false), + PromptArgument(name = "arg8", description = "Argument 8", required = false), + PromptArgument(name = "arg9", description = "Argument 9", required = false), + PromptArgument(name = "arg10", description = "Argument 10", required = false), + ), + ) { request -> + // validate required arguments + val requiredArgs = listOf("arg1", "arg2", "arg3") + for (argName in requiredArgs) { + if (request.params.arguments?.get(argName) == null) { + throw IllegalArgumentException("Missing required argument: $argName") + } + } + + val args = mutableMapOf() + for (i in 1..10) { + val argName = "arg$i" + val argValue = request.params.arguments?.get(argName) + if (argValue != null) { + args[argName] = argValue + } + } + + GetPromptResult( + description = complexPromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent( + text = "Arguments: ${ + args.entries.joinToString { + "${it.key}=${it.value}" + } + }", + ), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent(text = "Received ${args.size} arguments"), + ), + ), + ) + } + + // prompt with multiple messages and roles + server.addPrompt( + name = conversationPromptName, + description = conversationPromptDescription, + arguments = listOf( + PromptArgument( + name = "topic", + description = "The topic of the conversation", + required = false, + ), + ), + ) { request -> + val topic = request.params.arguments?.get("topic") ?: "weather" + + GetPromptResult( + description = conversationPromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Let's talk about the $topic."), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent( + text = "Sure, I'd love to discuss the $topic. What would you like to know?", + ), + ), + PromptMessage( + role = Role.User, + content = TextContent(text = "What's your opinion on the $topic?"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent( + text = "As an AI, I don't have personal opinions," + + " but I can provide information about $topic.", + ), + ), + PromptMessage( + role = Role.User, + content = TextContent(text = "That's helpful, thank you!"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent( + text = "You're welcome! Let me know if you have more questions about $topic.", + ), + ), + ), + ) + } + + // prompt with strict required arguments + server.addPrompt( + name = strictPromptName, + description = strictPromptDescription, + arguments = listOf( + PromptArgument( + name = "requiredArg1", + description = "First required argument", + required = true, + ), + PromptArgument( + name = "requiredArg2", + description = "Second required argument", + required = true, + ), + PromptArgument( + name = "optionalArg", + description = "Optional argument", + required = false, + ), + ), + ) { request -> + val args = request.params.arguments ?: emptyMap() + val arg1 = args["requiredArg1"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg1", + ) + val arg2 = args["requiredArg2"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg2", + ) + val optArg = args["optionalArg"] ?: "default" + + GetPromptResult( + description = strictPromptDescription, + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Required arguments: $arg1, $arg2. Optional: $optArg"), + ), + PromptMessage( + role = Role.Assistant, + content = TextContent(text = "I received your arguments: $arg1, $arg2, and $optArg"), + ), + ), + ) + } + } + + @Test + fun testListPrompts() = runBlocking(Dispatchers.IO) { + val result = client.listPrompts() + + assertNotNull(result, "List prompts result should not be null") + assertTrue(result.prompts.isNotEmpty(), "Prompts list should not be empty") + + val testPrompt = result.prompts.find { it.name == basicPromptName } + assertNotNull(testPrompt, "Test prompt should be in the list") + assertEquals( + basicPromptDescription, + testPrompt.description, + "Prompt description should match", + ) + + val arguments = testPrompt.arguments ?: error("Prompt arguments should not be null") + assertTrue(arguments.isNotEmpty(), "Prompt arguments should not be empty") + + val nameArg = arguments.find { it.name == "name" } + assertNotNull(nameArg, "Name argument should be in the list") + assertEquals( + "The name to greet", + nameArg.description, + "Argument description should match", + ) + assertEquals(true, nameArg.required, "Argument required flag should match") + } + + @Test + fun testGetPrompt() = runBlocking(Dispatchers.IO) { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + basicPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.User } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertNotNull(userContent.text, "User message text should not be null") + assertEquals( + "Hello, $testName!", + userContent.text, + "User message content should match", + ) + + val assistantMessage = result.messages.find { it.role == Role.Assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertNotNull(assistantContent.text, "Assistant message text should not be null") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + + @Test + fun testMissingRequiredArguments() = runBlocking(Dispatchers.IO) { + val promptsList = client.listPrompts() + assertNotNull(promptsList, "Prompts list should not be null") + val strictPrompt = promptsList.prompts.find { it.name == strictPromptName } + assertNotNull(strictPrompt, "Strict prompt should be in the list") + + val argsDef = strictPrompt.arguments ?: error("Prompt arguments should not be null") + val requiredArgs = argsDef.filter { it.required == true } + assertEquals( + 2, + requiredArgs.size, + "Strict prompt should have 2 required arguments", + ) + + // test missing required arg + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = strictPromptName, + arguments = mapOf("requiredArg1" to "value1"), + ), + ), + ) + } + } + + assertTrue(exception.message.contains("requiredArg2"), "Exception should mention the missing argument") + + // test with no args + val exception2 = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = strictPromptName, + arguments = emptyMap(), + ), + ), + ) + } + } + + assertTrue(exception2.message.contains("requiredArg"), "Exception should mention a missing required argument") + + // test with all required args + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = strictPromptName, + arguments = mapOf( + "requiredArg1" to "value1", + "requiredArg2" to "value2", + ), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.User } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = requireNotNull(userContent.text) + assertTrue(userText.contains("value1"), "Message should contain first argument") + assertTrue(userText.contains("value2"), "Message should contain second argument") + } + + @Test + fun testMultipleMessagesAndRoles() = runBlocking(Dispatchers.IO) { + val topic = "climate change" + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = conversationPromptName, + arguments = mapOf("topic" to topic), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + conversationPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(6, result.messages.size, "Prompt should have 6 messages") + + val userMessages = result.messages.filter { it.role == Role.User } + val assistantMessages = result.messages.filter { it.role == Role.Assistant } + + assertEquals(3, userMessages.size, "Should have 3 user messages") + assertEquals(3, assistantMessages.size, "Should have 3 assistant messages") + + for (i in 0 until result.messages.size) { + val expectedRole = if (i % 2 == 0) Role.User else Role.Assistant + assertEquals( + expectedRole, + result.messages[i].role, + "Message $i should have role $expectedRole", + ) + } + + for (message in result.messages) { + val content = message.content as? TextContent + assertNotNull(content, "Message content should be TextContent") + val text = requireNotNull(content.text) + + // Either the message contains the topic or it's a generic conversation message + val containsTopic = text.contains(topic) + val isGenericMessage = text.contains("thank you") || text.contains("welcome") + + assertTrue( + containsTopic || isGenericMessage, + "Message should either contain the topic or be a generic conversation message", + ) + } + } + + @Test + fun testBasicPrompt() = runBlocking(Dispatchers.IO) { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(basicPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.User } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertEquals("Hello, $testName!", userContent.text, "User message content should match") + + val assistantMessage = result.messages.find { it.role == Role.Assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + + @Test + fun testComplexPromptWithManyArguments() = runBlocking(Dispatchers.IO) { + val arguments = (1..10).associate { i -> "arg$i" to "value$i" } + + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = complexPromptName, + arguments = arguments, + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(complexPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.User } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + + // verify all arguments + val text = userContent.text + for (i in 1..10) { + assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") + } + + val assistantMessage = result.messages.find { it.role == Role.Assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Received 10 arguments", + assistantContent.text, + "Assistant message should indicate 10 arguments", + ) + } + + @Test + fun testLargePrompt() = runBlocking(Dispatchers.IO) { + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = largePromptName, + arguments = mapOf("size" to "1"), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(largePromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val assistantMessage = result.messages.find { it.role == Role.Assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val text = assistantContent.text + assertEquals(100_000, text.length, "Assistant message should be 100KB in size") + } + + @Test + fun testSpecialCharacters() = runBlocking(Dispatchers.IO) { + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = specialCharsPromptName, + arguments = mapOf("special" to specialCharsContent), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.User } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = userContent.text + assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") + + val assistantMessage = result.messages.find { it.role == Role.Assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = assistantContent.text + assertTrue( + assistantText.contains(specialCharsContent), + "Assistant message should contain special characters", + ) + } + + @Test + fun testConcurrentPromptRequests() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val promptName = when (index % 4) { + 0 -> basicPromptName + 1 -> complexPromptName + 2 -> largePromptName + else -> specialCharsPromptName + } + + val arguments = when (promptName) { + basicPromptName -> mapOf("name" to "User$index") + complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") + largePromptName -> mapOf("size" to "1") + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = promptName, + arguments = arguments, + ), + ), + ) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") + } + } + + @Test + fun testNonExistentPrompt() = runTest { + val nonExistentPromptName = "non-existent-prompt" + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = nonExistentPromptName, + arguments = mapOf("name" to "Test"), + ), + ), + ) + } + } + + val expectedMessage = "MCP error -32603: Prompt not found: non-existent-prompt" + + assertEquals( + RPCError.ErrorCode.INTERNAL_ERROR, + exception.code, + "Exception code should be INTERNAL_ERROR: ${RPCError.ErrorCode.INTERNAL_ERROR}", + ) + assertEquals(expectedMessage, exception.message, "Unexpected error message for non-existent prompt") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt new file mode 100644 index 00000000..fc088103 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -0,0 +1,317 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.SubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.types.SubscribeRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.UnsubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.types.UnsubscribeRequestParams +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.Ignore +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractResourceIntegrationTest : KotlinTestBase() { + + private val testResourceUri = "test://example.txt" + private val testResourceName = "Test Resource" + private val testResourceDescription = "A test resource for integration testing" + private val testResourceContent = "This is the content of the test resource." + + private val binaryResourceUri = "test://image.png" + private val binaryResourceName = "Binary Resource" + private val binaryResourceDescription = "A binary resource for testing" + private val binaryResourceContent = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + private val largeResourceUri = "test://large.txt" + private val largeResourceName = "Large Resource" + private val largeResourceDescription = "A large text resource for testing" + private val largeResourceContent = "X".repeat(100_000) // 100KB of data + + private val dynamicResourceUri = "test://dynamic.txt" + private val dynamicResourceName = "Dynamic Resource" + private val dynamicResourceContent = AtomicBoolean(false) + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = binaryResourceUri, + name = binaryResourceName, + description = binaryResourceDescription, + mimeType = "image/png", + ) { request -> + ReadResourceResult( + contents = listOf( + BlobResourceContents( + blob = binaryResourceContent, + uri = request.params.uri, + mimeType = "image/png", + ), + ), + ) + } + + server.addResource( + uri = largeResourceUri, + name = largeResourceName, + description = largeResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = largeResourceContent, + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = dynamicResourceUri, + name = dynamicResourceName, + description = "A resource that can be updated", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = if (dynamicResourceContent.get()) "Updated content" else "Original content", + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + } + + @Test + fun testListResources() = runBlocking(Dispatchers.IO) { + val result = client.listResources() + + assertNotNull(result, "List resources result should not be null") + assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") + + val testResource = result.resources.find { it.uri == testResourceUri } + assertNotNull(testResource, "Test resource should be in the list") + assertEquals(testResourceName, testResource.name, "Resource name should match") + assertEquals(testResourceDescription, testResource.description, "Resource description should match") + } + + @Test + fun testReadResource() = runBlocking(Dispatchers.IO) { + val result = client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = testResourceUri))) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(testResourceContent, content.text, "Resource content should match") + } + + @Ignore("Blocked by https://github.com/modelcontextprotocol/kotlin-sdk/issues/249") + @Test + fun testSubscribeAndUnsubscribe() { + runBlocking(Dispatchers.IO) { + val subscribeResult = + client.subscribeResource(SubscribeRequest(SubscribeRequestParams(uri = testResourceUri))) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = + client.unsubscribeResource(UnsubscribeRequest(UnsubscribeRequestParams(uri = testResourceUri))) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } + } + + @Test + fun testBinaryResource() = runBlocking(Dispatchers.IO) { + val result = client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = binaryResourceUri))) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? BlobResourceContents + assertNotNull(content, "Resource content should be BlobResourceContents") + assertEquals(binaryResourceContent, content.blob, "Binary resource content should match") + assertEquals("image/png", content.mimeType, "MIME type should match") + } + + @Test + fun testLargeResource() = runBlocking(Dispatchers.IO) { + val result = client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = largeResourceUri))) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(100_000, content.text.length, "Large resource content length should match") + assertEquals("X".repeat(100_000), content.text, "Large resource content should match") + } + + @Test + fun testInvalidResourceUri() = runTest { + val invalidUri = "test://nonexistent.txt" + + val exception = assertThrows { + runBlocking { + client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = invalidUri))) + } + } + + val expectedMessage = "MCP error -32603: Resource not found: test://nonexistent.txt" + + assertEquals( + RPCError.ErrorCode.INTERNAL_ERROR, + exception.code, + "Exception code should be INTERNAL_ERROR: ${RPCError.ErrorCode.INTERNAL_ERROR}", + ) + assertEquals(expectedMessage, exception.message, "Unexpected error message for invalid resource URI") + } + + @Test + fun testDynamicResource() = runBlocking(Dispatchers.IO) { + val initialResult = + client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = dynamicResourceUri))) + assertNotNull(initialResult, "Initial read result should not be null") + val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Original content", initialContent, "Initial content should match") + + // update resource + dynamicResourceContent.set(true) + + val updatedResult = + client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = dynamicResourceUri))) + assertNotNull(updatedResult, "Updated read result should not be null") + val updatedContent = (updatedResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Updated content", updatedContent, "Updated content should match") + } + + @Test + fun testResourceAddAndRemove() = runBlocking(Dispatchers.IO) { + val initialList = client.listResources() + assertNotNull(initialList, "Initial list result should not be null") + val initialCount = initialList.resources.size + + val newResourceUri = "test://new-resource.txt" + server.addResource( + uri = newResourceUri, + name = "New Resource", + description = "A newly added resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "New resource content", + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + val updatedList = client.listResources() + assertNotNull(updatedList, "Updated list result should not be null") + val updatedCount = updatedList.resources.size + + assertEquals(initialCount + 1, updatedCount, "Resource count should increase by 1") + val newResource = updatedList.resources.find { it.uri == newResourceUri } + assertNotNull(newResource, "New resource should be in the list") + + server.removeResource(newResourceUri) + + val finalList = client.listResources() + assertNotNull(finalList, "Final list result should not be null") + val finalCount = finalList.resources.size + + assertEquals(initialCount, finalCount, "Resource count should return to initial value") + val removedResource = finalList.resources.find { it.uri == newResourceUri } + assertEquals(null, removedResource, "Resource should be removed from the list") + } + + @Test + fun testConcurrentResourceOperations() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val uri = when (index % 3) { + 0 -> testResourceUri + 1 -> binaryResourceUri + else -> largeResourceUri + } + + val result = client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = uri))) + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt new file mode 100644 index 00000000..7da82cc3 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -0,0 +1,794 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.kotest.assertions.json.shouldEqualJson +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock +import io.modelcontextprotocol.kotlin.sdk.types.ImageContent +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import java.text.DecimalFormat +import java.text.DecimalFormatSymbols +import java.util.Locale +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractToolIntegrationTest : KotlinTestBase() { + private val testToolName = "echo" + private val testToolDescription = "A simple echo tool that returns the input text" + private val complexToolName = "calculator" + private val complexToolDescription = "A calculator tool that performs operations on numbers" + private val errorToolName = "error-tool" + private val errorToolDescription = "A tool that demonstrates error handling" + private val multiContentToolName = "multi-content" + private val multiContentToolDescription = "A tool that returns multiple content types" + + private val basicToolName = "basic-tool" + private val basicToolDescription = "A basic tool for testing" + + private val largeToolName = "large-tool" + private val largeToolDescription = "A tool that returns a large response" + private val largeToolContent = "X".repeat(100_000) // 100KB of data + + private val slowToolName = "slow-tool" + private val slowToolDescription = "A tool that takes time to respond" + + private val specialCharsToolName = "special-chars-tool" + private val specialCharsToolDescription = "A tool that handles special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + setupEchoTool() + setupCalculatorTool() + setupErrorHandlingTool() + setupMultiContentTool() + } + + private fun setupEchoTool() { + server.addTool( + name = testToolName, + description = testToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.params.arguments?.get("text") as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + } + + private fun setupCalculatorTool() { + server.addTool( + name = basicToolName, + description = basicToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.params.arguments?.get("text") as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + + server.addTool( + name = specialCharsToolName, + description = specialCharsToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "special", + buildJsonObject { + put("type", "string") + put("description", "Special characters to process") + }, + ) + }, + ), + ) { request -> + val special = (request.params.arguments?.get("special") as? JsonPrimitive)?.content ?: specialCharsContent + + CallToolResult( + content = listOf(TextContent(text = "Received special characters: $special")), + structuredContent = buildJsonObject { + put("special", special) + put("length", special.length) + }, + ) + } + + server.addTool( + name = slowToolName, + description = slowToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "delay", + buildJsonObject { + put("type", "integer") + put("description", "Delay in milliseconds") + }, + ) + }, + ), + ) { request -> + val delay = (request.params.arguments?.get("delay") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 + + // simulate slow operation + runBlocking { + delay(delay.toLong()) + } + + CallToolResult( + content = listOf(TextContent(text = "Completed after ${delay}ms delay")), + structuredContent = buildJsonObject { + put("delay", delay) + }, + ) + } + + server.addTool( + name = largeToolName, + description = largeToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "size", + buildJsonObject { + put("type", "integer") + put("description", "Size multiplier") + }, + ) + }, + ), + ) { request -> + val size = (request.params.arguments?.get("size") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1 + val content = largeToolContent.take(largeToolContent.length.coerceAtMost(size * 1000)) + + CallToolResult( + content = listOf(TextContent(text = content)), + structuredContent = buildJsonObject { + put("size", content.length) + }, + ) + } + + server.addTool( + name = complexToolName, + description = complexToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "operation", + buildJsonObject { + put("type", "string") + put("description", "The operation to perform (add, subtract, multiply, divide)") + put( + "enum", + buildJsonArray { + add("add") + add("subtract") + add("multiply") + add("divide") + }, + ) + }, + ) + put( + "a", + buildJsonObject { + put("type", "number") + put("description", "First operand") + }, + ) + put( + "b", + buildJsonObject { + put("type", "number") + put("description", "Second operand") + }, + ) + put( + "precision", + buildJsonObject { + put("type", "integer") + put("description", "Number of decimal places (optional)") + put("default", 2) + }, + ) + put( + "showSteps", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to show calculation steps") + put("default", false) + }, + ) + put( + "tags", + buildJsonObject { + put("type", "array") + put("description", "Optional tags for the calculation") + put( + "items", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + required = listOf("operation", "a", "b"), + ), + ) { request -> + val operation = (request.params.arguments?.get("operation") as? JsonPrimitive)?.content ?: "add" + val a = (request.params.arguments?.get("a") as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val b = (request.params.arguments?.get("b") as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val precision = (request.params.arguments?.get("precision") as? JsonPrimitive)?.content?.toIntOrNull() ?: 2 + val showSteps = + (request.params.arguments?.get("showSteps") as? JsonPrimitive)?.content?.toBoolean() ?: false + val tags = (request.params.arguments?.get("tags") as? JsonArray)?.mapNotNull { + (it as? JsonPrimitive)?.content + } ?: emptyList() + + val result = when (operation) { + "add" -> a + b + "subtract" -> a - b + "multiply" -> a * b + "divide" -> if (b != 0.0) a / b else Double.POSITIVE_INFINITY + else -> 0.0 + } + + val pattern = if (precision > 0) "0." + "0".repeat(precision) else "0" + val symbols = DecimalFormatSymbols(Locale.US).apply { decimalSeparator = '.' } + val df = DecimalFormat(pattern, symbols).apply { isGroupingUsed = false } + val formattedResult = df.format(result) + + val textContent = if (showSteps) { + "Operation: $operation\nA: $a\nB: $b\nResult: $formattedResult\nTags: ${ + tags.joinToString(", ") + }" + } else { + "Result: $formattedResult" + } + + CallToolResult( + content = listOf(TextContent(text = textContent)), + structuredContent = buildJsonObject { + put("operation", operation) + put("a", a) + put("b", b) + put("result", result) + put("formattedResult", formattedResult) + put("precision", precision) + put("tags", buildJsonArray { tags.forEach { add(it) } }) + }, + ) + } + } + + private fun setupErrorHandlingTool() { + server.addTool( + name = errorToolName, + description = errorToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "errorType", + buildJsonObject { + put("type", "string") + put("description", "Type of error to simulate (none, exception, error)") + put( + "enum", + buildJsonArray { + add("none") + add("exception") + add("error") + }, + ) + }, + ) + put( + "message", + buildJsonObject { + put("type", "string") + put("description", "Custom error message") + put("default", "An error occurred") + }, + ) + }, + required = listOf("errorType"), + ), + ) { request -> + val errorType = (request.params.arguments?.get("errorType") as? JsonPrimitive)?.content ?: "none" + val message = (request.params.arguments?.get("message") as? JsonPrimitive)?.content ?: "An error occurred" + + when (errorType) { + "exception" -> throw IllegalArgumentException(message) + + "error" -> CallToolResult( + content = listOf(TextContent(text = "Error: $message")), + structuredContent = buildJsonObject { + put("error", true) + put("message", message) + }, + ) + + else -> CallToolResult( + content = listOf(TextContent(text = "No error occurred")), + structuredContent = buildJsonObject { + put("error", false) + put("message", "Success") + }, + ) + } + } + } + + private fun setupMultiContentTool() { + server.addTool( + name = multiContentToolName, + description = multiContentToolDescription, + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Text to include in the response") + }, + ) + put( + "includeImage", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to include an image in the response") + put("default", true) + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.params.arguments?.get("text") as? JsonPrimitive)?.content ?: "Default text" + val includeImage = + (request.params.arguments?.get("includeImage") as? JsonPrimitive)?.content?.toBoolean() ?: true + + val content = mutableListOf( + TextContent(text = "Text content: $text"), + ) + + if (includeImage) { + content.add( + ImageContent( + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==", + mimeType = "image/png", + ), + ) + } + + CallToolResult( + content = content, + structuredContent = buildJsonObject { + put("text", text) + put("includeImage", includeImage) + }, + ) + } + } + + @Test + fun testListTools(): Unit = runBlocking(Dispatchers.IO) { + val result = client.listTools() + + assertNotNull(result, "List utils result should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + val testTool = result.tools.find { it.name == testToolName } + + assertNotNull(testTool, "Test tool should be in the list") + assertEquals( + testToolDescription, + testTool.description, + "Tool description should match", + ) + } + + @Test + fun testCallTool(): Unit = runBlocking(Dispatchers.IO) { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) + + val result = client.callTool(testToolName, arguments) + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + {"result":"Hello, world!"} + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testComplexInputSchemaTool(): Unit = runBlocking(Dispatchers.IO) { + val toolsList = client.listTools() + assertNotNull(toolsList, "Tools list should not be null") + val calculatorTool = toolsList.tools.find { it.name == complexToolName } + assertNotNull(calculatorTool, "Calculator tool should be in the list") + + val arguments = mapOf( + "operation" to "multiply", + "a" to 5.5, + "b" to 2.0, + "precision" to 3, + "showSteps" to true, + "tags" to listOf("test", "calculator", "integration"), + ) + + val result = client.callTool(complexToolName, arguments) + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "operation" : "multiply", + "a" : 5.5, + "b" : 2.0, + "result" : 11.0, + "formattedResult" : "11.000", + "precision" : 3, + "tags" : ["test", "calculator", "integration"] + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testToolErrorHandling(): Unit = runBlocking(Dispatchers.IO) { + val successArgs = mapOf("errorType" to "none") + val successResult = client.callTool(errorToolName, successArgs) + + val actualContent = successResult.structuredContent.toString() + val expectedContent = """ + { + "error" : false, + "message" : "Success" + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + + val errorArgs = mapOf( + "errorType" to "error", + "message" to "Custom error message", + ) + val errorResult = client.callTool(errorToolName, errorArgs) + + val actualError = errorResult.structuredContent.toString() + val expectedError = """ + { + "error" : true, + "message" : "Custom error message" + } + """.trimIndent() + + actualError shouldEqualJson expectedError + + val exceptionArgs = mapOf( + "errorType" to "exception", + "message" to "My exception message", + ) + + val exceptionResult = client.callTool(errorToolName, exceptionArgs) + + assertTrue(exceptionResult.isError ?: false, "isError should be true for exception") + + val exceptionContent = exceptionResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(exceptionContent, "Error content should be present in the result") + + val exceptionText = exceptionContent.text + assertTrue( + exceptionText.contains("Error executing tool") && exceptionText.contains("My exception message"), + "Error message should contain the exception details", + ) + } + + @Test + fun testMultiContentTool(): Unit = runBlocking(Dispatchers.IO) { + val testText = "Test multi-content" + val arguments = mapOf( + "text" to testText, + "includeImage" to true, + ) + + val result = client.callTool(multiContentToolName, arguments) + + assertEquals( + 2, + result.content.size, + "Tool result should have 2 content items", + ) + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Result should contain TextContent") + assertNotNull(textContent.text, "Text content should not be null") + assertEquals( + "Text content: $testText", + textContent.text, + "Text content should match", + ) + + val imageContent = result.content.firstOrNull { it is ImageContent } as? ImageContent + assertNotNull(imageContent, "Result should contain ImageContent") + assertEquals("image/png", imageContent.mimeType, "Image MIME type should match") + assertTrue(imageContent.data.isNotEmpty(), "Image data should not be empty") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "text" : "Test multi-content", + "includeImage" : true + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + + val textOnlyArgs = mapOf( + "text" to testText, + "includeImage" to false, + ) + + val textOnlyResult = client.callTool(multiContentToolName, textOnlyArgs) + + assertEquals( + 1, + textOnlyResult.content.size, + "Text-only result should have 1 content item", + ) + } + + @Test + fun testComplexNestedSchema(): Unit = runBlocking(Dispatchers.IO) { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Galt")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } + + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } + + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } + + val result = client.callTool( + CallToolRequest( + CallToolRequestParams( + name = complexToolName, + arguments = arguments, + ), + ), + ) + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "operation": "add", + "a": 0.0, + "b": 0.0, + "result": 0.0, + "formattedResult": "0.00", + "precision": 2, + "tags": [] + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testLargeResponse(): Unit = runBlocking(Dispatchers.IO) { + val size = 10 + val arguments = mapOf("size" to size) + + val result = client.callTool(largeToolName, arguments) + + val content = result.content.firstOrNull() as TextContent + assertNotNull(content, "Tool result content should be TextContent") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "size" : 10000 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testSlowTool(): Unit = runBlocking(Dispatchers.IO) { + val delay = 500 + val arguments = mapOf("delay" to delay) + + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) + val endTime = System.currentTimeMillis() + + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "delay" : 500 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testSpecialCharacters() { + runBlocking(Dispatchers.IO) { + val arguments = mapOf("special" to specialCharsContent) + + val result = client.callTool(specialCharsToolName, arguments) + + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text + + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "special" : "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t", + "length" : 34 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + } + + @Test + fun testConcurrentToolCalls() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val toolName = when (index % 5) { + 0 -> basicToolName + 1 -> complexToolName + 2 -> largeToolName + 3 -> slowToolName + else -> specialCharsToolName + } + + val arguments = when (toolName) { + basicToolName -> mapOf("text" to "Concurrent call $index") + + complexToolName -> mapOf( + "user" to mapOf( + "name" to "User $index", + "age" to 20 + index, + "address" to mapOf( + "street" to "Street $index", + "city" to "City $index", + "country" to "Country $index", + ), + ), + ) + + largeToolName -> mapOf("size" to 1) + + slowToolName -> mapOf("delay" to 100) + + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.callTool(toolName, arguments) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.content.isNotEmpty(), "Result content should not be empty") + } + } + + @Test + fun testNonExistentTool() = runTest { + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("text" to "Test") + + val result = runBlocking { + client.callTool(nonExistentToolName, arguments) + } + + assertNotNull(result, "Tool call result should not be null") + assertTrue(result.isError ?: false, "isError should be true for non-existent tool") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + + val errorText = textContent.text + assertTrue( + errorText.contains("non-existent-tool") && errorText.contains("not found"), + "Error message should indicate the tool was not found", + ) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt new file mode 100644 index 00000000..e409ec6a --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -0,0 +1,185 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.install +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import org.awaitility.kotlin.await +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import java.io.PipedInputStream +import java.io.PipedOutputStream +import kotlin.time.Duration.Companion.seconds +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +@Retry(times = 3) +abstract class KotlinTestBase { + + protected val host = "localhost" + protected var port: Int = 0 + + protected lateinit var server: Server + protected lateinit var client: Client + protected lateinit var serverEngine: EmbeddedServer<*, *> + + // Transport selection + protected enum class TransportKind { SSE, STDIO } + protected open val transportKind: TransportKind = TransportKind.STDIO + + // STDIO-specific fields + private var stdioServerTransport: StdioServerTransport? = null + private var stdioClientInput: Source? = null + private var stdioClientOutput: Sink? = null + + protected abstract fun configureServerCapabilities(): ServerCapabilities + protected abstract fun configureServer() + + @BeforeEach + fun setUp() { + setupServer() + if (transportKind == TransportKind.SSE) { + await + .ignoreExceptions() + .until { + port = runBlocking { serverEngine.engine.resolvedConnectors().first().port } + port != 0 + } + } + runBlocking { + setupClient() + } + } + + protected suspend fun setupClient() { + when (transportKind) { + TransportKind.SSE -> { + val transport = SseClientTransport( + HttpClient(CIO) { + install(SSE) + }, + "http://$host:$port", + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + + TransportKind.STDIO -> { + val input = checkNotNull(stdioClientInput) { "STDIO client input not initialized" } + val output = checkNotNull(stdioClientOutput) { "STDIO client output not initialized" } + val transport = StdioClientTransport( + input = input, + output = output, + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + } + } + + protected fun setupServer() { + val capabilities = configureServerCapabilities() + + server = Server( + Implementation(name = "test-server", version = "1.0"), + ServerOptions(capabilities = capabilities), + ) + + configureServer() + + if (transportKind == TransportKind.SSE) { + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + install(ServerSSE) + routing { + mcp { server } + } + }.start(wait = false) + } else { + // Create in-memory stdio pipes: client->server and server->client + val clientToServerOut = PipedOutputStream() + val clientToServerIn = PipedInputStream(clientToServerOut) + + val serverToClientOut = PipedOutputStream() + val serverToClientIn = PipedInputStream(serverToClientOut) + + // Server transport reads from client and writes to client + val serverTransport = StdioServerTransport( + inputStream = clientToServerIn.asSource().buffered(), + outputStream = serverToClientOut.asSink().buffered(), + ) + stdioServerTransport = serverTransport + + // Prepare client-side streams for later client initialization + stdioClientInput = serverToClientIn.asSource().buffered() + stdioClientOutput = clientToServerOut.asSink().buffered() + + // Start server transport by connecting the server + runBlocking { + server.createSession(serverTransport) + } + } + } + + @AfterEach + fun tearDown() { + // close client + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + // stop server + if (transportKind == TransportKind.SSE) { + if (::serverEngine.isInitialized) { + try { + serverEngine.stop(500, 1000) + } catch (e: Exception) { + println("Warning: Error during server stop: ${e.message}") + } + } + } else { + stdioServerTransport?.let { + try { + runBlocking { it.close() } + } catch (e: Exception) { + println("Warning: Error during stdio server stop: ${e.message}") + } finally { + stdioServerTransport = null + stdioClientInput = null + stdioClientOutput = null + } + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/OldSchemaResourceIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/OldSchemaResourceIntegrationTestSse similarity index 100% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/OldSchemaResourceIntegrationTestSse.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/OldSchemaResourceIntegrationTestSse diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt new file mode 100644 index 00000000..10cde277 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractPromptIntegrationTest + +class SchemaPromptIntegrationTestSse : AbstractPromptIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt new file mode 100644 index 00000000..bf1240df --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractResourceIntegrationTest + +class ResourceIntegrationTestSse : AbstractResourceIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt new file mode 100644 index 00000000..dd007c6e --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +class ToolIntegrationTestSse : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt new file mode 100644 index 00000000..88be1e80 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractPromptIntegrationTest + +class PromptIntegrationTestStdio : AbstractPromptIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt new file mode 100644 index 00000000..88eca7b0 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractResourceIntegrationTest + +class ResourceIntegrationTestStdio : AbstractResourceIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt new file mode 100644 index 00000000..673d44bb --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +class ToolIntegrationTestStdio : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/websocket/WebSocketIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/websocket/WebSocketIntegrationTest.kt new file mode 100644 index 00000000..ab0d8480 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/websocket/WebSocketIntegrationTest.kt @@ -0,0 +1,199 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.websocket + +import io.ktor.client.HttpClient +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpWebSocketTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcpWebSocket +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.client.plugins.websocket.WebSockets as ClientWebSocket +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.websocket.WebSockets as ServerWebSockets + +class WebSocketIntegrationTest { + + @Test + fun `client should be able to connect to websocket server 2`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient(serverPort = port) + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open WebSocket from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single websocket connection`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient("Client A", port) + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open WebSocket connection #1 from Client A and note the sessionId= value. + * 2. Open WebSocket connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple websocket connections`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + clientA = initClient("Client A", port) + clientB = initClient("Client B", port) + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + private suspend fun initClient(name: String = "", serverPort: Int): Client { + val client = Client( + Implementation(name = name, version = "1.0.0"), + ) + + val httpClient = HttpClient(ClientCIO) { + install(ClientWebSocket) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpWebSocketTransport { + url { + host = URL + port = serverPort + } + } + + client.connect(transport) + + return client + } + + private fun initServer(): EmbeddedServer { + val server = Server( + Implementation(name = "websocket-server", version = "1.0.0"), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), + ) + + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent("Prompt for client ${request.params.arguments?.get("client")}"), + ), + ), + description = "Prompt for ${request.params.name}", + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerWebSockets) + routing { + mcpWebSocket(block = { server }) + } + } + + return ktorServer.start(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = "prompt", + arguments = mapOf("client" to clientName), + ), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + private const val PORT = 0 + private const val URL = "127.0.0.1" + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt new file mode 100644 index 00000000..8f6aed63 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt @@ -0,0 +1,105 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +open class AbstractSseIntegrationTest { + + suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().single().port + + suspend fun initTestClient(serverPort: Int, name: String? = null): Client { + val client = Client( + Implementation(name = name ?: DEFAULT_CLIENT_NAME, version = VERSION), + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = serverPort + } + } + + client.connect(transport) + + return client + } + + suspend fun initTestServer( + name: String? = null, + ): EmbeddedServer { + val server = Server( + Implementation(name = name ?: DEFAULT_SERVER_NAME, version = VERSION), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), + ) { + addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent("Prompt for client ${request.params.arguments?.get("client")}"), + ), + ), + description = "Prompt for ${request.params.name}", + ) + } + } + + val ktorServer = embeddedServer( + ServerCIO, + host = URL, + port = PORT, + ) { + install(ServerSSE) + routing { + mcp { server } + } + } + + return ktorServer.startSuspend(wait = false) + } + + companion object { + private const val DEFAULT_CLIENT_NAME = "sse-test-client" + private const val DEFAULT_SERVER_NAME = "sse-test-server" + private const val VERSION = "1.0.0" + private const val URL = "127.0.0.1" + private const val PORT = 0 + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt new file mode 100644 index 00000000..0826eb21 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt @@ -0,0 +1,114 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class SseIntegrationTest : AbstractSseIntegrationTest() { + + @Test + fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().single().port + client = initTestClient(serverPort = port) + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().single().port + client = initTestClient(port, "Client A") + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().first().port + + clientA = initTestClient(port, "Client A") + clientB = initTestClient(port, "Client B") + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } finally { + clientA?.close() + clientB?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = "prompt", + arguments = mapOf("client" to clientName), + ), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt new file mode 100644 index 00000000..5d045127 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt @@ -0,0 +1,77 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractKotlinClientTsServerTest : TsTestBase() { + protected abstract suspend fun useClient(block: suspend (Client) -> T): T + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun connectsAndPings() = runBlocking(Dispatchers.IO) { + useClient { client -> + assertNotNull(client, "Client should be initialized") + val ping = client.ping() + assertNotNull(ping, "Ping result should not be null") + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun listsTools() = runBlocking(Dispatchers.IO) { + useClient { client -> + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + val toolNames = result.tools.map { it.name } + assertTrue("greet" in toolNames, "Greet tool should be available") + assertTrue("multi-greet" in toolNames, "Multi-greet tool should be available") + // Some tests also check collect-user-info; keep base minimal and non-breaking + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun callGreet() = runBlocking(Dispatchers.IO) { + useClient { client -> + val testName = "TestUser" + val arguments = mapOf("name" to testName) + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + val callResult = result + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals("Hello, $testName!", textContent.text) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun multipleClients() = runBlocking(Dispatchers.IO) { + useClient { client1 -> + useClient { client2 -> + val tools1 = client1.listTools() + val tools2 = client2.listTools() + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + assertTrue("greet" in toolNames1, "Greet tool should be available to first client") + assertTrue("multi-greet" in toolNames1, "Multi-greet tool should be available to first client") + assertTrue("greet" in toolNames2, "Greet tool should be available to second client") + assertTrue("multi-greet" in toolNames2, "Multi-greet tool should be available to second client") + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaAbstractKotlinClientTsServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaAbstractKotlinClientTsServerTest.kt index 5d0406cb..64e114f3 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaAbstractKotlinClientTsServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaAbstractKotlinClientTsServerTest.kt @@ -11,7 +11,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -abstract class OldSchemaAbstractKotlinClientTsServerTest : TsTestBase() { +abstract class OldSchemaAbstractKotlinClientTsServerTest : OldSchemaTsTestBase() { protected abstract suspend fun useClient(block: suspend (Client) -> T): T @Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaTsTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaTsTestBase.kt new file mode 100644 index 00000000..98d83d3c --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/OldSchemaTsTestBase.kt @@ -0,0 +1,455 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse.KotlinServerForTsClient +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import org.awaitility.kotlin.await +import org.junit.jupiter.api.BeforeAll +import java.io.BufferedReader +import java.io.File +import java.io.InputStreamReader +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.TimeUnit +import kotlin.io.path.createTempDirectory +import kotlin.time.Duration.Companion.seconds + +enum class OldSchemaTransportKind { SSE, STDIO, DEFAULT } + +@Retry(times = 3) +abstract class OldSchemaTsTestBase { + + protected open val transportKind: OldSchemaTransportKind = OldSchemaTransportKind.DEFAULT + + protected val projectRoot: File get() = File(System.getProperty("user.dir")) + protected val tsClientDir: File + get() { + val base = File( + projectRoot, + "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript", + ) + + // Allow override via system property for CI: -Dts.transport=stdio|sse + val fromProp = System.getProperty("ts.transport")?.lowercase() + val overrideSubDir = when (fromProp) { + "stdio" -> "stdio" + "sse" -> "sse" + else -> null + } + + val subDirName = overrideSubDir ?: when (transportKind) { + OldSchemaTransportKind.STDIO -> "stdio" + OldSchemaTransportKind.SSE -> "sse" + OldSchemaTransportKind.DEFAULT -> null + } + if (subDirName != null) { + val sub = File(base, subDirName) + if (sub.exists()) return sub + } + return base + } + + companion object { + @JvmStatic + private val tempRootDir: File = createTempDirectory("typescript-sdk-").toFile().apply { deleteOnExit() } + + @JvmStatic + protected val sdkDir: File = File(tempRootDir, "typescript-sdk") + + @JvmStatic + @BeforeAll + fun setupTypeScriptSdk() { + println("Cloning TypeScript SDK repository") + + if (!sdkDir.exists()) { + val process = ProcessBuilder( + "git", + "clone", + "--depth", + "1", + "https://github.com/modelcontextprotocol/typescript-sdk.git", + sdkDir.absolutePath, + ) + .redirectErrorStream(true) + .start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw RuntimeException("Failed to clone TypeScript SDK repository: exit code $exitCode") + } + } + + println("Installing TypeScript SDK dependencies") + executeCommand("npm install", sdkDir, allowFailure = false, timeoutSeconds = null) + } + + @JvmStatic + protected fun killProcessOnPort(port: Int) { + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val killCommand = if (isWindows) { + "netstat -ano | findstr :$port | for /f \"tokens=5\" %a in ('more')" + + " do taskkill /F /PID %a 2>nul || echo No process found" + } else { + "lsof -ti:$port | xargs kill -9 2>/dev/null || true" + } + executeCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) + } + + @JvmStatic + protected fun findFreePort(): Int { + ServerSocket(0).use { socket -> + return socket.localPort + } + } + + @JvmStatic + protected fun executeCommand( + command: String, + workingDir: File, + allowFailure: Boolean = false, + timeoutSeconds: Long? = null, + ): String { + if (!workingDir.exists()) { + if (!workingDir.mkdirs()) { + throw RuntimeException("Failed to create working directory: ${workingDir.absolutePath}") + } + } + + if (!workingDir.isDirectory || !workingDir.canRead()) { + throw RuntimeException("Working directory is not accessible: ${workingDir.absolutePath}") + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && $command") + } else { + ProcessBuilder() + .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + } + + val process = processBuilder + .directory(workingDir) + .redirectErrorStream(true) + .start() + + val output = StringBuilder() + BufferedReader(InputStreamReader(process.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + println(line) + output.append(line).append("\n") + } + } + + if (timeoutSeconds == null) { + val exitCode = process.waitFor() + if (!allowFailure && exitCode != 0) { + throw RuntimeException( + "Command execution failed with exit code $exitCode: $command\n" + + "Working dir: ${workingDir.absolutePath}\nOutput:\n$output", + ) + } + } else { + process.waitFor(timeoutSeconds, TimeUnit.SECONDS) + } + + return output.toString() + } + } + + private fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { + if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { + process.destroyForcibly() + process.waitFor(2, TimeUnit.SECONDS) + return false + } + return true + } + + private fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val outputReader = Thread { + try { + process.inputStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process output: ${e.message}") + } + } + outputReader.isDaemon = true + return outputReader + } + + private fun createProcessErrorReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val errorReader = Thread { + try { + process.errorStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix][err] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process error stream: ${e.message}") + } + } + errorReader.isDaemon = true + return errorReader + } + + protected fun waitForPort(host: String = "localhost", port: Int, timeoutSeconds: Long = 10): Boolean = try { + await.atMost(timeoutSeconds, TimeUnit.SECONDS) + .pollDelay(200, TimeUnit.MILLISECONDS) + .pollInterval(100, TimeUnit.MILLISECONDS) + .until { + try { + Socket(host, port).use { true } + } catch (_: Exception) { + false + } + } + true + } catch (_: Exception) { + false + } + + protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = + executeCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) + + protected fun startTypeScriptServer(port: Int): Process { + killProcessOnPort(port) + + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val localServerPath = File(tsClientDir, "simpleStreamableHttp.ts").absolutePath + val processBuilder = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", + ) + } else { + ProcessBuilder() + .command( + "bash", + "-c", + "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", + ) + } + + processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath + + val process = processBuilder + .directory(tsClientDir) + .redirectErrorStream(true) + .start() + + createProcessOutputReader(process).start() + + if (!waitForPort(port = port, timeoutSeconds = 20)) { + throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") + } + return process + } + + protected fun stopProcess(process: Process, waitSeconds: Long = 3, name: String = "TypeScript server") { + process.destroy() + if (waitForProcessTermination(process, waitSeconds)) { + println("$name stopped gracefully") + } else { + println("$name did not stop gracefully, forced termination") + } + } + + // ===== SSE client helpers ===== + protected suspend fun newClient(serverUrl: String): Client = + HttpClient(CIO) { install(SSE) }.mcpStreamableHttp(serverUrl) + + protected suspend fun withClient(serverUrl: String, block: suspend (Client) -> T): T { + val client = newClient(serverUrl) + return try { + withTimeout(20.seconds) { block(client) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) { + // ignore errors + } + } + } + + // ===== STDIO client + server helpers ===== + protected fun startTypeScriptServerStdio(): Process { + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) + } + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val localServerPath = File(tsClientDir, "simpleStdio.ts").absolutePath + val processBuilder = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + "set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", + ) + } else { + ProcessBuilder() + .command( + "bash", + "-c", + "NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", + ) + } + processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath + val process = processBuilder + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + // For stdio transports, do NOT read from stdout (it's used for protocol). Read stderr for logs only. + createProcessErrorReader(process, prefix = "TS-SERVER-STDIO").start() + // Give the process a moment to start + await.atMost(2, TimeUnit.SECONDS) + .pollDelay(200, TimeUnit.MILLISECONDS) + .pollInterval(100, TimeUnit.MILLISECONDS) + .until { process.isAlive } + return process + } + + protected suspend fun newClientStdio(process: Process): Client { + val input: Source = process.inputStream.asSource().buffered() + val output: Sink = process.outputStream.asSink().buffered() + val transport = StdioClientTransport(input = input, output = output) + val client = Client(Implementation("test", "1.0")) + client.connect(transport) + return client + } + + protected suspend fun withClientStdio(block: suspend (Client, Process) -> T): T { + val proc = startTypeScriptServerStdio() + val client = newClientStdio(proc) + return try { + withTimeout(20.seconds) { block(client, proc) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) { + } + try { + stopProcess(proc, name = "TypeScript stdio server") + } catch (_: Exception) { + } + } + } + + // ===== Helpers to run TypeScript client over STDIO against Kotlin server over STDIO ===== + protected fun runStdioClient(vararg args: String): String { + // Start Node stdio client (it will speak MCP over its stdout/stdin) + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val clientPath = File(tsClientDir, "myClient.ts").absolutePath + + val process = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + ( + "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && " + + "set NODE_PATH=${sdkDir.absolutePath}\\node_modules && " + + "npx --prefix \"${sdkDir.absolutePath}\" tsx \"$clientPath\" " + + args.joinToString(" ") + ), + ) + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + } else { + ProcessBuilder() + .command( + "bash", + "-c", + ( + "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' " + + "NODE_PATH='${sdkDir.absolutePath}/node_modules' " + + "npx --prefix '${sdkDir.absolutePath}' tsx \"$clientPath\" " + + args.joinToString(" ") + ), + ) + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + } + + // Create Kotlin server and attach stdio transport to the process streams + val server: Server = KotlinServerForTsClient().createMcpServer() + val transport = StdioServerTransport( + inputStream = process.inputStream.asSource().buffered(), + outputStream = process.outputStream.asSink().buffered(), + ) + + // Connect server in a background thread to avoid blocking + val serverThread = Thread { + try { + runBlocking { server.createSession(transport) } + } catch (e: Exception) { + println("[STDIO-SERVER] Error connecting: ${e.message}") + } + } + serverThread.isDaemon = true + serverThread.start() + + // Read ONLY stderr from client for human-readable output + val output = StringBuilder() + val errReader = Thread { + try { + process.errorStream.bufferedReader().useLines { lines -> + lines.forEach { line -> + println("[TS-CLIENT-STDIO][err] $line") + output.append(line).append('\n') + } + } + } catch (e: Exception) { + println("Warning: Error reading stdio client stderr: ${e.message}") + } + } + errReader.isDaemon = true + errReader.start() + + // Wait up to 25s for client to exit + val finished = process.waitFor(25, TimeUnit.SECONDS) + if (!finished) { + println("Stdio client did not finish in time; destroying") + process.destroyForcibly() + } + + try { + runBlocking { transport.close() } + } catch (_: Exception) { + } + + return output.toString() + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt index 21593fe2..22af65a0 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt @@ -3,7 +3,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.sse.SSE -import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp @@ -11,6 +10,7 @@ import io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse.KotlinServe import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +import io.modelcontextprotocol.kotlin.sdk.types.Implementation import kotlinx.coroutines.withTimeout import kotlinx.io.Sink import kotlinx.io.Source diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt new file mode 100644 index 00000000..54c6e0f2 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt @@ -0,0 +1,215 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTsServerEdgeCasesTestSse : TsTestBase() { + + override val transportKind = TransportKind.SSE + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + + private lateinit var client: Client + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDown() { + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentTool(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val result = client.callTool(nonExistentToolName, arguments) + assertNotNull(result, "Tool call result should not be null") + + assertTrue(result.isError ?: false, "isError should be true for non-existent tool") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + + val errorText = textContent.text + assertTrue( + errorText.contains("non-existent-tool") && errorText.contains("not found"), + "Error message should indicate the tool was not found", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersInArguments(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayload(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testConcurrentRequests(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + val concurrentCount = 5 + val responses = coroutineScope { + val results = (1..concurrentCount).map { i -> + async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text + } + } + results.awaitAll() + } + + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testInvalidArguments(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + val result = client.callTool("greet", invalidArguments) + assertNotNull(result, "Tool call result should not be null") + + assertTrue(result.isError ?: false, "isError should be true for invalid arguments") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + + val errorText = textContent.text + assertTrue( + errorText.contains("Invalid arguments") && errorText.contains("greet"), + "Error message should indicate invalid arguments for tool greet", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleToolCalls(): Unit = runBlocking(Dispatchers.IO) { + withClient(serverUrl) { client -> + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt index 51fd2bf8..95a80f0b 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt @@ -1,14 +1,14 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaAbstractKotlinClientTsServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractKotlinClientTsServerTest import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind import kotlinx.coroutines.withTimeout import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import kotlin.time.Duration.Companion.seconds -class KotlinClientTsServerTestSse : OldSchemaAbstractKotlinClientTsServerTest() { +class KotlinClientTsServerTestSse : AbstractKotlinClientTsServerTest() { override val transportKind = TransportKind.SSE diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt new file mode 100644 index 00000000..3aa36bc7 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt @@ -0,0 +1,464 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.cio.CIO +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.request.header +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.response.respondTextWriter +import io.ktor.server.routing.delete +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.jsonPrimitive +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +private val logger = KotlinLogging.logger {} + +class KotlinServerForTsClient { + private val serverTransports = ConcurrentHashMap() + private val jsonFormat = Json { ignoreUnknownKeys = true } + private var server: EmbeddedServer<*, *>? = null + + fun start(port: Int = 3000) { + logger.info { "Starting HTTP server on port $port" } + + server = embeddedServer(CIO, port = port) { + routing { + get("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId == null) { + call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header") + return@get + } + val transport = serverTransports[sessionId] + if (transport == null) { + call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id") + return@get + } + transport.stream(call) + } + + post("/mcp") { + val sessionId = call.request.header("mcp-session-id") + val requestBody = call.receiveText() + + logger.debug { "Received request with sessionId: $sessionId" } + logger.trace { "Request body: $requestBody" } + + val jsonElement = try { + jsonFormat.parseToJsonElement(requestBody) + } catch (e: Exception) { + logger.error(e) { "Failed to parse request body as JSON" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32700), + "message" to JsonPrimitive("Parse error: ${e.message}"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + return@post + } + + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.debug { "Using existing transport for session: $sessionId" } + val transport = serverTransports[sessionId]!! + transport.handleRequest(call, jsonElement) + } else { + if (isInitializeRequest(jsonElement)) { + val newSessionId = UUID.randomUUID().toString() + logger.info { "Creating new session with ID: $newSessionId" } + + val transport = HttpServerTransport(newSessionId) + + serverTransports[newSessionId] = transport + + val mcpServer = createMcpServer() + + call.response.header("mcp-session-id", newSessionId) + + val serverThread = Thread { + runBlocking { + mcpServer.createSession(transport) + } + } + serverThread.start() + + Thread.sleep(500) + + transport.handleRequest(call, jsonElement) + } else { + logger.warn { "Invalid request: no session ID or not an initialization request" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32000), + "message" to + JsonPrimitive("Bad Request: No valid session ID provided"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + } + } + } + + delete("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.info { "Terminating session: $sessionId" } + val transport = serverTransports[sessionId]!! + serverTransports.remove(sessionId) + runBlocking { + transport.close() + } + call.respond(HttpStatusCode.OK) + } else { + logger.warn { "Invalid session termination request: $sessionId" } + call.respond(HttpStatusCode.BadRequest, "Invalid or missing session ID") + } + } + } + } + + server?.start(wait = false) + } + + fun stop() { + logger.info { "Stopping HTTP server" } + server?.stop(500, 1000) + server = null + } + + fun createMcpServer(): Server { + val server = Server( + Implementation( + name = "kotlin-http-server", + version = "1.0.0", + ), + ServerOptions( + capabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + tools = ServerCapabilities.Tools(listChanged = true), + ), + ), + ) + + server.addTool( + name = "greet", + description = "A simple greeting tool", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.params.arguments?.get("name") as? JsonPrimitive)?.content ?: "World" + CallToolResult( + content = listOf(TextContent("Hello, $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Hello, $name!")) + }, + ) + } + + server.addTool( + name = "multi-greet", + description = "A greeting tool that sends multiple notifications", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.params.arguments?.get("name") as? JsonPrimitive)?.content ?: "World" + + CallToolResult( + content = listOf(TextContent("Multiple greetings sent to $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Multiple greetings sent to $name!")) + put("notificationCount", JsonPrimitive(3)) + }, + ) + } + + server.addPrompt( + name = "greeting-template", + description = "A simple greeting prompt template", + arguments = listOf( + PromptArgument( + name = "name", + description = "Name to include in greeting", + required = true, + ), + ), + ) { request -> + GetPromptResult( + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent( + "Please greet ${request.params.arguments?.get("name") ?: "someone"} in a friendly manner.", + ), + ), + ), + description = "Greeting for ${request.params.name}", + ) + } + + server.addResource( + uri = "https://example.com/greetings/default", + name = "Default Greeting", + description = "A simple greeting resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents("Hello, world!", request.params.uri, "text/plain"), + ), + ) + } + + return server + } + + private fun isInitializeRequest(json: JsonElement): Boolean { + if (json !is JsonObject) return false + + val method = json["method"]?.jsonPrimitive?.contentOrNull + return method == "initialize" + } +} + +class HttpServerTransport(private val sessionId: String) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val pendingResponses = ConcurrentHashMap>() + private val messageQueue = Channel(Channel.UNLIMITED) + + suspend fun stream(call: ApplicationCall) { + logger.debug { "Starting SSE stream for session: $sessionId" } + call.response.header("Cache-Control", "no-cache") + call.response.header("Connection", "keep-alive") + call.respondTextWriter(ContentType.Text.EventStream) { + try { + while (true) { + val result = messageQueue.receiveCatching() + val msg = result.getOrNull() ?: break + val json = McpJson.encodeToString(msg) + write("event: message\n") + write("data: ") + write(json) + write("\n\n") + flush() + } + } catch (e: Exception) { + logger.warn(e) { "SSE stream terminated for session: $sessionId" } + } finally { + logger.debug { "SSE stream closed for session: $sessionId" } + } + } + } + + suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { + try { + logger.info { "Handling request body: $requestBody" } + val message = McpJson.decodeFromJsonElement(requestBody) + logger.info { "Decoded message: $message" } + + if (message is JSONRPCRequest) { + val id = message.id.toString() + logger.info { "Received request with ID: $id, method: ${message.method}" } + val responseDeferred = CompletableDeferred() + pendingResponses[id] = responseDeferred + logger.info { "Created deferred response for ID: $id" } + + logger.info { "Invoking onMessage handler" } + _onMessage.invoke(message) + logger.info { "onMessage handler completed" } + + try { + val response = withTimeoutOrNull(10000) { + responseDeferred.await() + } + + if (response != null) { + val jsonResponse = McpJson.encodeToString(response) + call.respondText(jsonResponse, ContentType.Application.Json) + } else { + logger.warn { "Timeout waiting for response to request ID: $id" } + call.respondText( + McpJson.encodeToString( + JSONRPCError( + id = message.id, + error = RPCError( + code = RPCError.ErrorCode.REQUEST_TIMEOUT, + message = "Request timed out", + ), + ), + ), + ContentType.Application.Json, + ) + } + } catch (_: CancellationException) { + logger.warn { "Request cancelled for ID: $id" } + pendingResponses.remove(id) + if (!call.response.isCommitted) { + call.respondText( + McpJson.encodeToString( + JSONRPCError( + id = message.id, + error = RPCError( + code = RPCError.ErrorCode.CONNECTION_CLOSED, + message = "Request cancelled", + ), + ), + ), + ContentType.Application.Json, + HttpStatusCode.ServiceUnavailable, + ) + } + } + } else { + call.respondText("", ContentType.Application.Json, HttpStatusCode.Accepted) + } + } catch (e: Exception) { + logger.error(e) { "Error handling request" } + if (!call.response.isCommitted) { + try { + val errorResponse = JSONRPCError( + id = RequestId(0), + error = RPCError( + code = RPCError.ErrorCode.INTERNAL_ERROR, + message = "Internal server error: ${e.message}", + ), + ) + + call.respondText( + McpJson.encodeToString(errorResponse), + ContentType.Application.Json, + HttpStatusCode.InternalServerError, + ) + } catch (responseEx: Exception) { + logger.error(responseEx) { "Failed to send error response" } + } + } + } + } + + override suspend fun start() { + logger.debug { "Starting HTTP server transport for session: $sessionId" } + } + + override suspend fun send(message: JSONRPCMessage) { + logger.info { "Sending message: $message" } + + if (message is JSONRPCResponse) { + val id = message.id.toString() + logger.info { "Sending response for request ID: $id" } + val deferred = pendingResponses.remove(id) + if (deferred != null) { + logger.info { "Found pending response for ID: $id, completing deferred" } + deferred.complete(message) + return + } else { + logger.warn { "No pending response found for ID: $id" } + } + } else if (message is JSONRPCRequest) { + logger.info { "Sending request with ID: ${message.id}" } + } else if (message is JSONRPCNotification) { + logger.info { "Sending notification: ${message.method}" } + } + + logger.info { "Queueing message for next client request" } + messageQueue.send(message) + } + + override suspend fun close() { + logger.debug { "Closing HTTP server transport for session: $sessionId" } + messageQueue.close() + _onClose.invoke() + } +} + +fun main() { + val server = KotlinServerForTsClient() + server.start() +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerEdgeCasesTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerEdgeCasesTestSse.kt index 44ff5a47..5af315b1 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerEdgeCasesTestSse.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerEdgeCasesTestSse.kt @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTsTestBase import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -22,9 +22,9 @@ import kotlin.test.assertNotNull import kotlin.test.assertTrue import kotlin.time.Duration.Companion.seconds -class OldSchemaKotlinClientTsServerEdgeCasesTestSse : TsTestBase() { +class OldSchemaKotlinClientTsServerEdgeCasesTestSse : OldSchemaTsTestBase() { - override val transportKind = TransportKind.SSE + override val transportKind = OldSchemaTransportKind.SSE private var port: Int = 0 private val host = "localhost" diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerTestSse.kt new file mode 100644 index 00000000..39ac18c9 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinClientTsServerTestSse.kt @@ -0,0 +1,49 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaAbstractKotlinClientTsServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTransportKind +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import kotlin.time.Duration.Companion.seconds + +class OldSchemaKotlinClientTsServerTestSse : OldSchemaAbstractKotlinClientTsServerTest() { + + override val transportKind = OldSchemaTransportKind.SSE + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUpSse() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDownSse() { + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + override suspend fun useClient(block: suspend (Client) -> T): T = withClient(serverUrl) { client -> + try { + withTimeout(20.seconds) { block(client) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) {} + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinServerForTsClientSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinServerForTsClientSse.kt index 5101bd89..d9a52216 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinServerForTsClientSse.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/OldSchemaKotlinServerForTsClientSse.kt @@ -59,8 +59,8 @@ import java.util.concurrent.ConcurrentHashMap private val logger = KotlinLogging.logger {} -class KotlinServerForTsClient { - private val serverTransports = ConcurrentHashMap() +class OldSchemaKotlinServerForTsClient { + private val serverTransports = ConcurrentHashMap() private val jsonFormat = Json { ignoreUnknownKeys = true } private var server: EmbeddedServer<*, *>? = null @@ -124,7 +124,7 @@ class KotlinServerForTsClient { val newSessionId = UUID.randomUUID().toString() logger.info { "Creating new session with ID: $newSessionId" } - val transport = HttpServerTransport(newSessionId) + val transport = OldSchemaHttpServerTransport(newSessionId) serverTransports[newSessionId] = transport @@ -309,7 +309,7 @@ class KotlinServerForTsClient { } } -class HttpServerTransport(private val sessionId: String) : AbstractTransport() { +class OldSchemaHttpServerTransport(private val sessionId: String) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val pendingResponses = ConcurrentHashMap>() private val messageQueue = Channel(Channel.UNLIMITED) @@ -460,6 +460,6 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() { } fun main() { - val server = KotlinServerForTsClient() + val server = OldSchemaKotlinServerForTsClient() server.start() } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt new file mode 100644 index 00000000..fcab03c3 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt @@ -0,0 +1,174 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class KotlinClientTsServerEdgeCasesTestStdio : TsTestBase() { + + override val transportKind = TransportKind.STDIO + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentToolOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val result = client.callTool(nonExistentToolName, arguments) + assertNotNull(result, "Tool call result should not be null") + + assertTrue(result.isError ?: false, "isError should be true for non-existent tool") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + + val errorText = textContent.text + assertTrue( + errorText.contains("non-existent-tool") && errorText.contains("not found"), + "Error message should indicate the tool was not found", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersInArgumentsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>.,?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayloadOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testConcurrentRequestsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val concurrentCount = 5 + val responses = coroutineScope { + val results = (1..concurrentCount).map { i -> + async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val callResult = result + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text + } + } + results.awaitAll() + } + + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testInvalidArgumentsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + val result = client.callTool("greet", invalidArguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result + assertTrue(callResult.isError ?: false, "isError should be true for invalid arguments") + + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + + val errorText = textContent.text + assertTrue( + errorText.contains("Invalid arguments") && errorText.contains("greet"), + "Error message should indicate invalid arguments for tool greet", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleToolCallsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt new file mode 100644 index 00000000..82737249 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt @@ -0,0 +1,23 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractKotlinClientTsServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind + +class KotlinClientTsServerTestStdio : AbstractKotlinClientTsServerTest() { + + override val transportKind = TransportKind.STDIO + + override suspend fun useClient(block: suspend (Client) -> T): T = withClientStdio { client, proc -> + try { + block(client) + } finally { + try { + client.close() + } catch (_: Exception) {} + try { + stopProcess(proc, name = "TypeScript stdio server") + } catch (_: Exception) {} + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerEdgeCasesTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerEdgeCasesTestStdio.kt index 0ed86763..3666622d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerEdgeCasesTestStdio.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerEdgeCasesTestStdio.kt @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTsTestBase import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -18,9 +18,9 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class OldSchemaKotlinClientTsServerEdgeCasesTestStdio : TsTestBase() { +class OldSchemaKotlinClientTsServerEdgeCasesTestStdio : OldSchemaTsTestBase() { - override val transportKind = TransportKind.STDIO + override val transportKind = OldSchemaTransportKind.STDIO @Test @Timeout(30, unit = TimeUnit.SECONDS) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerTestStdio.kt index e87a5c88..b1b5bae3 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerTestStdio.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/OldSchemaKotlinClientTsServerTestStdio.kt @@ -2,11 +2,11 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaAbstractKotlinClientTsServerTest -import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.OldSchemaTransportKind class OldSchemaKotlinClientTsServerTestStdio : OldSchemaAbstractKotlinClientTsServerTest() { - override val transportKind = TransportKind.STDIO + override val transportKind = OldSchemaTransportKind.STDIO override suspend fun useClient(block: suspend (Client) -> T): T = withClientStdio { client, proc -> try { diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractServerFeaturesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractServerFeaturesTest.kt new file mode 100644 index 00000000..21b46d2b --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractServerFeaturesTest.kt @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.BeforeEach + +abstract class AbstractServerFeaturesTest { + + protected lateinit var server: Server + protected lateinit var client: Client + + abstract fun getServerCapabilities(): ServerCapabilities + + protected open fun getServerInstructionsProvider(): (() -> String)? = null + + @BeforeEach + fun setUp() { + val serverOptions = ServerOptions( + capabilities = getServerCapabilities(), + ) + + server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = serverOptions, + instructionsProvider = getServerInstructionsProvider(), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + ) + + runBlocking { + // Connect client and server + launch { client.connect(clientTransport) } + launch { server.createSession(serverTransport) } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaServerInstructionsTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaServerInstructionsTest.kt new file mode 100644 index 00000000..26bb7bfc --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/OldSchemaServerInstructionsTest.kt @@ -0,0 +1,68 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertNull +import kotlin.test.assertEquals + +class OldSchemaServerInstructionsTest { + + @Test + fun `Server constructor should accept instructions provider parameter`() = runTest { + val serverInfo = Implementation(name = "test server", version = "1.0") + val serverOptions = ServerOptions(capabilities = ServerCapabilities()) + val instructions = "This is a test server. Use it for testing purposes only." + + val server = Server(serverInfo, serverOptions, { instructions }) + + // The instructions should be stored internally and used in handleInitialize + // We can't directly access the private field, but we can test it through initialization + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) + + server.createSession(serverTransport) + client.connect(clientTransport) + + assertEquals(instructions, client.serverInstructions) + } + + @Test + fun `Server constructor should accept instructions parameter`() = runTest { + val serverInfo = Implementation(name = "test server", version = "1.0") + val serverOptions = ServerOptions(capabilities = ServerCapabilities()) + val instructions = "This is a test server. Use it for testing purposes only." + + val server = Server(serverInfo, serverOptions, instructions) + + // The instructions should be stored internally and used in handleInitialize + // We can't directly access the private field, but we can test it through initialization + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) + + server.createSession(serverTransport) + client.connect(clientTransport) + + assertEquals(instructions, client.serverInstructions) + } + + @Test + fun `Server constructor should work without instructions parameter`() = runTest { + val serverInfo = Implementation(name = "test server", version = "1.0") + val serverOptions = ServerOptions(capabilities = ServerCapabilities()) + + // Test that server works when instructions parameter is omitted (defaults to null) + val server = Server(serverInfo, serverOptions) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) + + server.createSession(serverTransport) + client.connect(clientTransport) + + assertNull(client.serverInstructions) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt new file mode 100644 index 00000000..60de0cab --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt @@ -0,0 +1,100 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.Prompt +import io.modelcontextprotocol.kotlin.sdk.types.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ServerPromptsTest : AbstractServerFeaturesTest() { + + override fun getServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(false), + ) + + @Test + fun `removePrompt should remove a prompt`() = runTest { + // Add a prompt + val testPrompt = Prompt("test-prompt", "Test Prompt", null) + server.addPrompt(testPrompt) { + GetPromptResult( + description = "Test prompt description", + messages = listOf(), + ) + } + + // Remove the prompt + val result = server.removePrompt(testPrompt.name) + + // Verify the prompt was removed + assertTrue(result, "Prompt should be removed successfully") + } + + @Test + fun `removePrompts should remove multiple prompts and send notification`() = runTest { + // Add prompts + val testPrompt1 = Prompt("test-prompt-1", "Test Prompt 1", null) + val testPrompt2 = Prompt("test-prompt-2", "Test Prompt 2", null) + server.addPrompt(testPrompt1) { + GetPromptResult( + description = "Test prompt description 1", + messages = listOf(), + ) + } + server.addPrompt(testPrompt2) { + GetPromptResult( + description = "Test prompt description 2", + messages = listOf(), + ) + } + + // Remove the prompts + val result = server.removePrompts(listOf(testPrompt1.name, testPrompt2.name)) + + // Verify the prompts were removed + assertEquals(2, result, "Both prompts should be removed") + } + + @Test + fun `removePrompt should return false when prompt does not exist`() = runTest { + // Track notifications + var promptListChangedNotificationReceived = false + client.setNotificationHandler(Method.Defined.NotificationsPromptsListChanged) { + promptListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + + // Try to remove a non-existent prompt + val result = server.removePrompt("non-existent-prompt") + + // Verify the result + assertFalse(result, "Removing non-existent prompt should return false") + assertFalse(promptListChangedNotificationReceived, "No notification should be sent when prompt doesn't exist") + } + + @Test + fun `removePrompt should throw when prompts capability is not supported`() = runTest { + // Create server without prompts capability + val serverOptions = ServerOptions( + capabilities = ServerCapabilities(), + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions, + ) + + // Verify that removing a prompt throws an exception + val exception = assertThrows { + server.removePrompt("test-prompt") + } + assertEquals("Server does not support prompts capability.", exception.message) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerResourcesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerResourcesTest.kt new file mode 100644 index 00000000..ad73c834 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerResourcesTest.kt @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ServerResourcesTest : AbstractServerFeaturesTest() { + + override fun getServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + ) + + @Test + fun `removeResource should remove a resource and send notification`() = runTest { + // Add a resource + val testResourceUri = "test://resource" + server.addResource( + uri = testResourceUri, + name = "Test Resource", + description = "A test resource", + mimeType = "text/plain", + ) { + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Test resource content", + uri = testResourceUri, + mimeType = "text/plain", + ), + ), + ) + } + + // Remove the resource + val result = server.removeResource(testResourceUri) + + // Verify the resource was removed + assertTrue(result, "Resource should be removed successfully") + } + + @Test + fun `removeResources should remove multiple resources and send notification`() = runTest { + // Add resources + val testResourceUri1 = "test://resource1" + val testResourceUri2 = "test://resource2" + server.addResource( + uri = testResourceUri1, + name = "Test Resource 1", + description = "A test resource 1", + mimeType = "text/plain", + ) { + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Test resource content 1", + uri = testResourceUri1, + mimeType = "text/plain", + ), + ), + ) + } + server.addResource( + uri = testResourceUri2, + name = "Test Resource 2", + description = "A test resource 2", + mimeType = "text/plain", + ) { + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Test resource content 2", + uri = testResourceUri2, + mimeType = "text/plain", + ), + ), + ) + } + + // Remove the resources + val result = server.removeResources(listOf(testResourceUri1, testResourceUri2)) + + // Verify the resources were removed + assertEquals(2, result, "Both resources should be removed") + } + + @Test + fun `removeResource should return false when resource does not exist`() = runTest { + // Track notifications + var resourceListChangedNotificationReceived = false + client.setNotificationHandler( + Method.Defined.NotificationsResourcesListChanged, + ) { + resourceListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + + // Try to remove a non-existent resource + val result = server.removeResource("non-existent-resource") + + // Verify the result + assertFalse(result, "Removing non-existent resource should return false") + assertFalse( + resourceListChangedNotificationReceived, + "No notification should be sent when resource doesn't exist", + ) + } + + @Test + fun `removeResource should throw when resources capability is not supported`() = runTest { + // Create server without resources capability + val serverOptions = ServerOptions( + capabilities = ServerCapabilities(), + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions, + ) + + // Verify that removing a resource throws an exception + val exception = assertThrows { + server.removeResource("test://resource") + } + assertEquals("Server does not support resources capability.", exception.message) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerToolsTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerToolsTest.kt new file mode 100644 index 00000000..eae2eb23 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerToolsTest.kt @@ -0,0 +1,89 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ServerToolsTest : AbstractServerFeaturesTest() { + + override fun getServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(null), + ) + + @Test + fun `removeTool should remove a tool`() = runTest { + // Add a tool + server.addTool("test-tool", "Test Tool", ToolSchema()) { + CallToolResult(listOf(TextContent("Test result"))) + } + + // Remove the tool + val result = server.removeTool("test-tool") + + // Verify the tool was removed + assertTrue(result, "Tool should be removed successfully") + } + + @Test + fun `removeTool should return false when tool does not exist`() = runTest { + // Track notifications + var toolListChangedNotificationReceived = false + client.setNotificationHandler(Method.Defined.NotificationsToolsListChanged) { + toolListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + + // Try to remove a non-existent tool + val result = server.removeTool("non-existent-tool") + + // Verify the result + assertFalse(result, "Removing non-existent tool should return false") + assertFalse(toolListChangedNotificationReceived, "No notification should be sent when tool doesn't exist") + } + + @Test + fun `removeTool should throw when tools capability is not supported`() = runTest { + // Create server without tools capability + val serverOptions = ServerOptions( + capabilities = ServerCapabilities(), + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions, + ) + + // Verify that removing a tool throws an exception + val exception = assertThrows { + server.removeTool("test-tool") + } + assertEquals("Server does not support tools capability.", exception.message) + } + + @Test + fun `removeTools should remove multiple tools`() = runTest { + // Add tools + server.addTool("test-tool-1", "Test Tool 1") { + CallToolResult(listOf(TextContent("Test result 1"))) + } + server.addTool("test-tool-2", "Test Tool 2") { + CallToolResult(listOf(TextContent("Test result 2"))) + } + + // Remove the tools + val result = server.removeTools(listOf("test-tool-1", "test-tool-2")) + + // Verify the tools were removed + assertEquals(2, result, "Both tools should be removed") + } +}