From a2586354c40714b3dafed3b0e7fbf74492e15eb9 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 12 Aug 2025 22:15:19 +0000 Subject: [PATCH 1/2] mcp: add a test for streamable sampling during a tool call Add a test that attempts (and fails) to reproduce the bug reported in issue #285. For #285 --- mcp/streamable_test.go | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index fd1dc3e4..23b775bc 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) { + // Test that we can make sampling requests during tool handling. + // + // Try this on both the request context and a background context, so + // that messages may be delivered on either the POST or GET connection. + for _, ctx := range map[string]context.Context{ + "request context": ctx, + "background context": context.Background(), + } { + res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) + if err != nil { + return nil, err + } + if g, w := res.Model, "aModel"; g != w { + return nil, fmt.Errorf("got %q, want %q", g, w) + } + } + return &CallToolResultFor[any]{}, nil + }) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. @@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) { Endpoint: httpServer.URL, HTTPClient: httpClient, } - client := NewClient(testImpl, nil) + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } + + // 6. Run the "sampling" tool and verify that the streamable server can + // call tools. + result, err := session.CallTool(ctx, &CallToolParams{ + Name: "sample", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatal(err) + } + if result.IsError { + t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text) + } }) } } From a52fe65788388d7c39d87fe5456decf2bae631dc Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 13 Aug 2025 14:08:21 +0000 Subject: [PATCH 2/2] Update test function signature --- mcp/streamable_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 23b775bc..25dd224e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,7 +40,7 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so