diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 48e2ca7e..045addff 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -65,22 +65,35 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun addTools (Ljava/util/List;)V public final fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun createElicitation (Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ElicitRequestParams$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ElicitRequestParams$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun createSession (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected final fun getInstructionsProvider ()Lkotlin/jvm/functions/Function0; protected final fun getOptions ()Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions; public final fun getPrompts ()Ljava/util/Map; public final fun getResources ()Ljava/util/Map; protected final fun getServerInfo ()Lio/modelcontextprotocol/kotlin/sdk/types/Implementation; + public final fun getSessions ()Ljava/util/Map; public final fun getTools ()Ljava/util/Map; + public final fun listRoots (Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun onClose (Lkotlin/jvm/functions/Function0;)V public final fun onConnect (Lkotlin/jvm/functions/Function0;)V public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun removePrompt (Ljava/lang/String;)Z public final fun removePrompts (Ljava/util/List;)I public final fun removeResource (Ljava/lang/String;)Z public final fun removeResources (Ljava/util/List;)I public final fun removeTool (Ljava/lang/String;)Z public final fun removeTools (Ljava/util/List;)I + public final fun sendLoggingMessage (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { @@ -98,10 +111,13 @@ public class io/modelcontextprotocol/kotlin/sdk/server/ServerSession : io/modelc public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ElicitRequestParams$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun equals (Ljava/lang/Object;)Z public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/types/Implementation; protected final fun getInstructions ()Ljava/lang/String; protected final fun getServerInfo ()Lio/modelcontextprotocol/kotlin/sdk/types/Implementation; + public final fun getSessionId ()Ljava/lang/String; + public fun hashCode ()I public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public fun onClose ()V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 3ce426f1..7fa7a98f 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -2,9 +2,16 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions +import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions import io.modelcontextprotocol.kotlin.sdk.shared.Transport import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult +import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.types.Implementation @@ -14,22 +21,22 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.types.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.Prompt import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult import io.modelcontextprotocol.kotlin.sdk.types.Resource +import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent import io.modelcontextprotocol.kotlin.sdk.types.Tool import io.modelcontextprotocol.kotlin.sdk.types.ToolAnnotations import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema -import kotlinx.atomicfu.atomic -import kotlinx.atomicfu.update -import kotlinx.collections.immutable.persistentListOf import kotlinx.coroutines.CancellationException import kotlinx.serialization.json.JsonObject @@ -45,7 +52,7 @@ public class ServerOptions(public val capabilities: ServerCapabilities, enforceS ProtocolOptions(enforceStrictCapabilities = enforceStrictCapabilities) /** - * An MCP server on top of a pluggable transport. + * An MCP server is responsible for storing features and handling new connections. * * This server automatically responds to the initialization flow as initiated by the client. * You can register tools, prompts, and resources using [addTool], [addPrompt], and [addResource]. @@ -79,7 +86,13 @@ public open class Server( block: Server.() -> Unit = {}, ) : this(serverInfo, options, { instructions }, block) - private val sessions = atomic(persistentListOf()) + private val sessionRegistry = ServerSessionRegistry() + + /** + * Provides a snapshot of all sessions currently registered in the server + */ + public val sessions: Map + get() = sessionRegistry.sessions @Suppress("ktlint:standard:backing-property-naming") private var _onInitialized: (() -> Unit) = {} @@ -107,7 +120,10 @@ public open class Server( public suspend fun close() { logger.debug { "Closing MCP server" } - sessions.value.forEach { session -> session.close() } + sessions.forEach { (sessionId, session) -> + logger.info { "Closing session $sessionId" } + session.close() + } _onClose() } @@ -171,12 +187,12 @@ public open class Server( // Register cleanup handler to remove session from list when it closes session.onClose { logger.debug { "Removing closed session from active sessions list" } - sessions.update { list -> list.remove(session) } + sessionRegistry.removeSession(session.sessionId) } logger.debug { "Server session connecting to transport" } session.connect(transport) logger.debug { "Server session successfully connected to transport" } - sessions.update { sessions -> sessions.add(session) } + sessionRegistry.addSession(session) _onConnect() return session @@ -538,4 +554,125 @@ public open class Server( // If you have resource templates, return them here. For now, return empty. return ListResourceTemplatesResult(listOf()) } + + // Start the ServerSession redirection section + + /** + * Triggers [ServerSession.ping] request for session by provided [sessionId]. + * @param sessionId The session ID to ping + */ + public suspend fun ping(sessionId: String): EmptyResult = with(sessionRegistry.getSession(sessionId)) { + ping() + } + + /** + * Triggers [ServerSession.createMessage] request for session by provided [sessionId]. + * + * @param sessionId The session ID to create a message. + * @param params The parameters for creating a message. + * @param options Optional request options. + * @return The created message result. + * @throws IllegalStateException If the server does not support sampling or if the request fails. + */ + public suspend fun createMessage( + sessionId: String, + params: CreateMessageRequest, + options: RequestOptions? = null, + ): CreateMessageResult = with(sessionRegistry.getSession(sessionId)) { + request(params, options) + } + + /** + * Triggers [ServerSession.listRoots] request for session by provided [sessionId]. + * + * @param sessionId The session ID to list roots for. + * @param params JSON parameters for the request, usually empty. + * @param options Optional request options. + * @return The list of roots. + * @throws IllegalStateException If the server or client does not support roots. + */ + public suspend fun listRoots( + sessionId: String, + params: JsonObject = EmptyJsonObject, + options: RequestOptions? = null, + ): ListRootsResult = with(sessionRegistry.getSession(sessionId)) { + listRoots(params, options) + } + + /** + * Triggers [ServerSession.createElicitation] request for session by provided [sessionId]. + * + * @param sessionId The session ID to create elicitation for. + * @param message The elicitation message. + * @param requestedSchema The requested schema for the elicitation. + * @param options Optional request options. + * @return The created elicitation result. + * @throws IllegalStateException If the server does not support elicitation or if the request fails. + */ + public suspend fun createElicitation( + sessionId: String, + message: String, + requestedSchema: ElicitRequestParams.RequestedSchema, + options: RequestOptions? = null, + ): ElicitResult = with(sessionRegistry.getSession(sessionId)) { + createElicitation(message, requestedSchema, options) + } + + /** + * Triggers [ServerSession.sendLoggingMessage] for session by provided [sessionId]. + * + * @param sessionId The session ID to send the logging message to. + * @param notification The logging message notification. + */ + public suspend fun sendLoggingMessage(sessionId: String, notification: LoggingMessageNotification) { + with(sessionRegistry.getSession(sessionId)) { + sendLoggingMessage(notification) + } + } + + /** + * Triggers [ServerSession.sendResourceUpdated] for session by provided [sessionId]. + * + * @param sessionId The session ID to send the resource updated notification to. + * @param notification Details of the updated resource. + */ + public suspend fun sendResourceUpdated(sessionId: String, notification: ResourceUpdatedNotification) { + with(sessionRegistry.getSession(sessionId)) { + sendResourceUpdated(notification) + } + } + + /** + * Triggers [ServerSession.sendResourceListChanged] for session by provided [sessionId]. + * + * @param sessionId The session ID to send the resource list changed notification to. + */ + public suspend fun sendResourceListChanged(sessionId: String) { + with(sessionRegistry.getSession(sessionId)) { + sendResourceListChanged() + } + } + + /** + * Triggers [ServerSession.sendToolListChanged] for session by provided [sessionId]. + * + * @param sessionId The session ID to send the tool list changed notification to. + */ + public suspend fun sendToolListChanged(sessionId: String) { + with(sessionRegistry.getSession(sessionId)) { + sendToolListChanged() + } + } + + /** + * Triggers [ServerSession.sendPromptListChanged] for session by provided [sessionId]. + * + * @param sessionId The session ID to send the prompt list changed notification to. + */ + public suspend fun sendPromptListChanged(sessionId: String) { + with(sessionRegistry.getSession(sessionId)) { + sendPromptListChanged() + } + } + // End the ServerSession redirection section } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt index 53d47b68..e65db2f4 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -35,14 +35,24 @@ import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic import kotlinx.coroutines.CompletableDeferred import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid private val logger = KotlinLogging.logger {} +/** + * Represents a server session. + */ +@Suppress("TooManyFunctions") public open class ServerSession( protected val serverInfo: Implementation, options: ServerOptions, protected val instructions: String?, ) : Protocol(options) { + + @OptIn(ExperimentalUuidApi::class) + public val sessionId: String = Uuid.random().toString() + @Suppress("ktlint:standard:backing-property-naming") private var _onInitialized: (() -> Unit) = {} @@ -430,4 +440,12 @@ public open class ServerSession( * @return true if the message should be accepted (not filtered out), false otherwise. */ private fun isMessageAccepted(level: LoggingLevel): Boolean = !isMessageIgnored(level) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ServerSession) return false + return sessionId == other.sessionId + } + + override fun hashCode(): Int = sessionId.hashCode() } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionRegistry.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionRegistry.kt new file mode 100644 index 00000000..57980d38 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionRegistry.kt @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentMapOf + +internal typealias ServerSessionKey = String + +/** + * Represents a registry for managing server sessions. + */ +internal class ServerSessionRegistry { + + private val logger = KotlinLogging.logger {} + + /** + * Atomic variable used to maintain a thread-safe registry of sessions. + * Stores a persistent map where each session is identified by its unique key. + */ + private val registry = atomic(persistentMapOf()) + + /** + * Returns a read-only view of the current server sessions. + */ + internal val sessions: Map + get() = registry.value + + /** + * Returns a server session by its ID. + * @param sessionId The ID of the session to retrieve. + * @throws IllegalArgumentException If the session doesn't exist. + */ + internal fun getSession(sessionId: ServerSessionKey): ServerSession = + sessions[sessionId] ?: throw IllegalArgumentException("Session not found: $sessionId") + + /** + * Returns a server session by its ID, or null if it doesn't exist. + * @param sessionId The ID of the session to retrieve. + */ + internal fun getSessionOrNull(sessionId: ServerSessionKey): ServerSession? = sessions[sessionId] + + /** + * Registers a server session. + * @param session The session to register. + */ + internal fun addSession(session: ServerSession) { + logger.info { "Adding session: ${session.sessionId}" } + registry.update { sessions -> sessions.put(session.sessionId, session) } + } + + /** + * Removes a server session by its ID. + * @param sessionId The ID of the session to remove. + */ + internal fun removeSession(sessionId: ServerSessionKey) { + logger.info { "Removing session: $sessionId" } + registry.update { sessions -> sessions.remove(sessionId) } + } +}