diff --git a/mcp/shared_test.go b/mcp/shared_test.go index f319d80e..0aea1947 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -7,6 +7,7 @@ package mcp import ( "context" "encoding/json" + "fmt" "strings" "testing" ) @@ -88,3 +89,146 @@ func TestToolValidate(t *testing.T) { }) } } + +// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. +// This addresses a vulnerability where missing or null parameters could crash the server. +func TestNilParamsHandling(t *testing.T) { + // Define test types for clarity + type TestArgs struct { + Name string `json:"name"` + Value int `json:"value"` + } + type TestParams = *CallToolParamsFor[TestArgs] + type TestResult = *CallToolResultFor[string] + + // Simple test handler + testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) { + result := "processed: " + params.Arguments.Name + return &CallToolResultFor[string]{StructuredContent: result}, nil + } + + methodInfo := newMethodInfo(testHandler, missingParamsOK) + + // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully + mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { + t.Helper() + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unmarshalParams panicked: %v", r) + } + }() + + params, err := methodInfo.unmarshalParams(rawMsg) + if err != nil { + t.Fatalf("unmarshalParams failed: %v", err) + } + + if expectNil { + if params != nil { + t.Fatalf("Expected nil params, got %v", params) + } + return params + } + + if params == nil { + t.Fatal("unmarshalParams returned unexpected nil") + } + + // Verify the result can be used safely + typedParams := params.(TestParams) + _ = typedParams.Name + _ = typedParams.Arguments.Name + _ = typedParams.Arguments.Value + + return params + } + + // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil + t.Run("missing_params", func(t *testing.T) { + mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag + }) + + t.Run("explicit_null", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag + }) + + t.Run("empty_object", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params + }) + + t.Run("valid_params", func(t *testing.T) { + rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) + params := mustNotPanic(t, rawMsg, false) + + // For valid params, also verify the values are parsed correctly + typedParams := params.(TestParams) + if typedParams.Name != "test" { + t.Errorf("Expected name 'test', got %q", typedParams.Name) + } + if typedParams.Arguments.Name != "hello" { + t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) + } + if typedParams.Arguments.Value != 42 { + t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) + } + }) +} + +// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix +func TestNilParamsEdgeCases(t *testing.T) { + type TestArgs struct { + Name string `json:"name"` + Value int `json:"value"` + } + type TestParams = *CallToolParamsFor[TestArgs] + + testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) { + return &CallToolResultFor[string]{StructuredContent: "test"}, nil + } + + methodInfo := newMethodInfo(testHandler, missingParamsOK) + + // These should fail normally, not be treated as nil params + invalidCases := []json.RawMessage{ + json.RawMessage(""), // empty string - should error + json.RawMessage("[]"), // array - should error + json.RawMessage(`"null"`), // string "null" - should error + json.RawMessage("0"), // number - should error + json.RawMessage("false"), // boolean - should error + } + + for i, rawMsg := range invalidCases { + t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { + params, err := methodInfo.unmarshalParams(rawMsg) + if err == nil && params == nil { + t.Error("Should not return nil params without error") + } + }) + } + + // Test that methods without missingParamsOK flag properly reject nil params + t.Run("reject_when_params_required", func(t *testing.T) { + methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag + + testCases := []struct { + name string + params json.RawMessage + }{ + {"nil_params", nil}, + {"null_params", json.RawMessage(`null`)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := methodInfoStrict.unmarshalParams(tc.params) + if err == nil { + t.Error("Expected error for required params, got nil") + } + if !strings.Contains(err.Error(), "missing required \"params\"") { + t.Errorf("Expected 'missing required params' error, got: %v", err) + } + }) + } + }) +}