Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/tui/components/tool/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/docker/cagent/pkg/tui/components/tool/shell"
"github.com/docker/cagent/pkg/tui/components/tool/todotool"
"github.com/docker/cagent/pkg/tui/components/tool/transfertask"
"github.com/docker/cagent/pkg/tui/components/tool/webtool"
"github.com/docker/cagent/pkg/tui/components/tool/writefile"
"github.com/docker/cagent/pkg/tui/core/layout"
"github.com/docker/cagent/pkg/tui/service"
Expand All @@ -31,10 +32,18 @@ func NewFactory(registry *Registry) *Factory {
func (f *Factory) Create(msg *types.Message, sessionState *service.SessionState) layout.Model {
toolName := msg.ToolCall.Function.Name

// First try to match by exact tool name
if builder, ok := f.registry.Get(toolName); ok {
return builder(msg, sessionState)
}

// Then try to match by category
if msg.ToolDefinition.Category != "" {
if builder, ok := f.registry.Get("category:" + msg.ToolDefinition.Category); ok {
return builder(msg, sessionState)
}
}

return defaulttool.New(msg, sessionState)
}

Expand All @@ -57,6 +66,10 @@ func newDefaultRegistry() *Registry {
registry.Register(builtin.ToolNameListTodos, todotool.New)
registry.Register(builtin.ToolNameShell, shell.New)

// Register category-based handlers
registry.Register("category:api", webtool.New)
registry.Register(builtin.ToolNameFetch, webtool.New)

return registry
}

Expand Down
170 changes: 170 additions & 0 deletions pkg/tui/components/tool/webtool/webtool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package webtool

import (
"encoding/json"
"fmt"

tea "charm.land/bubbletea/v2"

"github.com/docker/cagent/pkg/tui/components/spinner"
"github.com/docker/cagent/pkg/tui/components/toolcommon"
"github.com/docker/cagent/pkg/tui/core/layout"
"github.com/docker/cagent/pkg/tui/service"
"github.com/docker/cagent/pkg/tui/styles"
"github.com/docker/cagent/pkg/tui/types"
)

type Component struct {
message *types.Message
spinner spinner.Spinner
width int
height int
}

func New(
msg *types.Message,
_ *service.SessionState,
) layout.Model {
return &Component{
message: msg,
spinner: spinner.New(spinner.ModeSpinnerOnly),
width: 80,
height: 1,
}
}

func (c *Component) SetSize(width, height int) tea.Cmd {
c.width = width
c.height = height
return nil
}

func (c *Component) Init() tea.Cmd {
if c.message.ToolStatus == types.ToolStatusPending || c.message.ToolStatus == types.ToolStatusRunning {
return c.spinner.Init()
}
return nil
}

func (c *Component) Update(msg tea.Msg) (layout.Model, tea.Cmd) {
if c.message.ToolStatus == types.ToolStatusPending || c.message.ToolStatus == types.ToolStatusRunning {
var cmd tea.Cmd
var model layout.Model
model, cmd = c.spinner.Update(msg)
c.spinner = model.(spinner.Spinner)
return c, cmd
}

return c, nil
}

func (c *Component) View() string {
msg := c.message

// Parse the arguments to extract info about the API call
var args map[string]any
var progressText string

if err := json.Unmarshal([]byte(msg.ToolCall.Function.Arguments), &args); err != nil {
// If we can't parse, show spinner while running
if msg.ToolStatus == types.ToolStatusRunning {
progressText = c.spinner.View()
}
return toolcommon.RenderTool(toolcommon.Icon(msg.ToolStatus), msg.ToolDefinition.DisplayName(), progressText, "", c.width)
}

// Extract argument summary for the tool call display
argsText := formatArgs(args)

// Build the display name with inline result
displayName := msg.ToolDefinition.DisplayName()
if argsText != "" {
displayName = displayName + "(" + styles.MutedStyle.Render(argsText) + ")"
}

// Add inline result/progress after the tool name
switch msg.ToolStatus {
case types.ToolStatusRunning:
// While running, show what we're calling
endpoint := extractEndpoint(args)
if endpoint != "" {
displayName += styles.MutedStyle.Render(": Calling " + endpoint)
}
case types.ToolStatusCompleted:
// When completed, show a brief summary inline
resultSummary := extractSummary(msg.Content)
displayName += styles.MutedStyle.Render(": " + resultSummary)
}

// Render everything on one line
return toolcommon.RenderTool(toolcommon.Icon(msg.ToolStatus), displayName, "", "", c.width)
}

// extractEndpoint tries to find the endpoint/URL being called
func extractEndpoint(args map[string]any) string {
if endpoint, ok := args["endpoint"].(string); ok {
return endpoint
}
if url, ok := args["url"].(string); ok {
return url
}
return ""
}

// formatArgs creates a concise string representation of the arguments
func formatArgs(args map[string]any) string {
if len(args) == 0 {
return ""
}

// Check for URL or URLs field (common in fetch tools)
if urlVal, ok := args["url"].(string); ok && urlVal != "" {
return urlVal
}
if urlsVal, ok := args["urls"].([]any); ok && len(urlsVal) > 0 {
// Extract just the URLs from the array
var urls []string
for _, u := range urlsVal {
if urlStr, ok := u.(string); ok {
urls = append(urls, urlStr)
}
}
if len(urls) == 1 {
return urls[0]
} else if len(urls) > 1 {
return fmt.Sprintf("%s (+%d more)", urls[0], len(urls)-1)
}
}

// Try to find common parameter names that might indicate what's being queried
for _, key := range []string{"query", "q", "search", "message", "prompt", "text"} {
if val, ok := args[key]; ok {
if str, ok := val.(string); ok && str != "" {
return str
}
}
}

// Fallback: show JSON
b, _ := json.Marshal(args)
return string(b)
}

// extractSummary tries to extract a meaningful summary from the API response
func extractSummary(content string) string {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it

size := len(content)

// Convert to KB if >= 1024 bytes
if size >= 1024*1024 {
// Show in MB
mb := float64(size) / (1024 * 1024)
return fmt.Sprintf("Received %.1f MB", mb)
} else if size >= 1024 {
// Show in KB
kb := float64(size) / 1024
return fmt.Sprintf("Received %.1f KB", kb)
}

// Show in bytes for small responses
return fmt.Sprintf("Received %d bytes", size)
}