From 8bb5e142e46a7e592f8ae662fb430fdc91d84e79 Mon Sep 17 00:00:00 2001 From: noear Date: Thu, 17 Apr 2025 17:02:13 +0800 Subject: [PATCH 1/2] fixed: add WebRxSseClientTransport --- mcp-solon/mcp-solon-webrx/pom.xml | 12 +- .../transport/WebRxSseClientTransport.java | 319 ++++++++++++++++++ .../WebRxSseServerTransportProvider.java | 92 +++-- pom.xml | 2 +- 4 files changed, 386 insertions(+), 39 deletions(-) create mode 100644 mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransport.java diff --git a/mcp-solon/mcp-solon-webrx/pom.xml b/mcp-solon/mcp-solon-webrx/pom.xml index 1485b05..f6c337e 100644 --- a/mcp-solon/mcp-solon-webrx/pom.xml +++ b/mcp-solon/mcp-solon-webrx/pom.xml @@ -29,7 +29,7 @@ - io.modelcontextprotocol.sdk + io.modelcontextprotocol.sdk.j8 mcp-test 0.8.1 test @@ -51,7 +51,6 @@ org.noear solon-net-httputils ${solon.version} - test @@ -70,7 +69,14 @@ org.noear - solon-boot-jetty + solon-boot-smarthttp + ${solon.version} + test + + + + org.noear + solon-flow ${solon.version} test diff --git a/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransport.java b/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransport.java new file mode 100644 index 0000000..58debaa --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransport.java @@ -0,0 +1,319 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.noear.solon.net.http.HttpResponse; +import org.noear.solon.net.http.HttpUtilsBuilder; +import org.noear.solon.net.http.textstream.ServerSentEvent; +import org.noear.solon.rx.SimpleSubscriber; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * Server-Sent Events (SSE) implementation of the + * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE + * transport specification. + * + *

+ * This transport establishes a bidirectional communication channel where: + *

+ * + *

+ * The message flow follows these steps: + *

    + *
  1. The client establishes an SSE connection to the server's /sse endpoint
  2. + *
  3. The server sends an 'endpoint' event containing the URI for sending messages
  4. + *
+ * + * This implementation uses {@link HttpUtilsBuilder} for HTTP communications and supports JSON. and base JDK8 + * serialization/deserialization of messages. + * + * @author Christian Tzolov + * @author noear + * @see MCP + * HTTP with SSE Transport Specification + */ +public class WebRxSseClientTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebRxSseClientTransport.class); + + /** SSE event type for JSON-RPC messages */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + /** SSE event type for endpoint discovery */ + private static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** Default SSE endpoint path */ + private static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** HttpUtils instance builder */ + private final HttpUtilsBuilder webBuilder; + + /** SSE endpoint path */ + private final String sseEndpoint; + + /** JSON object mapper for message serialization/deserialization */ + protected ObjectMapper objectMapper; + + /** Flag indicating if the transport is in closing state */ + private volatile boolean isClosing = false; + + /** Latch for coordinating endpoint discovery */ + private final CountDownLatch closeLatch = new CountDownLatch(1); + + /** Holds the discovered message endpoint URL */ + private final AtomicReference messageEndpoint = new AtomicReference<>(); + + /** Holds the SSE connection future */ + private final AtomicReference> connectionFuture = new AtomicReference<>(); + + /** + * Creates a new transport instance with default HTTP client and object mapper. + * @param webBuilder the HttpUtilsBuilder to use for creating the HttpUtils instance + */ + public WebRxSseClientTransport(HttpUtilsBuilder webBuilder) { + this(webBuilder, new ObjectMapper()); + } + + /** + * Creates a new transport instance with custom HTTP client builder and object mapper. + * @param webBuilder the HttpUtilsBuilder to use for creating the HttpUtils instance + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper or clientBuilder is null + */ + public WebRxSseClientTransport(HttpUtilsBuilder webBuilder, ObjectMapper objectMapper) { + this(webBuilder, DEFAULT_SSE_ENDPOINT, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder and object mapper. + * @param webBuilder the HttpUtilsBuilder to use for creating the HttpUtils instance + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper or clientBuilder is null + */ + public WebRxSseClientTransport(HttpUtilsBuilder webBuilder, String sseEndpoint, + ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(webBuilder, "baseUri must not be empty"); + Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); + this.webBuilder = webBuilder; + this.sseEndpoint = sseEndpoint; + this.objectMapper = objectMapper; + } + + /** + * Creates a new builder for {@link WebRxSseClientTransport}. + * @param webBuilder the HttpUtilsBuilder to use for creating the HttpUtils instance + * @return a new builder instance + */ + public static Builder builder(HttpUtilsBuilder webBuilder) { + return new Builder(webBuilder); + } + + /** + * Builder for {@link WebRxSseClientTransport}. + */ + public static class Builder { + + private final HttpUtilsBuilder webBuilder; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Creates a new builder with the specified base URI. + * @param webBuilder the HttpUtilsBuilder to use for creating the HttpUtils instance + */ + public Builder(HttpUtilsBuilder webBuilder) { + Assert.notNull(webBuilder, "webBuilder must not be empty"); + this.webBuilder = webBuilder; + } + + /** + * Sets the SSE endpoint path. + * @param sseEndpoint the SSE endpoint path + * @return this builder + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "sseEndpoint must not be null"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the object mapper for JSON serialization/deserialization. + * @param objectMapper the object mapper + * @return this builder + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a new {@link WebRxSseClientTransport} instance. + * @return a new transport instance + */ + public WebRxSseClientTransport build() { + return new WebRxSseClientTransport(webBuilder, sseEndpoint, objectMapper); + } + + } + + /** + * Establishes the SSE connection with the server and sets up message handling. + * + *

+ * This method: + *

    + *
  • Initiates the SSE connection
  • + *
  • Handles endpoint discovery events
  • + *
  • Processes incoming JSON-RPC messages
  • + *
+ * @param handler the function to process received JSON-RPC messages + * @return a Mono that completes when the connection is established + */ + @Override + public Mono connect(Function, Mono> handler) { + CompletableFuture future = new CompletableFuture<>(); + connectionFuture.set(future); + + webBuilder.build(this.sseEndpoint) + .execAsSseStream("GET") + .subscribe(new SimpleSubscriber() + .doOnNext(event -> { + if (isClosing) { + return; + } + + try { + if (ENDPOINT_EVENT_TYPE.equals(event.getEvent())) { + String endpoint = event.data(); + messageEndpoint.set(endpoint); + closeLatch.countDown(); + future.complete(null); + } else if (MESSAGE_EVENT_TYPE.equals(event.getEvent())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); + handler.apply(Mono.just(message)).subscribe(); + } else { + logger.error("Received unrecognized SSE event type: {}", event.getEvent()); + } + } catch (IOException e) { + logger.error("Error processing SSE event", e); + future.completeExceptionally(e); + } + }).doOnError(error -> { + if (!isClosing) { + logger.warn("SSE connection error", error); + future.completeExceptionally(error); + } + })); + + return Mono.fromFuture(future); + } + + /** + * Sends a JSON-RPC message to the server. + * + *

+ * This method waits for the message endpoint to be discovered before sending the + * message. The message is serialized to JSON and sent as an HTTP POST request. + * @param message the JSON-RPC message to send + * @return a Mono that completes when the message is sent + * @throws McpError if the message endpoint is not available or the wait times out + */ + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (isClosing) { + return Mono.empty(); + } + + try { + if (!closeLatch.await(10, TimeUnit.SECONDS)) { + return Mono.error(new McpError("Failed to wait for the message endpoint")); + } + } catch (InterruptedException e) { + return Mono.error(new McpError("Failed to wait for the message endpoint")); + } + + String endpoint = messageEndpoint.get(); + if (endpoint == null) { + return Mono.error(new McpError("No message endpoint available")); + } + + try { + String jsonText = this.objectMapper.writeValueAsString(message); + CompletableFuture future = webBuilder.build(endpoint) + .header("Content-Type", "application/json") + .bodyOfJson(jsonText) + .execAsync("POST"); + + return Mono.fromFuture(future.thenAccept(response -> { + if (response.code() != 200 && response.code() != 201 && response.code() != 202 + && response.code() != 206) { + logger.error("Error sending message: {}", response.code()); + } + })); + } catch (IOException e) { + if (!isClosing) { + return Mono.error(new RuntimeException("Failed to serialize message", e)); + } + return Mono.empty(); + } + } + + /** + * Gracefully closes the transport connection. + * + *

+ * Sets the closing flag and cancels any pending connection future. This prevents new + * messages from being sent and allows ongoing operations to complete. + * @return a Mono that completes when the closing process is initiated + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + CompletableFuture future = connectionFuture.get(); + if (future != null && !future.isDone()) { + future.cancel(true); + } + }); + } + + /** + * Unmarshals data to the specified type using the configured object mapper. + * @param data the data to unmarshal + * @param typeRef the type reference for the target type + * @param the target type + * @return the unmarshalled object + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } +} \ No newline at end of file diff --git a/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider.java b/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider.java index 2746445..6a2ab8d 100644 --- a/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider.java +++ b/mcp-solon/mcp-solon-webrx/src/main/java/io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider.java @@ -52,7 +52,7 @@ *

* This implementation is thread-safe and can handle multiple concurrent client * connections. It uses {@link ConcurrentHashMap} for session management and Project - * Reactor's non-blocking APIs for message processing and delivery. + * Reactor's non-blocking APIs for message processing and delivery. and base JDK8 * * @author Christian Tzolov * @author Alexandros Pappas @@ -92,7 +92,7 @@ public class WebRxSseServerTransportProvider implements McpServerTransportProvid * Map of active client sessions, keyed by session ID. */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - + private final ConcurrentHashMap sessionTransports = new ConcurrentHashMap<>(); /** * Flag indicating if the transport is shutting down. */ @@ -117,9 +117,25 @@ public WebRxSseServerTransportProvider(ObjectMapper objectMapper, String message this.sseEndpoint = sseEndpoint; } + public void sendHeartbeat(){ + for (WebRxMcpSessionTransport transport : sessionTransports.values()) { + transport.sendHeartbeat(); + } + } + public void toHttpHandler(SolonApp app) { - app.get(this.sseEndpoint, this::handleSseConnection); - app.post(this.messageEndpoint, this::handleMessage); + if (app != null) { + app.get(this.sseEndpoint, this::handleSseConnection); + app.post(this.messageEndpoint, this::handleMessage); + } + } + + public String getSseEndpoint() { + return sseEndpoint; + } + + public String getMessageEndpoint() { + return messageEndpoint; } /** @@ -160,7 +176,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * errors if any session fails to receive the message */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Map params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); @@ -169,11 +185,11 @@ public Mono notifyClients(String method, Map params) { logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); return Flux.fromStream(sessions.values().stream()) - .flatMap(session -> session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), - e.getMessage())) - .onErrorComplete()) - .then(); + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); } // FIXME: This javadoc makes claims about using isClosing flag but it's not actually @@ -195,9 +211,9 @@ public Mono notifyClients(String method, Map params) { @Override public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) - .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) - .flatMap(McpServerSession::closeGracefully) - .then(); + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpServerSession::closeGracefully) + .then(); } /** @@ -206,7 +222,7 @@ public Mono closeGracefully() { * @param ctx The incoming server context * @return A Mono which emits a response with the SSE event stream */ - private void handleSseConnection(Context ctx) throws Throwable{ + public void handleSseConnection(Context ctx) throws Throwable{ if (isClosing) { ctx.status(503); ctx.output("Server is shutting down"); @@ -221,6 +237,7 @@ private void handleSseConnection(Context ctx) throws Throwable{ logger.debug("Created new SSE connection for session: {}", sessionId); sessions.put(sessionId, session); + sessionTransports.put(sessionId, sessionTransport); // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); @@ -230,6 +247,7 @@ private void handleSseConnection(Context ctx) throws Throwable{ sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); sessions.remove(sessionId); + sessionTransports.remove(sessionId); }); }); @@ -249,25 +267,25 @@ private void handleSseConnection(Context ctx) throws Throwable{ *

  • Returns appropriate HTTP responses based on processing results
  • *
  • Handles various error conditions with appropriate error responses
  • * - * @param request The incoming server request containing the JSON-RPC message + * @param ctx The incoming server request context containing the JSON-RPC message * @return A Mono emitting the response indicating the message processing result */ - private void handleMessage(Context request) throws Throwable { + public void handleMessage(Context ctx) throws Throwable { if (isClosing) { - request.status(503); - request.output("Server is shutting down"); + ctx.status(503); + ctx.output("Server is shutting down"); return; } - if (Utils.isEmpty(request.param("sessionId"))) { - request.status(404); - request.render(new McpError("Session ID missing in message endpoint")); + if (Utils.isEmpty(ctx.param("sessionId"))) { + ctx.status(404); + ctx.render(new McpError("Session ID missing in message endpoint")); return; } - McpServerSession session = sessions.get(request.param("sessionId")); + McpServerSession session = sessions.get(ctx.param("sessionId")); - String body = request.body(); + String body = ctx.body(); try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); @@ -276,18 +294,18 @@ private void handleMessage(Context request) throws Throwable { return Mono.just(new Entity()); }) .onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - // TODO: instead of signalling the error, just respond with 200 OK - // - the error is signalled on the SSE connection - // return ServerResponse.ok().build(); - return Mono.just(new Entity().status(500).body(new McpError(error.getMessage()))); - }); - - request.returnValue(mono); + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return Mono.just(new Entity().status(500).body(new McpError(error.getMessage()))); + }); + + ctx.returnValue(mono); } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); - request.status(400); - request.render(new McpError("Invalid message format")); + ctx.status(400); + ctx.render(new McpError("Invalid message format")); } } @@ -299,6 +317,10 @@ public WebRxMcpSessionTransport(FluxSink sink) { this.sink = sink; } + public void sendHeartbeat() { + sink.next(new SseEvent().comment("heartbeat")); + } + @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromSupplier(() -> { @@ -310,8 +332,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } }).doOnNext(jsonText -> { SseEvent event = new SseEvent() - .name(MESSAGE_EVENT_TYPE) - .data(jsonText); + .name(MESSAGE_EVENT_TYPE) + .data(jsonText); sink.next(event); }).doOnError(e -> { // TODO log with sessionid diff --git a/pom.xml b/pom.xml index 8465175..7e6282f 100644 --- a/pom.xml +++ b/pom.xml @@ -92,7 +92,7 @@ 4.2.0 5.0.1 2.40.1 - 3.1.2 + 3.2.0 From 36618cfabcb8d4e1db10280e60c49dabf1e877d6 Mon Sep 17 00:00:00 2001 From: noear Date: Thu, 17 Apr 2025 17:02:37 +0800 Subject: [PATCH 2/2] test: add mcp-solon-webrx test --- .../WebRxSseIntegrationTests.java | 500 ++++++++++++++++++ .../client/WebRxSseMcpAsyncClientTests.java | 54 ++ .../client/WebRxSseMcpSyncClientTests.java | 54 ++ .../WebRxSseClientTransportTests.java | 259 +++++++++ .../server/WebRxSseMcpAsyncServerTests.java | 55 ++ .../server/WebRxSseMcpSyncServerTests.java | 55 ++ .../server/transport/BlockingInputStream.java | 69 +++ 7 files changed, 1046 insertions(+) create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/WebRxSseIntegrationTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpAsyncClientTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpSyncClientTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransportTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpAsyncServerTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpSyncServerTests.java create mode 100644 mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/WebRxSseIntegrationTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/WebRxSseIntegrationTests.java new file mode 100644 index 0000000..b322022 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/WebRxSseIntegrationTests.java @@ -0,0 +1,500 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebRxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import lombok.var; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.noear.liquor.eval.Maps; +import org.noear.solon.Solon; +import org.noear.solon.boot.http.HttpServerConfigure; +import org.noear.solon.net.http.HttpUtils; +import org.noear.solon.net.http.HttpUtilsBuilder; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public class WebRxSseIntegrationTests { + + private static final int PORT = 8182; + + // private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + + private WebRxSseServerTransportProvider mcpServerTransportProvider; + + ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpServerTransportProvider = new WebRxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + Solon.start(WebRxSseIntegrationTests.class, new String[]{"server.port=" + PORT}, app -> { + mcpServerTransportProvider.toHttpHandler(app); + app.onEvent(HttpServerConfigure.class, event -> { + event.enableDebug(true); + }); + }); + + clientBulders.put("httpclient", + McpClient.sync(WebRxSseClientTransport.builder(new HttpUtilsBuilder().baseUri("http://localhost:" + PORT)) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build())); + + } + + @AfterEach + public void after() { + if (Solon.app() != null) { + Solon.stopBlock(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Maps.of())); + } catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBulders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.getMessages()).hasSize(1); + assertThat(request.getMessages().get(0).getContent()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(Arrays.asList(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(Arrays.asList()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.getRole()).isEqualTo(Role.USER); + assertThat(result.getContent()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.getContent()).getText()).isEqualTo("Test message"); + assertThat(result.getModel()).isEqualTo("MockModelName"); + assertThat(result.getStopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Maps.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = Arrays.asList(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).getUri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(Arrays.asList(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(Arrays.asList(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Maps.of())); + } catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testRootsNotifciationWithEmptyRootsList(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(Arrays.asList()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testRootsWithMultipleHandlers(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = Arrays.asList(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + List roots = Arrays.asList(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = "{\n" + + "\t\t\t\t\"$schema\": \"http://json-schema.org/draft-07/schema#\",\n" + + "\t\t\t\t\"type\": \"object\",\n" + + "\t\t\t\t\"properties\": {}\n" + + "\t\t\t}"; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(Arrays.asList(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = HttpUtils.http("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .get(); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().getTools()).contains(tool1.getTool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Maps.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(Arrays.asList(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = HttpUtils + .http("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .get(); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = HttpUtils + .http("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .get(); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().getTools()).contains(tool1.getTool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(Arrays.asList(tool1.getTool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(Arrays.asList(tool2.getTool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"httpclient"}) + void testInitialize(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpAsyncClientTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpAsyncClientTests.java new file mode 100644 index 0000000..7cb1656 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpAsyncClientTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.WebRxSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.noear.solon.net.http.HttpUtilsBuilder; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import java.time.Duration; + +/** + * Tests for the {@link McpAsyncClient} with {@link WebRxSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebRxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebRxSseClientTransport.builder(new HttpUtilsBuilder().baseUri(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpSyncClientTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpSyncClientTests.java new file mode 100644 index 0000000..feb2a6f --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/WebRxSseMcpSyncClientTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.WebRxSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.noear.solon.net.http.HttpUtilsBuilder; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import java.time.Duration; + +/** + * Tests for the {@link McpSyncClient} with {@link WebRxSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebRxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebRxSseClientTransport.builder(new HttpUtilsBuilder().baseUri(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(1); + } + +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransportTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransportTests.java new file mode 100644 index 0000000..43e7d96 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/client/transport/WebRxSseClientTransportTests.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.noear.liquor.eval.Maps; +import org.noear.solon.net.http.HttpUtilsBuilder; +import org.noear.solon.net.http.textstream.ServerSentEvent; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Tests for the {@link HttpClientSseClientTransport} class. + * + * @author Christian Tzolov + */ +@Timeout(15) +class WebRxSseClientTransportTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private TestWebRxSseClientTransport transport; + + // Test class to access protected methods + static class TestWebRxSseClientTransport extends WebRxSseClientTransport { + + private final AtomicInteger inboundMessageCount = new AtomicInteger(0); + + private Sinks.Many events = Sinks.many().unicast().onBackpressureBuffer(); + + public TestWebRxSseClientTransport(String baseUri) { + super(new HttpUtilsBuilder().baseUri(baseUri)); + } + + public int getInboundMessageCount() { + return inboundMessageCount.get(); + } + + public void simulateEndpointEvent(String jsonMessage) { + events.tryEmitNext(new ServerSentEvent(null,"endpoint",jsonMessage,null)); + inboundMessageCount.incrementAndGet(); + } + + public void simulateMessageEvent(String jsonMessage) { + events.tryEmitNext(new ServerSentEvent(null,"message",jsonMessage,null)); + inboundMessageCount.incrementAndGet(); + } + + } + + void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @BeforeEach + void setUp() { + startContainer(); + transport = new TestWebRxSseClientTransport(host); + transport.connect(Function.identity()).block(); + } + + @AfterEach + void afterEach() { + if (transport != null) { + assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + cleanup(); + } + + void cleanup() { + container.stop(); + } + + @Test + void testMessageProcessing() { + // Create a test message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Maps.of("key", "value")); + + // Simulate receiving the message + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"test-method\",\n" + + "\t\t\t\t \"id\": \"test-id\",\n" + + "\t\t\t\t \"params\": {\"key\": \"value\"}\n" + + "\t\t\t\t}"); + + // Subscribe to messages and verify + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testResponseMessageProcessing() { + // Simulate receiving a response message + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"id\": \"test-id\",\n" + + "\t\t\t\t \"result\": {\"status\": \"success\"}\n" + + "\t\t\t\t}"); + + // Create and send a request message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Maps.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testErrorMessageProcessing() { + // Simulate receiving an error message + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"id\": \"test-id\",\n" + + "\t\t\t\t \"error\": {\n" + + "\t\t\t\t \"code\": -32600,\n" + + "\t\t\t\t \"message\": \"Invalid Request\"\n" + + "\t\t\t\t }\n" + + "\t\t\t\t}"); + + // Create and send a request message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Maps.of("key", "value")); + + // Verify message handling + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testNotificationMessageProcessing() { + // Simulate receiving a notification message (no id) + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"update\",\n" + + "\t\t\t\t \"params\": {\"status\": \"processing\"}\n" + + "\t\t\t\t}"); + + // Verify the notification was processed + assertThat(transport.getInboundMessageCount()).isEqualTo(1); + } + + @Test + void testGracefulShutdown() { + // Test graceful shutdown + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + // Create a test message + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Maps.of("key", "value")); + + // Verify message is not processed after shutdown + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Message count should remain 0 after shutdown + assertThat(transport.getInboundMessageCount()).isEqualTo(0); + } + + @Test + void testRetryBehavior() { + // Create a client that simulates connection failures + HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host"); + + // Verify that the transport attempts to reconnect + StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); + + // Clean up + failingTransport.closeGracefully().block(); + } + + @Test + void testMultipleMessageProcessing() { + // Simulate receiving multiple messages in sequence + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"method1\",\n" + + "\t\t\t\t \"id\": \"id1\",\n" + + "\t\t\t\t \"params\": {\"key\": \"value1\"}\n" + + "\t\t\t\t}"); + + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"method2\",\n" + + "\t\t\t\t \"id\": \"id2\",\n" + + "\t\t\t\t \"params\": {\"key\": \"value2\"}\n" + + "\t\t\t\t}"); + + // Create and send corresponding messages + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + Maps.of("key", "value1")); + + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + Maps.of("key", "value2")); + + // Verify both messages are processed + StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); + + // Verify message count + assertThat(transport.getInboundMessageCount()).isEqualTo(2); + } + + @Test + void testMessageOrderPreservation() { + // Simulate receiving messages in a specific order + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"first\",\n" + + "\t\t\t\t \"id\": \"1\",\n" + + "\t\t\t\t \"params\": {\"sequence\": 1}\n" + + "\t\t\t\t}"); + + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"second\",\n" + + "\t\t\t\t \"id\": \"2\",\n" + + "\t\t\t\t \"params\": {\"sequence\": 2}\n" + + "\t\t\t\t}"); + + transport.simulateMessageEvent("{\n" + + "\t\t\t\t \"jsonrpc\": \"2.0\",\n" + + "\t\t\t\t \"method\": \"third\",\n" + + "\t\t\t\t \"id\": \"3\",\n" + + "\t\t\t\t \"params\": {\"sequence\": 3}\n" + + "\t\t\t\t}"); + + // Verify message count and order + assertThat(transport.getInboundMessageCount()).isEqualTo(3); + } +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpAsyncServerTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpAsyncServerTests.java new file mode 100644 index 0000000..6950e20 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpAsyncServerTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.noear.solon.Solon; +import org.noear.solon.boot.http.HttpServerConfigure; + +/** + * Tests for {@link McpSyncServer} using {@link WebRxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebRxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = 8181; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebRxSseServerTransportProvider transportProvider; + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebRxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + Solon.start(WebRxSseMcpAsyncServerTests.class, new String[]{"-server.port=" + PORT}, app -> { + transportProvider.toHttpHandler(app); + app.onEvent(HttpServerConfigure.class, event -> { + event.enableDebug(true); + }); + }); + + return transportProvider; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (Solon.app() != null) { + Solon.stopBlock(); + } + } +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpSyncServerTests.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpSyncServerTests.java new file mode 100644 index 0000000..46c04f6 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/WebRxSseMcpSyncServerTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.noear.solon.Solon; +import org.noear.solon.boot.http.HttpServerConfigure; + +/** + * Tests for {@link McpSyncServer} using {@link WebRxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebRxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebRxSseServerTransportProvider transportProvider; + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebRxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + Solon.start(WebRxSseMcpSyncServerTests.class, new String[]{"-server.port=" + PORT}, app -> { + transportProvider.toHttpHandler(app); + app.onEvent(HttpServerConfigure.class, event -> { + event.enableDebug(true); + }); + }); + + return transportProvider; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (Solon.app() != null) { + Solon.stopBlock(); + } + } +} diff --git a/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java new file mode 100644 index 0000000..0ab72a9 --- /dev/null +++ b/mcp-solon/mcp-solon-webrx/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java @@ -0,0 +1,69 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +public class BlockingInputStream extends InputStream { + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + + private volatile boolean completed = false; + + private volatile boolean closed = false; + + @Override + public int read() throws IOException { + if (closed) { + throw new IOException("Stream is closed"); + } + + try { + Integer value = queue.poll(); + if (value == null) { + if (completed) { + return -1; + } + value = queue.take(); // Blocks until data is available + if (value == null && completed) { + return -1; + } + } + return value; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Read interrupted", e); + } + } + + public void write(int b) { + if (!closed && !completed) { + queue.offer(b); + } + } + + public void write(byte[] data) { + if (!closed && !completed) { + for (byte b : data) { + queue.offer((int) b & 0xFF); + } + } + } + + public void complete() { + this.completed = true; + } + + @Override + public void close() { + this.closed = true; + this.completed = true; + this.queue.clear(); + } + +} \ No newline at end of file