Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions pkg/tools/builtin/deferred.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ type AddToolArgs struct {
Name string `json:"name" jsonschema:"The name of the tool to activate"`
}

func (d *DeferredToolset) handleSearchTool(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
var args SearchToolArgs
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
return nil, fmt.Errorf("failed to parse arguments: %w", err)
}

func (d *DeferredToolset) handleSearchTool(_ context.Context, args SearchToolArgs) (*tools.ToolCallResult, error) {
query := strings.ToLower(args.Query)

d.mu.RLock()
Expand Down Expand Up @@ -127,12 +122,7 @@ func (d *DeferredToolset) handleSearchTool(_ context.Context, toolCall tools.Too
}, nil
}

func (d *DeferredToolset) handleAddTool(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
var args AddToolArgs
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
return nil, fmt.Errorf("failed to parse arguments: %w", err)
}

func (d *DeferredToolset) handleAddTool(_ context.Context, args AddToolArgs) (*tools.ToolCallResult, error) {
d.mu.Lock()
defer d.mu.Unlock()

Expand Down Expand Up @@ -168,7 +158,7 @@ func (d *DeferredToolset) Tools(context.Context) ([]tools.Tool, error) {
Description: "Search for available deferred tools by name or description. Use this to discover tools that can be activated.",
Parameters: tools.MustSchemaFor[SearchToolArgs](),
OutputSchema: tools.MustSchemaFor[string](),
Handler: d.handleSearchTool,
Handler: NewHandler(d.handleSearchTool),
Annotations: tools.ToolAnnotations{
Title: "Search Tool",
ReadOnlyHint: true,
Expand All @@ -180,7 +170,7 @@ func (d *DeferredToolset) Tools(context.Context) ([]tools.Tool, error) {
Description: "Activate a deferred tool by name, making it available for use. Use search_tool first to find available tools.",
Parameters: tools.MustSchemaFor[AddToolArgs](),
OutputSchema: tools.MustSchemaFor[string](),
Handler: d.handleAddTool,
Handler: NewHandler(d.handleAddTool),
Annotations: tools.ToolAnnotations{
Title: "Add Tool",
ReadOnlyHint: true,
Expand Down
37 changes: 6 additions & 31 deletions pkg/tools/builtin/deferred_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package builtin

import (
"context"
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -51,32 +50,20 @@ func TestDeferredToolset_SearchTool(t *testing.T) {
require.NoError(t, err)

t.Run("search by name", func(t *testing.T) {
args, _ := json.Marshal(SearchToolArgs{Query: "create"})
result, err := dt.handleSearchTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleSearchTool(ctx, SearchToolArgs{Query: "create"})
require.NoError(t, err)
assert.Contains(t, result.Output, "create_file")
assert.NotContains(t, result.Output, "read_file")
})

t.Run("search by description", func(t *testing.T) {
args, _ := json.Marshal(SearchToolArgs{Query: "content"})
result, err := dt.handleSearchTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleSearchTool(ctx, SearchToolArgs{Query: "content"})
require.NoError(t, err)
assert.Contains(t, result.Output, "read_file")
})

t.Run("search no results", func(t *testing.T) {
args, _ := json.Marshal(SearchToolArgs{Query: "nonexistent"})
result, err := dt.handleSearchTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleSearchTool(ctx, SearchToolArgs{Query: "nonexistent"})
require.NoError(t, err)
assert.Contains(t, result.Output, "No deferred tools found")
})
Expand All @@ -100,11 +87,7 @@ func TestDeferredToolset_AddTool(t *testing.T) {
require.NoError(t, err)
assert.Len(t, initialTools, 2)
t.Run("add existing deferred tool", func(t *testing.T) {
args, _ := json.Marshal(AddToolArgs{Name: "tool1"})
result, err := dt.handleAddTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleAddTool(ctx, AddToolArgs{Name: "tool1"})
require.NoError(t, err)
assert.Contains(t, result.Output, "has been activated")

Expand All @@ -120,21 +103,13 @@ func TestDeferredToolset_AddTool(t *testing.T) {
})

t.Run("add already active tool", func(t *testing.T) {
args, _ := json.Marshal(AddToolArgs{Name: "tool1"})
result, err := dt.handleAddTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleAddTool(ctx, AddToolArgs{Name: "tool1"})
require.NoError(t, err)
assert.Contains(t, result.Output, "already active")
})

t.Run("add non-existent tool", func(t *testing.T) {
args, _ := json.Marshal(AddToolArgs{Name: "nonexistent"})
result, err := dt.handleAddTool(ctx, tools.ToolCall{
Function: tools.FunctionCall{Arguments: string(args)},
})

result, err := dt.handleAddTool(ctx, AddToolArgs{Name: "nonexistent"})
require.NoError(t, err)
assert.Contains(t, result.Output, "not found")
})
Expand Down
18 changes: 7 additions & 11 deletions pkg/tools/builtin/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,13 @@ type fetchHandler struct {
timeout time.Duration
}

func (h *fetchHandler) CallTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
var params struct {
URLs []string `json:"urls"`
Timeout int `json:"timeout,omitempty"`
Format string `json:"format,omitempty"`
}

if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &params); err != nil {
return nil, fmt.Errorf("invalid arguments: %w", err)
}
type FetchToolArgs struct {
URLs []string `json:"urls"`
Timeout int `json:"timeout,omitempty"`
Format string `json:"format,omitempty"`
}

func (h *fetchHandler) CallTool(ctx context.Context, params FetchToolArgs) (*tools.ToolCallResult, error) {
if len(params.URLs) == 0 {
return nil, fmt.Errorf("at least one URL is required")
}
Expand Down Expand Up @@ -338,7 +334,7 @@ func (t *FetchTool) Tools(context.Context) ([]tools.Tool, error) {
"required": []string{"urls", "format"},
},
OutputSchema: tools.MustSchemaFor[string](),
Handler: t.handler.CallTool,
Handler: NewHandler(t.handler.CallTool),
Annotations: tools.ToolAnnotations{
ReadOnlyHint: true,
Title: "Fetch URLs",
Expand Down
93 changes: 33 additions & 60 deletions pkg/tools/builtin/fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestFetch_Call_Success(t *testing.T) {

tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), fetch(t, url))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{URLs: []string{url}})
require.NoError(t, err)

assert.Contains(t, result.Output, "Successfully fetched")
Expand All @@ -117,7 +117,7 @@ func TestFetch_Call_MultipleURLs(t *testing.T) {

tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), fetch(t, url1, url2))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{URLs: []string{url1, url2}})
require.NoError(t, err)

var results []FetchResult
Expand All @@ -132,40 +132,35 @@ func TestFetch_Call_MultipleURLs(t *testing.T) {
func TestFetch_Call_InvalidURL(t *testing.T) {
tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), fetch(t, "invalid-url"))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{
"invalid-url",
},
})
require.NoError(t, err)

assert.Contains(t, result.Output, "Error fetching")
}

func TestFetch_Call_UnsupportedProtocol(t *testing.T) {
tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), fetch(t, "ftp://example.com"))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{
"ftp://example.com",
},
})
require.NoError(t, err)

assert.Contains(t, result.Output, "Error fetching")
assert.Contains(t, result.Output, "only HTTP and HTTPS URLs are supported")
}

func TestFetch_Call_NoURLs(t *testing.T) {
tool := NewFetchTool()

_, err := tool.handler.CallTool(t.Context(), fetch(t))
_, err := tool.handler.CallTool(t.Context(), FetchToolArgs{})
require.ErrorContains(t, err, "at least one URL is required")
}

func TestFetch_Call_InvalidJSON(t *testing.T) {
tool := NewFetchTool()

_, err := tool.handler.CallTool(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Arguments: "invalid json",
},
})
require.ErrorContains(t, err, "invalid arguments")
}

func TestFetch_Markdown(t *testing.T) {
url := runHTTPServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
Expand All @@ -174,10 +169,10 @@ func TestFetch_Markdown(t *testing.T) {

tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), toolCall(t, map[string]any{
"urls": []string{url},
"format": "markdown",
}))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{url},
Format: "markdown",
})
require.NoError(t, err)

assert.Contains(t, result.Output, "Successfully fetched")
Expand All @@ -194,10 +189,10 @@ func TestFetch_Text(t *testing.T) {

tool := NewFetchTool()

result, err := tool.handler.CallTool(t.Context(), toolCall(t, map[string]any{
"urls": []string{url},
"format": "text",
}))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{url},
Format: "text",
})
require.NoError(t, err)

assert.Contains(t, result.Output, "Successfully fetched")
Expand All @@ -215,27 +210,6 @@ func runHTTPServer(t *testing.T, handler http.HandlerFunc) string {
return server.URL
}

func fetch(t *testing.T, urls ...string) tools.ToolCall {
t.Helper()

return toolCall(t, map[string]any{
"urls": urls,
})
}

func toolCall(t *testing.T, args map[string]any) tools.ToolCall {
t.Helper()

argsJSON, err := json.Marshal(args)
require.NoError(t, err)

return tools.ToolCall{
Function: tools.FunctionCall{
Arguments: string(argsJSON),
},
}
}

func TestFetch_RobotsAllowed(t *testing.T) {
robotsContent := "User-agent: *\nAllow: /"

Expand All @@ -254,10 +228,10 @@ func TestFetch_RobotsAllowed(t *testing.T) {
})

tool := NewFetchTool()
result, err := tool.handler.CallTool(t.Context(), toolCall(t, map[string]any{
"urls": []string{url + "/allowed"},
"format": "text",
}))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{url + "/allowed"},
Format: "text",
})

require.NoError(t, err)
assert.Contains(t, result.Output, "Successfully fetched")
Expand All @@ -282,11 +256,10 @@ func TestFetch_RobotsBlocked(t *testing.T) {
})

tool := NewFetchTool()
result, err := tool.handler.CallTool(t.Context(), toolCall(t, map[string]any{
"urls": []string{url + "/blocked"},
"format": "text",
}))

result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{url + "/blocked"},
Format: "text",
})
require.NoError(t, err)
assert.Contains(t, result.Output, "Error fetching")
assert.Contains(t, result.Output, "URL blocked by robots.txt")
Expand All @@ -307,10 +280,10 @@ func TestFetch_RobotsMissing(t *testing.T) {
})

tool := NewFetchTool()
result, err := tool.handler.CallTool(t.Context(), toolCall(t, map[string]any{
"urls": []string{url + "/content"},
"format": "text",
}))
result, err := tool.handler.CallTool(t.Context(), FetchToolArgs{
URLs: []string{url + "/content"},
Format: "text",
})

require.NoError(t, err)
assert.Contains(t, result.Output, "Successfully fetched")
Expand Down
Loading