From fab3ffac5713563f1e2c8793f13cd9f8128fa10e Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Tue, 22 Jul 2025 17:42:11 +0300 Subject: [PATCH 01/22] Introduce Kotlin<->TypeScript integration tests Signed-off-by: Sergey Karpov --- kotlin-sdk-test/build.gradle.kts | 8 + .../sdk/client/ClientIntegrationTest.kt | 41 -- .../sdk/integration/kotlin/KotlinTestBase.kt | 102 ++++ .../integration/kotlin/PromptEdgeCasesTest.kt | 412 ++++++++++++++ .../kotlin/PromptIntegrationTest.kt | 470 ++++++++++++++++ .../kotlin/ResourceEdgeCasesTest.kt | 285 ++++++++++ .../kotlin/ResourceIntegrationTest.kt | 94 ++++ .../integration/kotlin/ToolEdgeCasesTest.kt | 505 ++++++++++++++++++ .../integration/kotlin/ToolIntegrationTest.kt | 473 ++++++++++++++++ ...tlinClientTypeScriptServerEdgeCasesTest.kt | 258 +++++++++ .../KotlinClientTypeScriptServerTest.kt | 172 ++++++ .../TypeScriptClientKotlinServerTest.kt | 198 +++++++ .../typescript/TypeScriptEdgeCasesTest.kt | 183 +++++++ .../typescript/TypeScriptTestBase.kt | 165 ++++++ .../utils/KotlinServerForTypeScriptClient.kt | 424 +++++++++++++++ .../kotlin/sdk/integration/utils/Retry.kt | 97 ++++ .../kotlin/sdk/integration/utils/TestUtils.kt | 91 ++++ .../kotlin/sdk/integration/utils/myClient.ts | 108 ++++ 18 files changed, 4045 insertions(+), 41 deletions(-) delete mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index 62d9f365..9f87efd6 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -17,5 +17,13 @@ kotlin { implementation(libs.kotlinx.coroutines.test) } } + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(libs.kotlin.logging) + implementation(libs.ktor.server.cio) + implementation(libs.ktor.client.cio) + } + } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt deleted file mode 100644 index 562601aa..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt +++ /dev/null @@ -1,41 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.client - -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import kotlinx.coroutines.test.runTest -import kotlinx.io.asSink -import kotlinx.io.asSource -import kotlinx.io.buffered -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import java.net.Socket - -class ClientIntegrationTest { - - fun createTransport(): StdioClientTransport { - val socket = Socket("localhost", 3000) - - return StdioClientTransport( - socket.inputStream.asSource().buffered(), - socket.outputStream.asSink().buffered(), - ) - } - - @Disabled("This test requires a running server") - @Test - fun testRequestTools() = runTest { - val client = Client( - Implementation("test", "1.0"), - ) - - val transport = createTransport() - try { - client.connect(transport) - - val response: ListToolsResult = client.listTools() - println(response.tools) - } finally { - transport.close() - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt new file mode 100644 index 00000000..c367cf12 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -0,0 +1,102 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.install +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import kotlin.time.Duration.Companion.seconds +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +@Retry(times = 3) +abstract class KotlinTestBase { + + protected val host = "localhost" + protected abstract val port: Int + + protected lateinit var server: Server + protected lateinit var client: Client + protected lateinit var serverEngine: EmbeddedServer<*, *> + + protected abstract fun configureServerCapabilities(): ServerCapabilities + protected abstract fun configureServer() + + @BeforeEach + fun setUp() { + setupServer() + runBlocking { + setupClient() + } + } + + protected suspend fun setupClient() { + val transport = SseClientTransport( + HttpClient(CIO) { + install(SSE) + }, + "http://$host:$port", + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + + protected fun setupServer() { + val capabilities = configureServerCapabilities() + + server = Server( + Implementation(name = "test-server", version = "1.0"), + ServerOptions(capabilities = capabilities), + ) + + configureServer() + + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + install(ServerSSE) + routing { + mcp { server } + } + }.start(wait = false) + } + + @AfterEach + fun tearDown() { + // close client + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + // stop server + if (::serverEngine.isInitialized) { + try { + serverEngine.stop(500, 1000) + } catch (e: Exception) { + println("Warning: Error during server stop: ${e.message}") + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt new file mode 100644 index 00000000..f5e736d7 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt @@ -0,0 +1,412 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class PromptEdgeCasesTest : KotlinTestBase() { + + override val port = 3008 + + private val basicPromptName = "basic-prompt" + private val basicPromptDescription = "A basic prompt for testing" + + private val complexPromptName = "complex-prompt" + private val complexPromptDescription = "A complex prompt with many arguments" + + private val largePromptName = "large-prompt" + private val largePromptDescription = "A very large prompt for testing" + private val largePromptContent = "X".repeat(100_000) // 100KB of data + + private val specialCharsPromptName = "special-chars-prompt" + private val specialCharsPromptDescription = "A prompt with special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + server.addPrompt( + name = basicPromptName, + description = basicPromptDescription, + arguments = listOf( + PromptArgument( + name = "name", + description = "The name to greet", + required = true, + ), + ), + ) { request -> + val name = request.arguments?.get("name") ?: "World" + + GetPromptResult( + description = basicPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Hello, $name!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Greetings, $name! How can I assist you today?"), + ), + ), + ) + } + + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument(name = "arg1", description = "Argument 1", required = true), + PromptArgument(name = "arg2", description = "Argument 2", required = true), + PromptArgument(name = "arg3", description = "Argument 3", required = true), + PromptArgument(name = "arg4", description = "Argument 4", required = false), + PromptArgument(name = "arg5", description = "Argument 5", required = false), + PromptArgument(name = "arg6", description = "Argument 6", required = false), + PromptArgument(name = "arg7", description = "Argument 7", required = false), + PromptArgument(name = "arg8", description = "Argument 8", required = false), + PromptArgument(name = "arg9", description = "Argument 9", required = false), + PromptArgument(name = "arg10", description = "Argument 10", required = false), + ), + ) { request -> + // validate required arguments + val requiredArgs = listOf("arg1", "arg2", "arg3") + for (argName in requiredArgs) { + if (request.arguments?.get(argName) == null) { + throw IllegalArgumentException("Missing required argument: $argName") + } + } + + val args = mutableMapOf() + for (i in 1..10) { + val argName = "arg$i" + val argValue = request.arguments?.get(argName) + if (argValue != null) { + args[argName] = argValue + } + } + + GetPromptResult( + description = complexPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent( + text = "Arguments: ${ + args.entries.joinToString { + "${it.key}=${it.value}" + } + }", + ), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received ${args.size} arguments"), + ), + ), + ) + } + + // Very large prompt + server.addPrompt( + name = largePromptName, + description = largePromptDescription, + arguments = listOf( + PromptArgument( + name = "size", + description = "Size multiplier", + required = false, + ), + ), + ) { request -> + val size = request.arguments?.get("size")?.toIntOrNull() ?: 1 + val content = largePromptContent.repeat(size) + + GetPromptResult( + description = largePromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Generate a large response"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = content), + ), + ), + ) + } + + server.addPrompt( + name = specialCharsPromptName, + description = specialCharsPromptDescription, + arguments = listOf( + PromptArgument( + name = "special", + description = "Special characters to include", + required = false, + ), + ), + ) { request -> + val special = request.arguments?.get("special") ?: specialCharsContent + + GetPromptResult( + description = specialCharsPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Special characters: $special"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received special characters: $special"), + ), + ), + ) + } + } + + @Test + fun testBasicPrompt() { + runTest { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(basicPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertEquals("Hello, $testName!", userContent.text, "User message content should match") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + } + + @Test + fun testComplexPromptWithManyArguments() { + runTest { + val arguments = (1..10).associate { i -> "arg$i" to "value$i" } + + val result = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = arguments, + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(complexPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + + // verify all arguments + val text = userContent.text ?: "" + for (i in 1..10) { + assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") + } + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Received 10 arguments", + assistantContent.text, + "Assistant message should indicate 10 arguments", + ) + } + } + + @Test + fun testLargePrompt() { + runTest { + val result = client.getPrompt( + GetPromptRequest( + name = largePromptName, + arguments = mapOf("size" to "1"), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(largePromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val text = assistantContent.text ?: "" + assertEquals(100_000, text.length, "Assistant message should be 100KB in size") + } + } + + @Test + fun testSpecialCharacters() { + runTest { + val result = client.getPrompt( + GetPromptRequest( + name = specialCharsPromptName, + arguments = mapOf("special" to specialCharsContent), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = userContent.text ?: "" + assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = assistantContent.text ?: "" + assertTrue( + assistantText.contains(specialCharsContent), + "Assistant message should contain special characters", + ) + } + } + + @Test + fun testMissingRequiredArguments() { + runTest { + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf("arg4" to "value4", "arg5" to "value5"), + ), + ) + } + } + + assertTrue( + exception.message?.contains("arg1") == true || + exception.message?.contains("arg2") == true || + exception.message?.contains("arg3") == true || + exception.message?.contains("required") == true, + "Exception should mention missing required arguments", + ) + } + } + + @Test + fun testConcurrentPromptRequests() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val promptName = when (index % 4) { + 0 -> basicPromptName + 1 -> complexPromptName + 2 -> largePromptName + else -> specialCharsPromptName + } + + val arguments = when (promptName) { + basicPromptName -> mapOf("name" to "User$index") + complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") + largePromptName -> mapOf("size" to "1") + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.getPrompt( + GetPromptRequest( + name = promptName, + arguments = arguments, + ), + ) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") + } + } + } + + @Test + fun testNonExistentPrompt() { + runTest { + val nonExistentPromptName = "non-existent-prompt" + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = nonExistentPromptName, + arguments = mapOf("name" to "Test"), + ), + ) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("does not exist") == true || + exception.message?.contains("unknown") == true || + exception.message?.contains("error") == true, + "Exception should indicate prompt not found", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt new file mode 100644 index 00000000..a609c2ba --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt @@ -0,0 +1,470 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.ImageContent +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class PromptIntegrationTest : KotlinTestBase() { + + override val port = 3004 + private val testPromptName = "greeting" + private val testPromptDescription = "A simple greeting prompt" + private val complexPromptName = "multimodal-prompt" + private val complexPromptDescription = "A prompt with multiple content types" + private val conversationPromptName = "conversation" + private val conversationPromptDescription = "A prompt with multiple messages and roles" + private val strictPromptName = "strict-prompt" + private val strictPromptDescription = "A prompt with required arguments" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + // simple prompt with a name parameter + server.addPrompt( + name = testPromptName, + description = testPromptDescription, + arguments = listOf( + PromptArgument( + name = "name", + description = "The name to greet", + required = true, + ), + ), + ) { request -> + val name = request.arguments?.get("name") ?: "World" + + GetPromptResult( + description = testPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Hello, $name!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Greetings, $name! How can I assist you today?"), + ), + ), + ) + } + + // prompt with multiple content types + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument( + name = "topic", + description = "The topic to discuss", + required = false, + ), + PromptArgument( + name = "includeImage", + description = "Whether to include an image", + required = false, + ), + ), + ) { request -> + val topic = request.arguments?.get("topic") ?: "general knowledge" + val includeImage = request.arguments?.get("includeImage")?.toBoolean() ?: true + + val messages = mutableListOf() + + messages.add( + PromptMessage( + role = Role.user, + content = TextContent(text = "I'd like to discuss $topic."), + ), + ) + + val assistantContents = mutableListOf() + assistantContents.add(TextContent(text = "I'd be happy to discuss $topic with you.")) + + if (includeImage) { + assistantContents.add( + ImageContent( + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BmMIQAAAABJRU5ErkJggg==", + mimeType = "image/png", + ), + ) + } + + messages.add( + PromptMessage( + role = Role.assistant, + content = assistantContents[0], + ), + ) + + GetPromptResult( + description = complexPromptDescription, + messages = messages, + ) + } + + // prompt with multiple messages and roles + server.addPrompt( + name = conversationPromptName, + description = conversationPromptDescription, + arguments = listOf( + PromptArgument( + name = "topic", + description = "The topic of the conversation", + required = false, + ), + ), + ) { request -> + val topic = request.arguments?.get("topic") ?: "weather" + + GetPromptResult( + description = conversationPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Let's talk about the $topic."), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "Sure, I'd love to discuss the $topic. What would you like to know?", + ), + ), + PromptMessage( + role = Role.user, + content = TextContent(text = "What's your opinion on the $topic?"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "As an AI, I don't have personal opinions," + + " but I can provide information about $topic.", + ), + ), + PromptMessage( + role = Role.user, + content = TextContent(text = "That's helpful, thank you!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "You're welcome! Let me know if you have more questions about $topic.", + ), + ), + ), + ) + } + + // prompt with strict required arguments + server.addPrompt( + name = strictPromptName, + description = strictPromptDescription, + arguments = listOf( + PromptArgument( + name = "requiredArg1", + description = "First required argument", + required = true, + ), + PromptArgument( + name = "requiredArg2", + description = "Second required argument", + required = true, + ), + PromptArgument( + name = "optionalArg", + description = "Optional argument", + required = false, + ), + ), + ) { request -> + val args = request.arguments ?: emptyMap() + val arg1 = args["requiredArg1"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg1", + ) + val arg2 = args["requiredArg2"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg2", + ) + val optArg = args["optionalArg"] ?: "default" + + GetPromptResult( + description = strictPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Required arguments: $arg1, $arg2. Optional: $optArg"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "I received your arguments: $arg1, $arg2, and $optArg"), + ), + ), + ) + } + } + + @Test + fun testListPrompts() = runTest { + val result = client.listPrompts() + + assertNotNull(result, "List prompts result should not be null") + assertTrue(result.prompts.isNotEmpty(), "Prompts list should not be empty") + + val testPrompt = result.prompts.find { it.name == testPromptName } + assertNotNull(testPrompt, "Test prompt should be in the list") + assertEquals( + testPromptDescription, + testPrompt.description, + "Prompt description should match", + ) + + val arguments = testPrompt.arguments ?: error("Prompt arguments should not be null") + assertTrue(arguments.isNotEmpty(), "Prompt arguments should not be empty") + + val nameArg = arguments.find { it.name == "name" } + assertNotNull(nameArg, "Name argument should be in the list") + assertEquals( + "The name to greet", + nameArg.description, + "Argument description should match", + ) + assertEquals(true, nameArg.required, "Argument required flag should match") + } + + @Test + fun testGetPrompt() = runTest { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = testPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + testPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertNotNull(userContent.text, "User message text should not be null") + assertEquals( + "Hello, $testName!", + userContent.text, + "User message content should match", + ) + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertNotNull(assistantContent.text, "Assistant message text should not be null") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + + @Test + fun testMissingRequiredArguments() = runTest { + val promptsList = client.listPrompts() + assertNotNull(promptsList, "Prompts list should not be null") + val strictPrompt = promptsList.prompts.find { it.name == strictPromptName } + assertNotNull(strictPrompt, "Strict prompt should be in the list") + + val argsDef = strictPrompt.arguments ?: error("Prompt arguments should not be null") + val requiredArgs = argsDef.filter { it.required == true } + assertEquals( + 2, + requiredArgs.size, + "Strict prompt should have 2 required arguments", + ) + + // test missing required arg + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = mapOf("requiredArg1" to "value1"), + ), + ) + } + } + + assertEquals( + true, + exception.message?.contains("requiredArg2"), + "Exception should mention the missing argument", + ) + + // test with no args + val exception2 = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = emptyMap(), + ), + ) + } + } + + assertEquals( + exception2.message?.contains("requiredArg"), + true, + "Exception should mention a missing required argument", + ) + + // test with all required args + val result = client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = mapOf( + "requiredArg1" to "value1", + "requiredArg2" to "value2", + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = requireNotNull(userContent.text) + assertTrue(userText.contains("value1"), "Message should contain first argument") + assertTrue(userText.contains("value2"), "Message should contain second argument") + } + + @Test + fun testComplexContentTypes() = runTest { + val topic = "artificial intelligence" + val result = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf( + "topic" to topic, + "includeImage" to "true", + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + complexPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText2 = requireNotNull(userContent.text) + assertTrue(userText2.contains(topic), "User message should contain the topic") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = requireNotNull(assistantContent.text) + assertTrue( + assistantText.contains(topic), + "Assistant message should contain the topic", + ) + + val resultNoImage = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf( + "topic" to topic, + "includeImage" to "false", + ), + ), + ) + + assertNotNull(resultNoImage, "Get prompt result (no image) should not be null") + assertEquals(2, resultNoImage.messages.size, "Prompt should have 2 messages") + } + + @Test + fun testMultipleMessagesAndRoles() = runTest { + val topic = "climate change" + val result = client.getPrompt( + GetPromptRequest( + name = conversationPromptName, + arguments = mapOf("topic" to topic), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + conversationPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(6, result.messages.size, "Prompt should have 6 messages") + + val userMessages = result.messages.filter { it.role == Role.user } + val assistantMessages = result.messages.filter { it.role == Role.assistant } + + assertEquals(3, userMessages.size, "Should have 3 user messages") + assertEquals(3, assistantMessages.size, "Should have 3 assistant messages") + + for (i in 0 until result.messages.size) { + val expectedRole = if (i % 2 == 0) Role.user else Role.assistant + assertEquals( + expectedRole, + result.messages[i].role, + "Message $i should have role $expectedRole", + ) + } + + for (message in result.messages) { + val content = message.content as? TextContent + assertNotNull(content, "Message content should be TextContent") + val text = requireNotNull(content.text) + + // Either the message contains the topic or it's a generic conversation message + val containsTopic = text.contains(topic) + val isGenericMessage = text.contains("thank you") || text.contains("welcome") + + assertTrue( + containsTopic || isGenericMessage, + "Message should either contain the topic or be a generic conversation message", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt new file mode 100644 index 00000000..232ac025 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt @@ -0,0 +1,285 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ResourceEdgeCasesTest : KotlinTestBase() { + + override val port = 3007 + + private val testResourceUri = "test://example.txt" + private val testResourceName = "Test Resource" + private val testResourceDescription = "A test resource for integration testing" + private val testResourceContent = "This is the content of the test resource." + + private val binaryResourceUri = "test://image.png" + private val binaryResourceName = "Binary Resource" + private val binaryResourceDescription = "A binary resource for testing" + private val binaryResourceContent = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + private val largeResourceUri = "test://large.txt" + private val largeResourceName = "Large Resource" + private val largeResourceDescription = "A large text resource for testing" + private val largeResourceContent = "X".repeat(100_000) // 100KB of data + + private val dynamicResourceUri = "test://dynamic.txt" + private val dynamicResourceName = "Dynamic Resource" + private val dynamicResourceContent = AtomicBoolean(false) + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = binaryResourceUri, + name = binaryResourceName, + description = binaryResourceDescription, + mimeType = "image/png", + ) { request -> + ReadResourceResult( + contents = listOf( + BlobResourceContents( + blob = binaryResourceContent, + uri = request.uri, + mimeType = "image/png", + ), + ), + ) + } + + server.addResource( + uri = largeResourceUri, + name = largeResourceName, + description = largeResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = largeResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = dynamicResourceUri, + name = dynamicResourceName, + description = "A resource that can be updated", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = if (dynamicResourceContent.get()) "Updated content" else "Original content", + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> + EmptyRequestResult() + } + + server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> + EmptyRequestResult() + } + } + + @Test + fun testBinaryResource() { + runTest { + val result = client.readResource(ReadResourceRequest(uri = binaryResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? BlobResourceContents + assertNotNull(content, "Resource content should be BlobResourceContents") + assertEquals(binaryResourceContent, content.blob, "Binary resource content should match") + assertEquals("image/png", content.mimeType, "MIME type should match") + } + } + + @Test + fun testLargeResource() { + runTest { + val result = client.readResource(ReadResourceRequest(uri = largeResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(100_000, content.text.length, "Large resource content length should match") + assertEquals("X".repeat(100_000), content.text, "Large resource content should match") + } + } + + @Test + fun testInvalidResourceUri() { + runTest { + val invalidUri = "test://nonexistent.txt" + + val exception = assertThrows { + runBlocking { + client.readResource(ReadResourceRequest(uri = invalidUri)) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("invalid") == true || + exception.message?.contains("error") == true, + "Exception should indicate resource not found or invalid URI", + ) + } + } + + @Test + fun testDynamicResource() { + runTest { + val initialResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(initialResult, "Initial read result should not be null") + val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Original content", initialContent, "Initial content should match") + + // update resource + dynamicResourceContent.set(true) + + val updatedResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(updatedResult, "Updated read result should not be null") + val updatedContent = (updatedResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Updated content", updatedContent, "Updated content should match") + } + } + + @Test + fun testResourceAddAndRemove() { + runTest { + val initialList = client.listResources() + assertNotNull(initialList, "Initial list result should not be null") + val initialCount = initialList.resources.size + + val newResourceUri = "test://new-resource.txt" + server.addResource( + uri = newResourceUri, + name = "New Resource", + description = "A newly added resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "New resource content", + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + val updatedList = client.listResources() + assertNotNull(updatedList, "Updated list result should not be null") + val updatedCount = updatedList.resources.size + + assertEquals(initialCount + 1, updatedCount, "Resource count should increase by 1") + val newResource = updatedList.resources.find { it.uri == newResourceUri } + assertNotNull(newResource, "New resource should be in the list") + + server.removeResource(newResourceUri) + + val finalList = client.listResources() + assertNotNull(finalList, "Final list result should not be null") + val finalCount = finalList.resources.size + + assertEquals(initialCount, finalCount, "Resource count should return to initial value") + val removedResource = finalList.resources.find { it.uri == newResourceUri } + assertEquals(null, removedResource, "Resource should be removed from the list") + } + } + + @Test + fun testConcurrentResourceOperations() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val uri = when (index % 3) { + 0 -> testResourceUri + 1 -> binaryResourceUri + else -> largeResourceUri + } + + val result = client.readResource(ReadResourceRequest(uri = uri)) + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") + } + } + } + + @Test + fun testSubscribeAndUnsubscribe() { + runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt new file mode 100644 index 00000000..c467b2a1 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt @@ -0,0 +1,94 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ResourceIntegrationTest : KotlinTestBase() { + + override val port = 3005 + private val testResourceUri = "test://example.txt" + private val testResourceName = "Test Resource" + private val testResourceDescription = "A test resource for integration testing" + private val testResourceContent = "This is the content of the test resource." + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> + EmptyRequestResult() + } + + server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> + EmptyRequestResult() + } + } + + @Test + fun testListResources() = runTest { + val result = client.listResources() + + assertNotNull(result, "List resources result should not be null") + assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") + + val testResource = result.resources.find { it.uri == testResourceUri } + assertNotNull(testResource, "Test resource should be in the list") + assertEquals(testResourceName, testResource.name, "Resource name should match") + assertEquals(testResourceDescription, testResource.description, "Resource description should match") + } + + @Test + fun testReadResource() = runTest { + val result = client.readResource(ReadResourceRequest(uri = testResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(testResourceContent, content.text, "Resource content should match") + } + + @Test + fun testSubscribeAndUnsubscribe() { + runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt new file mode 100644 index 00000000..0cb8c506 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -0,0 +1,505 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ToolEdgeCasesTest : KotlinTestBase() { + + override val port = 3009 + + private val basicToolName = "basic-tool" + private val basicToolDescription = "A basic tool for testing" + + private val complexToolName = "complex-tool" + private val complexToolDescription = "A complex tool with nested schema" + + private val largeToolName = "large-tool" + private val largeToolDescription = "A tool that returns a large response" + private val largeToolContent = "X".repeat(100_000) // 100KB of data + + private val slowToolName = "slow-tool" + private val slowToolDescription = "A tool that takes time to respond" + + private val specialCharsToolName = "special-chars-tool" + private val specialCharsToolDescription = "A tool that handles special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + server.addTool( + name = basicToolName, + description = basicToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + + server.addTool( + name = complexToolName, + description = complexToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "user", + buildJsonObject { + put("type", "object") + put("description", "User information") + put( + "properties", + buildJsonObject { + put( + "name", + buildJsonObject { + put("type", "string") + put("description", "User's name") + }, + ) + put( + "age", + buildJsonObject { + put("type", "integer") + put("description", "User's age") + }, + ) + put( + "address", + buildJsonObject { + put("type", "object") + put("description", "User's address") + put( + "properties", + buildJsonObject { + put( + "street", + buildJsonObject { + put("type", "string") + }, + ) + put( + "city", + buildJsonObject { + put("type", "string") + }, + ) + put( + "country", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + ) + }, + ) + }, + ) + put( + "options", + buildJsonObject { + put("type", "array") + put("description", "Additional options") + put( + "items", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + required = listOf("user"), + ), + ) { request -> + val user = request.arguments["user"] as? JsonObject + val name = (user?.get("name") as? JsonPrimitive)?.content ?: "Unknown" + val age = (user?.get("age") as? JsonPrimitive)?.content?.toIntOrNull() ?: 0 + + val address = user?.get("address") as? JsonObject + val street = (address?.get("street") as? JsonPrimitive)?.content ?: "Unknown" + val city = (address?.get("city") as? JsonPrimitive)?.content ?: "Unknown" + val country = (address?.get("country") as? JsonPrimitive)?.content ?: "Unknown" + + val options = (request.arguments["options"] as? JsonArray)?.mapNotNull { + (it as? JsonPrimitive)?.content + } ?: emptyList() + + val summary = + "User: $name, Age: $age, Address: $street, $city, $country, Options: ${options.joinToString(", ")}" + + CallToolResult( + content = listOf(TextContent(text = summary)), + structuredContent = buildJsonObject { + put("name", name) + put("age", age) + put( + "address", + buildJsonObject { + put("street", street) + put("city", city) + put("country", country) + }, + ) + put( + "options", + buildJsonArray { + options.forEach { add(it) } + }, + ) + }, + ) + } + + server.addTool( + name = largeToolName, + description = largeToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "size", + buildJsonObject { + put("type", "integer") + put("description", "Size multiplier") + }, + ) + }, + ), + ) { request -> + val size = (request.arguments["size"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1 + val content = largeToolContent.take(largeToolContent.length.coerceAtMost(size * 1000)) + + CallToolResult( + content = listOf(TextContent(text = content)), + structuredContent = buildJsonObject { + put("size", content.length) + }, + ) + } + + server.addTool( + name = slowToolName, + description = slowToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "delay", + buildJsonObject { + put("type", "integer") + put("description", "Delay in milliseconds") + }, + ) + }, + ), + ) { request -> + val delay = (request.arguments["delay"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 + + // simulate slow operation + runBlocking { + delay(delay.toLong()) + } + + CallToolResult( + content = listOf(TextContent(text = "Completed after ${delay}ms delay")), + structuredContent = buildJsonObject { + put("delay", delay) + }, + ) + } + + server.addTool( + name = specialCharsToolName, + description = specialCharsToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "special", + buildJsonObject { + put("type", "string") + put("description", "Special characters to process") + }, + ) + }, + ), + ) { request -> + val special = (request.arguments["special"] as? JsonPrimitive)?.content ?: specialCharsContent + + CallToolResult( + content = listOf(TextContent(text = "Received special characters: $special")), + structuredContent = buildJsonObject { + put("special", special) + put("length", special.length) + }, + ) + } + } + + @Test + fun testBasicTool() { + runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) + + val result = client.callTool(basicToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "result", testText) + } + } + + @Test + fun testComplexNestedSchema() { + runTest { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Doe")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } + + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } + + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } + + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains("John Doe"), "Result should contain the name") + assertTrue(text.contains("30"), "Result should contain the age") + assertTrue(text.contains("123 Main St"), "Result should contain the street") + assertTrue(text.contains("New York"), "Result should contain the city") + assertTrue(text.contains("USA"), "Result should contain the country") + assertTrue(text.contains("option1"), "Result should contain option1") + assertTrue(text.contains("option2"), "Result should contain option2") + assertTrue(text.contains("option3"), "Result should contain option3") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "name", "John Doe") + assertJsonProperty(structuredContent, "age", 30) + + val address = structuredContent["address"] as? JsonObject + assertNotNull(address, "Address should be present in structured content") + assertJsonProperty(address, "street", "123 Main St") + assertJsonProperty(address, "city", "New York") + assertJsonProperty(address, "country", "USA") + + val options = structuredContent["options"] as? JsonArray + assertNotNull(options, "Options should be present in structured content") + assertEquals(3, options.size, "Options should have 3 items") + } + } + + @Test + fun testLargeResponse() { + runTest { + val size = 10 + val arguments = mapOf("size" to size) + + val result = client.callTool(largeToolName, arguments) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertEquals(10000, text.length, "Response should be 10KB in size") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "size", 10000) + } + } + + @Test + fun testSlowTool() { + runTest { + val delay = 500 + val arguments = mapOf("delay" to delay) + + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) + val endTime = System.currentTimeMillis() + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains("${delay}ms"), "Result should mention the delay") + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "delay", delay) + } + } + + @Test + fun testSpecialCharacters() { + runTest { + val arguments = mapOf("special" to specialCharsContent) + + val result = client.callTool(specialCharsToolName, arguments) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + + val structuredContent = toolResult.structuredContent as JsonObject + val special = structuredContent["special"]?.toString()?.trim('"') + + assertNotNull(special, "Special characters should be in structured content") + assertTrue(text.contains(specialCharsContent), "Special characters should be in the content") + } + } + + @Test + fun testConcurrentToolCalls() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val toolName = when (index % 5) { + 0 -> basicToolName + 1 -> complexToolName + 2 -> largeToolName + 3 -> slowToolName + else -> specialCharsToolName + } + + val arguments = when (toolName) { + basicToolName -> mapOf("text" to "Concurrent call $index") + + complexToolName -> mapOf( + "user" to mapOf( + "name" to "User $index", + "age" to 20 + index, + "address" to mapOf( + "street" to "Street $index", + "city" to "City $index", + "country" to "Country $index", + ), + ), + ) + + largeToolName -> mapOf("size" to 1) + + slowToolName -> mapOf("delay" to 100) + + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.callTool(toolName, arguments) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.content.isNotEmpty(), "Result content should not be empty") + } + } + } + + @Test + fun testNonExistentTool() { + runTest { + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("text" to "Test") + + val exception = assertThrows { + runBlocking { + client.callTool(nonExistentToolName, arguments) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("does not exist") == true || + exception.message?.contains("unknown") == true || + exception.message?.contains("error") == true, + "Exception should indicate tool not found", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt new file mode 100644 index 00000000..c6262a13 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -0,0 +1,473 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.ImageContent +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ToolIntegrationTest : KotlinTestBase() { + + override val port = 3006 + private val testToolName = "echo" + private val testToolDescription = "A simple echo tool that returns the input text" + private val complexToolName = "calculator" + private val complexToolDescription = "A calculator tool that performs operations on numbers" + private val errorToolName = "error-tool" + private val errorToolDescription = "A tool that demonstrates error handling" + private val multiContentToolName = "multi-content" + private val multiContentToolDescription = "A tool that returns multiple content types" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + setupEchoTool() + setupCalculatorTool() + setupErrorHandlingTool() + setupMultiContentTool() + } + + private fun setupEchoTool() { + server.addTool( + name = testToolName, + description = testToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + } + + private fun setupCalculatorTool() { + server.addTool( + name = complexToolName, + description = complexToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "operation", + buildJsonObject { + put("type", "string") + put("description", "The operation to perform (add, subtract, multiply, divide)") + put( + "enum", + buildJsonArray { + add("add") + add("subtract") + add("multiply") + add("divide") + }, + ) + }, + ) + put( + "a", + buildJsonObject { + put("type", "number") + put("description", "First operand") + }, + ) + put( + "b", + buildJsonObject { + put("type", "number") + put("description", "Second operand") + }, + ) + put( + "precision", + buildJsonObject { + put("type", "integer") + put("description", "Number of decimal places (optional)") + put("default", 2) + }, + ) + put( + "showSteps", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to show calculation steps") + put("default", false) + }, + ) + put( + "tags", + buildJsonObject { + put("type", "array") + put("description", "Optional tags for the calculation") + put( + "items", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + required = listOf("operation", "a", "b"), + ), + ) { request -> + val operation = (request.arguments["operation"] as? JsonPrimitive)?.content ?: "add" + val a = (request.arguments["a"] as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val b = (request.arguments["b"] as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val precision = (request.arguments["precision"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 2 + val showSteps = (request.arguments["showSteps"] as? JsonPrimitive)?.content?.toBoolean() ?: false + val tags = (request.arguments["tags"] as? JsonArray)?.mapNotNull { + (it as? JsonPrimitive)?.content + } ?: emptyList() + + val result = when (operation) { + "add" -> a + b + "subtract" -> a - b + "multiply" -> a * b + "divide" -> if (b != 0.0) a / b else Double.POSITIVE_INFINITY + else -> 0.0 + } + + val formattedResult = "%.${precision}f".format(result) + + val textContent = if (showSteps) { + "Operation: $operation\nA: $a\nB: $b\nResult: $formattedResult\nTags: ${ + tags.joinToString(", ") + }" + } else { + "Result: $formattedResult" + } + + CallToolResult( + content = listOf(TextContent(text = textContent)), + structuredContent = buildJsonObject { + put("operation", operation) + put("a", a) + put("b", b) + put("result", result) + put("formattedResult", formattedResult) + put("precision", precision) + put("tags", buildJsonArray { tags.forEach { add(it) } }) + }, + ) + } + } + + private fun setupErrorHandlingTool() { + server.addTool( + name = errorToolName, + description = errorToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "errorType", + buildJsonObject { + put("type", "string") + put("description", "Type of error to simulate (none, exception, error)") + put( + "enum", + buildJsonArray { + add("none") + add("exception") + add("error") + }, + ) + }, + ) + put( + "message", + buildJsonObject { + put("type", "string") + put("description", "Custom error message") + put("default", "An error occurred") + }, + ) + }, + required = listOf("errorType"), + ), + ) { request -> + val errorType = (request.arguments["errorType"] as? JsonPrimitive)?.content ?: "none" + val message = (request.arguments["message"] as? JsonPrimitive)?.content ?: "An error occurred" + + when (errorType) { + "exception" -> throw IllegalArgumentException(message) + + "error" -> CallToolResult( + content = listOf(TextContent(text = "Error: $message")), + structuredContent = buildJsonObject { + put("error", true) + put("message", message) + }, + ) + + else -> CallToolResult( + content = listOf(TextContent(text = "No error occurred")), + structuredContent = buildJsonObject { + put("error", false) + put("message", "Success") + }, + ) + } + } + } + + private fun setupMultiContentTool() { + server.addTool( + name = multiContentToolName, + description = multiContentToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Text to include in the response") + }, + ) + put( + "includeImage", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to include an image in the response") + put("default", true) + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "Default text" + val includeImage = (request.arguments["includeImage"] as? JsonPrimitive)?.content?.toBoolean() ?: true + + val content = mutableListOf( + TextContent(text = "Text content: $text"), + ) + + if (includeImage) { + content.add( + ImageContent( + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==", + mimeType = "image/png", + ), + ) + } + + CallToolResult( + content = content, + structuredContent = buildJsonObject { + put("text", text) + put("includeImage", includeImage) + }, + ) + } + } + + @Test + fun testListTools() = runTest { + val result = client.listTools() + + assertNotNull(result, "List utils result should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + val testTool = result.tools.find { it.name == testToolName } + assertNotNull(testTool, "Test tool should be in the list") + assertEquals( + testToolDescription, + testTool.description, + "Tool description should match", + ) + } + + @Test + fun testCallTool() = runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) + + val result = client.callTool(testToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "result", testText) + } + + @Test + fun testComplexInputSchemaTool() { + runTest { + val toolsList = client.listTools() + assertNotNull(toolsList, "Tools list should not be null") + val calculatorTool = toolsList.tools.find { it.name == complexToolName } + assertNotNull(calculatorTool, "Calculator tool should be in the list") + + val arguments = mapOf( + "operation" to "multiply", + "a" to 5.5, + "b" to 2.0, + "precision" to 3, + "showSteps" to true, + "tags" to listOf("test", "calculator", "integration"), + ) + + val result = client.callTool(complexToolName, arguments) + + val toolResult = assertCallToolResult(result) + + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val contentText = requireNotNull(content.text) + + assertTrue(contentText.contains("Operation"), "Result should contain operation") + assertTrue( + contentText.contains("multiply"), + "Result should contain multiply operation", + ) + assertTrue(contentText.contains("5.5"), "Result should contain first operand") + assertTrue(contentText.contains("2.0"), "Result should contain second operand") + assertTrue(contentText.contains("11"), "Result should contain result value") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "operation", "multiply") + assertJsonProperty(structuredContent, "result", 11.0) + + val formattedResult = structuredContent["formattedResult"]?.toString()?.trim('"') ?: "" + assertTrue( + formattedResult == "11.000" || formattedResult == "11,000", + "Formatted result should be either '11.000' or '11,000', but was '$formattedResult'", + ) + assertJsonProperty(structuredContent, "precision", 3) + + val tags = structuredContent["tags"] as? JsonArray + assertNotNull(tags, "Tags should be present") + } + } + + @Test + fun testToolErrorHandling() = runTest { + val successArgs = mapOf("errorType" to "none") + val successResult = client.callTool(errorToolName, successArgs) + + val successToolResult = assertCallToolResult(successResult, "No error: ") + assertTextContent(successToolResult.content.firstOrNull(), "No error occurred") + + val noErrorStructured = successToolResult.structuredContent as JsonObject + assertJsonProperty(noErrorStructured, "error", false) + + val errorArgs = mapOf( + "errorType" to "error", + "message" to "Custom error message", + ) + val errorResult = client.callTool(errorToolName, errorArgs) + + val errorToolResult = assertCallToolResult(errorResult, "Error: ") + assertTextContent(errorToolResult.content.firstOrNull(), "Error: Custom error message") + + val errorStructured = errorToolResult.structuredContent as JsonObject + assertJsonProperty(errorStructured, "error", true) + assertJsonProperty(errorStructured, "message", "Custom error message") + + val exceptionArgs = mapOf( + "errorType" to "exception", + "message" to "Exception message", + ) + + val exception = assertThrows { + runBlocking { + client.callTool(errorToolName, exceptionArgs) + } + } + + assertEquals( + exception.message?.contains("Exception message"), + true, + "Exception message should contain 'Exception message'", + ) + } + + @Test + fun testMultiContentTool() = runTest { + val testText = "Test multi-content" + val arguments = mapOf( + "text" to testText, + "includeImage" to true, + ) + + val result = client.callTool(multiContentToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertEquals( + 2, + toolResult.content.size, + "Tool result should have 2 content items", + ) + + val textContent = toolResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Result should contain TextContent") + assertNotNull(textContent.text, "Text content should not be null") + assertEquals( + "Text content: $testText", + textContent.text, + "Text content should match", + ) + + val imageContent = toolResult.content.firstOrNull { it is ImageContent } as? ImageContent + assertNotNull(imageContent, "Result should contain ImageContent") + assertEquals("image/png", imageContent.mimeType, "Image MIME type should match") + assertTrue(imageContent.data.isNotEmpty(), "Image data should not be empty") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "text", testText) + assertJsonProperty(structuredContent, "includeImage", true) + + val textOnlyArgs = mapOf( + "text" to testText, + "includeImage" to false, + ) + + val textOnlyResult = client.callTool(multiContentToolName, textOnlyArgs) + + val textOnlyToolResult = assertCallToolResult(textOnlyResult, "Text-only: ") + assertEquals( + 1, + textOnlyToolResult.content.size, + "Text-only result should have 1 content item", + ) + + assertTextContent(textOnlyToolResult.content.firstOrNull(), "Text content: $testText") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt new file mode 100644 index 00000000..3905a20a --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -0,0 +1,258 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + + private lateinit var client: Client + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDown() { + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentTool() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) + } + + val errorMessage = exception.message ?: "" + assertTrue( + errorMessage.contains("not found") || + errorMessage.contains("unknown") || + errorMessage.contains("error"), + "Exception should indicate tool not found: $errorMessage", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersInArguments() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayload() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testConcurrentRequests() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val concurrentCount = 5 + val results = mutableListOf>() + + for (i in 1..concurrentCount) { + val deferred = async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text ?: "" + } + results.add(deferred) + } + + val responses = results.awaitAll() + + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testInvalidArguments() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + try { + val result = client.callTool("greet", invalidArguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + } catch (e: Exception) { + assertTrue( + e.message?.contains("invalid") == true || + e.message?.contains("error") == true, + "Exception should indicate invalid arguments: ${e.message}", + ) + } + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleToolCalls() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt new file mode 100644 index 00000000..f4cf8ffc --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -0,0 +1,172 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + + private lateinit var client: Client + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDown() { + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testKotlinClientConnectsToTypeScriptServer() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + assertNotNull(client, "Client should be initialized") + + val pingResult = client.ping() + assertNotNull(pingResult, "Ping result should not be null") + + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testListTools() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + // Verify specific utils are available + val toolNames = result.tools.map { it.name } + assertTrue(toolNames.contains("greet"), "Greet tool should be available") + assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") + assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") + + println("Available utils: ${toolNames.joinToString()}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testToolCall() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val testName = "TestUser" + val arguments = mapOf("name" to testName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals( + "Hello, $testName!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleClients() { + runBlocking { + withContext(Dispatchers.IO) { + // First client connection + val client1 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools1 = client1.listTools() + assertNotNull(tools1, "Tools list for first client should not be null") + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + + val client2 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools2 = client2.listTools() + assertNotNull(tools2, "Tools list for second client should not be null") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + + assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") + assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") + assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") + assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") + + client1.close() + client2.close() + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt new file mode 100644 index 00000000..e88b0a4c --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -0,0 +1,198 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertTrue + +class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { + + private var port: Int = 0 + private lateinit var serverUrl: String + private var httpServer: KotlinServerForTypeScriptClient? = null + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://localhost:$port/mcp" + killProcessOnPort(port) + httpServer = KotlinServerForTypeScriptClient() + httpServer?.start(port) + if (!waitForPort(port = port)) { + throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") + } + println("Kotlin server started on port $port") + } + + @AfterEach + fun tearDown() { + try { + httpServer?.stop() + println("HTTP server stopped") + } catch (e: Exception) { + println("Error during server shutdown: ${e.message}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testToolCall() { + val testName = "TestUser" + val command = "npx tsx myClient.ts $serverUrl greet $testName" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Hello, $testName!"), + "Tool response should contain the greeting with the provided name", + ) + assertTrue(output.contains("Tool result:"), "Output should indicate a successful tool call") + assertTrue(output.contains("Text content:"), "Output should contain the text content section") + assertTrue(output.contains("Structured content:"), "Output should contain the structured content section") + assertTrue( + output.contains("\"greeting\": \"Hello, $testName!\""), + "Structured content should contain the greeting", + ) + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testToolCallWithSessionManagement() { + val testName = "SessionTest" + val command = "npx tsx myClient.ts $serverUrl greet $testName" + val output = executeCommand(command, tsClientDir) + + assertTrue(output.contains("Connected to server"), "Client should connect to server") + assertTrue( + output.contains("Hello, $testName!"), + "Tool response should contain the greeting with the provided name", + ) + assertTrue(output.contains("Tool result:"), "Output should indicate a successful tool call") + assertTrue(output.contains("Disconnected from server"), "Client should disconnect cleanly") + + val multiGreetName = "NotificationTest" + val multiGreetCommand = "npx tsx myClient.ts $serverUrl multi-greet $multiGreetName" + val multiGreetOutput = executeCommand(multiGreetCommand, tsClientDir) + + assertTrue(multiGreetOutput.contains("Connected to server"), "Client should connect to server") + assertTrue( + multiGreetOutput.contains("Multiple greetings") || multiGreetOutput.contains("greeting"), + "Tool response should contain greeting message", + ) + assertTrue(multiGreetOutput.contains("Disconnected from server"), "Client should disconnect cleanly") + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun testMultipleClientSequence() { + val testName1 = "FirstClient" + val command1 = "npx tsx myClient.ts $serverUrl greet $testName1" + val output1 = executeCommand(command1, tsClientDir) + + assertTrue(output1.contains("Connected to server"), "First client should connect to server") + assertTrue(output1.contains("Hello, $testName1!"), "Tool response should contain the greeting for first client") + assertTrue(output1.contains("Disconnected from server"), "First client should disconnect cleanly") + + val testName2 = "SecondClient" + val command2 = "npx tsx myClient.ts $serverUrl multi-greet $testName2" + val output2 = executeCommand(command2, tsClientDir) + + assertTrue(output2.contains("Connected to server"), "Second client should connect to server") + assertTrue( + output2.contains("Multiple greetings") || output2.contains("greeting"), + "Tool response should contain greeting message", + ) + assertTrue(output2.contains("Disconnected from server"), "Second client should disconnect cleanly") + + val command3 = "npx tsx myClient.ts $serverUrl" + val output3 = executeCommand(command3, tsClientDir) + + assertTrue(output3.contains("Connected to server"), "Third client should connect to server") + assertTrue(output3.contains("Available utils:"), "Third client should list available utils") + assertTrue(output3.contains("greet"), "Greet tool should be available to third client") + assertTrue(output3.contains("multi-greet"), "Multi-greet tool should be available to third client") + assertTrue(output3.contains("Disconnected from server"), "Third client should disconnect cleanly") + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleClientParallel() { + val clientCount = 3 + val clients = listOf( + "FirstClient" to "greet", + "SecondClient" to "multi-greet", + "ThirdClient" to "", + ) + + val threads = mutableListOf() + val outputs = mutableListOf>() + val exceptions = mutableListOf() + + for (i in 0 until clientCount) { + val (clientName, toolName) = clients[i] + val thread = Thread { + try { + val command = if (toolName.isEmpty()) { + "npx tsx myClient.ts $serverUrl" + } else { + "npx tsx myClient.ts $serverUrl $toolName $clientName" + } + + val output = executeCommand(command, tsClientDir) + synchronized(outputs) { + outputs.add(i to output) + } + } catch (e: Exception) { + synchronized(exceptions) { + exceptions.add(e) + } + } + } + threads.add(thread) + thread.start() + Thread.sleep(500) + } + + threads.forEach { it.join() } + + if (exceptions.isNotEmpty()) { + println( + "Exceptions occurred in parallel clients: ${ + exceptions.joinToString { + it.message ?: it.toString() + } + }", + ) + } + + val sortedOutputs = outputs.sortedBy { it.first }.map { it.second } + + sortedOutputs.forEachIndexed { index, output -> + val clientName = clients[index].first + val toolName = clients[index].second + + when (toolName) { + "greet" -> { + val containsGreeting = output.contains("Hello, $clientName!") || + output.contains("\"greeting\": \"Hello, $clientName!\"") + assertTrue( + containsGreeting, + "Tool response should contain the greeting for $clientName", + ) + } + + "multi-greet" -> { + val containsGreeting = output.contains("Multiple greetings") || + output.contains("greeting") || + output.contains("greet") + assertTrue( + containsGreeting, + "Tool response should contain greeting message for $clientName", + ) + } + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt new file mode 100644 index 00000000..77241281 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -0,0 +1,183 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.io.File +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TypeScriptEdgeCasesTest : TypeScriptTestBase() { + + private var port: Int = 0 + private lateinit var serverUrl: String + private var httpServer: KotlinServerForTypeScriptClient? = null + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://localhost:$port/mcp" + killProcessOnPort(port) + httpServer = KotlinServerForTypeScriptClient() + httpServer?.start(port) + if (!waitForPort(port = port)) { + throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") + } + println("Kotlin server started on port $port") + } + + @AfterEach + fun tearDown() { + try { + httpServer?.stop() + println("HTTP server stopped") + } catch (e: Exception) { + println("Error during server shutdown: ${e.message}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testErrorHandling() { + val nonExistentToolCommand = "npx tsx myClient.ts $serverUrl non-existent-tool" + val nonExistentToolOutput = executeCommandAllowingFailure(nonExistentToolCommand, tsClientDir) + + assertTrue( + nonExistentToolOutput.contains("Tool \"non-existent-tool\" not found"), + "Client should handle non-existent tool gracefully", + ) + + val invalidUrlCommand = "npx tsx myClient.ts http://localhost:${port + 1000}/mcp greet TestUser" + val invalidUrlOutput = executeCommandAllowingFailure(invalidUrlCommand, tsClientDir) + + assertTrue( + invalidUrlOutput.contains("Error:") || invalidUrlOutput.contains("ECONNREFUSED"), + "Client should handle connection errors gracefully", + ) + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharacters() { + val specialChars = "!@#$+-[].,?" + + val tempFile = File.createTempFile("special_chars", ".txt") + tempFile.writeText(specialChars) + tempFile.deleteOnExit() + + val specialCharsContent = tempFile.readText() + val specialCharsCommand = "npx tsx myClient.ts $serverUrl greet \"$specialCharsContent\"" + val specialCharsOutput = executeCommand(specialCharsCommand, tsClientDir) + + assertTrue( + specialCharsOutput.contains("Hello, $specialChars!"), + "Tool should handle special characters in arguments", + ) + assertTrue( + specialCharsOutput.contains("Disconnected from server"), + "Client should disconnect cleanly after handling special characters", + ) + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayload() { + val largeName = "A".repeat(10 * 1024) + + val tempFile = File.createTempFile("large_name", ".txt") + tempFile.writeText(largeName) + tempFile.deleteOnExit() + + val largeNameContent = tempFile.readText() + val largePayloadCommand = "npx tsx myClient.ts $serverUrl greet \"$largeNameContent\"" + val largePayloadOutput = executeCommand(largePayloadCommand, tsClientDir) + + tempFile.delete() + + assertTrue( + largePayloadOutput.contains("Hello,") && largePayloadOutput.contains("A".repeat(20)), + "Tool should handle large payloads", + ) + assertTrue( + largePayloadOutput.contains("Disconnected from server"), + "Client should disconnect cleanly after handling large payload", + ) + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testComplexConcurrentRequests() { + val commands = listOf( + "npx tsx myClient.ts $serverUrl greet \"Client1\"", + "npx tsx myClient.ts $serverUrl multi-greet \"Client2\"", + "npx tsx myClient.ts $serverUrl greet \"Client3\"", + "npx tsx myClient.ts $serverUrl", + "npx tsx myClient.ts $serverUrl multi-greet \"Client5\"", + ) + + val threads = commands.mapIndexed { index, command -> + Thread { + println("Starting client $index") + val output = executeCommand(command, tsClientDir) + println("Client $index completed") + + assertTrue( + output.contains("Connected to server"), + "Client $index should connect to server", + ) + assertTrue( + output.contains("Disconnected from server"), + "Client $index should disconnect cleanly", + ) + + when { + command.contains("greet \"Client1\"") -> + assertTrue(output.contains("Hello, Client1!"), "Client 1 should receive correct greeting") + + command.contains("multi-greet \"Client2\"") -> + assertTrue(output.contains("Multiple greetings"), "Client 2 should receive multiple greetings") + + command.contains("greet \"Client3\"") -> + assertTrue(output.contains("Hello, Client3!"), "Client 3 should receive correct greeting") + + !command.contains("greet") && !command.contains("multi-greet") -> + assertTrue(output.contains("Available utils:"), "Client 4 should list available tools") + + command.contains("multi-greet \"Client5\"") -> + assertTrue(output.contains("Multiple greetings"), "Client 5 should receive multiple greetings") + } + }.apply { start() } + } + + threads.forEach { it.join() } + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun testRapidSequentialRequests() { + val outputs = (1..10).map { i -> + val command = "npx tsx myClient.ts $serverUrl greet \"RapidClient$i\"" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Connected to server"), + "Client $i should connect to server", + ) + assertTrue( + output.contains("Hello, RapidClient$i!"), + "Client $i should receive correct greeting", + ) + assertTrue( + output.contains("Disconnected from server"), + "Client $i should disconnect cleanly", + ) + + output + } + + assertEquals(10, outputs.size, "All 10 rapid requests should complete successfully") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt new file mode 100644 index 00000000..5210340c --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -0,0 +1,165 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import org.junit.jupiter.api.BeforeAll +import java.io.BufferedReader +import java.io.File +import java.io.InputStreamReader +import java.net.ServerSocket +import java.net.Socket +import java.nio.file.Files +import java.util.concurrent.TimeUnit + +@Retry(times = 3) +abstract class TypeScriptTestBase { + + protected val projectRoot: File get() = File(System.getProperty("user.dir")) + protected val tsClientDir: File + get() = File( + projectRoot, + "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils", + ) + + companion object { + @JvmStatic + private val tempRootDir: File = Files.createTempDirectory("typescript-sdk-").toFile().apply { deleteOnExit() } + + @JvmStatic + protected val sdkDir: File = File(tempRootDir, "typescript-sdk") + + @JvmStatic + @BeforeAll + fun setupTypeScriptSdk() { + println("Cloning TypeScript SDK repository") + if (!sdkDir.exists()) { + val cloneCommand = + "git clone --depth 1 https://github.com/modelcontextprotocol/typescript-sdk.git ${sdkDir.absolutePath}" + val process = ProcessBuilder() + .command("bash", "-c", cloneCommand) + .redirectErrorStream(true) + .start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw RuntimeException("Failed to clone TypeScript SDK repository: exit code $exitCode") + } + } + + println("Installing TypeScript SDK dependencies") + executeCommand("npm install", sdkDir) + } + + @JvmStatic + protected fun executeCommand(command: String, workingDir: File): String = + runCommand(command, workingDir, allowFailure = false, timeoutSeconds = null) + + @JvmStatic + protected fun killProcessOnPort(port: Int) { + executeCommand("lsof -ti:$port | xargs kill -9 2>/dev/null || true", File(".")) + } + + @JvmStatic + protected fun findFreePort(): Int { + ServerSocket(0).use { socket -> + return socket.localPort + } + } + + private fun runCommand( + command: String, + workingDir: File, + allowFailure: Boolean, + timeoutSeconds: Long?, + ): String { + val process = ProcessBuilder() + .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + .directory(workingDir) + .redirectErrorStream(true) + .start() + + val output = StringBuilder() + BufferedReader(InputStreamReader(process.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + println(line) + output.append(line).append("\n") + } + } + + if (timeoutSeconds == null) { + val exitCode = process.waitFor() + if (!allowFailure && exitCode != 0) { + throw RuntimeException( + "Command execution failed with exit code $exitCode: $command\nOutput:\n$output", + ) + } + } else { + process.waitFor(timeoutSeconds, TimeUnit.SECONDS) + } + + return output.toString() + } + } + + protected fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { + if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { + process.destroyForcibly() + process.waitFor(2, TimeUnit.SECONDS) + return false + } + return true + } + + protected fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val outputReader = Thread { + try { + process.inputStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process output: ${e.message}") + } + } + outputReader.isDaemon = true + return outputReader + } + + protected fun waitForPort(host: String = "localhost", port: Int, timeoutSeconds: Long = 10): Boolean { + val deadline = System.currentTimeMillis() + timeoutSeconds * 1000 + while (System.currentTimeMillis() < deadline) { + try { + Socket(host, port).use { return true } + } catch (_: Exception) { + Thread.sleep(100) + } + } + return false + } + + protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = + runCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) + + protected fun startTypeScriptServer(port: Int): Process { + killProcessOnPort(port) + val processBuilder = ProcessBuilder() + .command("bash", "-c", "MCP_PORT=$port npx tsx src/examples/server/simpleStreamableHttp.ts") + .directory(sdkDir) + .redirectErrorStream(true) + val process = processBuilder.start() + if (!waitForPort(port = port)) { + throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") + } + createProcessOutputReader(process).start() + return process + } + + protected fun stopProcess(process: Process, waitSeconds: Long = 3, name: String = "TypeScript server") { + process.destroy() + if (waitForProcessTermination(process, waitSeconds)) { + println("$name stopped gracefully") + } else { + println("$name did not stop gracefully, forced termination") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt new file mode 100644 index 00000000..ff625608 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -0,0 +1,424 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.cio.CIO +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.request.header +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.routing.delete +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.jsonPrimitive +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +private val logger = KotlinLogging.logger {} + +class KotlinServerForTypeScriptClient { + private val serverTransports = ConcurrentHashMap() + private val jsonFormat = Json { ignoreUnknownKeys = true } + private var server: EmbeddedServer<*, *>? = null + + fun start(port: Int = 3000) { + logger.info { "Starting HTTP server on port $port" } + + server = embeddedServer(CIO, port = port) { + routing { + post("/mcp") { + val sessionId = call.request.header("mcp-session-id") + val requestBody = call.receiveText() + + logger.debug { "Received request with sessionId: $sessionId" } + logger.trace { "Request body: $requestBody" } + + val jsonElement = try { + jsonFormat.parseToJsonElement(requestBody) + } catch (e: Exception) { + logger.error(e) { "Failed to parse request body as JSON" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32700), + "message" to JsonPrimitive("Parse error: ${e.message}"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + return@post + } + + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.debug { "Using existing transport for session: $sessionId" } + val transport = serverTransports[sessionId]!! + transport.handleRequest(call, jsonElement) + } else { + if (isInitializeRequest(jsonElement)) { + val newSessionId = UUID.randomUUID().toString() + logger.info { "Creating new session with ID: $newSessionId" } + + val transport = HttpServerTransport(newSessionId) + + serverTransports[newSessionId] = transport + + val mcpServer = createMcpServer() + + call.response.header("mcp-session-id", newSessionId) + + val serverThread = Thread { + runBlocking { + mcpServer.connect(transport) + } + } + serverThread.start() + + Thread.sleep(500) + + transport.handleRequest(call, jsonElement) + } else { + logger.warn { "Invalid request: no session ID or not an initialization request" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32000), + "message" to + JsonPrimitive("Bad Request: No valid session ID provided"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + } + } + } + + delete("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.info { "Terminating session: $sessionId" } + val transport = serverTransports[sessionId]!! + serverTransports.remove(sessionId) + runBlocking { + transport.close() + } + call.respond(HttpStatusCode.OK) + } else { + logger.warn { "Invalid session termination request: $sessionId" } + call.respond(HttpStatusCode.BadRequest, "Invalid or missing session ID") + } + } + } + } + + server?.start(wait = false) + } + + fun stop() { + logger.info { "Stopping HTTP server" } + server?.stop(500, 1000) + server = null + } + + private fun createMcpServer(): Server { + val server = Server( + Implementation( + name = "kotlin-http-server", + version = "1.0.0", + ), + ServerOptions( + capabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + tools = ServerCapabilities.Tools(listChanged = true), + ), + ), + ) + + server.addTool( + name = "greet", + description = "A simple greeting tool", + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World" + CallToolResult( + content = listOf(TextContent("Hello, $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Hello, $name!")) + }, + ) + } + + server.addTool( + name = "multi-greet", + description = "A greeting tool that sends multiple notifications", + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World" + + CallToolResult( + content = listOf(TextContent("Multiple greetings sent to $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Multiple greetings sent to $name!")) + put("notificationCount", JsonPrimitive(3)) + }, + ) + } + + server.addPrompt( + name = "greeting-template", + description = "A simple greeting prompt template", + arguments = listOf( + PromptArgument( + name = "name", + description = "Name to include in greeting", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Greeting for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent( + "Please greet ${request.arguments?.get("name") ?: "someone"} in a friendly manner.", + ), + ), + ), + ) + } + + server.addResource( + uri = "https://example.com/greetings/default", + name = "Default Greeting", + description = "A simple greeting resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents("Hello, world!", request.uri, "text/plain"), + ), + ) + } + + return server + } + + private fun isInitializeRequest(json: JsonElement): Boolean { + if (json !is JsonObject) return false + + val method = json["method"]?.jsonPrimitive?.contentOrNull + return method == "initialize" + } +} + +class HttpServerTransport(private val sessionId: String) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val pendingResponses = ConcurrentHashMap>() + private val messageQueue = Channel(Channel.UNLIMITED) + + suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { + try { + logger.info { "Handling request body: $requestBody" } + val message = McpJson.decodeFromJsonElement(requestBody) + logger.info { "Decoded message: $message" } + + if (message is JSONRPCRequest) { + val id = message.id.toString() + logger.info { "Received request with ID: $id, method: ${message.method}" } + val responseDeferred = CompletableDeferred() + pendingResponses[id] = responseDeferred + logger.info { "Created deferred response for ID: $id" } + + logger.info { "Invoking onMessage handler" } + _onMessage.invoke(message) + logger.info { "onMessage handler completed" } + + try { + val response = withTimeoutOrNull(10000) { + responseDeferred.await() + } + + if (response != null) { + val jsonResponse = McpJson.encodeToString(response) + call.respondText(jsonResponse, ContentType.Application.Json) + } else { + logger.warn { "Timeout waiting for response to request ID: $id" } + call.respondText( + McpJson.encodeToString( + JSONRPCResponse( + id = message.id, + error = JSONRPCError( + code = ErrorCode.Defined.RequestTimeout, + message = "Request timed out", + ), + ), + ), + ContentType.Application.Json, + ) + } + } catch (_: CancellationException) { + logger.warn { "Request cancelled for ID: $id" } + pendingResponses.remove(id) + if (!call.response.isCommitted) { + call.respondText( + McpJson.encodeToString( + JSONRPCResponse( + id = message.id, + error = JSONRPCError( + code = ErrorCode.Defined.ConnectionClosed, + message = "Request cancelled", + ), + ), + ), + ContentType.Application.Json, + HttpStatusCode.ServiceUnavailable, + ) + } + } + } else { + call.respondText("", ContentType.Application.Json, HttpStatusCode.Accepted) + } + } catch (e: Exception) { + logger.error(e) { "Error handling request" } + if (!call.response.isCommitted) { + try { + val errorResponse = JSONRPCResponse( + id = RequestId.NumberId(0), + error = JSONRPCError( + code = ErrorCode.Defined.InternalError, + message = "Internal server error: ${e.message}", + ), + ) + + call.respondText( + McpJson.encodeToString(errorResponse), + ContentType.Application.Json, + HttpStatusCode.InternalServerError, + ) + } catch (responseEx: Exception) { + logger.error(responseEx) { "Failed to send error response" } + } + } + } + } + + override suspend fun start() { + logger.debug { "Starting HTTP server transport for session: $sessionId" } + } + + override suspend fun send(message: JSONRPCMessage) { + logger.info { "Sending message: $message" } + + if (message is JSONRPCResponse) { + val id = message.id.toString() + logger.info { "Sending response for request ID: $id" } + val deferred = pendingResponses.remove(id) + if (deferred != null) { + logger.info { "Found pending response for ID: $id, completing deferred" } + deferred.complete(message) + return + } else { + logger.warn { "No pending response found for ID: $id" } + } + } else if (message is JSONRPCRequest) { + logger.info { "Sending request with ID: ${message.id}" } + } else if (message is JSONRPCNotification) { + logger.info { "Sending notification: ${message.method}" } + } + + logger.info { "Queueing message for next client request" } + messageQueue.send(message) + } + + override suspend fun close() { + logger.debug { "Closing HTTP server transport for session: $sessionId" } + messageQueue.close() + _onClose.invoke() + } +} + +fun main() { + val server = KotlinServerForTypeScriptClient() + server.start() +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt new file mode 100644 index 00000000..32f20534 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.InvocationInterceptor +import org.junit.jupiter.api.extension.InvocationInterceptor.Invocation +import org.junit.jupiter.api.extension.ReflectiveInvocationContext +import org.opentest4j.TestAbortedException +import java.lang.reflect.AnnotatedElement +import java.lang.reflect.Method +import java.util.Optional + +@Target(AnnotationTarget.CLASS) +@Retention(AnnotationRetention.RUNTIME) +@ExtendWith(RetryExtension::class) +annotation class Retry(val times: Int = 3, val delayMs: Long = 1000) + +class RetryExtension : InvocationInterceptor { + override fun interceptTestMethod( + invocation: Invocation, + invocationContext: ReflectiveInvocationContext, + extensionContext: ExtensionContext, + ) { + executeWithRetry(invocation, extensionContext) + } + + private fun resolveRetryAnnotation(extensionContext: ExtensionContext): Retry? { + val classAnn = extensionContext.testClass.flatMap { findRetry(it) } + return classAnn.orElse(null) + } + + private fun findRetry(element: AnnotatedElement): Optional = + Optional.ofNullable(element.getAnnotation(Retry::class.java)) + + private fun executeWithRetry(invocation: Invocation, extensionContext: ExtensionContext) { + val retry = resolveRetryAnnotation(extensionContext) + if (retry == null || retry.times <= 1) { + invocation.proceed() + return + } + + val maxAttempts = retry.times + val delay = retry.delayMs + var lastError: Throwable? = null + + for (attempt in 1..maxAttempts) { + if (attempt > 1 && delay > 0) { + try { + Thread.sleep(delay) + } catch (_: InterruptedException) { + Thread.currentThread().interrupt() + break + } + } + + try { + if (attempt == 1) { + invocation.proceed() + } else { + val instance = extensionContext.requiredTestInstance + val testMethod = extensionContext.requiredTestMethod + testMethod.isAccessible = true + testMethod.invoke(instance) + } + return + } catch (t: Throwable) { + if (t is TestAbortedException) throw t + lastError = if (t is java.lang.reflect.InvocationTargetException) t.targetException ?: t else t + if (attempt == maxAttempts) { + println( + "[Retry] Giving up after $attempt attempts for ${ + describeTest( + extensionContext, + ) + }: ${lastError.message}", + ) + throw lastError + } + println( + "[Retry] Failure on attempt $attempt/$maxAttempts for ${ + describeTest( + extensionContext, + ) + }: ${lastError.message}", + ) + } + } + + throw lastError ?: IllegalStateException("Unexpected state in retry logic") + } + + private fun describeTest(ctx: ExtensionContext): String { + val methodName = ctx.testMethod.map(Method::getName).orElse("") + val className = ctx.testClass.map { it.name }.orElse("") + return "$className#$methodName" + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt new file mode 100644 index 00000000..bed66cd4 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt @@ -0,0 +1,91 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +object TestUtils { + fun runTest(block: suspend () -> T): T = runBlocking { + withContext(Dispatchers.IO) { + block() + } + } + + fun assertTextContent(content: PromptMessageContent?, expectedText: String) { + assertNotNull(content, "Content should not be null") + assertTrue(content is TextContent, "Content should be TextContent") + assertNotNull(content.text, "Text content should not be null") + assertEquals(expectedText, content.text, "Text content should match") + } + + fun assertCallToolResult(result: Any?, message: String = ""): CallToolResultBase { + assertNotNull(result, "${message}Call tool result should not be null") + assertTrue(result is CallToolResultBase, "${message}Result should be CallToolResultBase") + assertTrue(result.content.isNotEmpty(), "${message}Tool result content should not be empty") + assertNotNull(result.structuredContent, "${message}Tool result structured content should not be null") + + return result + } + + /** + * Asserts that a JSON property has the expected string value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: String, + message: String = "", + ) { + assertEquals(expectedValue, json[property]?.toString()?.trim('"'), "${message}$property should match") + } + + /** + * Asserts that a JSON property has the expected numeric value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: Number, + message: String = "", + ) { + when (expectedValue) { + is Int -> assertEquals( + expectedValue, + (json[property] as? JsonPrimitive)?.content?.toIntOrNull(), + "${message}$property should match", + ) + + is Double -> assertEquals( + expectedValue, + (json[property] as? JsonPrimitive)?.content?.toDoubleOrNull(), + "${message}$property should match", + ) + + else -> assertEquals( + expectedValue.toString(), + json[property]?.toString()?.trim('"'), + "${message}$property should match", + ) + } + } + + /** + * Asserts that a JSON property has the expected boolean value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: Boolean, + message: String = "", + ) { + assertEquals(expectedValue.toString(), json[property].toString(), "${message}$property should match") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts new file mode 100644 index 00000000..3b5ea75c --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts @@ -0,0 +1,108 @@ +// @ts-ignore +const args = process.argv.slice(2); +const serverUrl = args[0] || 'http://localhost:3001/mcp'; +const toolName = args[1]; +const toolArgs = args.slice(2); +const PROTOCOL_VERSION = "2024-11-05"; + +// @ts-ignore +async function main() { + // @ts-ignore + const sdkDir = process.env.TYPESCRIPT_SDK_DIR; + let Client: any; + let StreamableHTTPClientTransport: any; + if (sdkDir) { + // @ts-ignore + ({Client} = await import(`${sdkDir}/src/client`)); + // @ts-ignore + ({StreamableHTTPClientTransport} = await import(`${sdkDir}/src/client/streamableHttp.js`)); + } else { + // @ts-ignore + ({Client} = await import("../../../../../../../resources/typescript-sdk/src/client")); + // @ts-ignore + ({StreamableHTTPClientTransport} = await import("../../../../../../../resources/typescript-sdk/src/client/streamableHttp.js")); + } + if (!toolName) { + console.log('Usage: npx tsx myClient.ts [server-url] [tool-args...]'); + console.log('Using default server URL:', serverUrl); + console.log('Available utils will be listed after connection'); + } + + console.log(`Connecting to server at ${serverUrl}`); + if (toolName) { + console.log(`Will call tool: ${toolName} with args: ${toolArgs.join(', ')}`); + } + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(new URL(serverUrl)); + + try { + await client.connect(transport, {protocolVersion: PROTOCOL_VERSION}); + console.log('Connected to server'); + + const toolsResult = await client.listTools(); + const tools = toolsResult.tools; + console.log('Available utils:', tools.map((t: { name: any; }) => t.name).join(', ')); + + if (!toolName) { + await client.close(); + return; + } + + const tool = tools.find((t: { name: string; }) => t.name === toolName); + if (!tool) { + console.error(`Tool "${toolName}" not found`); + // @ts-ignore + process.exit(1); + } + + const toolArguments = {}; + + if (toolName === "greet" && toolArgs.length > 0) { + toolArguments["name"] = toolArgs[0]; + } else if (tool.input && tool.input.properties) { + const propNames = Object.keys(tool.input.properties); + if (propNames.length > 0 && toolArgs.length > 0) { + toolArguments[propNames[0]] = toolArgs[0]; + } + } + + console.log(`Calling tool ${toolName} with arguments:`, toolArguments); + + const result = await client.callTool({ + name: toolName, + arguments: toolArguments + }); + console.log('Tool result:', result); + + if (result.content) { + for (const content of result.content) { + if (content.type === 'text') { + console.log('Text content:', content.text); + } + } + } + + if (result.structuredContent) { + console.log('Structured content:', JSON.stringify(result.structuredContent, null, 2)); + } + + } catch (error) { + console.error('Error:', error); + // @ts-ignore + process.exit(1); + } finally { + await client.close(); + console.log('Disconnected from server'); + } +} + +main().catch(error => { + console.error('Unhandled error:', error); + // @ts-ignore + process.exit(1); +}); \ No newline at end of file From 5c29005d230260fc2e4c1676632cf18b45f5dbba Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Tue, 5 Aug 2025 21:51:01 +0300 Subject: [PATCH 02/22] Add notifications test Signed-off-by: Sergey Karpov --- .../TypeScriptClientKotlinServerTest.kt | 18 +++++ .../utils/KotlinServerForTypeScriptClient.kt | 66 +++++++++++++++++++ .../kotlin/sdk/integration/utils/myClient.ts | 16 ++++- 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt index e88b0a4c..01cbdabc 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -57,6 +57,24 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { ) } + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNotifications() { + val name = "NotifUser" + val command = "npx tsx myClient.ts $serverUrl multi-greet $name" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Multiple greetings") || output.contains("greeting"), + "Tool response should contain greeting message", + ) + // verify that the server sent 3 notifications + assertTrue( + output.contains("\"notificationCount\": 3") || output.contains("notificationCount: 3"), + "Structured content should indicate that 3 notifications were emitted by the server.\nOutput:\n$output", + ) + } + @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testToolCallWithSessionManagement() { diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt index ff625608..c3574e72 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -12,7 +12,9 @@ import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.response.respondText +import io.ktor.server.response.respondTextWriter import io.ktor.server.routing.delete +import io.ktor.server.routing.get import io.ktor.server.routing.post import io.ktor.server.routing.routing import io.modelcontextprotocol.kotlin.sdk.CallToolResult @@ -66,6 +68,20 @@ class KotlinServerForTypeScriptClient { server = embeddedServer(CIO, port = port) { routing { + get("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId == null) { + call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header") + return@get + } + val transport = serverTransports[sessionId] + if (transport == null) { + call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id") + return@get + } + transport.stream(call) + } + post("/mcp") { val sessionId = call.request.header("mcp-session-id") val requestBody = call.receiveText() @@ -235,6 +251,32 @@ class KotlinServerForTypeScriptClient { ) { request -> val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World" + server.sendToolListChanged() + server.sendLoggingMessage( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( + level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + data = JsonPrimitive("Preparing greeting for $name") + ) + ) + ) + server.sendLoggingMessage( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( + level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + data = JsonPrimitive("Halfway there for $name") + ) + ) + ) + server.sendLoggingMessage( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( + io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( + level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + data = JsonPrimitive("Done sending greetings to $name") + ) + ) + ) + CallToolResult( content = listOf(TextContent("Multiple greetings sent to $name!")), structuredContent = buildJsonObject { @@ -297,6 +339,30 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() { private val pendingResponses = ConcurrentHashMap>() private val messageQueue = Channel(Channel.UNLIMITED) + suspend fun stream(call: ApplicationCall) { + logger.debug { "Starting SSE stream for session: $sessionId" } + call.response.header("Cache-Control", "no-cache") + call.response.header("Connection", "keep-alive") + call.respondTextWriter(ContentType.Text.EventStream) { + try { + while (true) { + val result = messageQueue.receiveCatching() + val msg = result.getOrNull() ?: break + val json = McpJson.encodeToString(msg) + write("event: message\n") + write("data: ") + write(json) + write("\n\n") + flush() + } + } catch (e: Exception) { + logger.warn(e) { "SSE stream terminated for session: $sessionId" } + } finally { + logger.debug { "SSE stream closed for session: $sessionId" } + } + } + } + suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { try { logger.info { "Handling request body: $requestBody" } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts index 3b5ea75c..39a0e324 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts @@ -44,6 +44,20 @@ async function main() { await client.connect(transport, {protocolVersion: PROTOCOL_VERSION}); console.log('Connected to server'); + try { + if (typeof (client as any).on === 'function') { + (client as any).on('notification', (n: any) => { + try { + const method = (n && (n.method || (n.params && n.params.method))) || 'unknown'; + console.log('Notification:', method, JSON.stringify(n)); + } catch { + console.log('Notification: '); + } + }); + } + } catch { + } + const toolsResult = await client.listTools(); const tools = toolsResult.tools; console.log('Available utils:', tools.map((t: { name: any; }) => t.name).join(', ')); @@ -105,4 +119,4 @@ main().catch(error => { console.error('Unhandled error:', error); // @ts-ignore process.exit(1); -}); \ No newline at end of file +}); From 6c7220153dd629321132b3a9b24e5b91eed14592 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Tue, 5 Aug 2025 21:51:31 +0300 Subject: [PATCH 03/22] fixup! Introduce Kotlin integration tests Signed-off-by: Sergey Karpov --- .../TypeScriptClientKotlinServerTest.kt | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt index 01cbdabc..9459dcd1 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -75,33 +75,6 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { ) } - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCallWithSessionManagement() { - val testName = "SessionTest" - val command = "npx tsx myClient.ts $serverUrl greet $testName" - val output = executeCommand(command, tsClientDir) - - assertTrue(output.contains("Connected to server"), "Client should connect to server") - assertTrue( - output.contains("Hello, $testName!"), - "Tool response should contain the greeting with the provided name", - ) - assertTrue(output.contains("Tool result:"), "Output should indicate a successful tool call") - assertTrue(output.contains("Disconnected from server"), "Client should disconnect cleanly") - - val multiGreetName = "NotificationTest" - val multiGreetCommand = "npx tsx myClient.ts $serverUrl multi-greet $multiGreetName" - val multiGreetOutput = executeCommand(multiGreetCommand, tsClientDir) - - assertTrue(multiGreetOutput.contains("Connected to server"), "Client should connect to server") - assertTrue( - multiGreetOutput.contains("Multiple greetings") || multiGreetOutput.contains("greeting"), - "Tool response should contain greeting message", - ) - assertTrue(multiGreetOutput.contains("Disconnected from server"), "Client should disconnect cleanly") - } - @Test @Timeout(120, unit = TimeUnit.SECONDS) fun testMultipleClientSequence() { From fb5588e1a07b6b6322473d292e31a5fc73f721e3 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Tue, 5 Aug 2025 22:14:18 +0300 Subject: [PATCH 04/22] fixup! Introduce Kotlin integration tests Signed-off-by: Sergey Karpov --- .../integration/utils/KotlinServerForTypeScriptClient.kt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt index c3574e72..e2669526 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -256,7 +256,7 @@ class KotlinServerForTypeScriptClient { io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, - data = JsonPrimitive("Preparing greeting for $name") + data = JsonPrimitive("Preparing greeting for $name"), ) ) ) @@ -264,7 +264,7 @@ class KotlinServerForTypeScriptClient { io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, - data = JsonPrimitive("Halfway there for $name") + data = JsonPrimitive("Halfway there for $name"), ) ) ) @@ -272,7 +272,7 @@ class KotlinServerForTypeScriptClient { io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, - data = JsonPrimitive("Done sending greetings to $name") + data = JsonPrimitive("Done sending greetings to $name"), ) ) ) From 3da8edae5585a9ef3caa541fb9a41216eca96d91 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Tue, 5 Aug 2025 22:14:40 +0300 Subject: [PATCH 05/22] fixup! Introduce Kotlin integration tests Signed-off-by: Sergey Karpov --- .../utils/KotlinServerForTypeScriptClient.kt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt index e2669526..58cf62b5 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -257,24 +257,24 @@ class KotlinServerForTypeScriptClient { io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, data = JsonPrimitive("Preparing greeting for $name"), - ) - ) + ), + ), ) server.sendLoggingMessage( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, data = JsonPrimitive("Halfway there for $name"), - ) - ) + ), + ), ) server.sendLoggingMessage( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, data = JsonPrimitive("Done sending greetings to $name"), - ) - ) + ), + ), ) CallToolResult( From 4567040cd37122973c46fff32d5fb2278449e87e Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Wed, 6 Aug 2025 12:38:58 +0300 Subject: [PATCH 06/22] cleanup Signed-off-by: Sergey Karpov --- .../utils/KotlinServerForTypeScriptClient.kt | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt index 58cf62b5..535304a8 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -26,6 +26,8 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.PromptArgument import io.modelcontextprotocol.kotlin.sdk.PromptMessage import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult @@ -253,25 +255,25 @@ class KotlinServerForTypeScriptClient { server.sendToolListChanged() server.sendLoggingMessage( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( - level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, data = JsonPrimitive("Preparing greeting for $name"), ), ), ) server.sendLoggingMessage( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( - level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, data = JsonPrimitive("Halfway there for $name"), ), ), ) server.sendLoggingMessage( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification( - io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params( - level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info, + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, data = JsonPrimitive("Done sending greetings to $name"), ), ), From fb1deea56b30ef184bb02785c7649fc67599cee0 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Fri, 8 Aug 2025 01:49:14 +0300 Subject: [PATCH 07/22] Fix tests on Windows --- .../typescript/TypeScriptTestBase.kt | 65 +++++++++++++++---- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index 5210340c..477c9b26 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -31,11 +31,16 @@ abstract class TypeScriptTestBase { @BeforeAll fun setupTypeScriptSdk() { println("Cloning TypeScript SDK repository") + if (!sdkDir.exists()) { - val cloneCommand = - "git clone --depth 1 https://github.com/modelcontextprotocol/typescript-sdk.git ${sdkDir.absolutePath}" - val process = ProcessBuilder() - .command("bash", "-c", cloneCommand) + val process = ProcessBuilder( + "git", + "clone", + "--depth", + "1", + "https://github.com/modelcontextprotocol/typescript-sdk.git", + sdkDir.absolutePath, + ) .redirectErrorStream(true) .start() val exitCode = process.waitFor() @@ -54,7 +59,13 @@ abstract class TypeScriptTestBase { @JvmStatic protected fun killProcessOnPort(port: Int) { - executeCommand("lsof -ti:$port | xargs kill -9 2>/dev/null || true", File(".")) + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val killCommand = if (isWindows) { + "netstat -ano | findstr :$port | for /f \"tokens=5\" %a in ('more') do taskkill /F /PID %a 2>nul || echo No process found" + } else { + "lsof -ti:$port | xargs kill -9 2>/dev/null || true" + } + runCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) } @JvmStatic @@ -70,8 +81,26 @@ abstract class TypeScriptTestBase { allowFailure: Boolean, timeoutSeconds: Long?, ): String { - val process = ProcessBuilder() - .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + if (!workingDir.exists()) { + if (!workingDir.mkdirs()) { + throw RuntimeException("Failed to create working directory: ${workingDir.absolutePath}") + } + } + + if (!workingDir.isDirectory || !workingDir.canRead()) { + throw RuntimeException("Working directory is not accessible: ${workingDir.absolutePath}") + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && $command") + } else { + ProcessBuilder() + .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + } + + val process = processBuilder .directory(workingDir) .redirectErrorStream(true) .start() @@ -89,7 +118,7 @@ abstract class TypeScriptTestBase { val exitCode = process.waitFor() if (!allowFailure && exitCode != 0) { throw RuntimeException( - "Command execution failed with exit code $exitCode: $command\nOutput:\n$output", + "Command execution failed with exit code $exitCode: $command\nWorking dir: ${workingDir.absolutePath}\nOutput:\n$output", ) } } else { @@ -142,11 +171,25 @@ abstract class TypeScriptTestBase { protected fun startTypeScriptServer(port: Int): Process { killProcessOnPort(port) - val processBuilder = ProcessBuilder() - .command("bash", "-c", "MCP_PORT=$port npx tsx src/examples/server/simpleStreamableHttp.ts") + + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException("TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}") + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set MCP_PORT=$port && npx tsx src/examples/server/simpleStreamableHttp.ts") + } else { + ProcessBuilder() + .command("bash", "-c", "MCP_PORT=$port npx tsx src/examples/server/simpleStreamableHttp.ts") + } + + val process = processBuilder .directory(sdkDir) .redirectErrorStream(true) - val process = processBuilder.start() + .start() + if (!waitForPort(port = port)) { throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") } From f11372732002b7de14eeaa6edaf2b93859faddb4 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Fri, 8 Aug 2025 02:10:38 +0300 Subject: [PATCH 08/22] fixup! Fix tests on Windows Signed-off-by: Sergey Karpov --- .../kotlin/sdk/integration/typescript/TypeScriptTestBase.kt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index 477c9b26..0cfcef60 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -173,7 +173,9 @@ abstract class TypeScriptTestBase { killProcessOnPort(port) if (!sdkDir.exists() || !sdkDir.isDirectory) { - throw IllegalStateException("TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}") + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) } val isWindows = System.getProperty("os.name").lowercase().contains("windows") From d5955096a0b4f7e5628a80528b61d61211d4e03c Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Fri, 8 Aug 2025 13:54:37 +0300 Subject: [PATCH 09/22] fixup! Fix tests on Windows --- .../typescript/TypeScriptEdgeCasesTest.kt | 4 ++++ .../kotlin/sdk/integration/utils/myClient.ts | 13 ++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 77241281..3652a2ef 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -5,6 +5,8 @@ import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.condition.EnabledOnOs +import org.junit.jupiter.api.condition.OS import java.io.File import java.util.concurrent.TimeUnit import kotlin.test.assertEquals @@ -84,6 +86,8 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) + @EnabledOnOs(OS.MAC, OS.LINUX) + // skip on windows as it can't handle long commands fun testLargePayload() { val largeName = "A".repeat(10 * 1024) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts index 39a0e324..42a14f5f 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts @@ -8,14 +8,21 @@ const PROTOCOL_VERSION = "2024-11-05"; // @ts-ignore async function main() { // @ts-ignore - const sdkDir = process.env.TYPESCRIPT_SDK_DIR; + const sdkDirRaw = process.env.TYPESCRIPT_SDK_DIR; + const sdkDir = sdkDirRaw ? sdkDirRaw.trim() : undefined; let Client: any; let StreamableHTTPClientTransport: any; if (sdkDir) { // @ts-ignore - ({Client} = await import(`${sdkDir}/src/client`)); + const path = await import('path'); // @ts-ignore - ({StreamableHTTPClientTransport} = await import(`${sdkDir}/src/client/streamableHttp.js`)); + const { pathToFileURL } = await import('url'); + const clientUrl = pathToFileURL(path.join(sdkDir, 'src', 'client', 'index.ts')).href; + const streamUrl = pathToFileURL(path.join(sdkDir, 'src', 'client', 'streamableHttp.js')).href; + // @ts-ignore + ({ Client } = await import(clientUrl)); + // @ts-ignore + ({ StreamableHTTPClientTransport } = await import(streamUrl)); } else { // @ts-ignore ({Client} = await import("../../../../../../../resources/typescript-sdk/src/client")); From 473c000d51c21c9dc533eac382795d954ccf54f5 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Fri, 8 Aug 2025 14:30:56 +0300 Subject: [PATCH 10/22] fixup! Fix tests on Windows Signed-off-by: Sergey Karpov --- .../sdk/integration/typescript/TypeScriptEdgeCasesTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 3652a2ef..86c9b8fe 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -84,10 +84,10 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { ) } + // skip on windows as it can't handle long commands @Test @Timeout(30, unit = TimeUnit.SECONDS) @EnabledOnOs(OS.MAC, OS.LINUX) - // skip on windows as it can't handle long commands fun testLargePayload() { val largeName = "A".repeat(10 * 1024) From e5a78ceb84b93bbf07a10f5d85f04e0bd730899b Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 15:01:42 +0300 Subject: [PATCH 11/22] fixup! Introduce Kotlin integration tests --- .../KotlinServerForTypeScriptClient.kt | 2 +- .../TypeScriptClientKotlinServerTest.kt | 1 - .../typescript/TypeScriptEdgeCasesTest.kt | 1 - .../typescript/TypeScriptTestBase.kt | 2 +- .../{utils => typescript}/myClient.ts | 0 .../typescript/simpleStreamableHttp.ts | 672 ++++++++++++++++++ 6 files changed, 674 insertions(+), 4 deletions(-) rename kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/{utils => typescript}/KotlinServerForTypeScriptClient.kt (99%) rename kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/{utils => typescript}/myClient.ts (100%) create mode 100644 kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt similarity index 99% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt index 535304a8..5757fcbc 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt @@ -1,4 +1,4 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.utils +package io.modelcontextprotocol.kotlin.sdk.integration.typescript import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.ContentType diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt index 9459dcd1..967e1d06 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -1,6 +1,5 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 86c9b8fe..4baadb8a 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -1,6 +1,5 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index 0cfcef60..bfd76893 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -17,7 +17,7 @@ abstract class TypeScriptTestBase { protected val tsClientDir: File get() = File( projectRoot, - "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils", + "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript", ) companion object { diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/myClient.ts similarity index 100% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/myClient.ts diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts new file mode 100644 index 00000000..3271e621 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts @@ -0,0 +1,672 @@ +import express, { Request, Response } from 'express'; +import { randomUUID } from 'node:crypto'; +import { z } from 'zod'; +import { McpServer } from '../../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; +import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../../server/auth/router.js'; +import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; +import { CallToolResult, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, ReadResourceResult, ResourceLink } from '../../types.js'; +import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; +import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; +import { OAuthMetadata } from 'src/shared/auth.js'; +import { checkResourceAllowed } from 'src/shared/auth-utils.js'; + +import cors from 'cors'; + +// Check for OAuth flag +const useOAuth = process.argv.includes('--oauth'); +const strictOAuth = process.argv.includes('--oauth-strict'); + +// Create an MCP server with implementation details +const getServer = () => { + const server = new McpServer({ + name: 'simple-streamable-http-server', + version: '1.0.0' + }, { capabilities: { logging: {} } }); + + // Register a simple tool that returns a greeting + server.registerTool( + 'greet', + { + title: 'Greeting Tool', // Display name for UI + description: 'A simple greeting tool', + inputSchema: { + name: z.string().describe('Name to greet'), + }, + }, + async ({ name }): Promise => { + return { + content: [ + { + type: 'text', + text: `Hello, ${name}!`, + }, + ], + }; + } + ); + + // Register a tool that sends multiple greetings with notifications (with annotations) + server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet'), + }, + { + title: 'Multiple Greeting Tool', + readOnlyHint: true, + openWorldHint: false + }, + async ({ name }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await server.sendLoggingMessage({ + level: "debug", + data: `Starting multi-greet for ${name}` + }, extra.sessionId); + + await sleep(1000); // Wait 1 second before first greeting + + await server.sendLoggingMessage({ + level: "info", + data: `Sending first greeting to ${name}` + }, extra.sessionId); + + await sleep(1000); // Wait another second before second greeting + + await server.sendLoggingMessage({ + level: "info", + data: `Sending second greeting to ${name}` + }, extra.sessionId); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!`, + } + ], + }; + } + ); + // Register a tool that demonstrates elicitation (user input collection) + // This creates a closure that captures the server instance + server.tool( + 'collect-user-info', + 'A tool that collects user information through elicitation', + { + infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect'), + }, + async ({ infoType }): Promise => { + let message: string; + let requestedSchema: { + type: 'object'; + properties: Record; + required?: string[]; + }; + + switch (infoType) { + case 'contact': + message = 'Please provide your contact information'; + requestedSchema = { + type: 'object', + properties: { + name: { + type: 'string', + title: 'Full Name', + description: 'Your full name', + }, + email: { + type: 'string', + title: 'Email Address', + description: 'Your email address', + format: 'email', + }, + phone: { + type: 'string', + title: 'Phone Number', + description: 'Your phone number (optional)', + }, + }, + required: ['name', 'email'], + }; + break; + case 'preferences': + message = 'Please set your preferences'; + requestedSchema = { + type: 'object', + properties: { + theme: { + type: 'string', + title: 'Theme', + description: 'Choose your preferred theme', + enum: ['light', 'dark', 'auto'], + enumNames: ['Light', 'Dark', 'Auto'], + }, + notifications: { + type: 'boolean', + title: 'Enable Notifications', + description: 'Would you like to receive notifications?', + default: true, + }, + frequency: { + type: 'string', + title: 'Notification Frequency', + description: 'How often would you like notifications?', + enum: ['daily', 'weekly', 'monthly'], + enumNames: ['Daily', 'Weekly', 'Monthly'], + }, + }, + required: ['theme'], + }; + break; + case 'feedback': + message = 'Please provide your feedback'; + requestedSchema = { + type: 'object', + properties: { + rating: { + type: 'integer', + title: 'Rating', + description: 'Rate your experience (1-5)', + minimum: 1, + maximum: 5, + }, + comments: { + type: 'string', + title: 'Comments', + description: 'Additional comments (optional)', + maxLength: 500, + }, + recommend: { + type: 'boolean', + title: 'Would you recommend this?', + description: 'Would you recommend this to others?', + }, + }, + required: ['rating', 'recommend'], + }; + break; + default: + throw new Error(`Unknown info type: ${infoType}`); + } + + try { + // Use the underlying server instance to elicit input from the client + const result = await server.server.elicitInput({ + message, + requestedSchema, + }); + + if (result.action === 'accept') { + return { + content: [ + { + type: 'text', + text: `Thank you! Collected ${infoType} information: ${JSON.stringify(result.content, null, 2)}`, + }, + ], + }; + } else if (result.action === 'decline') { + return { + content: [ + { + type: 'text', + text: `No information was collected. User declined ${infoType} information request.`, + }, + ], + }; + } else { + return { + content: [ + { + type: 'text', + text: `Information collection was cancelled by the user.`, + }, + ], + }; + } + } catch (error) { + return { + content: [ + { + type: 'text', + text: `Error collecting ${infoType} information: ${error}`, + }, + ], + }; + } + } + ); + + // Register a simple prompt with title + server.registerPrompt( + 'greeting-template', + { + title: 'Greeting Template', // Display name for UI + description: 'A simple greeting prompt template', + argsSchema: { + name: z.string().describe('Name to include in greeting'), + }, + }, + async ({ name }): Promise => { + return { + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please greet ${name} in a friendly manner.`, + }, + }, + ], + }; + } + ); + + // Register a tool specifically for testing resumability + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(50), + }, + async ({ interval, count }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + while (count === 0 || counter < count) { + counter++; + try { + await server.sendLoggingMessage( { + level: "info", + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + }, extra.sessionId); + } + catch (error) { + console.error("Error sending notification:", error); + } + // Wait for the specified interval + await sleep(interval); + } + + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms`, + } + ], + }; + } + ); + + // Create a simple resource at a fixed URI + server.registerResource( + 'greeting-resource', + 'https://example.com/greetings/default', + { + title: 'Default Greeting', // Display name for UI + description: 'A simple greeting resource', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + text: 'Hello, world!', + }, + ], + }; + } + ); + + // Create additional resources for ResourceLink demonstration + server.registerResource( + 'example-file-1', + 'file:///example/file1.txt', + { + title: 'Example File 1', + description: 'First example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file1.txt', + text: 'This is the content of file 1', + }, + ], + }; + } + ); + + server.registerResource( + 'example-file-2', + 'file:///example/file2.txt', + { + title: 'Example File 2', + description: 'Second example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file2.txt', + text: 'This is the content of file 2', + }, + ], + }; + } + ); + + // Register a tool that returns ResourceLinks + server.registerTool( + 'list-files', + { + title: 'List Files with ResourceLinks', + description: 'Returns a list of files as ResourceLinks without embedding their content', + inputSchema: { + includeDescriptions: z.boolean().optional().describe('Whether to include descriptions in the resource links'), + }, + }, + async ({ includeDescriptions = true }): Promise => { + const resourceLinks: ResourceLink[] = [ + { + type: 'resource_link', + uri: 'https://example.com/greetings/default', + name: 'Default Greeting', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'A simple greeting resource' }) + }, + { + type: 'resource_link', + uri: 'file:///example/file1.txt', + name: 'Example File 1', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'First example file for ResourceLink demonstration' }) + }, + { + type: 'resource_link', + uri: 'file:///example/file2.txt', + name: 'Example File 2', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'Second example file for ResourceLink demonstration' }) + } + ]; + + return { + content: [ + { + type: 'text', + text: 'Here are the available files as resource links:', + }, + ...resourceLinks, + { + type: 'text', + text: '\nYou can read any of these resources using their URI.', + } + ], + }; + } + ); + + return server; +}; + +const MCP_PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; +const AUTH_PORT = process.env.MCP_AUTH_PORT ? parseInt(process.env.MCP_AUTH_PORT, 10) : 3001; + +const app = express(); +app.use(express.json()); + +// Allow CORS all domains, expose the Mcp-Session-Id header +app.use(cors({ + origin: '*', // Allow all origins + exposedHeaders: ["Mcp-Session-Id"] +})); + +// Set up OAuth if enabled +let authMiddleware = null; +if (useOAuth) { + // Create auth middleware for MCP endpoints + const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); + const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); + + const oauthMetadata: OAuthMetadata = setupAuthServer({ authServerUrl, mcpServerUrl, strictResource: strictOAuth }); + + const tokenVerifier = { + verifyAccessToken: async (token: string) => { + const endpoint = oauthMetadata.introspection_endpoint; + + if (!endpoint) { + throw new Error('No token verification endpoint available in metadata'); + } + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + token: token + }).toString() + }); + + + if (!response.ok) { + throw new Error(`Invalid or expired token: ${await response.text()}`); + } + + const data = await response.json(); + + if (strictOAuth) { + if (!data.aud) { + throw new Error(`Resource Indicator (RFC8707) missing`); + } + if (!checkResourceAllowed({ requestedResource: data.aud, configuredResource: mcpServerUrl })) { + throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + } + } + + // Convert the response to AuthInfo format + return { + token, + clientId: data.client_id, + scopes: data.scope ? data.scope.split(' ') : [], + expiresAt: data.exp, + }; + } + } + // Add metadata routes to the main MCP server + app.use(mcpAuthMetadataRouter({ + oauthMetadata, + resourceServerUrl: mcpServerUrl, + scopesSupported: ['mcp:tools'], + resourceName: 'MCP Demo Server', + })); + + authMiddleware = requireBearerAuth({ + verifier: tokenVerifier, + requiredScopes: [], + resourceMetadataUrl: getOAuthProtectedResourceMetadataUrl(mcpServerUrl), + }); +} + +// Map to store transports by session ID +const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + +// MCP POST endpoint with optional auth +const mcpPostHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) { + console.log(`Received MCP request for session: ${sessionId}`); + } else { + console.log('Request body:', req.body); + } + + if (useOAuth && req.auth) { + console.log('Authenticated user:', req.auth); + } + try { + let transport: StreamableHTTPServerTransport; + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + const eventStore = new InMemoryEventStore(); + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, // Enable resumability + onsessioninitialized: (sessionId) => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); + + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + const sid = transport.sessionId; + if (sid && transports[sid]) { + console.log(`Transport closed for session ${sid}, removing from transports map`); + delete transports[sid]; + } + }; + + // Connect the transport to the MCP server BEFORE handling the request + // so responses can flow back through the same transport + const server = getServer(); + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided', + }, + id: null, + }); + return; + } + + // Handle the request with existing transport - no need to reconnect + // The existing transport is already connected to the server + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error', + }, + id: null, + }); + } + } +}; + +// Set up routes with conditional auth middleware +if (useOAuth && authMiddleware) { + app.post('/mcp', authMiddleware, mcpPostHandler); +} else { + app.post('/mcp', mcpPostHandler); +} + +// Handle GET requests for SSE streams (using built-in support from StreamableHTTP) +const mcpGetHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + if (useOAuth && req.auth) { + console.log('Authenticated SSE connection from user:', req.auth); + } + + // Check for Last-Event-ID header for resumability + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`Establishing new SSE stream for session ${sessionId}`); + } + + const transport = transports[sessionId]; + await transport.handleRequest(req, res); +}; + +// Set up GET route with conditional auth middleware +if (useOAuth && authMiddleware) { + app.get('/mcp', authMiddleware, mcpGetHandler); +} else { + app.get('/mcp', mcpGetHandler); +} + +// Handle DELETE requests for session termination (according to MCP spec) +const mcpDeleteHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + console.log(`Received session termination request for session ${sessionId}`); + + try { + const transport = transports[sessionId]; + await transport.handleRequest(req, res); + } catch (error) { + console.error('Error handling session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } + } +}; + +// Set up DELETE route with conditional auth middleware +if (useOAuth && authMiddleware) { + app.delete('/mcp', authMiddleware, mcpDeleteHandler); +} else { + app.delete('/mcp', mcpDeleteHandler); +} + +app.listen(MCP_PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); +}); + +// Handle server shutdown +process.on('SIGINT', async () => { + console.log('Shutting down server...'); + + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } + } + console.log('Server shutdown complete'); + process.exit(0); +}); From 2eacbe388ce83ec827566ba4181275ed3abcf691 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 15:19:37 +0300 Subject: [PATCH 12/22] fixup! Introduce Kotlin integration tests --- .../typescript/TypeScriptTestBase.kt | 22 +- .../typescript/simpleStreamableHttp.ts | 1298 +++++++++-------- 2 files changed, 681 insertions(+), 639 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index bfd76893..a4bc6243 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -179,23 +179,35 @@ abstract class TypeScriptTestBase { } val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val localServerPath = File(tsClientDir, "simpleStreamableHttp.ts").absolutePath val processBuilder = if (isWindows) { ProcessBuilder() - .command("cmd.exe", "/c", "set MCP_PORT=$port && npx tsx src/examples/server/simpleStreamableHttp.ts") + .command( + "cmd.exe", + "/c", + "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"" + ) } else { ProcessBuilder() - .command("bash", "-c", "MCP_PORT=$port npx tsx src/examples/server/simpleStreamableHttp.ts") + .command( + "bash", + "-c", + "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"" + ) } + processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath + val process = processBuilder - .directory(sdkDir) + .directory(tsClientDir) .redirectErrorStream(true) .start() - if (!waitForPort(port = port)) { + createProcessOutputReader(process).start() + + if (!waitForPort(port = port, timeoutSeconds = 20)) { throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") } - createProcessOutputReader(process).start() return process } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts index 3271e621..67088cff 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts @@ -1,672 +1,702 @@ -import express, { Request, Response } from 'express'; -import { randomUUID } from 'node:crypto'; -import { z } from 'zod'; -import { McpServer } from '../../server/mcp.js'; -import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; -import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../../server/auth/router.js'; -import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; -import { CallToolResult, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, ReadResourceResult, ResourceLink } from '../../types.js'; -import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; -import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; -import { OAuthMetadata } from 'src/shared/auth.js'; -import { checkResourceAllowed } from 'src/shared/auth-utils.js'; - +// @ts-nocheck +import express, {Request, Response} from 'express'; +import {randomUUID} from 'node:crypto'; +import {z} from 'zod'; import cors from 'cors'; +import path from 'node:path'; +import {pathToFileURL} from 'node:url'; + +const SDK_DIR = process.env.TYPESCRIPT_SDK_DIR; +if (!SDK_DIR) { + throw new Error('TYPESCRIPT_SDK_DIR environment variable is not set. It should point to the cloned typescript-sdk directory.'); +} -// Check for OAuth flag -const useOAuth = process.argv.includes('--oauth'); -const strictOAuth = process.argv.includes('--oauth-strict'); +async function importFromSdk(rel: string): Promise { + const full = path.resolve(SDK_DIR!, rel); + const url = pathToFileURL(full).href; + return (await import(url)) as unknown as T; +} + +async function main() { + const {McpServer} = await importFromSdk('src/server/mcp.ts'); + const {StreamableHTTPServerTransport} = await importFromSdk('src/server/streamableHttp.ts'); + const { + getOAuthProtectedResourceMetadataUrl, + mcpAuthMetadataRouter + } = await importFromSdk('src/server/auth/router.ts'); + const {requireBearerAuth} = await importFromSdk('src/server/auth/middleware/bearerAuth.ts'); + const { + isInitializeRequest, + } = await importFromSdk('src/types.ts'); + const {InMemoryEventStore} = await importFromSdk('src/examples/shared/inMemoryEventStore.ts'); + const {setupAuthServer} = await importFromSdk('src/examples/server/demoInMemoryOAuthProvider.ts'); + const {OAuthMetadata} = await importFromSdk('src/shared/auth.ts'); + const {checkResourceAllowed} = await importFromSdk('src/shared/auth-utils.ts'); + + // Check for OAuth flag + const useOAuth = process.argv.includes('--oauth'); + const strictOAuth = process.argv.includes('--oauth-strict'); // Create an MCP server with implementation details -const getServer = () => { - const server = new McpServer({ - name: 'simple-streamable-http-server', - version: '1.0.0' - }, { capabilities: { logging: {} } }); - - // Register a simple tool that returns a greeting - server.registerTool( - 'greet', - { - title: 'Greeting Tool', // Display name for UI - description: 'A simple greeting tool', - inputSchema: { - name: z.string().describe('Name to greet'), - }, - }, - async ({ name }): Promise => { - return { - content: [ - { - type: 'text', - text: `Hello, ${name}!`, - }, - ], - }; - } - ); - - // Register a tool that sends multiple greetings with notifications (with annotations) - server.tool( - 'multi-greet', - 'A tool that sends different greetings with delays between them', - { - name: z.string().describe('Name to greet'), - }, - { - title: 'Multiple Greeting Tool', - readOnlyHint: true, - openWorldHint: false - }, - async ({ name }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - - await server.sendLoggingMessage({ - level: "debug", - data: `Starting multi-greet for ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait 1 second before first greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending first greeting to ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait another second before second greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending second greeting to ${name}` - }, extra.sessionId); - - return { - content: [ - { - type: 'text', - text: `Good morning, ${name}!`, - } - ], - }; - } - ); - // Register a tool that demonstrates elicitation (user input collection) - // This creates a closure that captures the server instance - server.tool( - 'collect-user-info', - 'A tool that collects user information through elicitation', - { - infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect'), - }, - async ({ infoType }): Promise => { - let message: string; - let requestedSchema: { - type: 'object'; - properties: Record; - required?: string[]; - }; - - switch (infoType) { - case 'contact': - message = 'Please provide your contact information'; - requestedSchema = { - type: 'object', - properties: { - name: { - type: 'string', - title: 'Full Name', - description: 'Your full name', - }, - email: { - type: 'string', - title: 'Email Address', - description: 'Your email address', - format: 'email', - }, - phone: { - type: 'string', - title: 'Phone Number', - description: 'Your phone number (optional)', - }, + const getServer = () => { + const server = new McpServer({ + name: 'simple-streamable-http-server', + version: '1.0.0' + }, {capabilities: {logging: {}}}); + + // Register a simple tool that returns a greeting + server.registerTool( + 'greet', + { + title: 'Greeting Tool', // Display name for UI + description: 'A simple greeting tool', + inputSchema: { + name: z.string().describe('Name to greet'), + }, }, - required: ['name', 'email'], - }; - break; - case 'preferences': - message = 'Please set your preferences'; - requestedSchema = { - type: 'object', - properties: { - theme: { - type: 'string', - title: 'Theme', - description: 'Choose your preferred theme', - enum: ['light', 'dark', 'auto'], - enumNames: ['Light', 'Dark', 'Auto'], - }, - notifications: { - type: 'boolean', - title: 'Enable Notifications', - description: 'Would you like to receive notifications?', - default: true, - }, - frequency: { - type: 'string', - title: 'Notification Frequency', - description: 'How often would you like notifications?', - enum: ['daily', 'weekly', 'monthly'], - enumNames: ['Daily', 'Weekly', 'Monthly'], - }, + async ({name}): Promise => { + return { + content: [ + { + type: 'text', + text: `Hello, ${name}!`, + }, + ], + }; + } + ); + + // Register a tool that sends multiple greetings with notifications (with annotations) + server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet'), }, - required: ['theme'], - }; - break; - case 'feedback': - message = 'Please provide your feedback'; - requestedSchema = { - type: 'object', - properties: { - rating: { - type: 'integer', - title: 'Rating', - description: 'Rate your experience (1-5)', - minimum: 1, - maximum: 5, - }, - comments: { - type: 'string', - title: 'Comments', - description: 'Additional comments (optional)', - maxLength: 500, - }, - recommend: { - type: 'boolean', - title: 'Would you recommend this?', - description: 'Would you recommend this to others?', - }, + { + title: 'Multiple Greeting Tool', + readOnlyHint: true, + openWorldHint: false }, - required: ['rating', 'recommend'], - }; - break; - default: - throw new Error(`Unknown info type: ${infoType}`); - } - - try { - // Use the underlying server instance to elicit input from the client - const result = await server.server.elicitInput({ - message, - requestedSchema, - }); - - if (result.action === 'accept') { - return { - content: [ - { - type: 'text', - text: `Thank you! Collected ${infoType} information: ${JSON.stringify(result.content, null, 2)}`, - }, - ], - }; - } else if (result.action === 'decline') { - return { - content: [ - { - type: 'text', - text: `No information was collected. User declined ${infoType} information request.`, - }, - ], - }; - } else { - return { - content: [ - { - type: 'text', - text: `Information collection was cancelled by the user.`, - }, - ], - }; - } - } catch (error) { - return { - content: [ + async ({name}, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await server.sendLoggingMessage({ + level: "debug", + data: `Starting multi-greet for ${name}` + }, extra.sessionId); + + await sleep(1000); // Wait 1 second before first greeting + + await server.sendLoggingMessage({ + level: "info", + data: `Sending first greeting to ${name}` + }, extra.sessionId); + + await sleep(1000); // Wait another second before second greeting + + await server.sendLoggingMessage({ + level: "info", + data: `Sending second greeting to ${name}` + }, extra.sessionId); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!`, + } + ], + }; + } + ); + // Register a tool that demonstrates elicitation (user input collection) + // This creates a closure that captures the server instance + server.tool( + 'collect-user-info', + 'A tool that collects user information through elicitation', { - type: 'text', - text: `Error collecting ${infoType} information: ${error}`, + infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect'), }, - ], - }; - } - } - ); - - // Register a simple prompt with title - server.registerPrompt( - 'greeting-template', - { - title: 'Greeting Template', // Display name for UI - description: 'A simple greeting prompt template', - argsSchema: { - name: z.string().describe('Name to include in greeting'), - }, - }, - async ({ name }): Promise => { - return { - messages: [ - { - role: 'user', - content: { - type: 'text', - text: `Please greet ${name} in a friendly manner.`, + async ({infoType}): Promise => { + let message: string; + let requestedSchema: { + type: 'object'; + properties: Record; + required?: string[]; + }; + + switch (infoType) { + case 'contact': + message = 'Please provide your contact information'; + requestedSchema = { + type: 'object', + properties: { + name: { + type: 'string', + title: 'Full Name', + description: 'Your full name', + }, + email: { + type: 'string', + title: 'Email Address', + description: 'Your email address', + format: 'email', + }, + phone: { + type: 'string', + title: 'Phone Number', + description: 'Your phone number (optional)', + }, + }, + required: ['name', 'email'], + }; + break; + case 'preferences': + message = 'Please set your preferences'; + requestedSchema = { + type: 'object', + properties: { + theme: { + type: 'string', + title: 'Theme', + description: 'Choose your preferred theme', + enum: ['light', 'dark', 'auto'], + enumNames: ['Light', 'Dark', 'Auto'], + }, + notifications: { + type: 'boolean', + title: 'Enable Notifications', + description: 'Would you like to receive notifications?', + default: true, + }, + frequency: { + type: 'string', + title: 'Notification Frequency', + description: 'How often would you like notifications?', + enum: ['daily', 'weekly', 'monthly'], + enumNames: ['Daily', 'Weekly', 'Monthly'], + }, + }, + required: ['theme'], + }; + break; + case 'feedback': + message = 'Please provide your feedback'; + requestedSchema = { + type: 'object', + properties: { + rating: { + type: 'integer', + title: 'Rating', + description: 'Rate your experience (1-5)', + minimum: 1, + maximum: 5, + }, + comments: { + type: 'string', + title: 'Comments', + description: 'Additional comments (optional)', + maxLength: 500, + }, + recommend: { + type: 'boolean', + title: 'Would you recommend this?', + description: 'Would you recommend this to others?', + }, + }, + required: ['rating', 'recommend'], + }; + break; + default: + throw new Error(`Unknown info type: ${infoType}`); + } + + try { + // Use the underlying server instance to elicit input from the client + const result = await server.server.elicitInput({ + message, + requestedSchema, + }); + + if (result.action === 'accept') { + return { + content: [ + { + type: 'text', + text: `Thank you! Collected ${infoType} information: ${JSON.stringify(result.content, null, 2)}`, + }, + ], + }; + } else if (result.action === 'decline') { + return { + content: [ + { + type: 'text', + text: `No information was collected. User declined ${infoType} information request.`, + }, + ], + }; + } else { + return { + content: [ + { + type: 'text', + text: `Information collection was cancelled by the user.`, + }, + ], + }; + } + } catch (error) { + return { + content: [ + { + type: 'text', + text: `Error collecting ${infoType} information: ${error}`, + }, + ], + }; + } + } + ); + + // Register a simple prompt with title + server.registerPrompt( + 'greeting-template', + { + title: 'Greeting Template', // Display name for UI + description: 'A simple greeting prompt template', + argsSchema: { + name: z.string().describe('Name to include in greeting'), + }, }, - }, - ], - }; - } - ); - - // Register a tool specifically for testing resumability - server.tool( - 'start-notification-stream', - 'Starts sending periodic notifications for testing resumability', - { - interval: z.number().describe('Interval in milliseconds between notifications').default(100), - count: z.number().describe('Number of notifications to send (0 for 100)').default(50), - }, - async ({ interval, count }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - let counter = 0; - - while (count === 0 || counter < count) { - counter++; - try { - await server.sendLoggingMessage( { - level: "info", - data: `Periodic notification #${counter} at ${new Date().toISOString()}` - }, extra.sessionId); - } - catch (error) { - console.error("Error sending notification:", error); - } - // Wait for the specified interval - await sleep(interval); - } - - return { - content: [ - { - type: 'text', - text: `Started sending periodic notifications every ${interval}ms`, - } - ], - }; - } - ); - - // Create a simple resource at a fixed URI - server.registerResource( - 'greeting-resource', - 'https://example.com/greetings/default', - { - title: 'Default Greeting', // Display name for UI - description: 'A simple greeting resource', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'https://example.com/greetings/default', - text: 'Hello, world!', - }, - ], - }; - } - ); - - // Create additional resources for ResourceLink demonstration - server.registerResource( - 'example-file-1', - 'file:///example/file1.txt', - { - title: 'Example File 1', - description: 'First example file for ResourceLink demonstration', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'file:///example/file1.txt', - text: 'This is the content of file 1', - }, - ], - }; - } - ); - - server.registerResource( - 'example-file-2', - 'file:///example/file2.txt', - { - title: 'Example File 2', - description: 'Second example file for ResourceLink demonstration', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'file:///example/file2.txt', - text: 'This is the content of file 2', - }, - ], - }; - } - ); - - // Register a tool that returns ResourceLinks - server.registerTool( - 'list-files', - { - title: 'List Files with ResourceLinks', - description: 'Returns a list of files as ResourceLinks without embedding their content', - inputSchema: { - includeDescriptions: z.boolean().optional().describe('Whether to include descriptions in the resource links'), - }, - }, - async ({ includeDescriptions = true }): Promise => { - const resourceLinks: ResourceLink[] = [ - { - type: 'resource_link', - uri: 'https://example.com/greetings/default', - name: 'Default Greeting', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'A simple greeting resource' }) - }, - { - type: 'resource_link', - uri: 'file:///example/file1.txt', - name: 'Example File 1', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'First example file for ResourceLink demonstration' }) - }, - { - type: 'resource_link', - uri: 'file:///example/file2.txt', - name: 'Example File 2', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'Second example file for ResourceLink demonstration' }) - } - ]; - - return { - content: [ - { - type: 'text', - text: 'Here are the available files as resource links:', - }, - ...resourceLinks, - { - type: 'text', - text: '\nYou can read any of these resources using their URI.', - } - ], - }; - } - ); - - return server; -}; - -const MCP_PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; -const AUTH_PORT = process.env.MCP_AUTH_PORT ? parseInt(process.env.MCP_AUTH_PORT, 10) : 3001; - -const app = express(); -app.use(express.json()); + async ({name}): Promise => { + return { + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please greet ${name} in a friendly manner.`, + }, + }, + ], + }; + } + ); + + // Register a tool specifically for testing resumability + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(50), + }, + async ({interval, count}, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + while (count === 0 || counter < count) { + counter++; + try { + await server.sendLoggingMessage({ + level: "info", + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + }, extra.sessionId); + } catch (error) { + console.error("Error sending notification:", error); + } + // Wait for the specified interval + await sleep(interval); + } + + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms`, + } + ], + }; + } + ); + + // Create a simple resource at a fixed URI + server.registerResource( + 'greeting-resource', + 'https://example.com/greetings/default', + { + title: 'Default Greeting', // Display name for UI + description: 'A simple greeting resource', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + text: 'Hello, world!', + }, + ], + }; + } + ); + + // Create additional resources for ResourceLink demonstration + server.registerResource( + 'example-file-1', + 'file:///example/file1.txt', + { + title: 'Example File 1', + description: 'First example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file1.txt', + text: 'This is the content of file 1', + }, + ], + }; + } + ); + + server.registerResource( + 'example-file-2', + 'file:///example/file2.txt', + { + title: 'Example File 2', + description: 'Second example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file2.txt', + text: 'This is the content of file 2', + }, + ], + }; + } + ); + + // Register a tool that returns ResourceLinks + server.registerTool( + 'list-files', + { + title: 'List Files with ResourceLinks', + description: 'Returns a list of files as ResourceLinks without embedding their content', + inputSchema: { + includeDescriptions: z.boolean().optional().describe('Whether to include descriptions in the resource links'), + }, + }, + async ({includeDescriptions = true}): Promise => { + const resourceLinks: ResourceLink[] = [ + { + type: 'resource_link', + uri: 'https://example.com/greetings/default', + name: 'Default Greeting', + mimeType: 'text/plain', + ...(includeDescriptions && {description: 'A simple greeting resource'}) + }, + { + type: 'resource_link', + uri: 'file:///example/file1.txt', + name: 'Example File 1', + mimeType: 'text/plain', + ...(includeDescriptions && {description: 'First example file for ResourceLink demonstration'}) + }, + { + type: 'resource_link', + uri: 'file:///example/file2.txt', + name: 'Example File 2', + mimeType: 'text/plain', + ...(includeDescriptions && {description: 'Second example file for ResourceLink demonstration'}) + } + ]; + + return { + content: [ + { + type: 'text', + text: 'Here are the available files as resource links:', + }, + ...resourceLinks, + { + type: 'text', + text: '\nYou can read any of these resources using their URI.', + } + ], + }; + } + ); + + return server; + }; + + const MCP_PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; + const AUTH_PORT = process.env.MCP_AUTH_PORT ? parseInt(process.env.MCP_AUTH_PORT, 10) : 3001; + + const app = express(); + app.use(express.json()); // Allow CORS all domains, expose the Mcp-Session-Id header -app.use(cors({ - origin: '*', // Allow all origins - exposedHeaders: ["Mcp-Session-Id"] -})); + app.use(cors({ + origin: '*', // Allow all origins + exposedHeaders: ["Mcp-Session-Id"] + })); // Set up OAuth if enabled -let authMiddleware = null; -if (useOAuth) { - // Create auth middleware for MCP endpoints - const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); - const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); - - const oauthMetadata: OAuthMetadata = setupAuthServer({ authServerUrl, mcpServerUrl, strictResource: strictOAuth }); - - const tokenVerifier = { - verifyAccessToken: async (token: string) => { - const endpoint = oauthMetadata.introspection_endpoint; - - if (!endpoint) { - throw new Error('No token verification endpoint available in metadata'); - } - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - body: new URLSearchParams({ - token: token - }).toString() - }); - - - if (!response.ok) { - throw new Error(`Invalid or expired token: ${await response.text()}`); - } - - const data = await response.json(); - - if (strictOAuth) { - if (!data.aud) { - throw new Error(`Resource Indicator (RFC8707) missing`); - } - if (!checkResourceAllowed({ requestedResource: data.aud, configuredResource: mcpServerUrl })) { - throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + let authMiddleware = null; + if (useOAuth) { + // Create auth middleware for MCP endpoints + const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); + const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); + + const oauthMetadata: OAuthMetadata = setupAuthServer({ + authServerUrl, + mcpServerUrl, + strictResource: strictOAuth + }); + + const tokenVerifier = { + verifyAccessToken: async (token: string) => { + const endpoint = oauthMetadata.introspection_endpoint; + + if (!endpoint) { + throw new Error('No token verification endpoint available in metadata'); + } + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + token: token + }).toString() + }); + + + if (!response.ok) { + throw new Error(`Invalid or expired token: ${await response.text()}`); + } + + const data = await response.json(); + + if (strictOAuth) { + if (!data.aud) { + throw new Error(`Resource Indicator (RFC8707) missing`); + } + if (!checkResourceAllowed({requestedResource: data.aud, configuredResource: mcpServerUrl})) { + throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + } + } + + // Convert the response to AuthInfo format + return { + token, + clientId: data.client_id, + scopes: data.scope ? data.scope.split(' ') : [], + expiresAt: data.exp, + }; + } } - } - - // Convert the response to AuthInfo format - return { - token, - clientId: data.client_id, - scopes: data.scope ? data.scope.split(' ') : [], - expiresAt: data.exp, - }; + // Add metadata routes to the main MCP server + app.use(mcpAuthMetadataRouter({ + oauthMetadata, + resourceServerUrl: mcpServerUrl, + scopesSupported: ['mcp:tools'], + resourceName: 'MCP Demo Server', + })); + + authMiddleware = requireBearerAuth({ + verifier: tokenVerifier, + requiredScopes: [], + resourceMetadataUrl: getOAuthProtectedResourceMetadataUrl(mcpServerUrl), + }); } - } - // Add metadata routes to the main MCP server - app.use(mcpAuthMetadataRouter({ - oauthMetadata, - resourceServerUrl: mcpServerUrl, - scopesSupported: ['mcp:tools'], - resourceName: 'MCP Demo Server', - })); - - authMiddleware = requireBearerAuth({ - verifier: tokenVerifier, - requiredScopes: [], - resourceMetadataUrl: getOAuthProtectedResourceMetadataUrl(mcpServerUrl), - }); -} // Map to store transports by session ID -const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; // MCP POST endpoint with optional auth -const mcpPostHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (sessionId) { - console.log(`Received MCP request for session: ${sessionId}`); - } else { - console.log('Request body:', req.body); - } - - if (useOAuth && req.auth) { - console.log('Authenticated user:', req.auth); - } - try { - let transport: StreamableHTTPServerTransport; - if (sessionId && transports[sessionId]) { - // Reuse existing transport - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - // New initialization request - const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - eventStore, // Enable resumability - onsessioninitialized: (sessionId) => { - // Store the transport by session ID when session is initialized - // This avoids race conditions where requests might come in before the session is stored - console.log(`Session initialized with ID: ${sessionId}`); - transports[sessionId] = transport; - } - }); - - // Set up onclose handler to clean up transport when closed - transport.onclose = () => { - const sid = transport.sessionId; - if (sid && transports[sid]) { - console.log(`Transport closed for session ${sid}, removing from transports map`); - delete transports[sid]; + const mcpPostHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) { + console.log(`Received MCP request for session: ${sessionId}`); + } else { + console.log('Request body:', req.body); } - }; - // Connect the transport to the MCP server BEFORE handling the request - // so responses can flow back through the same transport - const server = getServer(); - await server.connect(transport); + if (useOAuth && req.auth) { + console.log('Authenticated user:', req.auth); + } + try { + let transport: StreamableHTTPServerTransport; + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + const eventStore = new InMemoryEventStore(); + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, // Enable resumability + onsessioninitialized: (sessionId) => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); + + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + const sid = transport.sessionId; + if (sid && transports[sid]) { + console.log(`Transport closed for session ${sid}, removing from transports map`); + delete transports[sid]; + } + }; + + // Connect the transport to the MCP server BEFORE handling the request + // so responses can flow back through the same transport + const server = getServer(); + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided', + }, + id: null, + }); + return; + } + + // Handle the request with existing transport - no need to reconnect + // The existing transport is already connected to the server + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error', + }, + id: null, + }); + } + } + }; - await transport.handleRequest(req, res, req.body); - return; // Already handled +// Set up routes with conditional auth middleware + if (useOAuth && authMiddleware) { + app.post('/mcp', authMiddleware, mcpPostHandler); } else { - // Invalid request - no session ID or not initialization request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; + app.post('/mcp', mcpPostHandler); } - // Handle the request with existing transport - no need to reconnect - // The existing transport is already connected to the server - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); - } - } -}; +// Handle GET requests for SSE streams (using built-in support from StreamableHTTP) + const mcpGetHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } -// Set up routes with conditional auth middleware -if (useOAuth && authMiddleware) { - app.post('/mcp', authMiddleware, mcpPostHandler); -} else { - app.post('/mcp', mcpPostHandler); -} + if (useOAuth && req.auth) { + console.log('Authenticated SSE connection from user:', req.auth); + } -// Handle GET requests for SSE streams (using built-in support from StreamableHTTP) -const mcpGetHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } - - if (useOAuth && req.auth) { - console.log('Authenticated SSE connection from user:', req.auth); - } - - // Check for Last-Event-ID header for resumability - const lastEventId = req.headers['last-event-id'] as string | undefined; - if (lastEventId) { - console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); - } else { - console.log(`Establishing new SSE stream for session ${sessionId}`); - } - - const transport = transports[sessionId]; - await transport.handleRequest(req, res); -}; + // Check for Last-Event-ID header for resumability + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`Establishing new SSE stream for session ${sessionId}`); + } + + const transport = transports[sessionId]; + await transport.handleRequest(req, res); + }; // Set up GET route with conditional auth middleware -if (useOAuth && authMiddleware) { - app.get('/mcp', authMiddleware, mcpGetHandler); -} else { - app.get('/mcp', mcpGetHandler); -} + if (useOAuth && authMiddleware) { + app.get('/mcp', authMiddleware, mcpGetHandler); + } else { + app.get('/mcp', mcpGetHandler); + } // Handle DELETE requests for session termination (according to MCP spec) -const mcpDeleteHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } - - console.log(`Received session termination request for session ${sessionId}`); - - try { - const transport = transports[sessionId]; - await transport.handleRequest(req, res); - } catch (error) { - console.error('Error handling session termination:', error); - if (!res.headersSent) { - res.status(500).send('Error processing session termination'); - } - } -}; + const mcpDeleteHandler = async (req: Request, res: Response) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + console.log(`Received session termination request for session ${sessionId}`); + + try { + const transport = transports[sessionId]; + await transport.handleRequest(req, res); + } catch (error) { + console.error('Error handling session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } + } + }; // Set up DELETE route with conditional auth middleware -if (useOAuth && authMiddleware) { - app.delete('/mcp', authMiddleware, mcpDeleteHandler); -} else { - app.delete('/mcp', mcpDeleteHandler); -} + if (useOAuth && authMiddleware) { + app.delete('/mcp', authMiddleware, mcpDeleteHandler); + } else { + app.delete('/mcp', mcpDeleteHandler); + } -app.listen(MCP_PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); -}); + app.listen(MCP_PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); + }); // Handle server shutdown -process.on('SIGINT', async () => { - console.log('Shutting down server...'); - - // Close all active transports to properly clean up resources - for (const sessionId in transports) { - try { - console.log(`Closing transport for session ${sessionId}`); - await transports[sessionId].close(); - delete transports[sessionId]; - } catch (error) { - console.error(`Error closing transport for session ${sessionId}:`, error); - } - } - console.log('Server shutdown complete'); - process.exit(0); + process.on('SIGINT', async () => { + console.log('Shutting down server...'); + + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } + } + console.log('Server shutdown complete'); + process.exit(0); + }); + +} + +main().catch((err) => { + console.error('Failed to start server:', err); + process.exit(1); }); From 4363945092455b10ed3596ac083f1a3fe5f44ae7 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 15:21:33 +0300 Subject: [PATCH 13/22] fixup! Introduce Kotlin integration tests --- .../kotlin/sdk/integration/typescript/TypeScriptTestBase.kt | 4 ++-- .../kotlin/sdk/integration/typescript/simpleStreamableHttp.ts | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index a4bc6243..b5c3165c 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -185,14 +185,14 @@ abstract class TypeScriptTestBase { .command( "cmd.exe", "/c", - "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"" + "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", ) } else { ProcessBuilder() .command( "bash", "-c", - "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"" + "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", ) } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts index 67088cff..c750aedb 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts @@ -30,7 +30,6 @@ async function main() { } = await importFromSdk('src/types.ts'); const {InMemoryEventStore} = await importFromSdk('src/examples/shared/inMemoryEventStore.ts'); const {setupAuthServer} = await importFromSdk('src/examples/server/demoInMemoryOAuthProvider.ts'); - const {OAuthMetadata} = await importFromSdk('src/shared/auth.ts'); const {checkResourceAllowed} = await importFromSdk('src/shared/auth-utils.ts'); // Check for OAuth flag From 3f756aad63ed4432e0d16aa71a0b2698e76a8bf7 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 17:40:27 +0300 Subject: [PATCH 14/22] fixup! Introduce Kotlin integration tests --- ...tlinClientTypeScriptServerEdgeCasesTest.kt | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index 3905a20a..bb281df7 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -208,20 +208,15 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), ) - try { - val result = client.callTool("greet", invalidArguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - } catch (e: Exception) { - assertTrue( - e.message?.contains("invalid") == true || - e.message?.contains("error") == true, - "Exception should indicate invalid arguments: ${e.message}", - ) + val exception = assertThrows { + client.callTool("greet", invalidArguments) } + + assertTrue( + exception.message?.contains("invalid") == true || + exception.message?.contains("error") == true, + "Exception should indicate invalid arguments: ${exception.message}", + ) } } } From 0afe0a39e1a40ac7b690e480cc4d4a357aad2a37 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 18:11:20 +0300 Subject: [PATCH 15/22] fixup! Introduce Kotlin integration tests --- .../integration/kotlin/PromptEdgeCasesTest.kt | 354 +++++++++--------- .../kotlin/ResourceEdgeCasesTest.kt | 213 +++++------ .../integration/kotlin/ToolEdgeCasesTest.kt | 321 ++++++++-------- .../integration/kotlin/ToolIntegrationTest.kt | 13 +- ...tlinClientTypeScriptServerEdgeCasesTest.kt | 278 +++++++------- .../KotlinClientTypeScriptServerTest.kt | 163 ++++---- .../TypeScriptClientKotlinServerTest.kt | 9 +- .../typescript/TypeScriptEdgeCasesTest.kt | 11 +- 8 files changed, 652 insertions(+), 710 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt index f5e736d7..559129c3 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt @@ -123,7 +123,7 @@ class PromptEdgeCasesTest : KotlinTestBase() { ) } - // Very large prompt + // very large prompt server.addPrompt( name = largePromptName, description = largePromptDescription, @@ -183,230 +183,210 @@ class PromptEdgeCasesTest : KotlinTestBase() { } @Test - fun testBasicPrompt() { - runTest { - val testName = "Alice" - val result = client.getPrompt( - GetPromptRequest( - name = basicPromptName, - arguments = mapOf("name" to testName), - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(basicPromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") - assertEquals("Hello, $testName!", userContent.text, "User message content should match") - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - assertEquals( - "Greetings, $testName! How can I assist you today?", - assistantContent.text, - "Assistant message content should match", - ) - } + fun testBasicPrompt() = runTest { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(basicPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertEquals("Hello, $testName!", userContent.text, "User message content should match") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) } @Test - fun testComplexPromptWithManyArguments() { - runTest { - val arguments = (1..10).associate { i -> "arg$i" to "value$i" } - - val result = client.getPrompt( - GetPromptRequest( - name = complexPromptName, - arguments = arguments, - ), - ) + fun testComplexPromptWithManyArguments() = runTest { + val arguments = (1..10).associate { i -> "arg$i" to "value$i" } - assertNotNull(result, "Get prompt result should not be null") - assertEquals(complexPromptDescription, result.description, "Prompt description should match") + val result = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = arguments, + ), + ) - assertEquals(2, result.messages.size, "Prompt should have 2 messages") + assertNotNull(result, "Get prompt result should not be null") + assertEquals(complexPromptDescription, result.description, "Prompt description should match") - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") - // verify all arguments - val text = userContent.text ?: "" - for (i in 1..10) { - assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") - } + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - assertEquals( - "Received 10 arguments", - assistantContent.text, - "Assistant message should indicate 10 arguments", - ) + // verify all arguments + val text = userContent.text ?: "" + for (i in 1..10) { + assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") } + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Received 10 arguments", + assistantContent.text, + "Assistant message should indicate 10 arguments", + ) } @Test - fun testLargePrompt() { - runTest { - val result = client.getPrompt( - GetPromptRequest( - name = largePromptName, - arguments = mapOf("size" to "1"), - ), - ) + fun testLargePrompt() = runTest { + val result = client.getPrompt( + GetPromptRequest( + name = largePromptName, + arguments = mapOf("size" to "1"), + ), + ) - assertNotNull(result, "Get prompt result should not be null") - assertEquals(largePromptDescription, result.description, "Prompt description should match") + assertNotNull(result, "Get prompt result should not be null") + assertEquals(largePromptDescription, result.description, "Prompt description should match") - assertEquals(2, result.messages.size, "Prompt should have 2 messages") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - val text = assistantContent.text ?: "" - assertEquals(100_000, text.length, "Assistant message should be 100KB in size") - } + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val text = assistantContent.text ?: "" + assertEquals(100_000, text.length, "Assistant message should be 100KB in size") } @Test - fun testSpecialCharacters() { - runTest { - val result = client.getPrompt( - GetPromptRequest( - name = specialCharsPromptName, - arguments = mapOf("special" to specialCharsContent), - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") - val userText = userContent.text ?: "" - assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - val assistantText = assistantContent.text ?: "" - assertTrue( - assistantText.contains(specialCharsContent), - "Assistant message should contain special characters", - ) - } + fun testSpecialCharacters() = runTest { + val result = client.getPrompt( + GetPromptRequest( + name = specialCharsPromptName, + arguments = mapOf("special" to specialCharsContent), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = userContent.text ?: "" + assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = assistantContent.text ?: "" + assertTrue( + assistantText.contains(specialCharsContent), + "Assistant message should contain special characters", + ) } @Test - fun testMissingRequiredArguments() { - runTest { - val exception = assertThrows { - runBlocking { - client.getPrompt( - GetPromptRequest( - name = complexPromptName, - arguments = mapOf("arg4" to "value4", "arg5" to "value5"), - ), - ) - } + fun testMissingRequiredArguments() = runTest { + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf("arg4" to "value4", "arg5" to "value5"), + ), + ) } - - assertTrue( - exception.message?.contains("arg1") == true || - exception.message?.contains("arg2") == true || - exception.message?.contains("arg3") == true || - exception.message?.contains("required") == true, - "Exception should mention missing required arguments", - ) } + + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=Missing required argument: arg1, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for missing required argument") } @Test - fun testConcurrentPromptRequests() { - runTest { - val concurrentCount = 10 - val results = mutableListOf() + fun testConcurrentPromptRequests() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val promptName = when (index % 4) { + 0 -> basicPromptName + 1 -> complexPromptName + 2 -> largePromptName + else -> specialCharsPromptName + } - runBlocking { - repeat(concurrentCount) { index -> - launch { - val promptName = when (index % 4) { - 0 -> basicPromptName - 1 -> complexPromptName - 2 -> largePromptName - else -> specialCharsPromptName - } - - val arguments = when (promptName) { - basicPromptName -> mapOf("name" to "User$index") - complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") - largePromptName -> mapOf("size" to "1") - else -> mapOf("special" to "!@#$%^&*()") - } - - val result = client.getPrompt( - GetPromptRequest( - name = promptName, - arguments = arguments, - ), - ) - - synchronized(results) { - results.add(result) - } + val arguments = when (promptName) { + basicPromptName -> mapOf("name" to "User$index") + complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") + largePromptName -> mapOf("size" to "1") + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.getPrompt( + GetPromptRequest( + name = promptName, + arguments = arguments, + ), + ) + + synchronized(results) { + results.add(result) } } } + } - assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - results.forEach { result -> - assertNotNull(result, "Result should not be null") - assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") - } + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") } } @Test - fun testNonExistentPrompt() { - runTest { - val nonExistentPromptName = "non-existent-prompt" + fun testNonExistentPrompt() = runTest { + val nonExistentPromptName = "non-existent-prompt" - val exception = assertThrows { - runBlocking { - client.getPrompt( - GetPromptRequest( - name = nonExistentPromptName, - arguments = mapOf("name" to "Test"), - ), - ) - } + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = nonExistentPromptName, + arguments = mapOf("name" to "Test"), + ), + ) } - - assertTrue( - exception.message?.contains("not found") == true || - exception.message?.contains("does not exist") == true || - exception.message?.contains("unknown") == true || - exception.message?.contains("error") == true, - "Exception should indicate prompt not found", - ) } + + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt") } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt index 232ac025..d4cf5187 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt @@ -129,157 +129,142 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testBinaryResource() { - runTest { - val result = client.readResource(ReadResourceRequest(uri = binaryResourceUri)) + fun testBinaryResource() = runTest { + val result = client.readResource(ReadResourceRequest(uri = binaryResourceUri)) - assertNotNull(result, "Read resource result should not be null") - assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") - val content = result.contents.firstOrNull() as? BlobResourceContents - assertNotNull(content, "Resource content should be BlobResourceContents") - assertEquals(binaryResourceContent, content.blob, "Binary resource content should match") - assertEquals("image/png", content.mimeType, "MIME type should match") - } + val content = result.contents.firstOrNull() as? BlobResourceContents + assertNotNull(content, "Resource content should be BlobResourceContents") + assertEquals(binaryResourceContent, content.blob, "Binary resource content should match") + assertEquals("image/png", content.mimeType, "MIME type should match") } @Test - fun testLargeResource() { - runTest { - val result = client.readResource(ReadResourceRequest(uri = largeResourceUri)) + fun testLargeResource() = runTest { + val result = client.readResource(ReadResourceRequest(uri = largeResourceUri)) - assertNotNull(result, "Read resource result should not be null") - assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") - val content = result.contents.firstOrNull() as? TextResourceContents - assertNotNull(content, "Resource content should be TextResourceContents") - assertEquals(100_000, content.text.length, "Large resource content length should match") - assertEquals("X".repeat(100_000), content.text, "Large resource content should match") - } + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(100_000, content.text.length, "Large resource content length should match") + assertEquals("X".repeat(100_000), content.text, "Large resource content should match") } @Test - fun testInvalidResourceUri() { - runTest { - val invalidUri = "test://nonexistent.txt" + fun testInvalidResourceUri() = runTest { + val invalidUri = "test://nonexistent.txt" - val exception = assertThrows { - runBlocking { - client.readResource(ReadResourceRequest(uri = invalidUri)) - } + val exception = assertThrows { + runBlocking { + client.readResource(ReadResourceRequest(uri = invalidUri)) } - - assertTrue( - exception.message?.contains("not found") == true || - exception.message?.contains("invalid") == true || - exception.message?.contains("error") == true, - "Exception should indicate resource not found or invalid URI", - ) } + + val msg = exception.message ?: "" + val expectedMessage = + "JSONRPCError(code=InternalError, message=Resource not found: test://nonexistent.txt, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for invalid resource URI") } @Test - fun testDynamicResource() { - runTest { - val initialResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) - assertNotNull(initialResult, "Initial read result should not be null") - val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text - assertEquals("Original content", initialContent, "Initial content should match") - - // update resource - dynamicResourceContent.set(true) - - val updatedResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) - assertNotNull(updatedResult, "Updated read result should not be null") - val updatedContent = (updatedResult.contents.firstOrNull() as? TextResourceContents)?.text - assertEquals("Updated content", updatedContent, "Updated content should match") - } + fun testDynamicResource() = runTest { + val initialResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(initialResult, "Initial read result should not be null") + val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Original content", initialContent, "Initial content should match") + + // update resource + dynamicResourceContent.set(true) + + val updatedResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(updatedResult, "Updated read result should not be null") + val updatedContent = (updatedResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Updated content", updatedContent, "Updated content should match") } @Test - fun testResourceAddAndRemove() { - runTest { - val initialList = client.listResources() - assertNotNull(initialList, "Initial list result should not be null") - val initialCount = initialList.resources.size - - val newResourceUri = "test://new-resource.txt" - server.addResource( - uri = newResourceUri, - name = "New Resource", - description = "A newly added resource", - mimeType = "text/plain", - ) { request -> - ReadResourceResult( - contents = listOf( - TextResourceContents( - text = "New resource content", - uri = request.uri, - mimeType = "text/plain", - ), + fun testResourceAddAndRemove() = runTest { + val initialList = client.listResources() + assertNotNull(initialList, "Initial list result should not be null") + val initialCount = initialList.resources.size + + val newResourceUri = "test://new-resource.txt" + server.addResource( + uri = newResourceUri, + name = "New Resource", + description = "A newly added resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "New resource content", + uri = request.uri, + mimeType = "text/plain", ), - ) - } + ), + ) + } - val updatedList = client.listResources() - assertNotNull(updatedList, "Updated list result should not be null") - val updatedCount = updatedList.resources.size + val updatedList = client.listResources() + assertNotNull(updatedList, "Updated list result should not be null") + val updatedCount = updatedList.resources.size - assertEquals(initialCount + 1, updatedCount, "Resource count should increase by 1") - val newResource = updatedList.resources.find { it.uri == newResourceUri } - assertNotNull(newResource, "New resource should be in the list") + assertEquals(initialCount + 1, updatedCount, "Resource count should increase by 1") + val newResource = updatedList.resources.find { it.uri == newResourceUri } + assertNotNull(newResource, "New resource should be in the list") - server.removeResource(newResourceUri) + server.removeResource(newResourceUri) - val finalList = client.listResources() - assertNotNull(finalList, "Final list result should not be null") - val finalCount = finalList.resources.size + val finalList = client.listResources() + assertNotNull(finalList, "Final list result should not be null") + val finalCount = finalList.resources.size - assertEquals(initialCount, finalCount, "Resource count should return to initial value") - val removedResource = finalList.resources.find { it.uri == newResourceUri } - assertEquals(null, removedResource, "Resource should be removed from the list") - } + assertEquals(initialCount, finalCount, "Resource count should return to initial value") + val removedResource = finalList.resources.find { it.uri == newResourceUri } + assertEquals(null, removedResource, "Resource should be removed from the list") } @Test - fun testConcurrentResourceOperations() { - runTest { - val concurrentCount = 10 - val results = mutableListOf() + fun testConcurrentResourceOperations() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val uri = when (index % 3) { + 0 -> testResourceUri + 1 -> binaryResourceUri + else -> largeResourceUri + } - runBlocking { - repeat(concurrentCount) { index -> - launch { - val uri = when (index % 3) { - 0 -> testResourceUri - 1 -> binaryResourceUri - else -> largeResourceUri - } - - val result = client.readResource(ReadResourceRequest(uri = uri)) - synchronized(results) { - results.add(result) - } + val result = client.readResource(ReadResourceRequest(uri = uri)) + synchronized(results) { + results.add(result) } } } + } - assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - results.forEach { result -> - assertNotNull(result, "Result should not be null") - assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") - } + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") } } @Test - fun testSubscribeAndUnsubscribe() { - runTest { - val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) - assertNotNull(subscribeResult, "Subscribe result should not be null") + fun testSubscribeAndUnsubscribe() = runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") - val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) - assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") - } + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index 0cb8c506..83f32ef8 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -281,225 +281,208 @@ class ToolEdgeCasesTest : KotlinTestBase() { } @Test - fun testBasicTool() { - runTest { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) + fun testBasicTool() = runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) - val result = client.callTool(basicToolName, arguments) + val result = client.callTool(basicToolName, arguments) - val toolResult = assertCallToolResult(result) - assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + val toolResult = assertCallToolResult(result) + assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") - val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "result", testText) - } + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "result", testText) } @Test - fun testComplexNestedSchema() { - runTest { - val userJson = buildJsonObject { - put("name", JsonPrimitive("John Doe")) - put("age", JsonPrimitive(30)) - put( - "address", - buildJsonObject { - put("street", JsonPrimitive("123 Main St")) - put("city", JsonPrimitive("New York")) - put("country", JsonPrimitive("USA")) - }, - ) - } - - val optionsJson = buildJsonArray { - add(JsonPrimitive("option1")) - add(JsonPrimitive("option2")) - add(JsonPrimitive("option3")) - } - - val arguments = buildJsonObject { - put("user", userJson) - put("options", optionsJson) - } - - val result = client.callTool( - CallToolRequest( - name = complexToolName, - arguments = arguments, - ), + fun testComplexNestedSchema() = runTest { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Doe")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, ) + } + + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" - - assertTrue(text.contains("John Doe"), "Result should contain the name") - assertTrue(text.contains("30"), "Result should contain the age") - assertTrue(text.contains("123 Main St"), "Result should contain the street") - assertTrue(text.contains("New York"), "Result should contain the city") - assertTrue(text.contains("USA"), "Result should contain the country") - assertTrue(text.contains("option1"), "Result should contain option1") - assertTrue(text.contains("option2"), "Result should contain option2") - assertTrue(text.contains("option3"), "Result should contain option3") - - val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "name", "John Doe") - assertJsonProperty(structuredContent, "age", 30) - - val address = structuredContent["address"] as? JsonObject - assertNotNull(address, "Address should be present in structured content") - assertJsonProperty(address, "street", "123 Main St") - assertJsonProperty(address, "city", "New York") - assertJsonProperty(address, "country", "USA") - - val options = structuredContent["options"] as? JsonArray - assertNotNull(options, "Options should be present in structured content") - assertEquals(3, options.size, "Options should have 3 items") + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) } + + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains("John Doe"), "Result should contain the name") + assertTrue(text.contains("30"), "Result should contain the age") + assertTrue(text.contains("123 Main St"), "Result should contain the street") + assertTrue(text.contains("New York"), "Result should contain the city") + assertTrue(text.contains("USA"), "Result should contain the country") + assertTrue(text.contains("option1"), "Result should contain option1") + assertTrue(text.contains("option2"), "Result should contain option2") + assertTrue(text.contains("option3"), "Result should contain option3") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "name", "John Doe") + assertJsonProperty(structuredContent, "age", 30) + + val address = structuredContent["address"] as? JsonObject + assertNotNull(address, "Address should be present in structured content") + assertJsonProperty(address, "street", "123 Main St") + assertJsonProperty(address, "city", "New York") + assertJsonProperty(address, "country", "USA") + + val options = structuredContent["options"] as? JsonArray + assertNotNull(options, "Options should be present in structured content") + assertEquals(3, options.size, "Options should have 3 items") } @Test - fun testLargeResponse() { - runTest { - val size = 10 - val arguments = mapOf("size" to size) + fun testLargeResponse() = runTest { + val size = 10 + val arguments = mapOf("size" to size) - val result = client.callTool(largeToolName, arguments) + val result = client.callTool(largeToolName, arguments) - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" - assertEquals(10000, text.length, "Response should be 10KB in size") + assertEquals(10000, text.length, "Response should be 10KB in size") - val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "size", 10000) - } + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "size", 10000) } @Test - fun testSlowTool() { - runTest { - val delay = 500 - val arguments = mapOf("delay" to delay) + fun testSlowTool() = runTest { + val delay = 500 + val arguments = mapOf("delay" to delay) - val startTime = System.currentTimeMillis() - val result = client.callTool(slowToolName, arguments) - val endTime = System.currentTimeMillis() + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) + val endTime = System.currentTimeMillis() - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" - assertTrue(text.contains("${delay}ms"), "Result should mention the delay") - assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + assertTrue(text.contains("${delay}ms"), "Result should mention the delay") + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") - val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "delay", delay) - } + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "delay", delay) } @Test - fun testSpecialCharacters() { - runTest { - val arguments = mapOf("special" to specialCharsContent) + fun testSpecialCharacters() = runTest { + val arguments = mapOf("special" to specialCharsContent) - val result = client.callTool(specialCharsToolName, arguments) + val result = client.callTool(specialCharsToolName, arguments) - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" - assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") - val structuredContent = toolResult.structuredContent as JsonObject - val special = structuredContent["special"]?.toString()?.trim('"') + val structuredContent = toolResult.structuredContent as JsonObject + val special = structuredContent["special"]?.toString()?.trim('"') - assertNotNull(special, "Special characters should be in structured content") - assertTrue(text.contains(specialCharsContent), "Special characters should be in the content") - } + assertNotNull(special, "Special characters should be in structured content") + assertTrue(text.contains(specialCharsContent), "Special characters should be in the content") } @Test - fun testConcurrentToolCalls() { - runTest { - val concurrentCount = 10 - val results = mutableListOf() + fun testConcurrentToolCalls() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val toolName = when (index % 5) { + 0 -> basicToolName + 1 -> complexToolName + 2 -> largeToolName + 3 -> slowToolName + else -> specialCharsToolName + } - runBlocking { - repeat(concurrentCount) { index -> - launch { - val toolName = when (index % 5) { - 0 -> basicToolName - 1 -> complexToolName - 2 -> largeToolName - 3 -> slowToolName - else -> specialCharsToolName - } - - val arguments = when (toolName) { - basicToolName -> mapOf("text" to "Concurrent call $index") - - complexToolName -> mapOf( - "user" to mapOf( - "name" to "User $index", - "age" to 20 + index, - "address" to mapOf( - "street" to "Street $index", - "city" to "City $index", - "country" to "Country $index", - ), + val arguments = when (toolName) { + basicToolName -> mapOf("text" to "Concurrent call $index") + + complexToolName -> mapOf( + "user" to mapOf( + "name" to "User $index", + "age" to 20 + index, + "address" to mapOf( + "street" to "Street $index", + "city" to "City $index", + "country" to "Country $index", ), - ) + ), + ) - largeToolName -> mapOf("size" to 1) + largeToolName -> mapOf("size" to 1) - slowToolName -> mapOf("delay" to 100) + slowToolName -> mapOf("delay" to 100) - else -> mapOf("special" to "!@#$%^&*()") - } + else -> mapOf("special" to "!@#$%^&*()") + } - val result = client.callTool(toolName, arguments) + val result = client.callTool(toolName, arguments) - synchronized(results) { - results.add(result) - } + synchronized(results) { + results.add(result) } } } + } - assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - results.forEach { result -> - assertNotNull(result, "Result should not be null") - assertTrue(result.content.isNotEmpty(), "Result content should not be empty") - } + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.content.isNotEmpty(), "Result content should not be empty") } } @Test - fun testNonExistentTool() { - runTest { - val nonExistentToolName = "non-existent-tool" - val arguments = mapOf("text" to "Test") - - val exception = assertThrows { - runBlocking { - client.callTool(nonExistentToolName, arguments) - } - } + fun testNonExistentTool() = runTest { + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("text" to "Test") - assertTrue( - exception.message?.contains("not found") == true || - exception.message?.contains("does not exist") == true || - exception.message?.contains("unknown") == true || - exception.message?.contains("error") == true, - "Exception should indicate tool not found", - ) + val exception = assertThrows { + runBlocking { + client.callTool(nonExistentToolName, arguments) + } } + + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=Tool not found: non-existent-tool, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for non-existent tool") } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index c6262a13..5fae911f 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -403,20 +403,19 @@ class ToolIntegrationTest : KotlinTestBase() { val exceptionArgs = mapOf( "errorType" to "exception", - "message" to "Exception message", + "message" to "My exception message", ) - val exception = assertThrows { + val exception = assertThrows { runBlocking { client.callTool(errorToolName, exceptionArgs) } } - assertEquals( - exception.message?.contains("Exception message"), - true, - "Exception message should contain 'Exception message'", - ) + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=My exception message, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for exception") } @Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index bb281df7..ec83c0d3 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -7,6 +7,7 @@ import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async @@ -70,183 +71,182 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testNonExistentTool() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val nonExistentToolName = "non-existent-tool" - val arguments = mapOf("name" to "TestUser") - - val exception = assertThrows { - client.callTool(nonExistentToolName, arguments) - } + fun testNonExistentTool() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val errorMessage = exception.message ?: "" - assertTrue( - errorMessage.contains("not found") || - errorMessage.contains("unknown") || - errorMessage.contains("error"), - "Exception should indicate tool not found: $errorMessage", - ) + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) } + + val expectedMessage = + "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" + assertEquals( + expectedMessage, + exception.message, + "Unexpected error message for non-existent tool", + ) } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testSpecialCharactersInArguments() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" - val arguments = mapOf("name" to specialChars) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - - val text = textContent.text ?: "" - assertTrue( - text.contains(specialChars), - "Tool response should contain the special characters", - ) - } + fun testSpecialCharactersInArguments() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testLargePayload() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val largeName = "A".repeat(10 * 1024) - val arguments = mapOf("name" to largeName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - - val text = textContent.text ?: "" - assertTrue( - text.contains("Hello,") && text.contains("A"), - "Tool response should contain the greeting with the large name", - ) - } + fun testLargePayload() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) } } @Test @Timeout(60, unit = TimeUnit.SECONDS) - fun testConcurrentRequests() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val concurrentCount = 5 - val results = mutableListOf>() - - for (i in 1..concurrentCount) { - val deferred = async { - val name = "ConcurrentClient$i" - val arguments = mapOf("name" to name) + fun testConcurrentRequests() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val concurrentCount = 5 + val results = mutableListOf>() + + for (i in 1..concurrentCount) { + val deferred = async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null for client $i") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present for client $i") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") - textContent.text ?: "" - } - results.add(deferred) + textContent.text ?: "" } + results.add(deferred) + } - val responses = results.awaitAll() + val responses = results.awaitAll() - for (i in 1..concurrentCount) { - val expectedName = "ConcurrentClient$i" - val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } - assertEquals( - 1, - matchingResponses.size, - "Should have exactly one response for $expectedName", - ) - } + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) } } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testInvalidArguments() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val invalidArguments = mapOf( - "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), - ) - - val exception = assertThrows { - client.callTool("greet", invalidArguments) - } - - assertTrue( - exception.message?.contains("invalid") == true || - exception.message?.contains("error") == true, - "Exception should indicate invalid arguments: ${exception.message}", - ) + fun testInvalidArguments() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + val exception = assertThrows { + client.callTool("greet", invalidArguments) } + + val msg = exception.message ?: "" + val expectedMessage = """ + JSONRPCError(code=InvalidParams, message=MCP error -32602: Invalid arguments for tool greet: [ + { + "code": "invalid_type", + "expected": "string", + "received": "object", + "path": [ + "name" + ], + "message": "Expected string, received object" + } + ], data={}) + """.trimIndent() + + assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleToolCalls() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - repeat(10) { i -> - val name = "SequentialClient$i" - val arguments = mapOf("name" to name) + fun testMultipleToolCalls() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null for call $i") + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present for call $i") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") - assertEquals( - "Hello, $name!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) - } + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt index f4cf8ffc..b7e77652 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -7,6 +7,7 @@ import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext @@ -64,109 +65,101 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testKotlinClientConnectsToTypeScriptServer() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + fun testKotlinClientConnectsToTypeScriptServer() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - assertNotNull(client, "Client should be initialized") + assertNotNull(client, "Client should be initialized") - val pingResult = client.ping() - assertNotNull(pingResult, "Ping result should not be null") + val pingResult = client.ping() + assertNotNull(pingResult, "Ping result should not be null") - val serverImpl = client.serverVersion - assertNotNull(serverImpl, "Server implementation should not be null") - println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") - } + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testListTools() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val result = client.listTools() - assertNotNull(result, "Tools list should not be null") - assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") - - // Verify specific utils are available - val toolNames = result.tools.map { it.name } - assertTrue(toolNames.contains("greet"), "Greet tool should be available") - assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") - assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") - - println("Available utils: ${toolNames.joinToString()}") - } + fun testListTools() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + // Verify specific utils are available + val toolNames = result.tools.map { it.name } + assertTrue(toolNames.contains("greet"), "Greet tool should be available") + assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") + assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") + + println("Available utils: ${toolNames.joinToString()}") } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCall() { - runBlocking { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val testName = "TestUser" - val arguments = mapOf("name" to testName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - assertEquals( - "Hello, $testName!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) - } + fun testToolCall() = runTest { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val testName = "TestUser" + val arguments = mapOf("name" to testName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals( + "Hello, $testName!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) } } @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleClients() { - runBlocking { - withContext(Dispatchers.IO) { - // First client connection - val client1 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools1 = client1.listTools() - assertNotNull(tools1, "Tools list for first client should not be null") - assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") - - val client2 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools2 = client2.listTools() - assertNotNull(tools2, "Tools list for second client should not be null") - assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") - - val toolNames1 = tools1.tools.map { it.name } - val toolNames2 = tools2.tools.map { it.name } - - assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") - assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") - assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") - assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") - - client1.close() - client2.close() - } + fun testMultipleClients() = runTest { + withContext(Dispatchers.IO) { + // First client connection + val client1 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools1 = client1.listTools() + assertNotNull(tools1, "Tools list for first client should not be null") + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + + val client2 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools2 = client2.listTools() + assertNotNull(tools2, "Tools list for second client should not be null") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + + assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") + assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") + assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") + assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") + + client1.close() + client2.close() } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt index 967e1d06..351406ef 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -38,7 +39,7 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCall() { + fun testToolCall() = runTest { val testName = "TestUser" val command = "npx tsx myClient.ts $serverUrl greet $testName" val output = executeCommand(command, tsClientDir) @@ -58,7 +59,7 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testNotifications() { + fun testNotifications() = runTest { val name = "NotifUser" val command = "npx tsx myClient.ts $serverUrl multi-greet $name" val output = executeCommand(command, tsClientDir) @@ -76,7 +77,7 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { @Test @Timeout(120, unit = TimeUnit.SECONDS) - fun testMultipleClientSequence() { + fun testMultipleClientSequence() = runTest { val testName1 = "FirstClient" val command1 = "npx tsx myClient.ts $serverUrl greet $testName1" val output1 = executeCommand(command1, tsClientDir) @@ -108,7 +109,7 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleClientParallel() { + fun testMultipleClientParallel() = runTest { val clientCount = 3 val clients = listOf( "FirstClient" to "greet", diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 4baadb8a..9020d439 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -42,7 +43,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testErrorHandling() { + fun testErrorHandling() = runTest { val nonExistentToolCommand = "npx tsx myClient.ts $serverUrl non-existent-tool" val nonExistentToolOutput = executeCommandAllowingFailure(nonExistentToolCommand, tsClientDir) @@ -62,7 +63,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testSpecialCharacters() { + fun testSpecialCharacters() = runTest { val specialChars = "!@#$+-[].,?" val tempFile = File.createTempFile("special_chars", ".txt") @@ -87,7 +88,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) @EnabledOnOs(OS.MAC, OS.LINUX) - fun testLargePayload() { + fun testLargePayload() = runTest { val largeName = "A".repeat(10 * 1024) val tempFile = File.createTempFile("large_name", ".txt") @@ -112,7 +113,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(60, unit = TimeUnit.SECONDS) - fun testComplexConcurrentRequests() { + fun testComplexConcurrentRequests() = runTest { val commands = listOf( "npx tsx myClient.ts $serverUrl greet \"Client1\"", "npx tsx myClient.ts $serverUrl multi-greet \"Client2\"", @@ -160,7 +161,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(120, unit = TimeUnit.SECONDS) - fun testRapidSequentialRequests() { + fun testRapidSequentialRequests() = runTest { val outputs = (1..10).map { i -> val command = "npx tsx myClient.ts $serverUrl greet \"RapidClient$i\"" val output = executeCommand(command, tsClientDir) From aee13bb52199033a6cc06b7f1ad33788aea3b34c Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 18:27:21 +0300 Subject: [PATCH 16/22] fixup! Introduce Kotlin integration tests --- .../sdk/integration/typescript/TypeScriptEdgeCasesTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 9020d439..5e47307a 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -56,7 +56,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { val invalidUrlOutput = executeCommandAllowingFailure(invalidUrlCommand, tsClientDir) assertTrue( - invalidUrlOutput.contains("Error:") || invalidUrlOutput.contains("ECONNREFUSED"), + invalidUrlOutput.contains("Invalid URL") && invalidUrlOutput.contains("ERR_INVALID_URL"), "Client should handle connection errors gracefully", ) } From 38b5fb22928ce2f232fe298e7d3eee32f7021824 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 18:56:24 +0300 Subject: [PATCH 17/22] fixup! Introduce Kotlin integration tests --- kotlin-sdk-test/build.gradle.kts | 1 + .../integration/kotlin/ToolEdgeCasesTest.kt | 45 +++++++------- .../integration/kotlin/ToolIntegrationTest.kt | 47 +++++++++------ .../typescript/TypeScriptEdgeCasesTest.kt | 6 +- .../kotlin/sdk/integration/utils/TestUtils.kt | 58 ++++--------------- 5 files changed, 69 insertions(+), 88 deletions(-) diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index 9f87efd6..cbe22a7f 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -15,6 +15,7 @@ kotlin { implementation(kotlin("test")) implementation(libs.ktor.server.test.host) implementation(libs.kotlinx.coroutines.test) + implementation(libs.kotest.assertions.json) } } jvmTest { diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index 83f32ef8..69df3b2d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -7,7 +7,7 @@ import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonEquals import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.delay @@ -291,7 +291,8 @@ class ToolEdgeCasesTest : KotlinTestBase() { assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "result", testText) + val expected = buildJsonObject { put("result", testText) } + assertJsonEquals(expected, structuredContent) } @Test @@ -342,18 +343,19 @@ class ToolEdgeCasesTest : KotlinTestBase() { assertTrue(text.contains("option3"), "Result should contain option3") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "name", "John Doe") - assertJsonProperty(structuredContent, "age", 30) - - val address = structuredContent["address"] as? JsonObject - assertNotNull(address, "Address should be present in structured content") - assertJsonProperty(address, "street", "123 Main St") - assertJsonProperty(address, "city", "New York") - assertJsonProperty(address, "country", "USA") - - val options = structuredContent["options"] as? JsonArray - assertNotNull(options, "Options should be present in structured content") - assertEquals(3, options.size, "Options should have 3 items") + val expectedStructured = buildJsonObject { + put("name", "John Doe") + put("age", 30) + put("address", buildJsonObject { + put("street", "123 Main St") + put("city", "New York") + put("country", "USA") + }) + put("options", buildJsonArray { + add("option1"); add("option2"); add("option3") + }) + } + assertJsonEquals(expectedStructured, structuredContent) } @Test @@ -371,7 +373,8 @@ class ToolEdgeCasesTest : KotlinTestBase() { assertEquals(10000, text.length, "Response should be 10KB in size") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "size", 10000) + val expected = buildJsonObject { put("size", 10000) } + assertJsonEquals(expected, structuredContent) } @Test @@ -392,7 +395,8 @@ class ToolEdgeCasesTest : KotlinTestBase() { assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "delay", delay) + val expected = buildJsonObject { put("delay", delay) } + assertJsonEquals(expected, structuredContent) } @Test @@ -409,10 +413,11 @@ class ToolEdgeCasesTest : KotlinTestBase() { assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") val structuredContent = toolResult.structuredContent as JsonObject - val special = structuredContent["special"]?.toString()?.trim('"') - - assertNotNull(special, "Special characters should be in structured content") - assertTrue(text.contains(specialCharsContent), "Special characters should be in the content") + val expected = buildJsonObject { + put("special", specialCharsContent) + put("length", specialCharsContent.length) + } + assertJsonEquals(expected, structuredContent) } @Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index 5fae911f..6018c0d6 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -7,7 +7,7 @@ import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonEquals import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.runBlocking @@ -324,7 +324,8 @@ class ToolIntegrationTest : KotlinTestBase() { assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "result", testText) + val expected = buildJsonObject { put("result", testText) } + assertJsonEquals(expected, structuredContent) } @Test @@ -362,18 +363,18 @@ class ToolIntegrationTest : KotlinTestBase() { assertTrue(contentText.contains("11"), "Result should contain result value") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "operation", "multiply") - assertJsonProperty(structuredContent, "result", 11.0) - - val formattedResult = structuredContent["formattedResult"]?.toString()?.trim('"') ?: "" - assertTrue( - formattedResult == "11.000" || formattedResult == "11,000", - "Formatted result should be either '11.000' or '11,000', but was '$formattedResult'", - ) - assertJsonProperty(structuredContent, "precision", 3) + val actualWithoutFormatted = buildJsonObject { + structuredContent.filterKeys { it != "formattedResult" && it != "tags" }.forEach { (k, v) -> put(k, v) } + } + val expectedWithoutFormatted = buildJsonObject { + put("operation", "multiply") + put("a", 5.5) + put("b", 2.0) + put("result", 11.0) + put("precision", 3) + } - val tags = structuredContent["tags"] as? JsonArray - assertNotNull(tags, "Tags should be present") + assertJsonEquals(expectedWithoutFormatted, actualWithoutFormatted) } } @@ -386,7 +387,11 @@ class ToolIntegrationTest : KotlinTestBase() { assertTextContent(successToolResult.content.firstOrNull(), "No error occurred") val noErrorStructured = successToolResult.structuredContent as JsonObject - assertJsonProperty(noErrorStructured, "error", false) + val expectedNoError = buildJsonObject { + put("error", false) + put("message", "Success") + } + assertJsonEquals(expectedNoError, noErrorStructured) val errorArgs = mapOf( "errorType" to "error", @@ -398,8 +403,11 @@ class ToolIntegrationTest : KotlinTestBase() { assertTextContent(errorToolResult.content.firstOrNull(), "Error: Custom error message") val errorStructured = errorToolResult.structuredContent as JsonObject - assertJsonProperty(errorStructured, "error", true) - assertJsonProperty(errorStructured, "message", "Custom error message") + val expectedError = buildJsonObject { + put("error", true) + put("message", "Custom error message") + } + assertJsonEquals(expectedError, errorStructured) val exceptionArgs = mapOf( "errorType" to "exception", @@ -450,8 +458,11 @@ class ToolIntegrationTest : KotlinTestBase() { assertTrue(imageContent.data.isNotEmpty(), "Image data should not be empty") val structuredContent = toolResult.structuredContent as JsonObject - assertJsonProperty(structuredContent, "text", testText) - assertJsonProperty(structuredContent, "includeImage", true) + val expectedStructured = buildJsonObject { + put("text", testText) + put("includeImage", true) + } + assertJsonEquals(expectedStructured, structuredContent) val textOnlyArgs = mapOf( "text" to testText, diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 5e47307a..7f27c5b7 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -43,7 +43,7 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testErrorHandling() = runTest { + fun testInvalidURL() = runTest { val nonExistentToolCommand = "npx tsx myClient.ts $serverUrl non-existent-tool" val nonExistentToolOutput = executeCommandAllowingFailure(nonExistentToolCommand, tsClientDir) @@ -56,7 +56,9 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { val invalidUrlOutput = executeCommandAllowingFailure(invalidUrlCommand, tsClientDir) assertTrue( - invalidUrlOutput.contains("Invalid URL") && invalidUrlOutput.contains("ERR_INVALID_URL"), + invalidUrlOutput.contains("Invalid URL") || + invalidUrlOutput.contains("ERR_INVALID_URL") || + invalidUrlOutput.contains("ECONNREFUSED"), "Client should handle connection errors gracefully", ) } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt index bed66cd4..0bca29c8 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt @@ -6,8 +6,10 @@ import io.modelcontextprotocol.kotlin.sdk.TextContent import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext +import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive +import io.kotest.assertions.json.* import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -35,57 +37,17 @@ object TestUtils { return result } - /** - * Asserts that a JSON property has the expected string value. - */ - fun assertJsonProperty( - json: JsonObject, - property: String, - expectedValue: String, - message: String = "", - ) { - assertEquals(expectedValue, json[property]?.toString()?.trim('"'), "${message}$property should match") + // Use Kotest JSON assertions to compare whole JSON structures. + fun assertJsonEquals(expectedJson: String, actual: JsonElement, message: String = "") { + val prefix = if (message.isNotEmpty()) "$message\n" else "" + (actual.toString()).shouldEqualJson(prefix + expectedJson) } - /** - * Asserts that a JSON property has the expected numeric value. - */ - fun assertJsonProperty( - json: JsonObject, - property: String, - expectedValue: Number, - message: String = "", - ) { - when (expectedValue) { - is Int -> assertEquals( - expectedValue, - (json[property] as? JsonPrimitive)?.content?.toIntOrNull(), - "${message}$property should match", - ) - - is Double -> assertEquals( - expectedValue, - (json[property] as? JsonPrimitive)?.content?.toDoubleOrNull(), - "${message}$property should match", - ) - - else -> assertEquals( - expectedValue.toString(), - json[property]?.toString()?.trim('"'), - "${message}$property should match", - ) - } + fun assertJsonEquals(expected: JsonElement, actual: JsonElement) { + (actual.toString()).shouldEqualJson(expected.toString()) } - /** - * Asserts that a JSON property has the expected boolean value. - */ - fun assertJsonProperty( - json: JsonObject, - property: String, - expectedValue: Boolean, - message: String = "", - ) { - assertEquals(expectedValue.toString(), json[property].toString(), "${message}$property should match") + fun assertIsJsonArray(actual: JsonElement) { + actual.toString().shouldBeJsonArray() } } From 6ad991f08a5c8b9b566d533bee05002ca244ee33 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 20:41:34 +0300 Subject: [PATCH 18/22] fixup! Introduce Kotlin integration tests --- .../kotlin/ResourceEdgeCasesTest.kt | 12 +- .../integration/kotlin/ToolEdgeCasesTest.kt | 221 +++++++++--------- .../integration/kotlin/ToolIntegrationTest.kt | 138 +++++------ ...tlinClientTypeScriptServerEdgeCasesTest.kt | 200 ++++++++-------- .../KotlinClientTypeScriptServerTest.kt | 148 ++++++------ .../kotlin/sdk/integration/utils/TestUtils.kt | 40 ---- 6 files changed, 345 insertions(+), 414 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt index d4cf5187..165e6936 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt @@ -260,11 +260,13 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testSubscribeAndUnsubscribe() = runTest { - val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) - assertNotNull(subscribeResult, "Subscribe result should not be null") + fun testSubscribeAndUnsubscribe() { + runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") - val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) - assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index 69df3b2d..0d740a7d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -1,14 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin +import io.kotest.assertions.json.shouldEqualJson import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonEquals -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -281,143 +279,150 @@ class ToolEdgeCasesTest : KotlinTestBase() { } @Test - fun testBasicTool() = runTest { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) + fun testBasicTool() { + runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) - val result = client.callTool(basicToolName, arguments) + val result = client.callTool(basicToolName, arguments) as CallToolResultBase - val toolResult = assertCallToolResult(result) - assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + val expectedToolResult = "[TextContent(text=Echo: Hello, world!, annotations=null)]" + assertEquals(expectedToolResult, result.content.toString(), "Unexpected tool result") - val structuredContent = toolResult.structuredContent as JsonObject - val expected = buildJsonObject { put("result", testText) } - assertJsonEquals(expected, structuredContent) + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "result" : "Hello, world!" + } + """.trimIndent() + + actualContent.shouldEqualJson(expectedContent) + } } @Test - fun testComplexNestedSchema() = runTest { - val userJson = buildJsonObject { - put("name", JsonPrimitive("John Doe")) - put("age", JsonPrimitive(30)) - put( - "address", - buildJsonObject { - put("street", JsonPrimitive("123 Main St")) - put("city", JsonPrimitive("New York")) - put("country", JsonPrimitive("USA")) - }, - ) - } + fun testComplexNestedSchema() { + runTest { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Galt")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } - val optionsJson = buildJsonArray { - add(JsonPrimitive("option1")) - add(JsonPrimitive("option2")) - add(JsonPrimitive("option3")) - } + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } - val arguments = buildJsonObject { - put("user", userJson) - put("options", optionsJson) - } + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } - val result = client.callTool( - CallToolRequest( - name = complexToolName, - arguments = arguments, - ), - ) - - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" - - assertTrue(text.contains("John Doe"), "Result should contain the name") - assertTrue(text.contains("30"), "Result should contain the age") - assertTrue(text.contains("123 Main St"), "Result should contain the street") - assertTrue(text.contains("New York"), "Result should contain the city") - assertTrue(text.contains("USA"), "Result should contain the country") - assertTrue(text.contains("option1"), "Result should contain option1") - assertTrue(text.contains("option2"), "Result should contain option2") - assertTrue(text.contains("option3"), "Result should contain option3") - - val structuredContent = toolResult.structuredContent as JsonObject - val expectedStructured = buildJsonObject { - put("name", "John Doe") - put("age", 30) - put("address", buildJsonObject { - put("street", "123 Main St") - put("city", "New York") - put("country", "USA") - }) - put("options", buildJsonArray { - add("option1"); add("option2"); add("option3") - }) + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) as CallToolResultBase + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "name" : "John Galt", + "age" : 30, + "address" : { + "street" : "123 Main St", + "city" : "New York", + "country" : "USA" + }, + "options" : [ "option1", "option2", "option3" ] + } + """.trimIndent() + + actualContent.shouldEqualJson(expectedContent) } - assertJsonEquals(expectedStructured, structuredContent) } @Test - fun testLargeResponse() = runTest { - val size = 10 - val arguments = mapOf("size" to size) + fun testLargeResponse() { + runTest { + val size = 10 + val arguments = mapOf("size" to size) - val result = client.callTool(largeToolName, arguments) + val result = client.callTool(largeToolName, arguments) as CallToolResultBase - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val content = result.content.firstOrNull() as TextContent + assertNotNull(content, "Tool result content should be TextContent") - assertEquals(10000, text.length, "Response should be 10KB in size") + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "size" : 10000 + } + """.trimIndent() - val structuredContent = toolResult.structuredContent as JsonObject - val expected = buildJsonObject { put("size", 10000) } - assertJsonEquals(expected, structuredContent) + actualContent.shouldEqualJson(expectedContent) + } } @Test - fun testSlowTool() = runTest { - val delay = 500 - val arguments = mapOf("delay" to delay) + fun testSlowTool() { + runTest { + val delay = 500 + val arguments = mapOf("delay" to delay) - val startTime = System.currentTimeMillis() - val result = client.callTool(slowToolName, arguments) - val endTime = System.currentTimeMillis() + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) as CallToolResultBase + val endTime = System.currentTimeMillis() - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") - assertTrue(text.contains("${delay}ms"), "Result should mention the delay") - assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") - val structuredContent = toolResult.structuredContent as JsonObject - val expected = buildJsonObject { put("delay", delay) } - assertJsonEquals(expected, structuredContent) + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "delay" : 500 + } + """.trimIndent() + + actualContent.shouldEqualJson(expectedContent) + } } @Test - fun testSpecialCharacters() = runTest { - val arguments = mapOf("special" to specialCharsContent) + fun testSpecialCharacters() { + runTest { + val arguments = mapOf("special" to specialCharsContent) - val result = client.callTool(specialCharsToolName, arguments) + val result = client.callTool(specialCharsToolName, arguments) as CallToolResultBase - val toolResult = assertCallToolResult(result) - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" - assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "special" : "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t", + "length" : 34 + } + """.trimIndent() - val structuredContent = toolResult.structuredContent as JsonObject - val expected = buildJsonObject { - put("special", specialCharsContent) - put("length", specialCharsContent.length) + actualContent.shouldEqualJson(expectedContent) } - assertJsonEquals(expected, structuredContent) } @Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index 6018c0d6..9fc29ee1 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -1,18 +1,16 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin +import io.kotest.assertions.json.shouldEqualJson import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase import io.modelcontextprotocol.kotlin.sdk.ImageContent import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonEquals -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.runBlocking import kotlinx.serialization.json.JsonArray -import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.add import kotlinx.serialization.json.buildJsonArray @@ -305,6 +303,7 @@ class ToolIntegrationTest : KotlinTestBase() { assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") val testTool = result.tools.find { it.name == testToolName } + assertNotNull(testTool, "Test tool should be in the list") assertEquals( testToolDescription, @@ -314,18 +313,20 @@ class ToolIntegrationTest : KotlinTestBase() { } @Test - fun testCallTool() = runTest { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) + fun testCallTool() { + runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) - val result = client.callTool(testToolName, arguments) + val result = client.callTool(testToolName, arguments) as CallToolResultBase - val toolResult = assertCallToolResult(result) - assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + val actualContent = result.structuredContent.toString() + val expectedContent = """ + {"result":"Hello, world!"} + """.trimIndent() - val structuredContent = toolResult.structuredContent as JsonObject - val expected = buildJsonObject { put("result", testText) } - assertJsonEquals(expected, structuredContent) + actualContent.shouldEqualJson(expectedContent) + } } @Test @@ -345,36 +346,22 @@ class ToolIntegrationTest : KotlinTestBase() { "tags" to listOf("test", "calculator", "integration"), ) - val result = client.callTool(complexToolName, arguments) - - val toolResult = assertCallToolResult(result) - - val content = toolResult.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val contentText = requireNotNull(content.text) - - assertTrue(contentText.contains("Operation"), "Result should contain operation") - assertTrue( - contentText.contains("multiply"), - "Result should contain multiply operation", - ) - assertTrue(contentText.contains("5.5"), "Result should contain first operand") - assertTrue(contentText.contains("2.0"), "Result should contain second operand") - assertTrue(contentText.contains("11"), "Result should contain result value") - - val structuredContent = toolResult.structuredContent as JsonObject - val actualWithoutFormatted = buildJsonObject { - structuredContent.filterKeys { it != "formattedResult" && it != "tags" }.forEach { (k, v) -> put(k, v) } - } - val expectedWithoutFormatted = buildJsonObject { - put("operation", "multiply") - put("a", 5.5) - put("b", 2.0) - put("result", 11.0) - put("precision", 3) - } - - assertJsonEquals(expectedWithoutFormatted, actualWithoutFormatted) + val result = client.callTool(complexToolName, arguments) as CallToolResultBase + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "operation" : "multiply", + "a" : 5.5, + "b" : 2.0, + "result" : 11.0, + "formattedResult" : "11,000", + "precision" : 3, + "tags" : [ ] + } + """.trimIndent() + + actualContent.shouldEqualJson(expectedContent) } } @@ -383,31 +370,31 @@ class ToolIntegrationTest : KotlinTestBase() { val successArgs = mapOf("errorType" to "none") val successResult = client.callTool(errorToolName, successArgs) - val successToolResult = assertCallToolResult(successResult, "No error: ") - assertTextContent(successToolResult.content.firstOrNull(), "No error occurred") + val actualContent = successResult?.structuredContent.toString() + val expectedContent = """ + { + "error" : false, + "message" : "Success" + } + """.trimIndent() - val noErrorStructured = successToolResult.structuredContent as JsonObject - val expectedNoError = buildJsonObject { - put("error", false) - put("message", "Success") - } - assertJsonEquals(expectedNoError, noErrorStructured) + actualContent.shouldEqualJson(expectedContent) val errorArgs = mapOf( "errorType" to "error", "message" to "Custom error message", ) - val errorResult = client.callTool(errorToolName, errorArgs) + val errorResult = client.callTool(errorToolName, errorArgs) as CallToolResultBase - val errorToolResult = assertCallToolResult(errorResult, "Error: ") - assertTextContent(errorToolResult.content.firstOrNull(), "Error: Custom error message") + val actualError = errorResult.structuredContent.toString() + val expectedError = """ + { + "error" : true, + "message" : "Custom error message" + } + """.trimIndent() - val errorStructured = errorToolResult.structuredContent as JsonObject - val expectedError = buildJsonObject { - put("error", true) - put("message", "Custom error message") - } - assertJsonEquals(expectedError, errorStructured) + actualError.shouldEqualJson(expectedError) val exceptionArgs = mapOf( "errorType" to "exception", @@ -434,16 +421,15 @@ class ToolIntegrationTest : KotlinTestBase() { "includeImage" to true, ) - val result = client.callTool(multiContentToolName, arguments) + val result = client.callTool(multiContentToolName, arguments) as CallToolResultBase - val toolResult = assertCallToolResult(result) assertEquals( 2, - toolResult.content.size, + result.content.size, "Tool result should have 2 content items", ) - val textContent = toolResult.content.firstOrNull { it is TextContent } as? TextContent + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent assertNotNull(textContent, "Result should contain TextContent") assertNotNull(textContent.text, "Text content should not be null") assertEquals( @@ -452,32 +438,32 @@ class ToolIntegrationTest : KotlinTestBase() { "Text content should match", ) - val imageContent = toolResult.content.firstOrNull { it is ImageContent } as? ImageContent + val imageContent = result.content.firstOrNull { it is ImageContent } as? ImageContent assertNotNull(imageContent, "Result should contain ImageContent") assertEquals("image/png", imageContent.mimeType, "Image MIME type should match") assertTrue(imageContent.data.isNotEmpty(), "Image data should not be empty") - val structuredContent = toolResult.structuredContent as JsonObject - val expectedStructured = buildJsonObject { - put("text", testText) - put("includeImage", true) - } - assertJsonEquals(expectedStructured, structuredContent) + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "text" : "Test multi-content", + "includeImage" : true + } + """.trimIndent() + + actualContent.shouldEqualJson(expectedContent) val textOnlyArgs = mapOf( "text" to testText, "includeImage" to false, ) - val textOnlyResult = client.callTool(multiContentToolName, textOnlyArgs) + val textOnlyResult = client.callTool(multiContentToolName, textOnlyArgs) as CallToolResultBase - val textOnlyToolResult = assertCallToolResult(textOnlyResult, "Text-only: ") assertEquals( 1, - textOnlyToolResult.content.size, + textOnlyResult.content.size, "Text-only result should have 1 content item", ) - - assertTextContent(textOnlyToolResult.content.firstOrNull(), "Text content: $testText") } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index ec83c0d3..48bc1a65 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -9,11 +9,9 @@ import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.Deferred -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive @@ -72,92 +70,86 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testNonExistentTool() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val nonExistentToolName = "non-existent-tool" - val arguments = mapOf("name" to "TestUser") + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") - val exception = assertThrows { - client.callTool(nonExistentToolName, arguments) - } - - val expectedMessage = - "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" - assertEquals( - expectedMessage, - exception.message, - "Unexpected error message for non-existent tool", - ) + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) } + + val expectedMessage = + "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" + assertEquals( + expectedMessage, + exception.message, + "Unexpected error message for non-existent tool", + ) } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testSpecialCharactersInArguments() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" - val arguments = mapOf("name" to specialChars) + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") - val text = textContent.text ?: "" - assertTrue( - text.contains(specialChars), - "Tool response should contain the special characters", - ) - } + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testLargePayload() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val largeName = "A".repeat(10 * 1024) - val arguments = mapOf("name" to largeName) + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") - val text = textContent.text ?: "" - assertTrue( - text.contains("Hello,") && text.contains("A"), - "Tool response should contain the greeting with the large name", - ) - } + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) } @Test @Timeout(60, unit = TimeUnit.SECONDS) fun testConcurrentRequests() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val concurrentCount = 5 - val results = mutableListOf>() + val concurrentCount = 5 + val results = mutableListOf>() - for (i in 1..concurrentCount) { + for (i in 1..concurrentCount) { + runBlocking { val deferred = async { val name = "ConcurrentClient$i" val arguments = mapOf("name" to name) @@ -173,39 +165,38 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { } results.add(deferred) } + } - val responses = results.awaitAll() + val responses = results.awaitAll() - for (i in 1..concurrentCount) { - val expectedName = "ConcurrentClient$i" - val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } - assertEquals( - 1, - matchingResponses.size, - "Should have exactly one response for $expectedName", - ) - } + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testInvalidArguments() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - val invalidArguments = mapOf( - "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), - ) + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) - val exception = assertThrows { - client.callTool("greet", invalidArguments) - } + val exception = assertThrows { + client.callTool("greet", invalidArguments) + } - val msg = exception.message ?: "" - val expectedMessage = """ + val msg = exception.message ?: "" + val expectedMessage = """ JSONRPCError(code=InvalidParams, message=MCP error -32602: Invalid arguments for tool greet: [ { "code": "invalid_type", @@ -219,35 +210,32 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { ], data={}) """.trimIndent() - assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") - } + assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testMultipleToolCalls() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - repeat(10) { i -> - val name = "SequentialClient$i" - val arguments = mapOf("name" to name) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null for call $i") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present for call $i") - - assertEquals( - "Hello, $name!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) - } + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt index b7e77652..481ec937 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -8,9 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach @@ -66,100 +64,92 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testKotlinClientConnectsToTypeScriptServer() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) - assertNotNull(client, "Client should be initialized") + assertNotNull(client, "Client should be initialized") - val pingResult = client.ping() - assertNotNull(pingResult, "Ping result should not be null") + val pingResult = client.ping() + assertNotNull(pingResult, "Ping result should not be null") - val serverImpl = client.serverVersion - assertNotNull(serverImpl, "Server implementation should not be null") - println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") - } + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testListTools() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val result = client.listTools() - assertNotNull(result, "Tools list should not be null") - assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") - - // Verify specific utils are available - val toolNames = result.tools.map { it.name } - assertTrue(toolNames.contains("greet"), "Greet tool should be available") - assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") - assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") - - println("Available utils: ${toolNames.joinToString()}") - } + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + // Verify specific utils are available + val toolNames = result.tools.map { it.name } + assertTrue(toolNames.contains("greet"), "Greet tool should be available") + assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") + assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") + + println("Available utils: ${toolNames.joinToString()}") } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testToolCall() = runTest { - withContext(Dispatchers.IO) { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val testName = "TestUser" - val arguments = mapOf("name" to testName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - assertEquals( - "Hello, $testName!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) - } + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val testName = "TestUser" + val arguments = mapOf("name" to testName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals( + "Hello, $testName!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testMultipleClients() = runTest { - withContext(Dispatchers.IO) { - // First client connection - val client1 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools1 = client1.listTools() - assertNotNull(tools1, "Tools list for first client should not be null") - assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") - - val client2 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools2 = client2.listTools() - assertNotNull(tools2, "Tools list for second client should not be null") - assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") - - val toolNames1 = tools1.tools.map { it.name } - val toolNames2 = tools2.tools.map { it.name } - - assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") - assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") - assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") - assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") - - client1.close() - client2.close() - } + // First client connection + val client1 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools1 = client1.listTools() + assertNotNull(tools1, "Tools list for first client should not be null") + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + + val client2 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools2 = client2.listTools() + assertNotNull(tools2, "Tools list for second client should not be null") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + + assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") + assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") + assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") + assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") + + client1.close() + client2.close() } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt index 0bca29c8..46515610 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt @@ -1,18 +1,8 @@ package io.modelcontextprotocol.kotlin.sdk.integration.utils -import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase -import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent -import io.modelcontextprotocol.kotlin.sdk.TextContent import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.JsonPrimitive -import io.kotest.assertions.json.* -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue object TestUtils { fun runTest(block: suspend () -> T): T = runBlocking { @@ -20,34 +10,4 @@ object TestUtils { block() } } - - fun assertTextContent(content: PromptMessageContent?, expectedText: String) { - assertNotNull(content, "Content should not be null") - assertTrue(content is TextContent, "Content should be TextContent") - assertNotNull(content.text, "Text content should not be null") - assertEquals(expectedText, content.text, "Text content should match") - } - - fun assertCallToolResult(result: Any?, message: String = ""): CallToolResultBase { - assertNotNull(result, "${message}Call tool result should not be null") - assertTrue(result is CallToolResultBase, "${message}Result should be CallToolResultBase") - assertTrue(result.content.isNotEmpty(), "${message}Tool result content should not be empty") - assertNotNull(result.structuredContent, "${message}Tool result structured content should not be null") - - return result - } - - // Use Kotest JSON assertions to compare whole JSON structures. - fun assertJsonEquals(expectedJson: String, actual: JsonElement, message: String = "") { - val prefix = if (message.isNotEmpty()) "$message\n" else "" - (actual.toString()).shouldEqualJson(prefix + expectedJson) - } - - fun assertJsonEquals(expected: JsonElement, actual: JsonElement) { - (actual.toString()).shouldEqualJson(expected.toString()) - } - - fun assertIsJsonArray(actual: JsonElement) { - actual.toString().shouldBeJsonArray() - } } From 4ad0c7d017357e9e1b42c1b67fc523cb66af6d70 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 20:55:25 +0300 Subject: [PATCH 19/22] fixup! Introduce Kotlin integration tests --- ...tlinClientTypeScriptServerEdgeCasesTest.kt | 234 ++++++++---------- .../KotlinClientTypeScriptServerTest.kt | 129 +++++----- .../typescript/TypeScriptTestBase.kt | 27 +- 3 files changed, 193 insertions(+), 197 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index 48bc1a65..2eab45d5 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -70,172 +70,156 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testNonExistentTool() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + withClient(serverUrl) { client -> + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") - val nonExistentToolName = "non-existent-tool" - val arguments = mapOf("name" to "TestUser") + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) + } - val exception = assertThrows { - client.callTool(nonExistentToolName, arguments) + val expectedMessage = + "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" + assertEquals( + expectedMessage, + exception.message, + "Unexpected error message for non-existent tool", + ) } - - val expectedMessage = - "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" - assertEquals( - expectedMessage, - exception.message, - "Unexpected error message for non-existent tool", - ) } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testSpecialCharactersInArguments() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" - val arguments = mapOf("name" to specialChars) + withClient(serverUrl) { client -> + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") - val text = textContent.text ?: "" - assertTrue( - text.contains(specialChars), - "Tool response should contain the special characters", - ) + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testLargePayload() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + withClient(serverUrl) { client -> + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) - val largeName = "A".repeat(10 * 1024) - val arguments = mapOf("name" to largeName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") - val text = textContent.text ?: "" - assertTrue( - text.contains("Hello,") && text.contains("A"), - "Tool response should contain the greeting with the large name", - ) + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } } @Test @Timeout(60, unit = TimeUnit.SECONDS) fun testConcurrentRequests() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val concurrentCount = 5 - val results = mutableListOf>() - - for (i in 1..concurrentCount) { - runBlocking { - val deferred = async { - val name = "ConcurrentClient$i" - val arguments = mapOf("name" to name) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null for client $i") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present for client $i") - - textContent.text ?: "" + withClient(serverUrl) { client -> + val concurrentCount = 5 + val responses = kotlinx.coroutines.coroutineScope { + val results = (1..concurrentCount).map { i -> + async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text ?: "" + } } - results.add(deferred) + results.awaitAll() } - } - val responses = results.awaitAll() - - for (i in 1..concurrentCount) { - val expectedName = "ConcurrentClient$i" - val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } - assertEquals( - 1, - matchingResponses.size, - "Should have exactly one response for $expectedName", - ) + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testInvalidArguments() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + withClient(serverUrl) { client -> + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) - val invalidArguments = mapOf( - "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), - ) + val exception = assertThrows { + client.callTool("greet", invalidArguments) + } - val exception = assertThrows { - client.callTool("greet", invalidArguments) + val msg = exception.message ?: "" + val expectedMessage = """ + JSONRPCError(code=InvalidParams, message=MCP error -32602: Invalid arguments for tool greet: [ + { + "code": "invalid_type", + "expected": "string", + "received": "object", + "path": [ + "name" + ], + "message": "Expected string, received object" + } + ], data={}) + """.trimIndent() + + assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") } - - val msg = exception.message ?: "" - val expectedMessage = """ - JSONRPCError(code=InvalidParams, message=MCP error -32602: Invalid arguments for tool greet: [ - { - "code": "invalid_type", - "expected": "string", - "received": "object", - "path": [ - "name" - ], - "message": "Expected string, received object" - } - ], data={}) - """.trimIndent() - - assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testMultipleToolCalls() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - repeat(10) { i -> - val name = "SequentialClient$i" - val arguments = mapOf("name" to name) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null for call $i") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present for call $i") - - assertEquals( - "Hello, $name!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) + withClient(serverUrl) { client -> + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt index 481ec937..5e060798 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -64,92 +64,81 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testKotlinClientConnectsToTypeScriptServer() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) + withClient(serverUrl) { client -> + assertNotNull(client, "Client should be initialized") - assertNotNull(client, "Client should be initialized") + val pingResult = client.ping() + assertNotNull(pingResult, "Ping result should not be null") - val pingResult = client.ping() - assertNotNull(pingResult, "Ping result should not be null") - - val serverImpl = client.serverVersion - assertNotNull(serverImpl, "Server implementation should not be null") - println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testListTools() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val result = client.listTools() - assertNotNull(result, "Tools list should not be null") - assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") - - // Verify specific utils are available - val toolNames = result.tools.map { it.name } - assertTrue(toolNames.contains("greet"), "Greet tool should be available") - assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") - assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") - - println("Available utils: ${toolNames.joinToString()}") + withClient(serverUrl) { client -> + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + // Verify specific utils are available + val toolNames = result.tools.map { it.name } + assertTrue(toolNames.contains("greet"), "Greet tool should be available") + assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") + assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") + + println("Available utils: ${toolNames.joinToString()}") + } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testToolCall() = runTest { - client = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val testName = "TestUser" - val arguments = mapOf("name" to testName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - assertEquals( - "Hello, $testName!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) + withClient(serverUrl) { client -> + val testName = "TestUser" + val arguments = mapOf("name" to testName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals( + "Hello, $testName!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } } @Test @Timeout(30, unit = TimeUnit.SECONDS) fun testMultipleClients() = runTest { - // First client connection - val client1 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools1 = client1.listTools() - assertNotNull(tools1, "Tools list for first client should not be null") - assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") - - val client2 = HttpClient(CIO) { - install(SSE) - }.mcpStreamableHttp(serverUrl) - - val tools2 = client2.listTools() - assertNotNull(tools2, "Tools list for second client should not be null") - assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") - - val toolNames1 = tools1.tools.map { it.name } - val toolNames2 = tools2.tools.map { it.name } - - assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") - assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") - assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") - assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") - - client1.close() - client2.close() + val client1 = newClient(serverUrl) + val client2 = newClient(serverUrl) + try { + val tools1 = client1.listTools() + assertNotNull(tools1, "Tools list for first client should not be null") + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + + val tools2 = client2.listTools() + assertNotNull(tools2, "Tools list for second client should not be null") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + + assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") + assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") + assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") + assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") + } finally { + client1.close() + client2.close() + } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index b5c3165c..78daa4bf 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -1,6 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import kotlinx.coroutines.withTimeout import org.junit.jupiter.api.BeforeAll import java.io.BufferedReader import java.io.File @@ -9,6 +15,7 @@ import java.net.ServerSocket import java.net.Socket import java.nio.file.Files import java.util.concurrent.TimeUnit +import kotlin.time.Duration.Companion.seconds @Retry(times = 3) abstract class TypeScriptTestBase { @@ -129,7 +136,7 @@ abstract class TypeScriptTestBase { } } - protected fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { + private fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { process.destroyForcibly() process.waitFor(2, TimeUnit.SECONDS) @@ -138,7 +145,7 @@ abstract class TypeScriptTestBase { return true } - protected fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { + private fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { val outputReader = Thread { try { process.inputStream.bufferedReader().useLines { lines -> @@ -219,4 +226,20 @@ abstract class TypeScriptTestBase { println("$name did not stop gracefully, forced termination") } } + + private suspend fun newClient(serverUrl: String): Client = + HttpClient(CIO) { install(SSE) }.mcpStreamableHttp(serverUrl) + + protected suspend fun withClient(serverUrl: String, block: suspend (Client) -> T): T { + val client = newClient(serverUrl) + return try { + withTimeout(20.seconds) { block(client) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) { + // ignore errors + } + } + } } From 60a760e3b66cd3a457be50d6d8f0dbf9b715dd88 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 21:06:35 +0300 Subject: [PATCH 20/22] fixup! Introduce Kotlin integration tests --- .../integration/kotlin/ToolEdgeCasesTest.kt | 10 +++++----- .../integration/kotlin/ToolIntegrationTest.kt | 10 +++++----- .../typescript/TypeScriptTestBase.kt | 19 ++++++++----------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index 0d740a7d..9799c343 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -296,7 +296,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -348,7 +348,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -370,7 +370,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -396,7 +396,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -421,7 +421,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index 9fc29ee1..5a3f6930 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -325,7 +325,7 @@ class ToolIntegrationTest : KotlinTestBase() { {"result":"Hello, world!"} """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -361,7 +361,7 @@ class ToolIntegrationTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent } } @@ -378,7 +378,7 @@ class ToolIntegrationTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent val errorArgs = mapOf( "errorType" to "error", @@ -394,7 +394,7 @@ class ToolIntegrationTest : KotlinTestBase() { } """.trimIndent() - actualError.shouldEqualJson(expectedError) + actualError shouldEqualJson expectedError val exceptionArgs = mapOf( "errorType" to "exception", @@ -451,7 +451,7 @@ class ToolIntegrationTest : KotlinTestBase() { } """.trimIndent() - actualContent.shouldEqualJson(expectedContent) + actualContent shouldEqualJson expectedContent val textOnlyArgs = mapOf( "text" to testText, diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt index 78daa4bf..a19f00ec 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -57,13 +57,9 @@ abstract class TypeScriptTestBase { } println("Installing TypeScript SDK dependencies") - executeCommand("npm install", sdkDir) + executeCommand("npm install", sdkDir, allowFailure = false, timeoutSeconds = null) } - @JvmStatic - protected fun executeCommand(command: String, workingDir: File): String = - runCommand(command, workingDir, allowFailure = false, timeoutSeconds = null) - @JvmStatic protected fun killProcessOnPort(port: Int) { val isWindows = System.getProperty("os.name").lowercase().contains("windows") @@ -72,7 +68,7 @@ abstract class TypeScriptTestBase { } else { "lsof -ti:$port | xargs kill -9 2>/dev/null || true" } - runCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) + executeCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) } @JvmStatic @@ -82,11 +78,12 @@ abstract class TypeScriptTestBase { } } - private fun runCommand( + @JvmStatic + protected fun executeCommand( command: String, workingDir: File, - allowFailure: Boolean, - timeoutSeconds: Long?, + allowFailure: Boolean = false, + timeoutSeconds: Long? = null, ): String { if (!workingDir.exists()) { if (!workingDir.mkdirs()) { @@ -174,7 +171,7 @@ abstract class TypeScriptTestBase { } protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = - runCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) + executeCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) protected fun startTypeScriptServer(port: Int): Process { killProcessOnPort(port) @@ -227,7 +224,7 @@ abstract class TypeScriptTestBase { } } - private suspend fun newClient(serverUrl: String): Client = + protected suspend fun newClient(serverUrl: String): Client = HttpClient(CIO) { install(SSE) }.mcpStreamableHttp(serverUrl) protected suspend fun withClient(serverUrl: String, block: suspend (Client) -> T): T { From bae5f68296d275aeb860858e599968d769fee4a2 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 21:08:02 +0300 Subject: [PATCH 21/22] fixup! Introduce Kotlin integration tests --- .../kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt | 6 +++--- .../KotlinClientTypeScriptServerEdgeCasesTest.kt | 7 +------ .../typescript/KotlinClientTypeScriptServerTest.kt | 4 ---- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index 9799c343..a0dc9ba0 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -368,7 +368,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { { "size" : 10000 } - """.trimIndent() + """.trimIndent() actualContent shouldEqualJson expectedContent } @@ -394,7 +394,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { { "delay" : 500 } - """.trimIndent() + """.trimIndent() actualContent shouldEqualJson expectedContent } @@ -419,7 +419,7 @@ class ToolEdgeCasesTest : KotlinTestBase() { "special" : "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t", "length" : 34 } - """.trimIndent() + """.trimIndent() actualContent shouldEqualJson expectedContent } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index 2eab45d5..25ead220 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -1,14 +1,9 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.sse.SSE import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest -import kotlinx.coroutines.Deferred import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.runBlocking @@ -193,7 +188,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { "message": "Expected string, received object" } ], data={}) - """.trimIndent() + """.trimIndent() assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt index 5e060798..13aaa73d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -1,12 +1,8 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.sse.SSE import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout From b881a26b75ebd130710efc2a86581234ced99307 Mon Sep 17 00:00:00 2001 From: Sergey Karpov Date: Mon, 1 Sep 2025 21:59:08 +0300 Subject: [PATCH 22/22] fixup! Introduce Kotlin integration tests --- .../kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index 5a3f6930..044237a2 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -18,6 +18,9 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows +import java.text.DecimalFormat +import java.text.DecimalFormatSymbols +import java.util.Locale import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -161,7 +164,10 @@ class ToolIntegrationTest : KotlinTestBase() { else -> 0.0 } - val formattedResult = "%.${precision}f".format(result) + val pattern = if (precision > 0) "0." + "0".repeat(precision) else "0" + val symbols = DecimalFormatSymbols(Locale.US).apply { decimalSeparator = ',' } + val df = DecimalFormat(pattern, symbols).apply { isGroupingUsed = false } + val formattedResult = df.format(result) val textContent = if (showSteps) { "Operation: $operation\nA: $a\nB: $b\nResult: $formattedResult\nTags: ${