diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b4ea460..925571c 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -22,6 +22,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/mcp" "github.com/google/uuid" @@ -1352,6 +1353,152 @@ func TestStableRequestEncoding(t *testing.T) { } } +// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is +// correctly disabled based on the tool_choice parameter in the request. +// See https://github.com/coder/aibridge/issues/2 +func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { + t.Parallel() + + var ( + toolChoiceAuto = string(constant.ValueOf[constant.Auto]()) + toolChoiceAny = string(constant.ValueOf[constant.Any]()) + toolChoiceNone = string(constant.ValueOf[constant.None]()) + toolChoiceTool = string(constant.ValueOf[constant.Tool]()) + ) + + cases := []struct { + name string + toolChoice any // nil, or map with "type" key. + expectDisableParallel bool + expectToolChoiceTypeInRequest string + }{ + { + name: "no tool_choice defined defaults to auto", + toolChoice: nil, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "tool_choice auto", + toolChoice: map[string]any{"type": toolChoiceAuto}, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "tool_choice any", + toolChoice: map[string]any{"type": toolChoiceAny}, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + { + name: "tool_choice tool", + toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceTool, + }, + { + name: "tool_choice none", + toolChoice: map[string]any{"type": toolChoiceNone}, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceNone, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Configure the bridge. + mcpMgr := mcp.NewServerProxyManager(nil, testTracer) + require.NoError(t, mcpMgr.Init(ctx)) + + arc := txtar.Parse(antSimple) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + // Prepare request body with tool_choice set. + var reqJSON map[string]any + require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqJSON)) + if tc.toolChoice != nil { + reqJSON["tool_choice"] = tc.toolChoice + } + reqBody, err := json.Marshal(reqJSON) + require.NoError(t, err) + + var receivedRequest map[string]any + + // Create a mock server that captures the request body sent upstream. + mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the raw request body. + raw, err := io.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(t, err) + + require.NoError(t, json.Unmarshal(raw, &receivedRequest)) + + // Return a valid API response. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + mockSrv.Start() + t.Cleanup(mockSrv.Close) + + recorder := &mockRecorderClient{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(mockSrv.URL, apiKey), nil)} + bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) + require.NoError(t, err) + + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(bridge) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) + + req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Verify tool_choice in the upstream request. + require.NotNil(t, receivedRequest) + toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) + require.True(t, ok, "expected tool_choice in upstream request") + + // Verify the type matches expectation. + assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"]) + + // Verify name is preserved for tool_choice=tool. + if tc.expectToolChoiceTypeInRequest == toolChoiceTool { + assert.Equal(t, "some_tool", toolChoice["name"]) + } + + // Verify disable_parallel_tool_use based on expectations. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use + disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) + + if tc.expectDisableParallel { + require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") + assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") + } else { + assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") + } + }) + } +} + func TestEnvironmentDoNotLeak(t *testing.T) { // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 9fdb4c7..a9b8802 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -12,6 +12,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/coder/aibridge/mcp" @@ -96,11 +97,29 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() { } // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. - i.req.ToolChoice = anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{ - Type: "auto", - DisableParallelToolUse: anthropic.Bool(true), - }, + // https://github.com/coder/aibridge/issues/2 + toolChoiceType := i.req.ToolChoice.GetType() + var toolChoiceTypeStr string + if toolChoiceType != nil { + toolChoiceTypeStr = *toolChoiceType + } + + switch toolChoiceTypeStr { + // If no tool_choice was defined, assume auto. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. + case "", string(constant.ValueOf[constant.Auto]()): + // We only set OfAuto if no tool_choice was provided (the default). + // "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it. + if i.req.ToolChoice.OfAuto == nil { + i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{} + } + i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true) + case string(constant.ValueOf[constant.Any]()): + i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true) + case string(constant.ValueOf[constant.Tool]()): + i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true) + case string(constant.ValueOf[constant.None]()): + // No-op; if tool_choice=none then tools are not used at all. } }