diff --git a/config/detekt/detekt.yml b/config/detekt/detekt.yml index 8c19749c..2cb135ab 100644 --- a/config/detekt/detekt.yml +++ b/config/detekt/detekt.yml @@ -14,6 +14,10 @@ complexity: LongMethod: excludes: *testFolders +coroutines: + InjectDispatcher: + excludes: *testFolders + empty-blocks: EmptyFunctionBlock: excludes: *testFolders diff --git a/integration-test/detekt-baseline-test.xml b/integration-test/detekt-baseline-test.xml index 710b7293..cd94eac6 100644 --- a/integration-test/detekt-baseline-test.xml +++ b/integration-test/detekt-baseline-test.xml @@ -8,17 +8,6 @@ AbstractClassCanBeConcreteClass:BaseTransportTest.kt:BaseTransportTest$BaseTransportTest CyclomaticComplexMethod:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$private fun setupCalculatorTool ForbiddenComment:StdioClientTransportTest.kt:StdioClientTransportTest$// TODO: fix running on windows - InjectDispatcher:AbstractKotlinClientTsServerTest.kt:AbstractKotlinClientTsServerTest$IO - InjectDispatcher:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$IO - InjectDispatcher:AbstractResourceIntegrationTest.kt:AbstractResourceIntegrationTest$IO - InjectDispatcher:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$IO - InjectDispatcher:KotlinClientTsServerEdgeCasesTestSse.kt:KotlinClientTsServerEdgeCasesTestSse$IO - InjectDispatcher:KotlinClientTsServerEdgeCasesTestStdio.kt:KotlinClientTsServerEdgeCasesTestStdio$IO - InjectDispatcher:SseIntegrationTest.kt:SseIntegrationTest$IO - InjectDispatcher:StdioClientTransportTest.kt:StdioClientTransportTest$IO - InjectDispatcher:StreamableHttpIntegrationTest.kt:StreamableHttpIntegrationTest$IO - InjectDispatcher:TsEdgeCasesTestSse.kt:TsEdgeCasesTestSse$IO - InjectDispatcher:WebSocketIntegrationTest.kt:WebSocketIntegrationTest$IO MatchingDeclarationName:PromptIntegrationTestSse.kt:SchemaPromptIntegrationTestSse : AbstractPromptIntegrationTest SleepInsteadOfDelay:KotlinServerForTsClientSse.kt:KotlinServerForTsClient$sleep(500) ThrowsCount:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$override fun configureServer diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt index 1e88558f..ec420414 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt @@ -42,7 +42,6 @@ import io.ktor.server.sse.SSE as ServerSSE /** * Base class for MCP authentication integration tests. */ -@Suppress("InjectDispatcher") abstract class AbstractAuthenticationTest { protected companion object { diff --git a/kotlin-sdk-client/detekt-baseline-test.xml b/kotlin-sdk-client/detekt-baseline-test.xml index b16e8eb1..dd6adfab 100644 --- a/kotlin-sdk-client/detekt-baseline-test.xml +++ b/kotlin-sdk-client/detekt-baseline-test.xml @@ -4,8 +4,6 @@ AbstractClassCanBeConcreteClass:AbstractStreamableHttpClientTest.kt:AbstractStreamableHttpClientTest$AbstractStreamableHttpClientTest ForbiddenComment:StreamableHttpClientTest.kt:StreamableHttpClientTest$// TODO: how to get notifications via Client API? - InjectDispatcher:StdioClientTransportErrorHandlingTest.kt:StdioClientTransportErrorHandlingTest$IO - InjectDispatcher:StreamableHttpClientTransportTest.kt:StreamableHttpClientTransportTest$Default LongParameterList:MockMcp.kt:MockMcp$fun handleJSONRPCRequest LongParameterList:MockMcp.kt:MockMcp$fun handleWithResult LongParameterList:MockMcp.kt:MockMcp$fun onInitialize diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index c0b0199d..ad5208c6 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -187,6 +187,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;)V + public synthetic fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index 8fd46dd2..e9df85ec 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -8,11 +8,13 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.isActive import kotlinx.coroutines.launch @@ -25,34 +27,72 @@ import kotlinx.io.readByteArray import kotlinx.io.writeString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi -import kotlin.coroutines.CoroutineContext private const val READ_BUFFER_SIZE = 8192L /** * A server transport that communicates with a client via standard I/O. * - * Reads from input [Source] and writes to output [Sink]. + * [StdioServerTransport] manages the communication between a JSON-RPC server and its clients + * by reading incoming messages from the specified [Source] (input stream) and writing outgoing + * messages to the [Sink] (output stream). * - * @constructor Creates a new instance of [StdioServerTransport]. - * @param inputStream The input [Source] used to receive data. - * @param outputStream The output [Sink] used to send data. + * Example: + * ```kotlin + * val transport = StdioServerTransport( + * source = System.`in`.asInput() + * sink = System.out.asSink() + * ) + * ``` + * + * @constructor Creates an instance of [StdioServerTransport] with the specified parameters. + * @property source The source for reading incoming messages (e.g., stdin or other readable stream). + * @param sink The sink for writing outgoing messages (e.g., stdout or other writable stream). + * @property readBufferSize The maximum size of the read buffer, defaults to a pre-configured constant. + * @property readChannel The channel for receiving raw byte arrays from the input stream. + * @property writeChannel The channel for sending serialized JSON-RPC messages to the output stream. + * @property readingJobDispatcher The dispatcher to use for the message-reading coroutine. + * @property writingJobDispatcher The dispatcher to use for the message-writing coroutine. + * @property processingJobDispatcher The dispatcher to handle processing of read messages. + * @param coroutineScope Optional coroutine scope to use for managing internal jobs. A new scope + * will be created if not provided. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() { +@Suppress("LongParameterList") +public class StdioServerTransport( + private val source: Source, + sink: Sink, + private val readBufferSize: Long = READ_BUFFER_SIZE, + private val readChannel: Channel = Channel(Channel.UNLIMITED), + private val writeChannel: Channel = Channel(Channel.UNLIMITED), + private var readingJobDispatcher: CoroutineDispatcher = IODispatcher, + private var writingJobDispatcher: CoroutineDispatcher = IODispatcher, + private var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default, + coroutineScope: CoroutineScope? = null, +) : AbstractTransport() { + + private val scope: CoroutineScope + private val sink: Sink + + init { + require(readBufferSize > 0) { "readBufferSize must be > 0" } + val parentJob = coroutineScope?.coroutineContext?.get(Job) + scope = CoroutineScope(SupervisorJob(parentJob)) + this.sink = sink.buffered() + } + + /** + * Creates a new instance of [StdioServerTransport] + * with the given [inputStream] [Source] and [outputStream] [Sink]. + */ + public constructor(inputStream: Source, outputStream: Sink) : this( + source = inputStream, + sink = outputStream, + ) private val logger = KotlinLogging.logger {} private val readBuffer = ReadBuffer() private val initialized: AtomicBoolean = AtomicBoolean(false) - private var readingJob: Job? = null - private var sendingJob: Job? = null - private var processingJob: Job? = null - - private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob() - private val scope = CoroutineScope(coroutineContext) - private val readChannel = Channel(Channel.UNLIMITED) - private val writeChannel = Channel(Channel.UNLIMITED) - private val outputSink = outputStream.buffered() override suspend fun start() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { @@ -60,90 +100,82 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: } // Launch a coroutine to read from stdin - readingJob = launchReadingJob() + launchReadingJob() // Launch a coroutine to process messages from readChannel - processingJob = launchProcessingJob() + launchProcessingJob() // Launch a coroutine to handle message sending - sendingJob = launchSendingJob() + launchSendingJob() } - private fun launchReadingJob(): Job { - val job = scope.launch { - val buf = Buffer() - @Suppress("TooGenericExceptionCaught") - try { - while (isActive) { - val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE) - if (bytesRead == -1L) { - // EOF reached - break - } - if (bytesRead > 0) { - val chunk = buf.readByteArray() - readChannel.send(chunk) - } + private fun launchReadingJob(): Job = scope.launch(readingJobDispatcher) { + val buf = Buffer() + @Suppress("TooGenericExceptionCaught") + try { + while (isActive) { + val bytesRead = source.readAtMostTo(buf, readBufferSize) + if (bytesRead == -1L) { + // EOF reached + break + } + if (bytesRead > 0) { + val chunk = buf.readByteArray() + readChannel.send(chunk) } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error reading from stdin" } - _onError.invoke(e) - } finally { - // Reached EOF or error, close connection - close() } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error reading from stdin" } + _onError.invoke(e) + } finally { + // Reached EOF or error, close connection + close() } - job.invokeOnCompletion { cause -> - logJobCompletion("Message reading", cause) - } - return job + }.apply { + invokeOnCompletion { logJobCompletion("Message reading", it) } } - private fun launchProcessingJob(): Job { - val job = scope.launch { - @Suppress("TooGenericExceptionCaught") - try { - for (chunk in readChannel) { - readBuffer.append(chunk) - processReadBuffer() - } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - _onError.invoke(e) + private fun launchProcessingJob(): Job = scope.launch(processingJobDispatcher) { + @Suppress("TooGenericExceptionCaught") + try { + for (chunk in readChannel) { + readBuffer.append(chunk) + processReadBuffer() } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError.invoke(e) } - job.invokeOnCompletion { cause -> + }.apply { + invokeOnCompletion { cause -> logJobCompletion("Processing", cause) } - return job } - private fun launchSendingJob(): Job { - val job = scope.launch { - @Suppress("TooGenericExceptionCaught") - try { - for (message in writeChannel) { - val json = serializeMessage(message) - outputSink.writeString(json) - outputSink.flush() - } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error writing to stdout" } - _onError.invoke(e) + private fun launchSendingJob(): Job = scope.launch(writingJobDispatcher) { + @Suppress("TooGenericExceptionCaught") + try { + for (message in writeChannel) { + val json = serializeMessage(message) + sink.writeString(json) + sink.flush() } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error writing to stdout" } + _onError.invoke(e) } - job.invokeOnCompletion { cause -> + }.apply { + invokeOnCompletion { cause -> logJobCompletion("Message sending", cause) if (cause is CancellationException) { - readingJob?.cancel(cause) + readChannel.cancel(cause) } } - return job } private suspend fun processReadBuffer() { @@ -189,24 +221,22 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: withContext(NonCancellable) { writeChannel.close() - sendingJob?.cancelAndJoin() runCatching { - inputStream.close() + source.close() }.onFailure { logger.warn(it) { "Failed to close stdin" } } - readingJob?.cancel() readChannel.close() - processingJob?.cancelAndJoin() - readBuffer.clear() - runCatching { - outputSink.flush() - outputSink.close() + sink.flush() + sink.close() }.onFailure { logger.warn(it) { "Failed to close stdout" } } + scope.cancel() + scope.coroutineContext[Job]?.join() + invokeOnCloseCallback() } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index dcbb7a4b..5a84c33f 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -14,7 +14,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.channels.BufferOverflow +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.withTimeout import kotlinx.io.Buffer import kotlinx.io.RawSink @@ -69,6 +73,37 @@ class StdioServerTransportTest { printOutput = output.asSink().buffered() } + @Test + fun `should construct with builder`() = runIntegrationTest { + val received = CompletableDeferred() + + // Set every configuration parameter explicitly with non-default values, + // then verify a message round-trips correctly. + val server = StdioServerTransport( + source = bufferedInput, + sink = printOutput, + readBufferSize = 16L, // non-default: smaller read chunk + readingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Read"), // non-default: limited parallelism + writingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Write"), // non-default: limited parallelism + processingJobDispatcher = Dispatchers.IO.limitedParallelism(2, name = "Worker"), // non-default + readChannel = Channel(capacity = 8, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded + writeChannel = Channel(capacity = 16, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded + coroutineScope = CoroutineScope(Dispatchers.Default), // non-default: parent scope provided + ) + server.onError { throw it } + server.onMessage { received.complete(it) } + + server.start() + + val message = PingRequest().toJSON() + inputWriter.write(serializeMessage(message)) + inputWriter.flush() + + received.await() shouldBe message + + server.close() + } + @Test fun `should be safe to close before start`() = runIntegrationTest { val server = StdioServerTransport(bufferedInput, printOutput)