Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions pkg/tools/codemode/codemode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,32 @@ func TestCodeModeTool_Tools(t *testing.T) {
"value": {
"type": "string",
"description": "The value returned by the script"
},
"tool_calls": {
"type": "array",
"description": "The list of tool calls made during script execution, only included on failure",
"items": {
"type": "object",
"additionalProperties": false,
"required": ["name", "arguments"],
"properties": {
"name": {
"type": "string",
"description": "The name of the tool that was called"
},
"arguments": {
"description": "The arguments passed to the tool"
},
"result": {
"type": "string",
"description": "The raw response returned by the tool"
},
"error": {
"type": "string",
"description": "The error message, if the tool call failed"
}
}
}
}
},
"additionalProperties": false
Expand Down Expand Up @@ -182,3 +208,164 @@ func (t *testToolSet) Stop(context.Context) error {
t.stop++
return nil
}

// TestCodeModeTool_SuccessNoToolCalls verifies that successful execution does not include tool calls.
func TestCodeModeTool_SuccessNoToolCalls(t *testing.T) {
tool := Wrap(&testToolSet{
tools: []tools.Tool{
{
Name: "get_data",
Handler: builtin.NewHandler(func(ctx context.Context, args map[string]any) (*tools.ToolCallResult, error) {
return tools.ResultSuccess("data"), nil
}),
},
},
})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
require.Len(t, allTools, 1)

result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Arguments: `{"script":"return get_data();"}`,
},
})
require.NoError(t, err)

var scriptResult ScriptResult
err = json.Unmarshal([]byte(result.Output), &scriptResult)
require.NoError(t, err)

// Success case should not include tool calls
assert.Equal(t, "data", scriptResult.Value)
assert.Empty(t, scriptResult.ToolCalls, "successful execution should not include tool_calls")
}

// TestCodeModeTool_FailureIncludesToolCalls verifies that failed execution includes tool call history.
func TestCodeModeTool_FailureIncludesToolCalls(t *testing.T) {
tool := Wrap(&testToolSet{
tools: []tools.Tool{
{
Name: "first_tool",
Handler: builtin.NewHandler(func(ctx context.Context, args map[string]any) (*tools.ToolCallResult, error) {
return tools.ResultSuccess("first result"), nil
}),
},
{
Name: "second_tool",
Handler: builtin.NewHandler(func(ctx context.Context, args map[string]any) (*tools.ToolCallResult, error) {
return tools.ResultSuccess("second result"), nil
}),
},
},
})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
require.Len(t, allTools, 1)

// Script calls tools successfully but then throws a runtime error
result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Arguments: `{"script":"var a = first_tool(); var b = second_tool(); throw new Error('runtime error');"}`,
},
})
require.NoError(t, err)

var scriptResult ScriptResult
err = json.Unmarshal([]byte(result.Output), &scriptResult)
require.NoError(t, err)

// Failure case should include tool calls
assert.Contains(t, scriptResult.Value, "runtime error")
require.Len(t, scriptResult.ToolCalls, 2, "failed execution should include tool_calls")

// Verify first tool call
assert.Equal(t, "first_tool", scriptResult.ToolCalls[0].Name)
assert.Equal(t, "first result", scriptResult.ToolCalls[0].Result)
assert.Empty(t, scriptResult.ToolCalls[0].Error)

// Verify second tool call
assert.Equal(t, "second_tool", scriptResult.ToolCalls[1].Name)
assert.Equal(t, "second result", scriptResult.ToolCalls[1].Result)
assert.Empty(t, scriptResult.ToolCalls[1].Error)
}

// TestCodeModeTool_FailureIncludesToolError verifies that tool errors are captured in tool call history.
func TestCodeModeTool_FailureIncludesToolError(t *testing.T) {
tool := Wrap(&testToolSet{
tools: []tools.Tool{
{
Name: "failing_tool",
Handler: builtin.NewHandler(func(ctx context.Context, args map[string]any) (*tools.ToolCallResult, error) {
return nil, assert.AnError
}),
},
},
})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
require.Len(t, allTools, 1)

result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Arguments: `{"script":"return failing_tool();"}`,
},
})
require.NoError(t, err)

var scriptResult ScriptResult
err = json.Unmarshal([]byte(result.Output), &scriptResult)
require.NoError(t, err)

// Script fails due to tool error
assert.Contains(t, scriptResult.Value, "assert.AnError")
require.Len(t, scriptResult.ToolCalls, 1, "failed execution should include tool_calls")

// Verify the tool call recorded the error
assert.Equal(t, "failing_tool", scriptResult.ToolCalls[0].Name)
assert.Empty(t, scriptResult.ToolCalls[0].Result)
assert.Contains(t, scriptResult.ToolCalls[0].Error, "assert.AnError")
}

// TestCodeModeTool_FailureIncludesToolArguments verifies that tool arguments are captured.
func TestCodeModeTool_FailureIncludesToolArguments(t *testing.T) {
type TestArgs struct {
Value string `json:"value" jsonschema:"Test value"`
}

tool := Wrap(&testToolSet{
tools: []tools.Tool{
{
Name: "tool_with_args",
Handler: builtin.NewHandler(func(ctx context.Context, args map[string]any) (*tools.ToolCallResult, error) {
return tools.ResultSuccess("result"), nil
}),
Parameters: tools.MustSchemaFor[TestArgs](),
},
},
})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
require.Len(t, allTools, 1)

result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Arguments: `{"script":"tool_with_args({'value': 'test123'}); throw new Error('forced error');"}`,
},
})
require.NoError(t, err)

var scriptResult ScriptResult
err = json.Unmarshal([]byte(result.Output), &scriptResult)
require.NoError(t, err)

// Verify the tool call captured the arguments
require.Len(t, scriptResult.ToolCalls, 1)
assert.Equal(t, "tool_with_args", scriptResult.ToolCalls[0].Name)
assert.Equal(t, map[string]any{"value": "test123"}, scriptResult.ToolCalls[0].Arguments)
assert.Equal(t, "result", scriptResult.ToolCalls[0].Result)
}
59 changes: 51 additions & 8 deletions pkg/tools/codemode/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,32 @@ import (
)

type ScriptResult struct {
Value string `json:"value" jsonschema:"The value returned by the script"`
StdOut string `json:"stdout" jsonschema:"The standard output of the console"`
StdErr string `json:"stderr" jsonschema:"The standard error of the console"`
Value string `json:"value" jsonschema:"The value returned by the script"`
StdOut string `json:"stdout" jsonschema:"The standard output of the console"`
StdErr string `json:"stderr" jsonschema:"The standard error of the console"`
ToolCalls []ToolCallInfo `json:"tool_calls,omitempty" jsonschema:"The list of tool calls made during script execution, only included on failure"`
}

// ToolCallInfo contains information about a tool call made during script execution.
type ToolCallInfo struct {
Name string `json:"name" jsonschema:"The name of the tool that was called"`
Arguments any `json:"arguments" jsonschema:"The arguments passed to the tool"`
Result string `json:"result,omitempty" jsonschema:"The raw response returned by the tool"`
Error string `json:"error,omitempty" jsonschema:"The error message, if the tool call failed"`
}

// toolCallTracker tracks tool calls made during script execution.
type toolCallTracker struct {
calls []ToolCallInfo
}

func (t *toolCallTracker) record(info ToolCallInfo) {
t.calls = append(t.calls, info)
}

func (c *codeModeTool) runJavascript(ctx context.Context, script string) (ScriptResult, error) {
vm := goja.New()
tracker := &toolCallTracker{}

// Inject console object to the help the LLM debug its own code.
var (
Expand All @@ -36,7 +55,7 @@ func (c *codeModeTool) runJavascript(ctx context.Context, script string) (Script
}

for _, tool := range allTools {
_ = vm.Set(tool.Name, callTool(ctx, tool))
_ = vm.Set(tool.Name, callTool(ctx, tool, tracker))
}
}

Expand All @@ -46,10 +65,12 @@ func (c *codeModeTool) runJavascript(ctx context.Context, script string) (Script
// Run the script.
v, err := vm.RunString(script)
if err != nil {
// Script execution failed - include tool call history to help LLM understand what went wrong
return ScriptResult{
StdOut: stdOut.String(),
StdErr: stdErr.String(),
Value: err.Error(),
StdOut: stdOut.String(),
StdErr: stdErr.String(),
Value: err.Error(),
ToolCalls: tracker.calls,
}, nil
}

Expand All @@ -58,20 +79,26 @@ func (c *codeModeTool) runJavascript(ctx context.Context, script string) (Script
value = fmt.Sprintf("%v", result)
}

// Success case - don't include tool calls to avoid unnecessary overhead
return ScriptResult{
StdOut: stdOut.String(),
StdErr: stdErr.String(),
Value: value,
}, nil
}

func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) {
func callTool(ctx context.Context, tool tools.Tool, tracker *toolCallTracker) func(args map[string]any) (string, error) {
return func(args map[string]any) (string, error) {
var toolArgs struct {
Required []string `json:"required"`
}

if err := tools.ConvertSchema(tool.Parameters, &toolArgs); err != nil {
tracker.record(ToolCallInfo{
Name: tool.Name,
Arguments: args,
Error: err.Error(),
})
return "", err
}

Expand All @@ -84,6 +111,11 @@ func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (s

arguments, err := json.Marshal(nonNilArgs)
if err != nil {
tracker.record(ToolCallInfo{
Name: tool.Name,
Arguments: nonNilArgs,
Error: err.Error(),
})
return "", err
}

Expand All @@ -94,9 +126,20 @@ func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (s
},
})
if err != nil {
tracker.record(ToolCallInfo{
Name: tool.Name,
Arguments: nonNilArgs,
Error: err.Error(),
})
return "", err
}

tracker.record(ToolCallInfo{
Name: tool.Name,
Arguments: nonNilArgs,
Result: result.Output,
})

return result.Output, nil
}
}