From fbdf207bb0974c1f5505761cd225fd3e34283203 Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:50:03 +0200 Subject: [PATCH] #107 fix The class has been enhanced to monitor the standard error (stderr) stream, enabling the detection of fatal errors that should terminate the connection. Its internal threading model was refactored to use structured concurrency, ensuring that the separate coroutines for standard input, output, and error are properly supervised and cancelled together. The API was expanded with new constructors that accept an error source and a predicate for identifying fatal error lines. Furthermore, the helper was updated to expose functionality, supporting the processing of line-based text from the error stream. --- AGENTS.md | 1 + gradle/libs.versions.toml | 11 +- kotlin-sdk-client/api/kotlin-sdk-client.api | 3 + kotlin-sdk-client/build.gradle.kts | 2 + .../kotlin/sdk/client/StdioClientTransport.kt | 230 +++++++++++++----- .../StreamableHttpClientTransportTest.kt | 4 - .../StdioClientTransportIntegrationTest.kt | 63 +++++ .../sdk/client/StdioClientTransportTest.kt | 217 +++++++++++++++++ .../kotlin/sdk/client/testUtils.kt | 90 +++++++ kotlin-sdk-core/api/kotlin-sdk-core.api | 2 + .../kotlin/sdk/shared/AbstractTransport.kt | 2 +- .../kotlin/sdk/shared/ReadBuffer.kt | 41 +++- 12 files changed, 598 insertions(+), 68 deletions(-) create mode 100644 kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportIntegrationTest.kt create mode 100644 kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt create mode 100644 kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/testUtils.kt diff --git a/AGENTS.md b/AGENTS.md index 5b49a6db..3b53b931 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -55,6 +55,7 @@ MCP Kotlin SDK — Kotlin Multiplatform implementation of the Model Context Prot ### Multiplatform Patterns - Use `expect`/`actual` pattern for platform-specific implementations in `utils.*` files. - Test changes on JVM first, then verify platform-specific behavior if needed. +- Use Kotlin 2.2 api and language level - Supported targets: JVM (1.8+), JS/Wasm, iOS, watchOS, tvOS. ### Serialization diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9ad38f9b..6e6013e2 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -6,21 +6,23 @@ atomicfu = "0.29.0" ktlint = "14.0.1" kover = "0.9.3" netty = "4.2.7.Final" + mavenPublish = "0.35.0" binaryCompatibilityValidatorPlugin = "0.18.1" openapi-generator = "7.17.0" # libraries version -serialization = "1.9.0" +awaitility = "4.3.0" collections-immutable = "0.4.0" coroutines = "1.10.2" +kotest = "6.0.4" kotlinx-io = "0.8.1" ktor = "3.2.3" logging = "7.0.13" -slf4j = "2.0.17" -kotest = "6.0.4" -awaitility = "4.3.0" +mockk = "1.14.6" mokksy = "0.6.2" +serialization = "1.9.0" +slf4j = "2.0.17" [libraries] # Plugins @@ -53,6 +55,7 @@ kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "coroutines" } ktor-client-mock = { group = "io.ktor", name = "ktor-client-mock", version.ref = "ktor" } ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } +mockk = { module = "io.mockk:mockk", version.ref = "mockk" } mokksy = { group = "dev.mokksy", name = "mokksy", version.ref = "mokksy" } netty-bom = { group = "io.netty", name = "netty-bom", version.ref = "netty" } slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index 9186512c..a810f1e6 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -72,6 +72,9 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;Lkotlinx/io/Source;)V + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;Lkotlinx/io/Source;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;Lkotlinx/io/Source;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/kotlin-sdk-client/build.gradle.kts b/kotlin-sdk-client/build.gradle.kts index 473c32b5..077ffdf8 100644 --- a/kotlin-sdk-client/build.gradle.kts +++ b/kotlin-sdk-client/build.gradle.kts @@ -45,6 +45,7 @@ kotlin { implementation(libs.ktor.server.websockets) implementation(libs.kotlinx.coroutines.test) implementation(libs.ktor.client.logging) + implementation(libs.kotest.assertions.core) } } @@ -53,6 +54,7 @@ kotlin { implementation(libs.mokksy) implementation(libs.awaitility) implementation(libs.ktor.client.apache5) + implementation(libs.mockk) implementation(dependencies.platform(libs.netty.bom)) runtimeOnly(libs.slf4j.simple) } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 547079fb..42aca5eb 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -7,15 +7,19 @@ import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableJob import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import kotlinx.coroutines.supervisorScope import kotlinx.io.Buffer import kotlinx.io.Sink import kotlinx.io.Source @@ -24,7 +28,7 @@ import kotlinx.io.readByteArray import kotlinx.io.writeString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi -import kotlin.coroutines.CoroutineContext +import kotlin.jvm.JvmOverloads /** * A transport implementation for JSON-RPC communication that leverages standard input and output streams. @@ -32,21 +36,34 @@ import kotlin.coroutines.CoroutineContext * This class reads from an input stream to process incoming JSON-RPC messages and writes JSON-RPC messages * to an output stream. * + * Uses structured concurrency principles: + * - Parent job controls all child coroutines + * - Proper cancellation propagation + * - Resource cleanup guaranteed via structured concurrency + * * @param input The input stream where messages are received. * @param output The output stream where messages are sent. + * @param error Optional error stream for stderr processing. + * @param processStdError Callback for stderr lines. Returns true for fatal errors. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioClientTransport(private val input: Source, private val output: Sink) : AbstractTransport() { +public class StdioClientTransport @JvmOverloads public constructor( + private val input: Source, + private val output: Sink, + private val error: Source? = null, + private val processStdError: (String) -> Boolean = { true }, +) : AbstractTransport() { private val logger = KotlinLogging.logger {} - private val ioCoroutineContext: CoroutineContext = IODispatcher - private val scope by lazy { - CoroutineScope(ioCoroutineContext + SupervisorJob()) - } - private var job: Job? = null + + // Structured concurrency: single parent job manages all I/O operations + private val parentJob: CompletableJob = SupervisorJob() + private val scope = CoroutineScope(IODispatcher + parentJob) + + // State management through job lifecycle, not atomic flags private val initialized: AtomicBoolean = AtomicBoolean(false) private val sendChannel = Channel(Channel.UNLIMITED) - private val readBuffer = ReadBuffer() + @Suppress("TooGenericExceptionCaught") override suspend fun start() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error("StdioClientTransport already started!") @@ -54,50 +71,57 @@ public class StdioClientTransport(private val input: Source, private val output: logger.debug { "Starting StdioClientTransport..." } - val outputStream = output.buffered() - - job = scope.launch(CoroutineName("StdioClientTransport.IO#${hashCode()}")) { - val readJob = launch { - logger.debug { "Read coroutine started." } - try { - input.use { - while (isActive) { - val buffer = Buffer() - val bytesRead = input.readAtMostTo(buffer, 8192) - if (bytesRead == -1L) break - if (bytesRead > 0L) { - readBuffer.append(buffer.readByteArray()) - processReadBuffer() - } - } + // Launch all I/O operations in the scope - structured concurrency ensures cleanup + scope.launch(CoroutineName("StdioClientTransport.IO#${hashCode()}")) { + try { + val outputStream = output.buffered() + val errorStream = error?.buffered() + + // Use supervisorScope so individual stream failures don't cancel siblings + supervisorScope { + // Launch stdin reader + val stdinJob = launch(CoroutineName("stdin-reader")) { + readStream(input, ::processReadBuffer) } - } catch (e: Exception) { - _onError.invoke(e) - logger.error(e) { "Error reading from input stream" } - } - } - val writeJob = launch { - logger.debug { "Write coroutine started." } - try { - sendChannel.consumeEach { message -> - val json = serializeMessage(message) - outputStream.writeString(json) - outputStream.flush() + // Launch stderr reader if present + val stderrJob = errorStream?.let { + launch(CoroutineName("stderr-reader")) { + readStream(it, ::processStderrBuffer) + } } - } catch (e: Throwable) { - if (isActive) { - _onError.invoke(e) - logger.error(e) { "Error writing to output stream" } + + // Launch writer + val writerJob = launch(CoroutineName("stdout-writer")) { + writeMessages(outputStream) } - } finally { - output.close() + + // Wait for both stdin and stderr to complete (reach EOF or get cancelled) + // When a process exits, both streams will be closed by the OS + logger.debug { "Waiting for stdin to complete..." } + stdinJob.join() + logger.debug { "stdin completed, waiting for stderr..." } + stderrJob?.join() + logger.debug { "stderr completed, cancelling writer..." } + + // Cancel writer (it may be blocked waiting for channel messages) + writerJob.cancelAndJoin() + logger.debug { "writer cancelled, supervisorScope complete" } } + } catch (e: CancellationException) { + logger.debug { "Transport cancelled: ${e.message}" } + throw e + } catch (e: Exception) { + logger.error(e) { "Transport error" } + _onError.invoke(e) + } finally { + // Cleanup: close all streams and notify + runCatching { input.close() } + runCatching { output.close() } + runCatching { error?.close() } + runCatching { sendChannel.close() } + _onClose.invoke() } - - readJob.join() - writeJob.cancelAndJoin() - _onClose.invoke() } } @@ -113,23 +137,115 @@ public class StdioClientTransport(private val input: Source, private val output: if (!initialized.compareAndSet(expectedValue = true, newValue = false)) { error("Transport is already closed") } - job?.cancelAndJoin() - input.close() - output.close() - readBuffer.clear() - sendChannel.close() - _onClose.invoke() + + logger.debug { "Closing StdioClientTransport..." } + + // Cancel scope - structured concurrency handles cleanup via finally blocks + parentJob.cancelAndJoin() + } + + /** + * Reads from a source stream and processes each chunk through the provided block. + * Cancellation-aware and properly propagates CancellationException. + */ + private suspend fun CoroutineScope.readStream(source: Source, block: suspend (ReadBuffer) -> Unit) { + logger.debug { "Stream reader started" } + + source.use { + val readBuffer = ReadBuffer() + while (this.isActive) { + val buffer = Buffer() + val bytesRead = it.readAtMostTo(buffer, 8192) + + if (bytesRead == -1L) { + logger.debug { "EOF reached" } + break + } + + if (bytesRead > 0L) { + readBuffer.append(buffer.readByteArray()) + block(readBuffer) + } + } + } } - private suspend fun processReadBuffer() { + /** + * Processes JSON-RPC messages from the read buffer. + * Each message is delivered to the onMessage callback. + */ + private suspend fun processReadBuffer(buffer: ReadBuffer) { while (true) { - val msg = readBuffer.readMessage() ?: break + val msg = buffer.readMessage() ?: break + + @Suppress("TooGenericExceptionCaught") try { _onMessage.invoke(msg) } catch (e: Throwable) { _onError.invoke(e) - logger.error(e) { "Error processing message." } + logger.error(e) { "Error processing message" } } } } + + /** + * Processes stderr lines from the read buffer. + * If processStdError returns true (fatal), cancels the scope. + */ + private suspend fun processStderrBuffer(buffer: ReadBuffer) { + val errorLine = buffer.readLine() + buffer.clear() + + if (errorLine != null) { + val isFatal = processStdError(errorLine) + + if (isFatal) { + logger.error { "Fatal stderr error: $errorLine" } + + val exception = McpException( + RPCError.ErrorCode.CONNECTION_CLOSED, + "Fatal error in stderr: $errorLine", + ) + + // Notify error handler + _onError.invoke(exception) + + // Close streams to trigger EOF - this will cause natural shutdown + // The stdin reader will complete, then we'll shut down gracefully + runCatching { input.close() } + runCatching { output.close() } + + // Exit the stderr reader loop + return + } else { + logger.warn { "Non-fatal stderr warning: $errorLine" } + } + } + } + + /** + * Writes JSON-RPC messages from the send channel to the output stream. + * Runs until the channel is closed or coroutine is cancelled. + */ + private suspend fun writeMessages(outputStream: Sink) { + logger.debug { "Writer started" } + + try { + for (message in sendChannel) { + if (!currentCoroutineContext().isActive) break + + val json = serializeMessage(message) + outputStream.writeString(json) + outputStream.flush() + } + } catch (e: Exception) { + if (currentCoroutineContext().isActive) { + _onError.invoke(e) + logger.error(e) { "Error writing to output stream" } + } + throw e + } + + logger.debug { "Writer finished" } + } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt index 9968ceb1..8b23653d 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt @@ -123,8 +123,6 @@ class StreamableHttpClientTransportTest { @Test fun testTerminateSession() = runTest { -// transport.sessionId = "test-session-id" - val transport = createTransport { request -> assertEquals(HttpMethod.Delete, request.method) assertEquals("test-session-id", request.headers["mcp-session-id"]) @@ -143,8 +141,6 @@ class StreamableHttpClientTransportTest { @Test fun testTerminateSessionHandle405() = runTest { -// transport.sessionId = "test-session-id" - val transport = createTransport { request -> assertEquals(HttpMethod.Delete, request.method) respond( diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportIntegrationTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportIntegrationTest.kt new file mode 100644 index 00000000..5ae622c8 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportIntegrationTest.kt @@ -0,0 +1,63 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import kotlinx.coroutines.runBlocking +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import java.util.concurrent.TimeUnit + +/** + * Integration tests for StdioClientTransport with real process I/O. + * + * These tests use real ProcessBuilder and shell commands, so they run sequentially + * to avoid resource contention issues with parallel execution. + */ +@Execution(ExecutionMode.SAME_THREAD) +class StdioClientTransportIntegrationTest { + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun `handle stdio error`(): Unit = runBlocking { + val processBuilder = if (System.getProperty("os.name").lowercase().contains("win")) { + ProcessBuilder("cmd", "/c", "pause 0.5 && echo simulated error 1>&2 && exit 1") + } else { + ProcessBuilder("sh", "-c", "sleep 0.5 && echo 'simulated error' >&2 && exit 1") + } + + val process = processBuilder.start() + + val stdin = process.inputStream.asSource().buffered() + val stdout = process.outputStream.asSink().buffered() + val stderr = process.errorStream.asSource().buffered() + + val transport = StdioClientTransport( + input = stdin, + output = stdout, + error = stderr, + ) { + println("💥Ah-oh!, error: \"$it\"") + true + } + + val client = Client( + clientInfo = Implementation( + name = "test-client", + version = "1.0", + ), + ) + + // The error in stderr should cause connecting to fail + assertThrows { + client.connect(transport) + } + + process.destroyForcibly() + } +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt new file mode 100644 index 00000000..dc7e93f4 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt @@ -0,0 +1,217 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.kotest.matchers.shouldBe +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import kotlinx.coroutines.runBlocking +import kotlinx.io.buffered +import org.awaitility.kotlin.await +import org.awaitility.kotlin.untilAsserted +import org.awaitility.kotlin.untilNotNull +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference + +/** + * Unit tests for StdioClientTransport stderr error handling behavior. + * + * This test suite verifies the transport correctly distinguishes between: + * - Fatal errors (processStdError returns true) - should terminate transport and invoke onError/onClose + * - Non-fatal warnings (processStdError returns false) - should continue operation without terminating + * + * Uses mock sources to simulate stdin/stderr streams without real process I/O. + */ +@Timeout(10, unit = TimeUnit.SECONDS) +class StdioClientTransportTest { + + private lateinit var transport: StdioClientTransport + + @Test + fun `should invoke onError and onClose when processStdError returns true for fatal error`(): Unit = runBlocking { + val errorDetected = AtomicBoolean(false) + val onErrorCalled = AtomicBoolean(false) + val onCloseCalled = AtomicBoolean(false) + val capturedError = AtomicReference() + + // Create input that blocks (simulates waiting for server response) + val inputSource = ControllableBlockingSource() + + // Create error stream that provides a fatal error message + val errorMessage = "fatal error: connection failed\n" + val errorSource = ByteArraySource(errorMessage.encodeToByteArray()) + + // Create simple output sink that accepts writes + val outputSink = NoOpSink() + + transport = StdioClientTransport( + input = inputSource.buffered(), + output = outputSink.buffered(), + error = errorSource.buffered(), + processStdError = { + errorDetected.set(true) + true // Fatal error - should terminate transport + }, + ) + + // Set up callbacks to track invocations + transport.onError { error -> + capturedError.set(error) + onErrorCalled.set(true) + } + transport.onClose { + onCloseCalled.set(true) + } + + // Start the transport + transport.start() + + // Use awaitility for elegant, readable async assertions + await untilAsserted { + errorDetected.get() shouldBe true + onErrorCalled.get() shouldBe true + onCloseCalled.get() shouldBe true + } + + // Verify the error is of expected type + val error = await untilNotNull { capturedError.get() } + (error is McpException) shouldBe true + + // Clean up + inputSource.unblock() + } + + @Test + @Suppress("MaxLineLength") + fun `should NOT invoke onError when processStdError returns false for non-fatal warning`(): Unit = runBlocking { + val warningDetected = AtomicBoolean(false) + val onErrorCalled = AtomicBoolean(false) + val onCloseCalled = AtomicBoolean(false) + val capturedWarningMessage = AtomicReference() + + // Use blocking input so stderr has time to be processed before EOF + val inputSource = ControllableBlockingSource() + + // Create error stream that provides a non-fatal warning + val warningMessage = "warning: deprecated feature used\n" + val errorSource = ByteArraySource(warningMessage.encodeToByteArray()) + + // Create simple output sink + val outputSink = NoOpSink() + + transport = StdioClientTransport( + input = inputSource.buffered(), + output = outputSink.buffered(), + error = errorSource.buffered(), + ) { msg -> + warningDetected.set(true) + capturedWarningMessage.set(msg) + false // Non-fatal warning - should NOT terminate transport + } + + // Set up callbacks to track invocations + transport.onError { + onErrorCalled.set(true) + } + transport.onClose { + onCloseCalled.set(true) + } + + // Start the transport + transport.start() + + // Wait for warning to be processed - use awaitility DSL + await untilAsserted { + warningDetected.get() shouldBe true + capturedWarningMessage.get() shouldBe "warning: deprecated feature used" + } + + // Verify warning did NOT trigger error callback + onErrorCalled.get() shouldBe false + + // Now unblock stdin to trigger close + inputSource.unblock() + + // onClose WILL be called due to EOF on stdin/stderr - this is expected behavior + // The key difference is that onError was NOT called + await untilAsserted { + onCloseCalled.get() shouldBe true + } + } + + @Test + fun `should handle empty stderr stream gracefully`(): Unit = runBlocking { + val onErrorCalled = AtomicBoolean(false) + val onCloseCalled = AtomicBoolean(false) + val processStdErrorCalled = AtomicBoolean(false) + + // Create empty streams + val inputSource = ByteArraySource().buffered() + val errorSource = ByteArraySource().buffered() + val outputSink = NoOpSink().buffered() + + transport = StdioClientTransport( + input = inputSource, + output = outputSink, + error = errorSource, + processStdError = { + processStdErrorCalled.set(true) + false + }, + ) + + transport.onError { onErrorCalled.set(true) } + transport.onClose { onCloseCalled.set(true) } + + transport.start() + + // Should close cleanly without processing any errors - use awaitility + await untilAsserted { + onCloseCalled.get() shouldBe true + processStdErrorCalled.get() shouldBe false + onErrorCalled.get() shouldBe false + } + } + + @Test + fun `should process first stderr line and discard remaining buffer`(): Unit = runBlocking { + val errorMessagesProcessed = mutableListOf() + val onCloseCalled = AtomicBoolean(false) + + // Create error stream with multiple lines + // NOTE: StdioClientTransport.kt:78 calls readBuffer.clear() after reading one line, + // so only the FIRST line will be processed - this is the actual implementation behavior + val multipleLines = """ + warning: first warning + warning: second warning will be discarded + warning: third warning will be discarded + + """.trimIndent() + val errorSource = ByteArraySource(multipleLines.encodeToByteArray()) + + val inputSource = ByteArraySource() + val outputSink = NoOpSink() + + transport = StdioClientTransport( + input = inputSource.buffered(), + output = outputSink.buffered(), + error = errorSource.buffered(), + processStdError = { msg -> + synchronized(errorMessagesProcessed) { + errorMessagesProcessed.add(msg) + } + false // Non-fatal + }, + ) + + transport.onClose { onCloseCalled.set(true) } + transport.start() + + // Wait for first message to be processed and transport to close - use awaitility + await untilAsserted { + onCloseCalled.get() shouldBe true + errorMessagesProcessed.size shouldBe 1 + errorMessagesProcessed[0] shouldBe "warning: first warning" + } + } +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/testUtils.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/testUtils.kt new file mode 100644 index 00000000..4a03b559 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/testUtils.kt @@ -0,0 +1,90 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import kotlinx.io.Buffer +import kotlinx.io.RawSink +import kotlinx.io.RawSource +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +/** + * RawSource that reads from a byte array. + * + * Useful for simulating stdin/stderr streams with predefined content. + * Returns EOF (-1) when all data has been read. + */ +class ByteArraySource(private val data: ByteArray = ByteArray(512)) : RawSource { + private var position = 0 + private var closed = false + + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + if (closed) return -1 + if (position >= data.size) return -1 + + val toRead = minOf(byteCount.toInt(), data.size - position) + sink.write(data, position, toRead) + position += toRead + return toRead.toLong() + } + + override fun close() { + closed = true + } +} + +/** + * RawSource that blocks until explicitly unblocked. + * + * This is useful for simulating a process that's waiting for data (e.g., stdin from a server + * that hasn't responded yet). + * + * IMPORTANT: Always call [unblock] in cleanup to prevent resource leaks. + */ +class ControllableBlockingSource : RawSource { + private val latch = CountDownLatch(1) + private var closed = false + + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + // Block until unblocked or closed + while (!closed && latch.count > 0) { + latch.await(100, TimeUnit.MILLISECONDS) + } + return -1 + } + + override fun close() { + closed = true + latch.countDown() + } + + /** + * Unblocks the source, allowing readAtMostTo to return EOF. + * Should be called in test cleanup. + */ + fun unblock() { + latch.countDown() + } +} + +/** + * RawSink that discards all data written to it (like /dev/null). + * + * Useful for test scenarios where we don't care about output data + * but need a valid sink for the transport. + */ +class NoOpSink : RawSink { + private var closed = false + + override fun write(source: Buffer, byteCount: Long) { + if (closed) error("Sink is closed") + // Discard the data + source.skip(byteCount) + } + + override fun flush() { + // No-op + } + + override fun close() { + closed = true + } +} diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 8694e2d5..f22fc8a9 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -466,6 +466,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer { public fun ()V public final fun append ([B)V public final fun clear ()V + public final fun isEmpty ()Z + public final fun readLine ()Ljava/lang/String; public final fun readMessage ()Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage; } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt index 55e09da8..cbab7aeb 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt @@ -1,6 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import kotlinx.coroutines.CompletableDeferred /** diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt index d991e7fe..98b50913 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt @@ -9,6 +9,13 @@ import kotlinx.io.readString /** * Buffers a continuous stdio stream into discrete JSON-RPC messages. + * + * This class accumulates bytes from a stream and extracts complete lines, + * parsing them as JSON-RPC messages. Handles line-buffering with proper + * CR/LF handling for cross-platform compatibility. + * + * Thread-safety: This class is NOT thread-safe. It should be used from + * a single coroutine or protected by external synchronization. */ public class ReadBuffer { @@ -16,14 +23,29 @@ public class ReadBuffer { private val buffer: Buffer = Buffer() + /** + * Returns true if there's no pending data in the buffer. + */ + public fun isEmpty(): Boolean = buffer.exhausted() + + /** + * Appends a chunk of bytes to the buffer. + * Call this when new data arrives from the stream. + */ public fun append(chunk: ByteArray) { buffer.write(chunk) } - public fun readMessage(): JSONRPCMessage? { + /** + * Reads a complete line from the buffer if available. + * Returns null if no complete line is present. + * + * Handles both CRLF and LF line endings. + */ + public fun readLine(): String? { if (buffer.exhausted()) return null var lfIndex = buffer.indexOf('\n'.code.toByte()) - val line = when (lfIndex) { + return when (lfIndex) { -1L -> return null 0L -> { @@ -42,6 +64,17 @@ public class ReadBuffer { string } } + } + + /** + * Reads and parses the next JSON-RPC message from the buffer. + * Returns null if no complete message is available. + * + * Attempts recovery if the line has a non-JSON prefix by looking for the first '{'. + * If deserialization fails completely, logs the error and returns null. + */ + public fun readMessage(): JSONRPCMessage? { + val line = readLine() ?: return null try { return deserializeMessage(line) } catch (e: Exception) { @@ -61,6 +94,10 @@ public class ReadBuffer { return null } + /** + * Clears all pending data from the buffer. + * Useful for discarding incomplete messages after errors. + */ public fun clear() { buffer.clear() }