diff --git a/pkg/tui/components/tool/factory.go b/pkg/tui/components/tool/factory.go index 3ef62e1d3..b47bdf934 100644 --- a/pkg/tui/components/tool/factory.go +++ b/pkg/tui/components/tool/factory.go @@ -9,6 +9,7 @@ import ( "github.com/docker/cagent/pkg/tui/components/tool/shell" "github.com/docker/cagent/pkg/tui/components/tool/todotool" "github.com/docker/cagent/pkg/tui/components/tool/transfertask" + "github.com/docker/cagent/pkg/tui/components/tool/webtool" "github.com/docker/cagent/pkg/tui/components/tool/writefile" "github.com/docker/cagent/pkg/tui/core/layout" "github.com/docker/cagent/pkg/tui/service" @@ -31,10 +32,18 @@ func NewFactory(registry *Registry) *Factory { func (f *Factory) Create(msg *types.Message, sessionState *service.SessionState) layout.Model { toolName := msg.ToolCall.Function.Name + // First try to match by exact tool name if builder, ok := f.registry.Get(toolName); ok { return builder(msg, sessionState) } + // Then try to match by category + if msg.ToolDefinition.Category != "" { + if builder, ok := f.registry.Get("category:" + msg.ToolDefinition.Category); ok { + return builder(msg, sessionState) + } + } + return defaulttool.New(msg, sessionState) } @@ -57,6 +66,10 @@ func newDefaultRegistry() *Registry { registry.Register(builtin.ToolNameListTodos, todotool.New) registry.Register(builtin.ToolNameShell, shell.New) + // Register category-based handlers + registry.Register("category:api", webtool.New) + registry.Register(builtin.ToolNameFetch, webtool.New) + return registry } diff --git a/pkg/tui/components/tool/webtool/webtool.go b/pkg/tui/components/tool/webtool/webtool.go new file mode 100644 index 000000000..972f10c31 --- /dev/null +++ b/pkg/tui/components/tool/webtool/webtool.go @@ -0,0 +1,170 @@ +package webtool + +import ( + "encoding/json" + "fmt" + + tea "charm.land/bubbletea/v2" + + "github.com/docker/cagent/pkg/tui/components/spinner" + "github.com/docker/cagent/pkg/tui/components/toolcommon" + "github.com/docker/cagent/pkg/tui/core/layout" + "github.com/docker/cagent/pkg/tui/service" + "github.com/docker/cagent/pkg/tui/styles" + "github.com/docker/cagent/pkg/tui/types" +) + +type Component struct { + message *types.Message + spinner spinner.Spinner + width int + height int +} + +func New( + msg *types.Message, + _ *service.SessionState, +) layout.Model { + return &Component{ + message: msg, + spinner: spinner.New(spinner.ModeSpinnerOnly), + width: 80, + height: 1, + } +} + +func (c *Component) SetSize(width, height int) tea.Cmd { + c.width = width + c.height = height + return nil +} + +func (c *Component) Init() tea.Cmd { + if c.message.ToolStatus == types.ToolStatusPending || c.message.ToolStatus == types.ToolStatusRunning { + return c.spinner.Init() + } + return nil +} + +func (c *Component) Update(msg tea.Msg) (layout.Model, tea.Cmd) { + if c.message.ToolStatus == types.ToolStatusPending || c.message.ToolStatus == types.ToolStatusRunning { + var cmd tea.Cmd + var model layout.Model + model, cmd = c.spinner.Update(msg) + c.spinner = model.(spinner.Spinner) + return c, cmd + } + + return c, nil +} + +func (c *Component) View() string { + msg := c.message + + // Parse the arguments to extract info about the API call + var args map[string]any + var progressText string + + if err := json.Unmarshal([]byte(msg.ToolCall.Function.Arguments), &args); err != nil { + // If we can't parse, show spinner while running + if msg.ToolStatus == types.ToolStatusRunning { + progressText = c.spinner.View() + } + return toolcommon.RenderTool(toolcommon.Icon(msg.ToolStatus), msg.ToolDefinition.DisplayName(), progressText, "", c.width) + } + + // Extract argument summary for the tool call display + argsText := formatArgs(args) + + // Build the display name with inline result + displayName := msg.ToolDefinition.DisplayName() + if argsText != "" { + displayName = displayName + "(" + styles.MutedStyle.Render(argsText) + ")" + } + + // Add inline result/progress after the tool name + switch msg.ToolStatus { + case types.ToolStatusRunning: + // While running, show what we're calling + endpoint := extractEndpoint(args) + if endpoint != "" { + displayName += styles.MutedStyle.Render(": Calling " + endpoint) + } + case types.ToolStatusCompleted: + // When completed, show a brief summary inline + resultSummary := extractSummary(msg.Content) + displayName += styles.MutedStyle.Render(": " + resultSummary) + } + + // Render everything on one line + return toolcommon.RenderTool(toolcommon.Icon(msg.ToolStatus), displayName, "", "", c.width) +} + +// extractEndpoint tries to find the endpoint/URL being called +func extractEndpoint(args map[string]any) string { + if endpoint, ok := args["endpoint"].(string); ok { + return endpoint + } + if url, ok := args["url"].(string); ok { + return url + } + return "" +} + +// formatArgs creates a concise string representation of the arguments +func formatArgs(args map[string]any) string { + if len(args) == 0 { + return "" + } + + // Check for URL or URLs field (common in fetch tools) + if urlVal, ok := args["url"].(string); ok && urlVal != "" { + return urlVal + } + if urlsVal, ok := args["urls"].([]any); ok && len(urlsVal) > 0 { + // Extract just the URLs from the array + var urls []string + for _, u := range urlsVal { + if urlStr, ok := u.(string); ok { + urls = append(urls, urlStr) + } + } + if len(urls) == 1 { + return urls[0] + } else if len(urls) > 1 { + return fmt.Sprintf("%s (+%d more)", urls[0], len(urls)-1) + } + } + + // Try to find common parameter names that might indicate what's being queried + for _, key := range []string{"query", "q", "search", "message", "prompt", "text"} { + if val, ok := args[key]; ok { + if str, ok := val.(string); ok && str != "" { + return str + } + } + } + + // Fallback: show JSON + b, _ := json.Marshal(args) + return string(b) +} + +// extractSummary tries to extract a meaningful summary from the API response +func extractSummary(content string) string { + size := len(content) + + // Convert to KB if >= 1024 bytes + if size >= 1024*1024 { + // Show in MB + mb := float64(size) / (1024 * 1024) + return fmt.Sprintf("Received %.1f MB", mb) + } else if size >= 1024 { + // Show in KB + kb := float64(size) / 1024 + return fmt.Sprintf("Received %.1f KB", kb) + } + + // Show in bytes for small responses + return fmt.Sprintf("Received %d bytes", size) +}