diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 04c0e0b4..72d98b21 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -30,11 +30,11 @@ func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.Cal }, nil, nil } -func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func PromptHi(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []*mcp.PromptMessage{ - {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + params.Arguments["name"]}}, + {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}}, }, }, nil } diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index c1052c25..1449076e 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -126,6 +126,6 @@ func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { } } -func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func testPromptHandler(context.Context, *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { panic("not implemented") } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9c578392..0fce4c39 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -46,11 +46,11 @@ var codeReviewPrompt = &Prompt{ Arguments: []*PromptArgument{{Name: "Code", Required: true}}, } -func codReviewPromptHandler(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func codReviewPromptHandler(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { return &GetPromptResult{ Description: "Code review prompt", Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + req.Params.Arguments["Code"]}}, }, }, nil } @@ -102,7 +102,7 @@ func TestEndToEnd(t *testing.T) { return nil, nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) - s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { + s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *GetPromptRequest) (*GetPromptResult, error) { return nil, errTestFailure }) s.AddResource(resource1, readHandler) diff --git a/mcp/prompt.go b/mcp/prompt.go index 0ecf5528..62f38a36 100644 --- a/mcp/prompt.go +++ b/mcp/prompt.go @@ -9,7 +9,7 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) type serverPrompt struct { prompt *Prompt diff --git a/mcp/server.go b/mcp/server.go index 740b2b9d..e7a1de5c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -446,7 +446,7 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req.Session, req.Params) + return prompt.handler(ctx, req) } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) {