Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/rough_edges.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ v2.
landing after the SDK was at v1, we missed an opportunity to panic on invalid
tool names. Instead, we have to simply produce an error log. In v2, we should
panic.

- Inconsistent naming.
- `ResourceUpdatedNotificationsParams` should probably have just been
`ResourceUpdatedParams`, as we don't include the word 'notification' in
other notification param types.
- Similarly, `ProgressNotificationParams` should probably have been
`ProgressParams`.
48 changes: 33 additions & 15 deletions examples/server/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package main

import (
"context"
_ "embed"
"encoding/base64"
"flag"
"fmt"
Expand All @@ -17,6 +18,7 @@ import (
"os"
"runtime"
"strings"
"sync/atomic"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -51,12 +53,8 @@ func main() {
CompletionHandler: complete, // support completions by setting this handler
}

// Optionally add an icon to the server implementation.
icons, err := iconToBase64DataURL("./mcp.png")
if err != nil {
log.Fatalf("failed to read icon: %v", err)
}

// Add an icon to the server implementation.
icons := mcpIcons()
server := mcp.NewServer(&mcp.Implementation{Name: "everything", WebsiteURL: "https://example.com", Icons: icons}, opts)

// Add tools that exercise different features of the protocol.
Expand All @@ -67,7 +65,8 @@ func main() {
mcp.AddTool(server, &mcp.Tool{Name: "ping"}, pingingTool) // performs a ping
mcp.AddTool(server, &mcp.Tool{Name: "log"}, loggingTool) // performs a log
mcp.AddTool(server, &mcp.Tool{Name: "sample"}, samplingTool) // performs sampling
mcp.AddTool(server, &mcp.Tool{Name: "elicit"}, elicitingTool) // performs elicitation
mcp.AddTool(server, &mcp.Tool{Name: "elicit (form)"}, elicitFormTool) // performs form elicitation
mcp.AddTool(server, &mcp.Tool{Name: "elicit (url)"}, elicitURLTool) // performs url elicitation
mcp.AddTool(server, &mcp.Tool{Name: "roots"}, rootsTool) // lists roots

// Add a basic prompt.
Expand Down Expand Up @@ -235,7 +234,7 @@ func samplingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.Ca
}, nil, nil
}

func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) {
func elicitFormTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) {
res, err := req.Session.Elicit(ctx, &mcp.ElicitParams{
Message: "provide a random string",
RequestedSchema: &jsonschema.Schema{
Expand All @@ -255,6 +254,26 @@ func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.C
}, nil, nil
}

var elicitations atomic.Int32

func elicitURLTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) {
elicitID := fmt.Sprintf("%d", elicitations.Add(1))
_, err := req.Session.Elicit(ctx, &mcp.ElicitParams{
Message: "submit a string",
URL: fmt.Sprintf("http://localhost:6062?id=%s", elicitID),
ElicitationID: elicitID,
})
if err != nil {
return nil, nil, fmt.Errorf("eliciting failed: %v", err)
}
// TODO: actually wait for the elicitation form to be submitted.
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "(elicitation pending)"},
},
}, nil, nil
}

func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) {
return &mcp.CompleteResult{
Completion: mcp.CompletionResultDetails{
Expand All @@ -264,15 +283,14 @@ func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResul
}, nil
}

func iconToBase64DataURL(path string) ([]mcp.Icon, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
//go:embed mcp.png
var mcpIconData []byte

func mcpIcons() []mcp.Icon {
return []mcp.Icon{{
Source: "data:image/png;base64," + base64.StdEncoding.EncodeToString(data),
Source: "data:image/png;base64," + base64.StdEncoding.EncodeToString(mcpIconData),
MIMEType: "image/png",
Sizes: []string{"48x48"},
Theme: "light", // or "dark" or empty
}}, nil
}}
}
7 changes: 7 additions & 0 deletions internal/docs/rough_edges.src.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ v2.
landing after the SDK was at v1, we missed an opportunity to panic on invalid
tool names. Instead, we have to simply produce an error log. In v2, we should
panic.

- Inconsistent naming.
- `ResourceUpdatedNotificationsParams` should probably have just been
`ResourceUpdatedParams`, as we don't include the word 'notification' in
other notification param types.
- Similarly, `ProgressNotificationParams` should probably have been
`ProgressParams`.
87 changes: 62 additions & 25 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ type ClientOptions struct {
// Setting ElicitationHandler to a non-nil value causes the client to
// advertise the elicitation capability.
ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error)
// ElicitationModes specifies the elicitation modes supported by the client.
// If ElicitationHandler is set and ElicitationModes is empty, it defaults to ["form"].
ElicitationModes []string
// ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete.
ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest)
// Handlers for notifications from the server.
ToolListChangedHandler func(context.Context, *ToolListChangedRequest)
PromptListChangedHandler func(context.Context, *PromptListChangedRequest)
Expand Down Expand Up @@ -123,6 +128,15 @@ func (c *Client) capabilities() *ClientCapabilities {
}
if c.opts.ElicitationHandler != nil {
caps.Elicitation = &ElicitationCapabilities{}
modes := c.opts.ElicitationModes
if len(modes) == 0 || slices.Contains(modes, "form") {
// Technically, the empty ElicitationCapabilities value is equivalent to
// {"form":{}} for backward compatibility, but we explicitly set the form
// capability.
caps.Elicitation.Form = &FormElicitationCapabilities{}
} else if slices.Contains(modes, "url") {
caps.Elicitation.URL = &URLElicitationCapabilities{}
}
}
return caps
}
Expand Down Expand Up @@ -297,40 +311,55 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (

func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) {
if c.opts.ElicitationHandler == nil {
// TODO: wrap or annotate this error? Pick a standard code?
return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support elicitation")
}

// Validate that the requested schema only contains top-level properties without nesting
schema, err := validateElicitSchema(req.Params.RequestedSchema)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, err.Error())
return nil, jsonrpc2.NewError(codeInvalidParams, "client does not support elicitation")
}

res, err := c.opts.ElicitationHandler(ctx, req)
if err != nil {
return nil, err
// Validate the elicitation parameters based on the mode.
mode := req.Params.Mode
if mode == "" {
mode = "form"
}

// Validate elicitation result content against requested schema
if schema != nil && res.Content != nil {
// TODO: is this the correct behavior if validation fails?
// It isn't the *server's* params that are invalid, so why would we return
// this code to the server?
resolved, err := schema.Resolve(nil)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
switch mode {
case "form":
if req.Params.URL != "" {
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must not be set for form elicitation")
}
if err := resolved.Validate(res.Content); err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
schema, err := validateElicitSchema(req.Params.RequestedSchema)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, err.Error())
}
err = resolved.ApplyDefaults(&res.Content)
res, err := c.opts.ElicitationHandler(ctx, req)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
return nil, err
}
// Validate elicitation result content against requested schema.
if schema != nil && res.Content != nil {
resolved, err := schema.Resolve(nil)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
}
if err := resolved.Validate(res.Content); err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
}
err = resolved.ApplyDefaults(&res.Content)
if err != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
}
}
return res, nil
case "url":
if req.Params.RequestedSchema != nil {
return nil, jsonrpc2.NewError(codeInvalidParams, "requestedSchema must not be set for URL elicitation")
}
if req.Params.URL == "" {
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must be set for URL elicitation")
}
// No schema validation for URL mode, just pass through to handler.
return c.opts.ElicitationHandler(ctx, req)
default:
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("unsupported elicitation mode: %q", mode))
}

return res, nil
}

// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements.
Expand Down Expand Up @@ -528,6 +557,7 @@ var clientMethodInfos = map[string]methodInfo{
notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK),
notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification),
notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification),
notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK),
}

func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
Expand Down Expand Up @@ -692,6 +722,13 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
return nil, nil
}

func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) {
if h := c.opts.ElicitationCompleteHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

// NotifyProgress sends a progress notification from the client to the server
// associated with this session.
// This can be used if the client is performing a long-running task that was
Expand Down
Loading
Loading