diff --git a/pkg/inventory/server_tool.go b/pkg/inventory/server_tool.go index 752a4c2bd0..d359089b85 100644 --- a/pkg/inventory/server_tool.go +++ b/pkg/inventory/server_tool.go @@ -3,6 +3,7 @@ package inventory import ( "context" "encoding/json" + "fmt" "github.com/github/github-mcp-server/pkg/octicons" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -133,7 +134,12 @@ func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, hand return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var arguments In if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("invalid arguments: %s", err)}, + }, + IsError: true, + }, nil } resp, _, err := typedHandler(ctx, req, arguments) return resp, err @@ -157,7 +163,12 @@ func NewServerToolWithContextHandler[In any, Out any](tool mcp.Tool, toolset Too return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var arguments In if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("invalid arguments: %s", err)}, + }, + IsError: true, + }, nil } resp, _, err := handler(ctx, req, arguments) return resp, err diff --git a/pkg/inventory/server_tool_test.go b/pkg/inventory/server_tool_test.go new file mode 100644 index 0000000000..a84f90c301 --- /dev/null +++ b/pkg/inventory/server_tool_test.go @@ -0,0 +1,118 @@ +package inventory + +import ( + "context" + "encoding/json" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewServerTool_InvalidArguments_ReturnsIsError(t *testing.T) { + type expectedArgs struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + } + + tool := NewServerTool( + mcp.Tool{Name: "test_tool"}, + testToolsetMetadata("test"), + func(deps any) mcp.ToolHandlerFor[expectedArgs, *mcp.CallToolResult] { + return func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, *mcp.CallToolResult, error) { + t.Fatal("handler should not be called with invalid arguments") + return nil, nil, nil + } + }, + ) + + handler := tool.HandlerFunc(nil) + + badArgs, _ := json.Marshal(map[string]any{"owner": 12345, "repo": true}) + result, err := handler(context.Background(), &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "test_tool", + Arguments: badArgs, + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + assert.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "invalid arguments") +} + +func TestNewServerToolWithContextHandler_InvalidArguments_ReturnsIsError(t *testing.T) { + type expectedArgs struct { + Query string `json:"query"` + Limit int `json:"limit"` + } + + tool := NewServerToolWithContextHandler( + mcp.Tool{Name: "test_context_tool"}, + testToolsetMetadata("test"), + func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, any, error) { + t.Fatal("handler should not be called with invalid arguments") + return nil, nil, nil + }, + ) + + handler := tool.HandlerFunc(nil) + + result, err := handler(context.Background(), &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "test_context_tool", + Arguments: json.RawMessage(`{not valid json`), + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + assert.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "invalid arguments") +} + +func TestNewServerTool_ValidArguments_Succeeds(t *testing.T) { + type expectedArgs struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + } + + tool := NewServerTool( + mcp.Tool{Name: "test_tool"}, + testToolsetMetadata("test"), + func(deps any) mcp.ToolHandlerFor[expectedArgs, *mcp.CallToolResult] { + return func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, *mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "success: " + args.Owner + "/" + args.Repo}, + }, + }, nil, nil + } + }, + ) + + handler := tool.HandlerFunc(nil) + + goodArgs, _ := json.Marshal(map[string]any{"owner": "octocat", "repo": "hello-world"}) + result, err := handler(context.Background(), &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "test_tool", + Arguments: goodArgs, + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + assert.Equal(t, "success: octocat/hello-world", textContent.Text) +}