diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 3edfefd7..8c40de08 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -997,6 +997,11 @@ func TestElicitationUnsupportedMethod(t *testing.T) { } } +func anyPtr[T any](v T) *any { + var a any = v + return &a +} + func TestElicitationSchemaValidation(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() @@ -1118,6 +1123,37 @@ func TestElicitationSchemaValidation(t *testing.T) { }, }, }, + { + name: "enum with enum schema", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + OneOf: []*jsonschema.Schema{ + { + Const: anyPtr(map[string]string{ + "const": "high", + "title": "High Priority", + }), + }, + { + Const: anyPtr(map[string]string{ + "const": "medium", + "title": "Medium Priority", + }), + }, + { + Const: anyPtr(map[string]string{ + "const": "low", + "title": "Low Priority", + }), + }, + }, + }, + }, + }, + }, } for _, tc := range validSchemas { @@ -1377,6 +1413,96 @@ func TestElicitationSchemaValidation(t *testing.T) { } } +func TestElicitContentValidation(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + // Set up a client that exercises valid/invalid elicitation: the returned + // Content from the handler ("potato") is validated against the schemas + // defined in the testcases below. + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept", Content: map[string]any{"test": "potato"}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + testcases := []struct { + name string + schema *jsonschema.Schema + expectedError string + }{ + { + name: "string enum with schema not matching content", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": { + Type: "string", + OneOf: []*jsonschema.Schema{ + { + Const: anyPtr(map[string]string{ + "const": "high", + "title": "High Priority", + }), + }, + }, + }, + }, + }, + expectedError: "oneOf: did not validate against any of", + }, + { + name: "string enum with schema matching content", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": { + Type: "string", + OneOf: []*jsonschema.Schema{ + { + Const: anyPtr(map[string]string{ + "const": "potato", + "title": "Potato Priority", + }), + }, + }, + }, + }, + }, + expectedError: "", + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if tc.expectedError != "" { + if err == nil { + t.Errorf("expected error but got no error: %s", tc.expectedError) + return + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Errorf("error message %q does not contain expected text %q", err.Error(), tc.expectedError) + } + } + }) + } +} + func TestElicitationProgressToken(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports()