diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index bc3f53467..0ba7ab3b8 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -166,9 +166,15 @@ private void handle(McpSchema.JSONRPCMessage message) { else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); handleIncomingRequest(request).onErrorResume(error -> { + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); return Mono.just(errorResponse); }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { logger.warn("Issue sending response to the client, ", t); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 86912b4bf..3de06f503 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -6,11 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.function.Function; import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.json.TypeRef; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,7 +18,6 @@ import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, @@ -39,35 +37,6 @@ class McpClientSessionTests { private static final String ECHO_METHOD = "echo"; - private McpClientSession session; - - private MockMcpClientTransport transport; - - @BeforeEach - void setUp() { - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params)))); - } - - @AfterEach - void tearDown() { - if (session != null) { - session.close(); - } - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The requestTimeout can not be null"); - - assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("transport can not be null"); - } - TypeRef responseType = new TypeRef<>() { }; @@ -76,6 +45,11 @@ void testSendRequest() { String testParam = "test parameter"; String responseData = "test response"; + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + // Create a Mono that will emit the response after the request is sent Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); // Verify response handling @@ -92,10 +66,17 @@ void testSendRequest() { assertThat(request.params()).isEqualTo(testParam); assertThat(response).isEqualTo(responseData); }).verifyComplete(); + + session.close(); } @Test void testSendRequestWithError() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); // Verify error handling @@ -107,20 +88,34 @@ void testSendRequestWithError() { transport.simulateIncomingMessage( new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); }).expectError(McpError.class).verify(); + + session.close(); } @Test void testRequestTimeout() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); // Verify timeout StepVerifier.create(responseMono) .expectError(java.util.concurrent.TimeoutException.class) .verify(TIMEOUT.plusSeconds(1)); + + session.close(); } @Test void testSendNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + Map params = Map.of("key", "value"); Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); @@ -132,6 +127,8 @@ void testSendNotification() { assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); assertThat(notification.params()).isEqualTo(params); }).verifyComplete(); + + session.close(); } @Test @@ -139,8 +136,8 @@ void testRequestHandling() { String echoMessage = "Hello MCP!"; Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -153,15 +150,18 @@ void testRequestHandling() { McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; assertThat(response.result()).isEqualTo(echoMessage); assertThat(response.error()).isNull(); + + session.close(); } @Test void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))), + Function.identity()); // Simulate incoming notification from the server Map notificationParams = Map.of("status", "ready"); @@ -173,10 +173,18 @@ void testNotificationHandling() { // Verify handler was called assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); + + session.close(); } @Test void testUnknownMethodHandling() { + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + // Simulate incoming request for unknown method McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", "test-id", null); @@ -188,10 +196,117 @@ void testUnknownMethodHandling() { McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; assertThat(response.error()).isNotNull(); assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + + session.close(); + } + + @Test + void testRequestHandlerThrowsMcpErrorWithJsonRpcError() { + // Setup: Create a request handler that throws McpError with custom error code and + // data + String testMethod = "test.customError"; + Map errorData = Map.of("customField", "customValue"); + McpClientSession.RequestHandler failingHandler = params -> Mono + .error(McpError.builder(123).message("Custom error message").data(errorData).build()); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain the custom error from McpError + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(123); + assertThat(response.error().message()).isEqualTo("Custom error message"); + assertThat(response.error().data()).isEqualTo(errorData); + + session.close(); + } + + @Test + void testRequestHandlerThrowsGenericException() { + // Setup: Create a request handler that throws a generic RuntimeException + String testMethod = "test.genericError"; + RuntimeException exception = new RuntimeException("Something went wrong"); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(exception); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with aggregated exception + // messages in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Something went wrong"); + // Verify data field contains aggregated exception messages + assertThat(response.error().data()).isNotNull(); + assertThat(response.error().data().toString()).contains("RuntimeException"); + assertThat(response.error().data().toString()).contains("Something went wrong"); + + session.close(); + } + + @Test + void testRequestHandlerThrowsExceptionWithCause() { + // Setup: Create a request handler that throws an exception with a cause chain + String testMethod = "test.chainedError"; + RuntimeException rootCause = new IllegalArgumentException("Root cause message"); + RuntimeException middleCause = new IllegalStateException("Middle cause message", rootCause); + RuntimeException topException = new RuntimeException("Top level message", middleCause); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(topException); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with full exception chain + // in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Top level message"); + // Verify data field contains the full exception chain + String dataString = response.error().data().toString(); + assertThat(dataString).contains("RuntimeException"); + assertThat(dataString).contains("Top level message"); + assertThat(dataString).contains("IllegalStateException"); + assertThat(dataString).contains("Middle cause message"); + assertThat(dataString).contains("IllegalArgumentException"); + assertThat(dataString).contains("Root cause message"); + + session.close(); } @Test void testGracefulShutdown() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + StepVerifier.create(session.closeGracefully()).verifyComplete(); }