diff --git a/mcp/server.go b/mcp/server.go index ed4ec720..df59110e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -315,8 +315,11 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() if !ok { - // TODO: surface the error code over the wire, instead of flattening it into the string. - return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, req.Params.Name) + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } } return prompt.handler(ctx, req.Session, req.Params) } @@ -340,7 +343,10 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } } return st.handler(ctx, req) } diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..eb13e44f 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -319,6 +319,8 @@ const ( CodeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. CodeUnsupportedMethod = -31001 + // The error code for invalid parameters + CodeInvalidParams = -32602 ) // notifySessions calls Notify on all the sessions. diff --git a/mcp/tool.go b/mcp/tool.go index 15f17e11..6d6c7204 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -12,6 +12,7 @@ import ( "reflect" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) // A ToolHandler handles a call to tools/call. @@ -69,9 +70,16 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool Session: req.Session, Params: params, }) - // TODO(rfindley): investigate why server errors are embedded in this strange way, - // rather than returned as jsonrpc2 server errors. + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc2.WireError); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec return &CallToolResult{ Content: []Content{&TextContent{Text: err.Error()}}, IsError: true, diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 609536cc..4c73ec63 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -7,12 +7,16 @@ package mcp import ( "context" "encoding/json" + "errors" + "fmt" "reflect" + "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) // testToolHandler is used for type inference in TestNewServerTool. @@ -132,3 +136,91 @@ func TestUnmarshalSchema(t *testing.T) { } } + +func TestToolErrorHandling(t *testing.T) { + // Construct server and add both tools at the top level + server := NewServer(testImpl, nil) + + // Create a tool that returns a structured error + structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: "internal server error", + } + } + + // Create a tool that returns a regular error + regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { + return nil, fmt.Errorf("tool execution failed") + } + + AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler) + AddTool(server, &Tool{Name: "regular_error_tool", Description: "returns regular error"}, regularErrorHandler) + + // Connect server and client once + ct, st := NewInMemoryTransports() + _, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatal(err) + } + + client := NewClient(testImpl, nil) + cs, err := client.Connect(context.Background(), ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that structured JSON-RPC errors are returned directly + t.Run("structured_error", func(t *testing.T) { + // Call the tool + _, err = cs.CallTool(context.Background(), &CallToolParams{ + Name: "error_tool", + Arguments: map[string]any{}, + }) + + // Should get the structured error directly + if err == nil { + t.Fatal("expected error, got nil") + } + + var wireErr *jsonrpc2.WireError + if !errors.As(err, &wireErr) { + t.Fatalf("expected WireError, got %[1]T: %[1]v", err) + } + + if wireErr.Code != CodeInvalidParams { + t.Errorf("expected error code %d, got %d", CodeInvalidParams, wireErr.Code) + } + }) + + // Test that regular errors are embedded in tool results + t.Run("regular_error", func(t *testing.T) { + // Call the tool + result, err := cs.CallTool(context.Background(), &CallToolParams{ + Name: "regular_error_tool", + Arguments: map[string]any{}, + }) + + // Should not get an error at the protocol level + if err != nil { + t.Fatalf("unexpected protocol error: %v", err) + } + + // Should get a result with IsError=true + if !result.IsError { + t.Error("expected IsError=true, got false") + } + + // Should have error message in content + if len(result.Content) == 0 { + t.Error("expected error content, got empty") + } + + if textContent, ok := result.Content[0].(*TextContent); !ok { + t.Error("expected TextContent") + } else if !strings.Contains(textContent.Text, "tool execution failed") { + t.Errorf("expected error message in content, got: %s", textContent.Text) + } + }) +}