diff --git a/pkg/codemode/codemode.go b/pkg/codemode/codemode.go index f817ee8b8..3fe732d75 100644 --- a/pkg/codemode/codemode.go +++ b/pkg/codemode/codemode.go @@ -28,7 +28,7 @@ Available tools/functions: ` -func Wrap(toolsets []tools.ToolSet) tools.ToolSet { +func Wrap(toolsets ...tools.ToolSet) tools.ToolSet { return &codeModeTool{ toolsets: toolsets, } diff --git a/pkg/codemode/codemode_test.go b/pkg/codemode/codemode_test.go index ee54db174..4143589fb 100644 --- a/pkg/codemode/codemode_test.go +++ b/pkg/codemode/codemode_test.go @@ -1,12 +1,15 @@ package codemode import ( + "context" "encoding/json" "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/tools" ) func TestCodeModeTool_Tools(t *testing.T) { @@ -74,11 +77,117 @@ func TestCodeModeTool_Instructions(t *testing.T) { } func TestCodeModeTool_StartStop(t *testing.T) { - tool := &codeModeTool{} + inner := &testToolSet{} + + tool := Wrap(inner) + + assert.Equal(t, 0, inner.start) + assert.Equal(t, 0, inner.stop) err := tool.Start(t.Context()) require.NoError(t, err) + assert.Equal(t, 1, inner.start) + assert.Equal(t, 0, inner.stop) err = tool.Stop(t.Context()) require.NoError(t, err) + assert.Equal(t, 1, inner.start) + assert.Equal(t, 1, inner.stop) } + +func TestCodeModeTool_CallHello(t *testing.T) { + tool := Wrap(&testToolSet{ + tools: []tools.Tool{{ + Name: "hello_world", + Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + return &tools.ToolCallResult{ + Output: "Hello, World!", + }, nil + }, + }}, + }) + + allTools, err := tool.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, allTools, 1) + + result, err := allTools[0].Handler(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{ + Arguments: `{"script":"return hello_world();"}`, + }, + }) + require.NoError(t, err) + + var scriptResult ScriptResult + err = json.Unmarshal([]byte(result.Output), &scriptResult) + require.NoError(t, err) + + require.Equal(t, "Hello, World!", scriptResult.Value) + require.Empty(t, scriptResult.StdErr) + require.Empty(t, scriptResult.StdOut) +} + +func TestCodeModeTool_CallEcho(t *testing.T) { + type EchoArgs struct { + Message string `json:"message" jsonschema:"Message to echo"` + } + + tool := Wrap(&testToolSet{ + tools: []tools.Tool{{ + Name: "echo", + Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { + return &tools.ToolCallResult{ + Output: "ECHO", + }, nil + }, + Parameters: tools.MustSchemaFor[EchoArgs](), + }}, + }) + + allTools, err := tool.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, allTools, 1) + + result, err := allTools[0].Handler(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{ + Arguments: `{"script":"return echo({'message':'ECHO'});"}`, + }, + }) + require.NoError(t, err) + + var scriptResult ScriptResult + err = json.Unmarshal([]byte(result.Output), &scriptResult) + require.NoError(t, err) + + require.Equal(t, "ECHO", scriptResult.Value) + require.Empty(t, scriptResult.StdErr) + require.Empty(t, scriptResult.StdOut) +} + +type testToolSet struct { + tools []tools.Tool + start int + stop int +} + +func (t *testToolSet) Tools(ctx context.Context) ([]tools.Tool, error) { + return t.tools, nil +} + +func (t *testToolSet) Instructions() string { + return "" +} + +func (t *testToolSet) Start(context.Context) error { + t.start++ + return nil +} + +func (t *testToolSet) Stop(context.Context) error { + t.stop++ + return nil +} + +func (t *testToolSet) SetElicitationHandler(tools.ElicitationHandler) {} + +func (t *testToolSet) SetOAuthSuccessHandler(func()) {} diff --git a/pkg/codemode/exec.go b/pkg/codemode/exec.go index 276abba7f..9bfde106c 100644 --- a/pkg/codemode/exec.go +++ b/pkg/codemode/exec.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "slices" "github.com/dop251/goja" @@ -66,11 +67,19 @@ func (c *codeModeTool) runJavascript(ctx context.Context, script string) (Script func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) { return func(args map[string]any) (string, error) { + var toolArgs struct { + Required []string `json:"required"` + } + + if err := tools.ConvertSchema(tool.Parameters, &toolArgs); err != nil { + return "", err + } + nonNilArgs := make(map[string]any) for k, v := range args { - // if slices.Contains(tool.Parameters.Required, k) || v != nil { - nonNilArgs[k] = v - // } + if slices.Contains(toolArgs.Required, k) || v != nil { + nonNilArgs[k] = v + } } arguments, err := json.Marshal(nonNilArgs) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 5f9973a1a..d7be85176 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -255,7 +255,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri // This allows the agent to call multiple tools in a single response. // It also allows to combine the results of multiple tools in a single response. return []tools.ToolSet{ - codemode.Wrap(t), + codemode.Wrap(t...), }, nil }