diff --git a/config/detekt/detekt.yml b/config/detekt/detekt.yml
index 8c19749c..2cb135ab 100644
--- a/config/detekt/detekt.yml
+++ b/config/detekt/detekt.yml
@@ -14,6 +14,10 @@ complexity:
LongMethod:
excludes: *testFolders
+coroutines:
+ InjectDispatcher:
+ excludes: *testFolders
+
empty-blocks:
EmptyFunctionBlock:
excludes: *testFolders
diff --git a/integration-test/detekt-baseline-test.xml b/integration-test/detekt-baseline-test.xml
index 710b7293..cd94eac6 100644
--- a/integration-test/detekt-baseline-test.xml
+++ b/integration-test/detekt-baseline-test.xml
@@ -8,17 +8,6 @@
AbstractClassCanBeConcreteClass:BaseTransportTest.kt:BaseTransportTest$BaseTransportTest
CyclomaticComplexMethod:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$private fun setupCalculatorTool
ForbiddenComment:StdioClientTransportTest.kt:StdioClientTransportTest$// TODO: fix running on windows
- InjectDispatcher:AbstractKotlinClientTsServerTest.kt:AbstractKotlinClientTsServerTest$IO
- InjectDispatcher:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$IO
- InjectDispatcher:AbstractResourceIntegrationTest.kt:AbstractResourceIntegrationTest$IO
- InjectDispatcher:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$IO
- InjectDispatcher:KotlinClientTsServerEdgeCasesTestSse.kt:KotlinClientTsServerEdgeCasesTestSse$IO
- InjectDispatcher:KotlinClientTsServerEdgeCasesTestStdio.kt:KotlinClientTsServerEdgeCasesTestStdio$IO
- InjectDispatcher:SseIntegrationTest.kt:SseIntegrationTest$IO
- InjectDispatcher:StdioClientTransportTest.kt:StdioClientTransportTest$IO
- InjectDispatcher:StreamableHttpIntegrationTest.kt:StreamableHttpIntegrationTest$IO
- InjectDispatcher:TsEdgeCasesTestSse.kt:TsEdgeCasesTestSse$IO
- InjectDispatcher:WebSocketIntegrationTest.kt:WebSocketIntegrationTest$IO
MatchingDeclarationName:PromptIntegrationTestSse.kt:SchemaPromptIntegrationTestSse : AbstractPromptIntegrationTest
SleepInsteadOfDelay:KotlinServerForTsClientSse.kt:KotlinServerForTsClient$sleep(500)
ThrowsCount:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$override fun configureServer
diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt
index 1e88558f..ec420414 100644
--- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt
+++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt
@@ -42,7 +42,6 @@ import io.ktor.server.sse.SSE as ServerSSE
/**
* Base class for MCP authentication integration tests.
*/
-@Suppress("InjectDispatcher")
abstract class AbstractAuthenticationTest {
protected companion object {
diff --git a/kotlin-sdk-client/detekt-baseline-test.xml b/kotlin-sdk-client/detekt-baseline-test.xml
index b16e8eb1..dd6adfab 100644
--- a/kotlin-sdk-client/detekt-baseline-test.xml
+++ b/kotlin-sdk-client/detekt-baseline-test.xml
@@ -4,8 +4,6 @@
AbstractClassCanBeConcreteClass:AbstractStreamableHttpClientTest.kt:AbstractStreamableHttpClientTest$AbstractStreamableHttpClientTest
ForbiddenComment:StreamableHttpClientTest.kt:StreamableHttpClientTest$// TODO: how to get notifications via Client API?
- InjectDispatcher:StdioClientTransportErrorHandlingTest.kt:StdioClientTransportErrorHandlingTest$IO
- InjectDispatcher:StreamableHttpClientTransportTest.kt:StreamableHttpClientTransportTest$Default
LongParameterList:MockMcp.kt:MockMcp$fun handleJSONRPCRequest
LongParameterList:MockMcp.kt:MockMcp$fun handleWithResult
LongParameterList:MockMcp.kt:MockMcp$fun onInitialize
diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api
index c0b0199d..ad5208c6 100644
--- a/kotlin-sdk-server/api/kotlin-sdk-server.api
+++ b/kotlin-sdk-server/api/kotlin-sdk-server.api
@@ -187,6 +187,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport
public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport {
public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V
+ public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;)V
+ public synthetic fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt
index 8fd46dd2..e9df85ec 100644
--- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt
+++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt
@@ -8,11 +8,13 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.NonCancellable
import kotlinx.coroutines.SupervisorJob
-import kotlinx.coroutines.cancelAndJoin
+import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
@@ -25,34 +27,72 @@ import kotlinx.io.readByteArray
import kotlinx.io.writeString
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
-import kotlin.coroutines.CoroutineContext
private const val READ_BUFFER_SIZE = 8192L
/**
* A server transport that communicates with a client via standard I/O.
*
- * Reads from input [Source] and writes to output [Sink].
+ * [StdioServerTransport] manages the communication between a JSON-RPC server and its clients
+ * by reading incoming messages from the specified [Source] (input stream) and writing outgoing
+ * messages to the [Sink] (output stream).
*
- * @constructor Creates a new instance of [StdioServerTransport].
- * @param inputStream The input [Source] used to receive data.
- * @param outputStream The output [Sink] used to send data.
+ * Example:
+ * ```kotlin
+ * val transport = StdioServerTransport(
+ * source = System.`in`.asInput()
+ * sink = System.out.asSink()
+ * )
+ * ```
+ *
+ * @constructor Creates an instance of [StdioServerTransport] with the specified parameters.
+ * @property source The source for reading incoming messages (e.g., stdin or other readable stream).
+ * @param sink The sink for writing outgoing messages (e.g., stdout or other writable stream).
+ * @property readBufferSize The maximum size of the read buffer, defaults to a pre-configured constant.
+ * @property readChannel The channel for receiving raw byte arrays from the input stream.
+ * @property writeChannel The channel for sending serialized JSON-RPC messages to the output stream.
+ * @property readingJobDispatcher The dispatcher to use for the message-reading coroutine.
+ * @property writingJobDispatcher The dispatcher to use for the message-writing coroutine.
+ * @property processingJobDispatcher The dispatcher to handle processing of read messages.
+ * @param coroutineScope Optional coroutine scope to use for managing internal jobs. A new scope
+ * will be created if not provided.
*/
@OptIn(ExperimentalAtomicApi::class)
-public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() {
+@Suppress("LongParameterList")
+public class StdioServerTransport(
+ private val source: Source,
+ sink: Sink,
+ private val readBufferSize: Long = READ_BUFFER_SIZE,
+ private val readChannel: Channel = Channel(Channel.UNLIMITED),
+ private val writeChannel: Channel = Channel(Channel.UNLIMITED),
+ private var readingJobDispatcher: CoroutineDispatcher = IODispatcher,
+ private var writingJobDispatcher: CoroutineDispatcher = IODispatcher,
+ private var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default,
+ coroutineScope: CoroutineScope? = null,
+) : AbstractTransport() {
+
+ private val scope: CoroutineScope
+ private val sink: Sink
+
+ init {
+ require(readBufferSize > 0) { "readBufferSize must be > 0" }
+ val parentJob = coroutineScope?.coroutineContext?.get(Job)
+ scope = CoroutineScope(SupervisorJob(parentJob))
+ this.sink = sink.buffered()
+ }
+
+ /**
+ * Creates a new instance of [StdioServerTransport]
+ * with the given [inputStream] [Source] and [outputStream] [Sink].
+ */
+ public constructor(inputStream: Source, outputStream: Sink) : this(
+ source = inputStream,
+ sink = outputStream,
+ )
private val logger = KotlinLogging.logger {}
private val readBuffer = ReadBuffer()
private val initialized: AtomicBoolean = AtomicBoolean(false)
- private var readingJob: Job? = null
- private var sendingJob: Job? = null
- private var processingJob: Job? = null
-
- private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob()
- private val scope = CoroutineScope(coroutineContext)
- private val readChannel = Channel(Channel.UNLIMITED)
- private val writeChannel = Channel(Channel.UNLIMITED)
- private val outputSink = outputStream.buffered()
override suspend fun start() {
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
@@ -60,90 +100,82 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
}
// Launch a coroutine to read from stdin
- readingJob = launchReadingJob()
+ launchReadingJob()
// Launch a coroutine to process messages from readChannel
- processingJob = launchProcessingJob()
+ launchProcessingJob()
// Launch a coroutine to handle message sending
- sendingJob = launchSendingJob()
+ launchSendingJob()
}
- private fun launchReadingJob(): Job {
- val job = scope.launch {
- val buf = Buffer()
- @Suppress("TooGenericExceptionCaught")
- try {
- while (isActive) {
- val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE)
- if (bytesRead == -1L) {
- // EOF reached
- break
- }
- if (bytesRead > 0) {
- val chunk = buf.readByteArray()
- readChannel.send(chunk)
- }
+ private fun launchReadingJob(): Job = scope.launch(readingJobDispatcher) {
+ val buf = Buffer()
+ @Suppress("TooGenericExceptionCaught")
+ try {
+ while (isActive) {
+ val bytesRead = source.readAtMostTo(buf, readBufferSize)
+ if (bytesRead == -1L) {
+ // EOF reached
+ break
+ }
+ if (bytesRead > 0) {
+ val chunk = buf.readByteArray()
+ readChannel.send(chunk)
}
- } catch (e: CancellationException) {
- throw e
- } catch (e: Throwable) {
- logger.error(e) { "Error reading from stdin" }
- _onError.invoke(e)
- } finally {
- // Reached EOF or error, close connection
- close()
}
+ } catch (e: CancellationException) {
+ throw e
+ } catch (e: Throwable) {
+ logger.error(e) { "Error reading from stdin" }
+ _onError.invoke(e)
+ } finally {
+ // Reached EOF or error, close connection
+ close()
}
- job.invokeOnCompletion { cause ->
- logJobCompletion("Message reading", cause)
- }
- return job
+ }.apply {
+ invokeOnCompletion { logJobCompletion("Message reading", it) }
}
- private fun launchProcessingJob(): Job {
- val job = scope.launch {
- @Suppress("TooGenericExceptionCaught")
- try {
- for (chunk in readChannel) {
- readBuffer.append(chunk)
- processReadBuffer()
- }
- } catch (e: CancellationException) {
- throw e
- } catch (e: Throwable) {
- _onError.invoke(e)
+ private fun launchProcessingJob(): Job = scope.launch(processingJobDispatcher) {
+ @Suppress("TooGenericExceptionCaught")
+ try {
+ for (chunk in readChannel) {
+ readBuffer.append(chunk)
+ processReadBuffer()
}
+ } catch (e: CancellationException) {
+ throw e
+ } catch (e: Throwable) {
+ _onError.invoke(e)
}
- job.invokeOnCompletion { cause ->
+ }.apply {
+ invokeOnCompletion { cause ->
logJobCompletion("Processing", cause)
}
- return job
}
- private fun launchSendingJob(): Job {
- val job = scope.launch {
- @Suppress("TooGenericExceptionCaught")
- try {
- for (message in writeChannel) {
- val json = serializeMessage(message)
- outputSink.writeString(json)
- outputSink.flush()
- }
- } catch (e: CancellationException) {
- throw e
- } catch (e: Throwable) {
- logger.error(e) { "Error writing to stdout" }
- _onError.invoke(e)
+ private fun launchSendingJob(): Job = scope.launch(writingJobDispatcher) {
+ @Suppress("TooGenericExceptionCaught")
+ try {
+ for (message in writeChannel) {
+ val json = serializeMessage(message)
+ sink.writeString(json)
+ sink.flush()
}
+ } catch (e: CancellationException) {
+ throw e
+ } catch (e: Throwable) {
+ logger.error(e) { "Error writing to stdout" }
+ _onError.invoke(e)
}
- job.invokeOnCompletion { cause ->
+ }.apply {
+ invokeOnCompletion { cause ->
logJobCompletion("Message sending", cause)
if (cause is CancellationException) {
- readingJob?.cancel(cause)
+ readChannel.cancel(cause)
}
}
- return job
}
private suspend fun processReadBuffer() {
@@ -189,24 +221,22 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
withContext(NonCancellable) {
writeChannel.close()
- sendingJob?.cancelAndJoin()
runCatching {
- inputStream.close()
+ source.close()
}.onFailure { logger.warn(it) { "Failed to close stdin" } }
- readingJob?.cancel()
readChannel.close()
- processingJob?.cancelAndJoin()
-
readBuffer.clear()
-
runCatching {
- outputSink.flush()
- outputSink.close()
+ sink.flush()
+ sink.close()
}.onFailure { logger.warn(it) { "Failed to close stdout" } }
+ scope.cancel()
+ scope.coroutineContext[Job]?.join()
+
invokeOnCloseCallback()
}
}
diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt
index dcbb7a4b..5a84c33f 100644
--- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt
+++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt
@@ -14,7 +14,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON
import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.TimeoutCancellationException
+import kotlinx.coroutines.channels.BufferOverflow
+import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.withTimeout
import kotlinx.io.Buffer
import kotlinx.io.RawSink
@@ -69,6 +73,37 @@ class StdioServerTransportTest {
printOutput = output.asSink().buffered()
}
+ @Test
+ fun `should construct with builder`() = runIntegrationTest {
+ val received = CompletableDeferred()
+
+ // Set every configuration parameter explicitly with non-default values,
+ // then verify a message round-trips correctly.
+ val server = StdioServerTransport(
+ source = bufferedInput,
+ sink = printOutput,
+ readBufferSize = 16L, // non-default: smaller read chunk
+ readingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Read"), // non-default: limited parallelism
+ writingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Write"), // non-default: limited parallelism
+ processingJobDispatcher = Dispatchers.IO.limitedParallelism(2, name = "Worker"), // non-default
+ readChannel = Channel(capacity = 8, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
+ writeChannel = Channel(capacity = 16, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
+ coroutineScope = CoroutineScope(Dispatchers.Default), // non-default: parent scope provided
+ )
+ server.onError { throw it }
+ server.onMessage { received.complete(it) }
+
+ server.start()
+
+ val message = PingRequest().toJSON()
+ inputWriter.write(serializeMessage(message))
+ inputWriter.flush()
+
+ received.await() shouldBe message
+
+ server.close()
+ }
+
@Test
fun `should be safe to close before start`() = runIntegrationTest {
val server = StdioServerTransport(bufferedInput, printOutput)