From 22a76decd8f1b71abbe995ace9180118e87ed619 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 19 Nov 2025 21:44:38 +0000 Subject: [PATCH] mcp: support url mode elicitation Add initial support for URL mode elicitation, by introducing the new types, as well as both server and client-side validation. A follow-up CL will add support for ErrURLElicitationRequired. For #623 --- docs/rough_edges.md | 7 ++ examples/server/everything/main.go | 48 +++++--- internal/docs/rough_edges.src.md | 7 ++ mcp/client.go | 87 +++++++++---- mcp/elicitation_test.go | 190 +++++++++++++++++++++++++++++ mcp/mcp_test.go | 18 +-- mcp/protocol.go | 44 ++++++- mcp/requests.go | 23 ++-- mcp/server.go | 31 +++++ mcp/shared.go | 3 + 10 files changed, 398 insertions(+), 60 deletions(-) create mode 100644 mcp/elicitation_test.go diff --git a/docs/rough_edges.md b/docs/rough_edges.md index a99e00a8..e7f0a17b 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -16,3 +16,10 @@ v2. landing after the SDK was at v1, we missed an opportunity to panic on invalid tool names. Instead, we have to simply produce an error log. In v2, we should panic. + +- Inconsistent naming. + - `ResourceUpdatedNotificationsParams` should probably have just been + `ResourceUpdatedParams`, as we don't include the word 'notification' in + other notification param types. + - Similarly, `ProgressNotificationParams` should probably have been + `ProgressParams`. diff --git a/examples/server/everything/main.go b/examples/server/everything/main.go index f255ab49..9bd6a7cf 100644 --- a/examples/server/everything/main.go +++ b/examples/server/everything/main.go @@ -7,6 +7,7 @@ package main import ( "context" + _ "embed" "encoding/base64" "flag" "fmt" @@ -17,6 +18,7 @@ import ( "os" "runtime" "strings" + "sync/atomic" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -51,12 +53,8 @@ func main() { CompletionHandler: complete, // support completions by setting this handler } - // Optionally add an icon to the server implementation. - icons, err := iconToBase64DataURL("./mcp.png") - if err != nil { - log.Fatalf("failed to read icon: %v", err) - } - + // Add an icon to the server implementation. + icons := mcpIcons() server := mcp.NewServer(&mcp.Implementation{Name: "everything", WebsiteURL: "https://example.com", Icons: icons}, opts) // Add tools that exercise different features of the protocol. @@ -67,7 +65,8 @@ func main() { mcp.AddTool(server, &mcp.Tool{Name: "ping"}, pingingTool) // performs a ping mcp.AddTool(server, &mcp.Tool{Name: "log"}, loggingTool) // performs a log mcp.AddTool(server, &mcp.Tool{Name: "sample"}, samplingTool) // performs sampling - mcp.AddTool(server, &mcp.Tool{Name: "elicit"}, elicitingTool) // performs elicitation + mcp.AddTool(server, &mcp.Tool{Name: "elicit (form)"}, elicitFormTool) // performs form elicitation + mcp.AddTool(server, &mcp.Tool{Name: "elicit (url)"}, elicitURLTool) // performs url elicitation mcp.AddTool(server, &mcp.Tool{Name: "roots"}, rootsTool) // lists roots // Add a basic prompt. @@ -235,7 +234,7 @@ func samplingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.Ca }, nil, nil } -func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { +func elicitFormTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { res, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ Message: "provide a random string", RequestedSchema: &jsonschema.Schema{ @@ -255,6 +254,26 @@ func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.C }, nil, nil } +var elicitations atomic.Int32 + +func elicitURLTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + elicitID := fmt.Sprintf("%d", elicitations.Add(1)) + _, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: "submit a string", + URL: fmt.Sprintf("http://localhost:6062?id=%s", elicitID), + ElicitationID: elicitID, + }) + if err != nil { + return nil, nil, fmt.Errorf("eliciting failed: %v", err) + } + // TODO: actually wait for the elicitation form to be submitted. + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "(elicitation pending)"}, + }, + }, nil, nil +} + func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { return &mcp.CompleteResult{ Completion: mcp.CompletionResultDetails{ @@ -264,15 +283,14 @@ func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResul }, nil } -func iconToBase64DataURL(path string) ([]mcp.Icon, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } +//go:embed mcp.png +var mcpIconData []byte + +func mcpIcons() []mcp.Icon { return []mcp.Icon{{ - Source: "data:image/png;base64," + base64.StdEncoding.EncodeToString(data), + Source: "data:image/png;base64," + base64.StdEncoding.EncodeToString(mcpIconData), MIMEType: "image/png", Sizes: []string{"48x48"}, Theme: "light", // or "dark" or empty - }}, nil + }} } diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index d8c2c996..b5dbadc1 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -15,3 +15,10 @@ v2. landing after the SDK was at v1, we missed an opportunity to panic on invalid tool names. Instead, we have to simply produce an error log. In v2, we should panic. + +- Inconsistent naming. + - `ResourceUpdatedNotificationsParams` should probably have just been + `ResourceUpdatedParams`, as we don't include the word 'notification' in + other notification param types. + - Similarly, `ProgressNotificationParams` should probably have been + `ProgressParams`. diff --git a/mcp/client.go b/mcp/client.go index b19539c1..cd784015 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -66,6 +66,11 @@ type ClientOptions struct { // Setting ElicitationHandler to a non-nil value causes the client to // advertise the elicitation capability. ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) + // ElicitationModes specifies the elicitation modes supported by the client. + // If ElicitationHandler is set and ElicitationModes is empty, it defaults to ["form"]. + ElicitationModes []string + // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. + ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) // Handlers for notifications from the server. ToolListChangedHandler func(context.Context, *ToolListChangedRequest) PromptListChangedHandler func(context.Context, *PromptListChangedRequest) @@ -123,6 +128,15 @@ func (c *Client) capabilities() *ClientCapabilities { } if c.opts.ElicitationHandler != nil { caps.Elicitation = &ElicitationCapabilities{} + modes := c.opts.ElicitationModes + if len(modes) == 0 || slices.Contains(modes, "form") { + // Technically, the empty ElicitationCapabilities value is equivalent to + // {"form":{}} for backward compatibility, but we explicitly set the form + // capability. + caps.Elicitation.Form = &FormElicitationCapabilities{} + } else if slices.Contains(modes, "url") { + caps.Elicitation.URL = &URLElicitationCapabilities{} + } } return caps } @@ -297,40 +311,55 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) ( func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { if c.opts.ElicitationHandler == nil { - // TODO: wrap or annotate this error? Pick a standard code? - return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support elicitation") - } - - // Validate that the requested schema only contains top-level properties without nesting - schema, err := validateElicitSchema(req.Params.RequestedSchema) - if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, err.Error()) + return nil, jsonrpc2.NewError(codeInvalidParams, "client does not support elicitation") } - res, err := c.opts.ElicitationHandler(ctx, req) - if err != nil { - return nil, err + // Validate the elicitation parameters based on the mode. + mode := req.Params.Mode + if mode == "" { + mode = "form" } - // Validate elicitation result content against requested schema - if schema != nil && res.Content != nil { - // TODO: is this the correct behavior if validation fails? - // It isn't the *server's* params that are invalid, so why would we return - // this code to the server? - resolved, err := schema.Resolve(nil) - if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + switch mode { + case "form": + if req.Params.URL != "" { + return nil, jsonrpc2.NewError(codeInvalidParams, "URL must not be set for form elicitation") } - if err := resolved.Validate(res.Content); err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + schema, err := validateElicitSchema(req.Params.RequestedSchema) + if err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, err.Error()) } - err = resolved.ApplyDefaults(&res.Content) + res, err := c.opts.ElicitationHandler(ctx, req) if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)) + return nil, err + } + // Validate elicitation result content against requested schema. + if schema != nil && res.Content != nil { + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + } + if err := resolved.Validate(res.Content); err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)) + } } + return res, nil + case "url": + if req.Params.RequestedSchema != nil { + return nil, jsonrpc2.NewError(codeInvalidParams, "requestedSchema must not be set for URL elicitation") + } + if req.Params.URL == "" { + return nil, jsonrpc2.NewError(codeInvalidParams, "URL must be set for URL elicitation") + } + // No schema validation for URL mode, just pass through to handler. + return c.opts.ElicitationHandler(ctx, req) + default: + return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("unsupported elicitation mode: %q", mode)) } - - return res, nil } // validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. @@ -528,6 +557,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), + notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -692,6 +722,13 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa return nil, nil } +func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) { + if h := c.opts.ElicitationCompleteHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + // NotifyProgress sends a progress notification from the client to the server // associated with this session. // This can be used if the client is performing a long-running task that was diff --git a/mcp/elicitation_test.go b/mcp/elicitation_test.go new file mode 100644 index 00000000..de5f65d3 --- /dev/null +++ b/mcp/elicitation_test.go @@ -0,0 +1,190 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/google/jsonschema-go/jsonschema" +) + +// TODO: migrate other elicitation tests here. + +func TestElicitationURLMode(t *testing.T) { + ctx := context.Background() + clientErr := errors.New("client failed to elicit") + + testCases := []struct { + name string + handler func(context.Context, *ElicitRequest) (*ElicitResult, error) + params *ElicitParams + wantResultAction string + wantErrMsg string + wantErrCode int64 + }{ + { + name: "success", + handler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + params: &ElicitParams{ + Mode: "url", + Message: "Please provide information via URL", + URL: "https://example.com/form", + }, + wantResultAction: "accept", + }, + { + name: "decline", + handler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "decline"}, nil + }, + params: &ElicitParams{ + Mode: "url", + Message: "Please provide information via URL", + URL: "https://example.com/form", + }, + wantResultAction: "decline", + }, + { + name: "client error", + handler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return nil, clientErr + }, + params: &ElicitParams{ + Mode: "url", + Message: "This should fail", + URL: "https://example.com/form", + }, + wantErrMsg: clientErr.Error(), + }, + { + name: "missing url", + handler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + params: &ElicitParams{ + Mode: "url", + Message: "URL is missing", + }, + wantErrMsg: "URL must be set for URL elicitation", + wantErrCode: codeInvalidParams, + }, + { + name: "schema not allowed", + handler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + params: &ElicitParams{ + Mode: "url", + Message: "Schema is not allowed", + URL: "https://example.com/form", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + }, + }, + wantErrMsg: "requestedSchema must not be set for URL elicitation", + wantErrCode: codeInvalidParams, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ct, st := NewInMemoryTransports() + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationModes: []string{"url"}, + ElicitationHandler: tc.handler, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + result, err := ss.Elicit(ctx, tc.params) + + if tc.wantErrMsg != "" { + if err == nil || !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("Elicit(...): got error %v, want containing %q", err, tc.wantErrMsg) + } + if tc.wantErrCode != 0 { + if code := errorCode(err); code != tc.wantErrCode { + t.Errorf("Elicit(...): got error code %d, want %d", code, tc.wantErrCode) + } + } + } else { + if err != nil { + t.Fatalf("Elicit failed: %v", err) + } + if result.Action != tc.wantResultAction { + t.Errorf("Elicit(...): got action %q, want %q", result.Action, tc.wantResultAction) + } + } + }) + } +} + +func TestElicitationCompleteNotification(t *testing.T) { + ctx := context.Background() + + var elicitationCompleteCh = make(chan *ElicitationCompleteParams, 1) + + c := NewClient(testImpl, &ClientOptions{ + ElicitationModes: []string{"url"}, + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + ElicitationCompleteHandler: func(_ context.Context, req *ElicitationCompleteNotificationRequest) { + elicitationCompleteCh <- req.Params + }, + }) + + cs, ss, cleanup := basicClientServerConnection(t, c, nil, nil) + _ = cs // Dummy usage to avoid "declared and not used" error. + defer cleanup() + + // 1. Server initiates a URL elicitation + elicitID := "testElicitationID-123" + resp, err := ss.Elicit(ctx, &ElicitParams{ + Mode: "url", + Message: "Please complete this form: ", + URL: "https://example.com/form?id=" + elicitID, + ElicitationID: elicitID, + }) + if err != nil { + t.Fatalf("Elicit failed: %v", err) + } + if resp.Action != "accept" { + t.Fatalf("Elicit action is %q, want %q", resp.Action, "accept") + } + + // 2. Server sends elicitation complete notification (simulating out-of-band completion) + err = handleNotify(ctx, notificationElicitationComplete, newServerRequest(ss, &ElicitationCompleteParams{ + ElicitationID: elicitID, + })) + if err != nil { + t.Fatalf("failed to send elicitation complete notification: %v", err) + } + + // 3. Client should receive the notification + select { + case gotParams := <-elicitationCompleteCh: + if gotParams.ElicitationID != elicitID { + t.Errorf("elicitationComplete notification ID mismatch: got %q, want %q", gotParams.ElicitationID, elicitID) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for elicitation complete notification") + } +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index ffaa43dd..c2c949e8 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -62,9 +62,10 @@ func TestEndToEnd(t *testing.T) { // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} - for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} { + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe", "elicitation_complete"} { notificationChans[name] = make(chan int, 1) } + waitForNotification := func(t *testing.T, name string) { t.Helper() select { @@ -150,6 +151,9 @@ func TestEndToEnd(t *testing.T) { ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) { notificationChans["resource_updated"] <- 0 }, + ElicitationCompleteHandler: func(_ context.Context, req *ElicitationCompleteNotificationRequest) { + notificationChans["elicitation_complete"] <- 0 + }, } c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -985,8 +989,8 @@ func TestElicitationUnsupportedMethod(t *testing.T) { if err == nil { t.Error("expected error when ElicitationHandler is not provided, got nil") } - if code := errorCode(err); code != codeUnsupportedMethod { - t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, codeUnsupportedMethod) + if code := errorCode(err); code != -1 { + t.Errorf("got error code %d, want -1", code) } if !strings.Contains(err.Error(), "does not support elicitation") { t.Errorf("error should mention unsupported elicitation, got: %v", err) @@ -1414,7 +1418,7 @@ func TestElicitationProgressToken(t *testing.T) { func TestElicitationCapabilityDeclaration(t *testing.T) { ctx := context.Background() - t.Run("with_handler", func(t *testing.T) { + t.Run("with handler", func(t *testing.T) { ct, st := NewInMemoryTransports() // Client with ElicitationHandler should declare capability @@ -1451,7 +1455,7 @@ func TestElicitationCapabilityDeclaration(t *testing.T) { } }) - t.Run("without_handler", func(t *testing.T) { + t.Run("without handler", func(t *testing.T) { ct, st := NewInMemoryTransports() // Client without ElicitationHandler should not declare capability @@ -1483,8 +1487,8 @@ func TestElicitationCapabilityDeclaration(t *testing.T) { if err == nil { t.Error("expected UnsupportedMethod error when no capability declared") } - if code := errorCode(err); code != codeUnsupportedMethod { - t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, codeUnsupportedMethod) + if code := errorCode(err); code != -1 { + t.Errorf("got error code %d, want -1", code) } }) } diff --git a/mcp/protocol.go b/mcp/protocol.go index 54ccdd9e..8a88f8e2 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -879,7 +879,20 @@ func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t type SamplingCapabilities struct{} // ElicitationCapabilities describes the capabilities for elicitation. -type ElicitationCapabilities struct{} +// +// If neither Form nor URL is set, the 'Form' capabilitiy is assumed. +type ElicitationCapabilities struct { + Form *FormElicitationCapabilities + URL *URLElicitationCapabilities +} + +// FormElicitationCapabilities describes capabilities for form elicitation. +type FormElicitationCapabilities struct { +} + +// URLElicitationCapabilities describes capabilities for url elicitation. +type URLElicitationCapabilities struct { +} // Describes a message issued to or received from an LLM API. type SamplingMessage struct { @@ -1067,6 +1080,10 @@ type ElicitParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // The mode of elicitation to use. + // + // If unset, will be inferred from the other fields. + Mode string `json:"mode"` // The message to present to the user. Message string `json:"message"` // A JSON schema object defining the requested elicitation schema. @@ -1080,7 +1097,17 @@ type ElicitParams struct { // map[string]any). // // Only top-level properties are allowed, without nesting. - RequestedSchema any `json:"requestedSchema"` + // + // This is only used for "form" elicitation. + RequestedSchema any `json:"requestedSchema,omitempty"` + // The URL to present to the user. + // + // This is only used for "url" elicitation. + URL string `json:"url,omitempty"` + // The ID of the elicitation. + // + // This is only used for "url" elicitation. + ElicitationID string `json:"elicitationId,omitempty"` } func (x *ElicitParams) isParams() {} @@ -1105,6 +1132,18 @@ type ElicitResult struct { func (*ElicitResult) isResult() {} +// ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. +type ElicitationCompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The ID of the elicitation that has completed. This must correspond to the + // elicitationId from the original elicitation/create request. + ElicitationID string `json:"elicitationId"` +} + +func (*ElicitationCompleteParams) isParams() {} + // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. type Implementation struct { @@ -1171,6 +1210,7 @@ const ( methodComplete = "completion/complete" methodCreateMessage = "sampling/createMessage" methodElicit = "elicitation/create" + notificationElicitationComplete = "notifications/elicitation/complete" methodGetPrompt = "prompts/get" methodInitialize = "initialize" notificationInitialized = "notifications/initialized" diff --git a/mcp/requests.go b/mcp/requests.go index 82b700f5..f64d6fb6 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -23,15 +23,16 @@ type ( ) type ( - CreateMessageRequest = ClientRequest[*CreateMessageParams] - ElicitRequest = ClientRequest[*ElicitParams] - initializedClientRequest = ClientRequest[*InitializedParams] - InitializeRequest = ClientRequest[*InitializeParams] - ListRootsRequest = ClientRequest[*ListRootsParams] - LoggingMessageRequest = ClientRequest[*LoggingMessageParams] - ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] - PromptListChangedRequest = ClientRequest[*PromptListChangedParams] - ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] - ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] - ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams] ) diff --git a/mcp/server.go b/mcp/server.go index d1d6a1d3..254c2d5e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" + "errors" "fmt" "iter" "log/slog" @@ -1018,6 +1019,36 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli if err := ss.checkInitialized(methodElicit); err != nil { return nil, err } + if params == nil { + return nil, fmt.Errorf("%w: params cannot be nil", jsonrpc2.ErrInvalidParams) + } + + if params.Mode == "" { + params2 := *params + if params.URL != "" || params.ElicitationID != "" { + params2.Mode = "url" + } else { + params2.Mode = "form" + } + params = ¶ms2 + } + + if iparams := ss.InitializeParams(); iparams == nil || iparams.Capabilities == nil || iparams.Capabilities.Elicitation == nil { + return nil, fmt.Errorf("client does not support elicitation") + } + caps := ss.InitializeParams().Capabilities.Elicitation + switch params.Mode { + case "form": + if caps.Form == nil && caps.URL != nil { + // Note: if both 'Form' and 'URL' are nil, we assume the client supports + // form elicitation for backward compatibility. + return nil, errors.New(`client does not support "form" elicitation`) + } + case "url": + if caps.URL == nil { + return nil, errors.New(`client does not support "url" elicitation`) + } + } res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) if err != nil { diff --git a/mcp/shared.go b/mcp/shared.go index b710c952..3fac40b2 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -335,6 +335,9 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont const ( codeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. + // + // TODO(rfindley): this code is wrong, and we should fix it to be + // consistent with other SDKs. codeUnsupportedMethod = -31001 // The error code for invalid parameters codeInvalidParams = -32602