Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config/detekt/detekt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ complexity:
LongMethod:
excludes: *testFolders

coroutines:
InjectDispatcher:
excludes: *testFolders

empty-blocks:
EmptyFunctionBlock:
excludes: *testFolders
Expand Down
11 changes: 0 additions & 11 deletions integration-test/detekt-baseline-test.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,6 @@
<ID>AbstractClassCanBeConcreteClass:BaseTransportTest.kt:BaseTransportTest$BaseTransportTest</ID>
<ID>CyclomaticComplexMethod:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$private fun setupCalculatorTool</ID>
<ID>ForbiddenComment:StdioClientTransportTest.kt:StdioClientTransportTest$// TODO: fix running on windows</ID>
<ID>InjectDispatcher:AbstractKotlinClientTsServerTest.kt:AbstractKotlinClientTsServerTest$IO</ID>
<ID>InjectDispatcher:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$IO</ID>
<ID>InjectDispatcher:AbstractResourceIntegrationTest.kt:AbstractResourceIntegrationTest$IO</ID>
<ID>InjectDispatcher:AbstractToolIntegrationTest.kt:AbstractToolIntegrationTest$IO</ID>
<ID>InjectDispatcher:KotlinClientTsServerEdgeCasesTestSse.kt:KotlinClientTsServerEdgeCasesTestSse$IO</ID>
<ID>InjectDispatcher:KotlinClientTsServerEdgeCasesTestStdio.kt:KotlinClientTsServerEdgeCasesTestStdio$IO</ID>
<ID>InjectDispatcher:SseIntegrationTest.kt:SseIntegrationTest$IO</ID>
<ID>InjectDispatcher:StdioClientTransportTest.kt:StdioClientTransportTest$IO</ID>
<ID>InjectDispatcher:StreamableHttpIntegrationTest.kt:StreamableHttpIntegrationTest$IO</ID>
<ID>InjectDispatcher:TsEdgeCasesTestSse.kt:TsEdgeCasesTestSse$IO</ID>
<ID>InjectDispatcher:WebSocketIntegrationTest.kt:WebSocketIntegrationTest$IO</ID>
<ID>MatchingDeclarationName:PromptIntegrationTestSse.kt:SchemaPromptIntegrationTestSse : AbstractPromptIntegrationTest</ID>
<ID>SleepInsteadOfDelay:KotlinServerForTsClientSse.kt:KotlinServerForTsClient$sleep(500)</ID>
<ID>ThrowsCount:AbstractPromptIntegrationTest.kt:AbstractPromptIntegrationTest$override fun configureServer</ID>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import io.ktor.server.sse.SSE as ServerSSE
/**
* Base class for MCP authentication integration tests.
*/
@Suppress("InjectDispatcher")
abstract class AbstractAuthenticationTest {

protected companion object {
Expand Down
2 changes: 0 additions & 2 deletions kotlin-sdk-client/detekt-baseline-test.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
<CurrentIssues>
<ID>AbstractClassCanBeConcreteClass:AbstractStreamableHttpClientTest.kt:AbstractStreamableHttpClientTest$AbstractStreamableHttpClientTest</ID>
<ID>ForbiddenComment:StreamableHttpClientTest.kt:StreamableHttpClientTest$// TODO: how to get notifications via Client API?</ID>
<ID>InjectDispatcher:StdioClientTransportErrorHandlingTest.kt:StdioClientTransportErrorHandlingTest$IO</ID>
<ID>InjectDispatcher:StreamableHttpClientTransportTest.kt:StreamableHttpClientTransportTest$Default</ID>
<ID>LongParameterList:MockMcp.kt:MockMcp$fun handleJSONRPCRequest</ID>
<ID>LongParameterList:MockMcp.kt:MockMcp$fun handleWithResult</ID>
<ID>LongParameterList:MockMcp.kt:MockMcp$fun onInitialize</ID>
Expand Down
2 changes: 2 additions & 0 deletions kotlin-sdk-server/api/kotlin-sdk-server.api
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport

public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport {
public fun <init> (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V
public fun <init> (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;)V
public synthetic fun <init> (Lkotlinx/io/Source;Lkotlinx/io/Sink;JLkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/CoroutineScope;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.NonCancellable
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
Expand All @@ -25,125 +27,155 @@ import kotlinx.io.readByteArray
import kotlinx.io.writeString
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.coroutines.CoroutineContext

private const val READ_BUFFER_SIZE = 8192L

/**
* A server transport that communicates with a client via standard I/O.
*
* Reads from input [Source] and writes to output [Sink].
* [StdioServerTransport] manages the communication between a JSON-RPC server and its clients
* by reading incoming messages from the specified [Source] (input stream) and writing outgoing
* messages to the [Sink] (output stream).
*
* @constructor Creates a new instance of [StdioServerTransport].
* @param inputStream The input [Source] used to receive data.
* @param outputStream The output [Sink] used to send data.
* Example:
* ```kotlin
* val transport = StdioServerTransport(
* source = System.`in`.asInput()
* sink = System.out.asSink()
* )
* ```
*
* @constructor Creates an instance of [StdioServerTransport] with the specified parameters.
* @property source The source for reading incoming messages (e.g., stdin or other readable stream).
* @param sink The sink for writing outgoing messages (e.g., stdout or other writable stream).
* @property readBufferSize The maximum size of the read buffer, defaults to a pre-configured constant.
* @property readChannel The channel for receiving raw byte arrays from the input stream.
* @property writeChannel The channel for sending serialized JSON-RPC messages to the output stream.
* @property readingJobDispatcher The dispatcher to use for the message-reading coroutine.
* @property writingJobDispatcher The dispatcher to use for the message-writing coroutine.
* @property processingJobDispatcher The dispatcher to handle processing of read messages.
* @param coroutineScope Optional coroutine scope to use for managing internal jobs. A new scope
* will be created if not provided.
*/
@OptIn(ExperimentalAtomicApi::class)
public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() {
@Suppress("LongParameterList")
public class StdioServerTransport(
private val source: Source,
sink: Sink,
private val readBufferSize: Long = READ_BUFFER_SIZE,
private val readChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED),
private val writeChannel: Channel<JSONRPCMessage> = Channel(Channel.UNLIMITED),
private var readingJobDispatcher: CoroutineDispatcher = IODispatcher,
private var writingJobDispatcher: CoroutineDispatcher = IODispatcher,
private var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default,
coroutineScope: CoroutineScope? = null,
) : AbstractTransport() {

private val scope: CoroutineScope
private val sink: Sink

init {
require(readBufferSize > 0) { "readBufferSize must be > 0" }
val parentJob = coroutineScope?.coroutineContext?.get(Job)
scope = CoroutineScope(SupervisorJob(parentJob))
this.sink = sink.buffered()
}

/**
* Creates a new instance of [StdioServerTransport]
* with the given [inputStream] [Source] and [outputStream] [Sink].
*/
public constructor(inputStream: Source, outputStream: Sink) : this(
source = inputStream,
sink = outputStream,
)

private val logger = KotlinLogging.logger {}
private val readBuffer = ReadBuffer()
private val initialized: AtomicBoolean = AtomicBoolean(false)
private var readingJob: Job? = null
private var sendingJob: Job? = null
private var processingJob: Job? = null

private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob()
private val scope = CoroutineScope(coroutineContext)
private val readChannel = Channel<ByteArray>(Channel.UNLIMITED)
private val writeChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
private val outputSink = outputStream.buffered()

override suspend fun start() {
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
error("StdioServerTransport already started!")
}

// Launch a coroutine to read from stdin
readingJob = launchReadingJob()
launchReadingJob()

// Launch a coroutine to process messages from readChannel
processingJob = launchProcessingJob()
launchProcessingJob()

// Launch a coroutine to handle message sending
sendingJob = launchSendingJob()
launchSendingJob()
}

private fun launchReadingJob(): Job {
val job = scope.launch {
val buf = Buffer()
@Suppress("TooGenericExceptionCaught")
try {
while (isActive) {
val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE)
if (bytesRead == -1L) {
// EOF reached
break
}
if (bytesRead > 0) {
val chunk = buf.readByteArray()
readChannel.send(chunk)
}
private fun launchReadingJob(): Job = scope.launch(readingJobDispatcher) {
val buf = Buffer()
@Suppress("TooGenericExceptionCaught")
try {
while (isActive) {
val bytesRead = source.readAtMostTo(buf, readBufferSize)
if (bytesRead == -1L) {
// EOF reached
break
}
if (bytesRead > 0) {
val chunk = buf.readByteArray()
readChannel.send(chunk)
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
logger.error(e) { "Error reading from stdin" }
_onError.invoke(e)
} finally {
// Reached EOF or error, close connection
close()
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
logger.error(e) { "Error reading from stdin" }
_onError.invoke(e)
} finally {
// Reached EOF or error, close connection
close()
}
job.invokeOnCompletion { cause ->
logJobCompletion("Message reading", cause)
}
return job
}.apply {
invokeOnCompletion { logJobCompletion("Message reading", it) }
}

private fun launchProcessingJob(): Job {
val job = scope.launch {
@Suppress("TooGenericExceptionCaught")
try {
for (chunk in readChannel) {
readBuffer.append(chunk)
processReadBuffer()
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
_onError.invoke(e)
private fun launchProcessingJob(): Job = scope.launch(processingJobDispatcher) {
@Suppress("TooGenericExceptionCaught")
try {
for (chunk in readChannel) {
readBuffer.append(chunk)
processReadBuffer()
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
_onError.invoke(e)
}
job.invokeOnCompletion { cause ->
}.apply {
invokeOnCompletion { cause ->
logJobCompletion("Processing", cause)
}
return job
}

private fun launchSendingJob(): Job {
val job = scope.launch {
@Suppress("TooGenericExceptionCaught")
try {
for (message in writeChannel) {
val json = serializeMessage(message)
outputSink.writeString(json)
outputSink.flush()
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
logger.error(e) { "Error writing to stdout" }
_onError.invoke(e)
private fun launchSendingJob(): Job = scope.launch(writingJobDispatcher) {
@Suppress("TooGenericExceptionCaught")
try {
for (message in writeChannel) {
val json = serializeMessage(message)
sink.writeString(json)
sink.flush()
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
logger.error(e) { "Error writing to stdout" }
_onError.invoke(e)
}
job.invokeOnCompletion { cause ->
}.apply {
invokeOnCompletion { cause ->
logJobCompletion("Message sending", cause)
if (cause is CancellationException) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd try to control cancellation via channels

readingJob?.cancel(cause)
readChannel.cancel(cause)
}
}
return job
}

private suspend fun processReadBuffer() {
Expand Down Expand Up @@ -189,24 +221,22 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:

withContext(NonCancellable) {
writeChannel.close()
sendingJob?.cancelAndJoin()

runCatching {
inputStream.close()
source.close()
}.onFailure { logger.warn(it) { "Failed to close stdin" } }

readingJob?.cancel()
readChannel.close()

processingJob?.cancelAndJoin()

readBuffer.clear()

runCatching {
outputSink.flush()
outputSink.close()
sink.flush()
sink.close()
}.onFailure { logger.warn(it) { "Failed to close stdout" } }

scope.cancel()
scope.coroutineContext[Job]?.join()

invokeOnCloseCallback()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON
import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.withTimeout
import kotlinx.io.Buffer
import kotlinx.io.RawSink
Expand Down Expand Up @@ -69,6 +73,37 @@ class StdioServerTransportTest {
printOutput = output.asSink().buffered()
}

@Test
fun `should construct with builder`() = runIntegrationTest {
val received = CompletableDeferred<JSONRPCMessage>()

// Set every configuration parameter explicitly with non-default values,
// then verify a message round-trips correctly.
val server = StdioServerTransport(
source = bufferedInput,
sink = printOutput,
readBufferSize = 16L, // non-default: smaller read chunk
readingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Read"), // non-default: limited parallelism
writingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Write"), // non-default: limited parallelism
processingJobDispatcher = Dispatchers.IO.limitedParallelism(2, name = "Worker"), // non-default
readChannel = Channel(capacity = 8, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
writeChannel = Channel(capacity = 16, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
coroutineScope = CoroutineScope(Dispatchers.Default), // non-default: parent scope provided
)
server.onError { throw it }
server.onMessage { received.complete(it) }

server.start()

val message = PingRequest().toJSON()
inputWriter.write(serializeMessage(message))
inputWriter.flush()

received.await() shouldBe message

server.close()
}

@Test
fun `should be safe to close before start`() = runIntegrationTest {
val server = StdioServerTransport(bufferedInput, printOutput)
Expand Down
Loading