Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,11 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar
prompt, ok := s.prompts.get(req.Params.Name)
s.mu.Unlock()
if !ok {
// TODO: surface the error code over the wire, instead of flattening it into the string.
return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, req.Params.Name)
// Return a proper JSON-RPC error with the correct error code
return nil, &jsonrpc2.WireError{
Code: CodeInvalidParams,
Message: fmt.Sprintf("unknown prompt %q", req.Params.Name),
}
}
return prompt.handler(ctx, req.Session, req.Params)
}
Expand All @@ -340,7 +343,10 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam
st, ok := s.tools.get(req.Params.Name)
s.mu.Unlock()
if !ok {
return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name)
return nil, &jsonrpc2.WireError{
Code: CodeInvalidParams,
Message: fmt.Sprintf("unknown tool %q", req.Params.Name),
}
}
return st.handler(ctx, req)
}
Expand Down
2 changes: 2 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ const (
CodeResourceNotFound = -32002
// The error code if the method exists and was called properly, but the peer does not support it.
CodeUnsupportedMethod = -31001
// The error code for invalid parameters
CodeInvalidParams = -32602
)

// notifySessions calls Notify on all the sessions.
Expand Down
12 changes: 10 additions & 2 deletions mcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"reflect"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

// A ToolHandler handles a call to tools/call.
Expand Down Expand Up @@ -69,9 +70,16 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
Session: req.Session,
Params: params,
})
// TODO(rfindley): investigate why server errors are embedded in this strange way,
// rather than returned as jsonrpc2 server errors.
// Handle server errors appropriately:
// - If the handler returns a structured error (like jsonrpc2.WireError), return it directly
// - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true
// - This allows tools to distinguish between protocol errors and tool execution errors
if err != nil {
// Check if this is already a structured JSON-RPC error
if wireErr, ok := err.(*jsonrpc2.WireError); ok {
return nil, wireErr
}
// For regular errors, embed them in the tool result as per MCP spec
return &CallToolResult{
Content: []Content{&TextContent{Text: err.Error()}},
IsError: true,
Expand Down
92 changes: 92 additions & 0 deletions mcp/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

// testToolHandler is used for type inference in TestNewServerTool.
Expand Down Expand Up @@ -132,3 +136,91 @@ func TestUnmarshalSchema(t *testing.T) {

}
}

func TestToolErrorHandling(t *testing.T) {
// Construct server and add both tools at the top level
server := NewServer(testImpl, nil)

// Create a tool that returns a structured error
structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) {
return nil, &jsonrpc2.WireError{
Code: CodeInvalidParams,
Message: "internal server error",
}
}

// Create a tool that returns a regular error
regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) {
return nil, fmt.Errorf("tool execution failed")
}

AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler)
AddTool(server, &Tool{Name: "regular_error_tool", Description: "returns regular error"}, regularErrorHandler)

// Connect server and client once
ct, st := NewInMemoryTransports()
_, err := server.Connect(context.Background(), st, nil)
if err != nil {
t.Fatal(err)
}

client := NewClient(testImpl, nil)
cs, err := client.Connect(context.Background(), ct, nil)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

// Test that structured JSON-RPC errors are returned directly
t.Run("structured_error", func(t *testing.T) {
// Call the tool
_, err = cs.CallTool(context.Background(), &CallToolParams{
Name: "error_tool",
Arguments: map[string]any{},
})

// Should get the structured error directly
if err == nil {
t.Fatal("expected error, got nil")
}

var wireErr *jsonrpc2.WireError
if !errors.As(err, &wireErr) {
t.Fatalf("expected WireError, got %[1]T: %[1]v", err)
}

if wireErr.Code != CodeInvalidParams {
t.Errorf("expected error code %d, got %d", CodeInvalidParams, wireErr.Code)
}
})

// Test that regular errors are embedded in tool results
t.Run("regular_error", func(t *testing.T) {
// Call the tool
result, err := cs.CallTool(context.Background(), &CallToolParams{
Name: "regular_error_tool",
Arguments: map[string]any{},
})

// Should not get an error at the protocol level
if err != nil {
t.Fatalf("unexpected protocol error: %v", err)
}

// Should get a result with IsError=true
if !result.IsError {
t.Error("expected IsError=true, got false")
}

// Should have error message in content
if len(result.Content) == 0 {
t.Error("expected error content, got empty")
}

if textContent, ok := result.Content[0].(*TextContent); !ok {
t.Error("expected TextContent")
} else if !strings.Contains(textContent.Text, "tool execution failed") {
t.Errorf("expected error message in content, got: %s", textContent.Text)
}
})
}