Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pkg/teamloader/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package teamloader

import (
"context"
"log/slog"
"slices"

"github.com/docker/cagent/pkg/tools"
)

func WithToolsFilter(inner tools.ToolSet, toolNames ...string) tools.ToolSet {
if len(toolNames) == 0 {
return inner
}

return &filterTools{
ToolSet: inner,
toolNames: toolNames,
}
}

type filterTools struct {
tools.ToolSet
toolNames []string
}

func (f *filterTools) Tools(ctx context.Context) ([]tools.Tool, error) {
allTools, err := f.ToolSet.Tools(ctx)
if err != nil {
return nil, err
}

var filtered []tools.Tool
for _, tool := range allTools {
if !slices.Contains(f.toolNames, tool.Name) {
slog.Debug("Filtering out tool", "tool", tool.Name)
continue
}

filtered = append(filtered, tool)
}

return filtered, nil
}
121 changes: 121 additions & 0 deletions pkg/teamloader/filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package teamloader

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/tools"
)

type mockToolSet struct {
tools.ToolSet
toolsFunc func(ctx context.Context) ([]tools.Tool, error)
}

func (m *mockToolSet) Tools(ctx context.Context) ([]tools.Tool, error) {
if m.toolsFunc != nil {
return m.toolsFunc(ctx)
}
return nil, nil
}

func TestWithToolsFilter_NilToolNames(t *testing.T) {
inner := &mockToolSet{}

wrapped := WithToolsFilter(inner)

assert.Same(t, inner, wrapped)
}

func TestWithToolsFilter_EmptyNames(t *testing.T) {
inner := &mockToolSet{}

wrapped := WithToolsFilter(inner, []string{}...)

assert.Same(t, inner, wrapped)
}

func TestWithToolsFilter_PickOne(t *testing.T) {
inner := &mockToolSet{
toolsFunc: func(context.Context) ([]tools.Tool, error) {
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}, {Name: "tool3"}}, nil
},
}

wrapped := WithToolsFilter(inner, "tool2")

result, err := wrapped.Tools(t.Context())
require.NoError(t, err)
require.Len(t, result, 1)
assert.Equal(t, "tool2", result[0].Name)
}

func TestWithToolsFilter_PickAll(t *testing.T) {
inner := &mockToolSet{
toolsFunc: func(context.Context) ([]tools.Tool, error) {
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}, {Name: "tool3"}}, nil
},
}

wrapped := WithToolsFilter(inner, "tool1", "tool2", "tool3")

result, err := wrapped.Tools(t.Context())
require.NoError(t, err)

require.Len(t, result, 3)
assert.Equal(t, "tool1", result[0].Name)
assert.Equal(t, "tool2", result[1].Name)
assert.Equal(t, "tool3", result[2].Name)
}

func TestWithToolsFilter_NoMatch(t *testing.T) {
inner := &mockToolSet{
toolsFunc: func(context.Context) ([]tools.Tool, error) {
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil
},
}

wrapped := WithToolsFilter(inner, "tool3", "tool4")

result, err := wrapped.Tools(t.Context())
require.NoError(t, err)
assert.Empty(t, result)
}

func TestWithToolsFilter_ErrorFromInner(t *testing.T) {
expectedErr := errors.New("mock error")
inner := &mockToolSet{
toolsFunc: func(context.Context) ([]tools.Tool, error) {
return nil, expectedErr
},
}

wrapped := WithToolsFilter(inner, "tool1")

result, err := wrapped.Tools(t.Context())
assert.Nil(t, result)
assert.ErrorIs(t, err, expectedErr)
}

func TestWithToolsFilter_CaseSensitive(t *testing.T) {
inner := &mockToolSet{
toolsFunc: func(ctx context.Context) ([]tools.Tool, error) {
return []tools.Tool{
{Name: "Tool1"},
{Name: "tool1"},
{Name: "TOOL1"},
}, nil
},
}

wrapped := WithToolsFilter(inner, "tool1")

result, err := wrapped.Tools(t.Context())
require.NoError(t, err)
require.Len(t, result, 1)
assert.Equal(t, "tool1", result[0].Name)
}
15 changes: 9 additions & 6 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
return nil, err
}

t = append(t, WithInstructions(tool, a.Instruction))
wrapped := WithToolsFilter(tool, toolset.Tools...)
wrapped = WithInstructions(wrapped, a.Instruction)

t = append(t, wrapped)
}

if !a.CodeModeTools && !runtimeConfig.GlobalCodeMode {
Expand Down Expand Up @@ -312,7 +315,7 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e
}
}

opts := []builtin.FileSystemOpt{builtin.WithAllowedTools(toolset.Tools)}
var opts []builtin.FileSystemOpt
if len(toolset.PostEdit) > 0 {
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
for i, pe := range toolset.PostEdit {
Expand Down Expand Up @@ -343,13 +346,13 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e

// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
if serverSpec.Type == "remote" {
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, toolset.Tools, runtimeConfig.RedirectURI)
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil
}

return mcp.NewGatewayToolset(mcpServerName, toolset.Config, toolset.Tools, envProvider), nil
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)

case toolset.Type == "mcp" && toolset.Command != "":
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env, toolset.Tools), nil
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil

case toolset.Type == "mcp" && toolset.Remote.URL != "":
headers := map[string]string{}
Expand All @@ -362,7 +365,7 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e
headers[k] = expanded
}

return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, toolset.Tools, runtimeConfig.RedirectURI)
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil

default:
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)
Expand Down
25 changes: 2 additions & 23 deletions pkg/tools/builtin/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"os/exec"
"path/filepath"
"regexp"
"slices"
"strings"
"time"

Expand All @@ -27,20 +26,13 @@ type FilesystemTool struct {
tools.ElicitationTool

allowedDirectories []string
allowedTools []string
postEditCommands []PostEditConfig
}

var _ tools.ToolSet = (*FilesystemTool)(nil)

type FileSystemOpt func(*FilesystemTool)

func WithAllowedTools(allowedTools []string) FileSystemOpt {
return func(t *FilesystemTool) {
t.allowedTools = allowedTools
}
}

func WithPostEditCommands(postEditCommands []PostEditConfig) FileSystemOpt {
return func(t *FilesystemTool) {
t.postEditCommands = postEditCommands
Expand Down Expand Up @@ -150,7 +142,7 @@ type EditFileArgs struct {
}

func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
tls := []tools.Tool{
return []tools.Tool{
{
Name: "create_directory",
Category: "filesystem",
Expand Down Expand Up @@ -337,20 +329,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
Title: "Write File",
},
},
}

if len(t.allowedTools) == 0 {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now but it's going to bite us one day, the system prompt for a toolset should know what tools it has so that it can tweak itself.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day, yes

return tls, nil
}

var allowedTools []tools.Tool
for _, tool := range tls {
if slices.Contains(t.allowedTools, tool.Name) {
allowedTools = append(allowedTools, tool)
}
}

return allowedTools, nil
}, nil
}

// executePostEditCommands executes any matching post-edit commands for the given file path
Expand Down
11 changes: 0 additions & 11 deletions pkg/tools/builtin/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1137,17 +1137,6 @@ func TestFilesystemTool_AddAllowedDirectory(t *testing.T) {
})
}

func TestFilesystemTool_FilterTools(t *testing.T) {
allowedDirs := []string{"/tmp"}
tool := NewFilesystemTool(allowedDirs, WithAllowedTools([]string{"list_allowed_directories"}))

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
require.Len(t, allTools, 1)
require.Equal(t, "list_allowed_directories", allTools[0].Name)
require.NotNil(t, allTools[0].Handler)
}

func TestMatchExcludePattern(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading