diff --git a/cmd/root/new.go b/cmd/root/new.go index c53ae702c..ee3ce2556 100644 --- a/cmd/root/new.go +++ b/cmd/root/new.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" + "github.com/docker/cagent/pkg/cli" "github.com/docker/cagent/pkg/creator" "github.com/docker/cagent/pkg/input" "github.com/docker/cagent/pkg/runtime" @@ -54,19 +55,19 @@ func NewNewCmd() *cobra.Command { switch { case os.Getenv("ANTHROPIC_API_KEY") != "": modelProvider = "anthropic" - fmt.Printf("%s\n\n", white("ANTHROPIC_API_KEY found, using Anthropic")) + fmt.Printf("%s\n\n", cli.White("ANTHROPIC_API_KEY found, using Anthropic")) case os.Getenv("OPENAI_API_KEY") != "": modelProvider = "openai" - fmt.Printf("%s\n\n", white("OPENAI_API_KEY found, using OpenAI")) + fmt.Printf("%s\n\n", cli.White("OPENAI_API_KEY found, using OpenAI")) case os.Getenv("GOOGLE_API_KEY") != "": modelProvider = "google" - fmt.Printf("%s\n\n", white("GOOGLE_API_KEY found, using Google")) + fmt.Printf("%s\n\n", cli.White("GOOGLE_API_KEY found, using Google")) default: modelProvider = "dmr" - fmt.Printf("%s\n\n", yellow("⚠️ No provider credentials found, defaulting to Docker Model Runner (DMR)")) + fmt.Printf("%s\n\n", cli.Yellow("⚠️ No provider credentials found, defaulting to Docker Model Runner (DMR)")) } if modelParam == "" { - fmt.Printf("%s\n\n", white("use \"--model provider/model\" to use a different model")) + fmt.Printf("%s\n\n", cli.White("use \"--model provider/model\" to use a different model")) } } else { // Using Models Gateway; default to Anthropic if not specified @@ -78,10 +79,10 @@ func NewNewCmd() *cobra.Command { if len(args) > 0 { prompt = strings.Join(args, " ") } else { - fmt.Printf("%s\n", blue("------- Welcome to %s! -------", bold(AppName))) - fmt.Printf("%s\n\n", white(" (Ctrl+C to exit)")) - fmt.Printf("%s\n\n", blue("What should your agent/agent team do? (describe its purpose)")) - fmt.Print(blue("> ")) + fmt.Printf("%s\n", cli.Blue("------- Welcome to %s! -------", cli.Bold(AppName))) + fmt.Printf("%s\n\n", cli.White(" (Ctrl+C to exit)")) + fmt.Printf("%s\n\n", cli.Blue("What should your agent/agent team do? (describe its purpose)")) + fmt.Print(cli.Blue("> ")) var err error prompt, err = input.ReadLine(ctx, os.Stdin) @@ -112,33 +113,33 @@ func NewNewCmd() *cobra.Command { fmt.Println() llmIsTyping = false } - printToolCall(e.ToolCall) + cli.PrintToolCall(e.ToolCall) case *runtime.ToolCallResponseEvent: if llmIsTyping { fmt.Println() llmIsTyping = false } - printToolCallResponse(e.ToolCall, e.Response) + cli.PrintToolCallResponse(e.ToolCall, e.Response) case *runtime.ErrorEvent: if llmIsTyping { fmt.Println() llmIsTyping = false } - printError(fmt.Errorf("%s", e.Error)) + cli.PrintError(fmt.Errorf("%s", e.Error)) case *runtime.MaxIterationsReachedEvent: if llmIsTyping { fmt.Println() llmIsTyping = false } - result := promptMaxIterationsContinue(ctx, e.MaxIterations) + result := cli.PromptMaxIterationsContinue(ctx, e.MaxIterations) switch result { - case ConfirmationApprove: + case cli.ConfirmationApprove: rt.Resume(ctx, string(runtime.ResumeTypeApprove)) - case ConfirmationReject: + case cli.ConfirmationReject: rt.Resume(ctx, string(runtime.ResumeTypeReject)) return nil - case ConfirmationAbort: + case cli.ConfirmationAbort: rt.Resume(ctx, string(runtime.ResumeTypeReject)) } } diff --git a/cmd/root/run.go b/cmd/root/run.go index c29c87f28..5d1937f50 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -3,7 +3,6 @@ package root import ( "bytes" "context" - "encoding/base64" "fmt" "io" "log/slog" @@ -13,16 +12,13 @@ import ( "time" tea "github.com/charmbracelet/bubbletea/v2" - "github.com/fatih/color" "github.com/spf13/cobra" "go.opentelemetry.io/otel" "github.com/docker/cagent/pkg/aliases" "github.com/docker/cagent/pkg/app" - "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/cli" "github.com/docker/cagent/pkg/content" - "github.com/docker/cagent/pkg/evaluation" - "github.com/docker/cagent/pkg/input" "github.com/docker/cagent/pkg/remote" "github.com/docker/cagent/pkg/runtime" "github.com/docker/cagent/pkg/session" @@ -268,12 +264,26 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error { fmt.Println("Dry run mode enabled. Agent initialized but will not execute.") return nil } - return runWithoutTUI(ctx, agentFilename, rt, sess, execArgs) + err := cli.Run(ctx, cli.Config{ + AppName: AppName, + AttachmentPath: attachmentPath, + }, agentFilename, rt, sess, execArgs) + if cliErr, ok := err.(cli.RuntimeError); ok { + return RuntimeError{Err: cliErr.Err} + } + return err } // For `cagent run --tui=false` if !useTUI { - return runWithoutTUI(ctx, agentFilename, rt, sess, args) + err := cli.Run(ctx, cli.Config{ + AppName: AppName, + AttachmentPath: attachmentPath, + }, agentFilename, rt, sess, args) + if cliErr, ok := err.(cli.RuntimeError); ok { + return RuntimeError{Err: cliErr.Err} + } + return err } // The default is to use the TUI @@ -311,355 +321,6 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error { return err } -func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime, sess *session.Session, args []string) error { - // Create a cancellable context for this agentic loop and wire Ctrl+C to cancel it - ctx, cancel := context.WithCancel(ctx) - - // Ensure telemetry is initialized and add to context so runtime can access it - telemetry.EnsureGlobalTelemetryInitialized() - if telemetryClient := telemetry.GetGlobalTelemetryClient(); telemetryClient != nil { - ctx = telemetry.WithClient(ctx, telemetryClient) - } - - sess.Title = "Running agent" - // If the last received event was an error, return it. That way the exit code - // will be non-zero if the agent failed. - var lastErr error - - oneLoop := func(text string, rd io.Reader) error { - userInput := strings.TrimSpace(text) - if userInput == "" { - return nil - } - - userInput = runtime.ResolveCommand(ctx, rt, userInput) - - handled, err := runUserCommand(userInput, sess, rt, ctx) - if err != nil { - return err - } - if handled { - return nil - } - - // Parse for /attach commands in the message - messageText, attachPath := parseAttachCommand(userInput) - - // Use either the per-message attachment or the global one - finalAttachPath := attachPath - if finalAttachPath == "" { - finalAttachPath = attachmentPath - } - - sess.AddMessage(createUserMessageWithAttachment(agentFilename, messageText, finalAttachPath)) - - firstLoop := true - lastAgent := rt.CurrentAgentName() - llmIsTyping := false - reasoningStarted := false // Track if we've printed "Thinking:" prefix - var lastConfirmedToolCallID string - for event := range rt.RunStream(ctx, sess) { - agentName := event.GetAgentName() - if agentName != "" && (firstLoop || lastAgent != agentName) { - if !firstLoop { - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - fmt.Println() - } - printAgentName(agentName) - firstLoop = false - lastAgent = agentName - reasoningStarted = false // Reset reasoning state on agent change - } - switch e := event.(type) { - case *runtime.AgentChoiceEvent: - agentChanged := lastAgent != e.AgentName - if !llmIsTyping { - // Only add newline if we're not already typing - if !agentChanged { - fmt.Println() - } - llmIsTyping = true - } - // Add newline when transitioning from reasoning to regular content - if reasoningStarted { - fmt.Println() - } - reasoningStarted = false // Reset when regular content starts - fmt.Printf("%s", e.Content) - case *runtime.AgentChoiceReasoningEvent: - if !reasoningStarted { - // First reasoning chunk: print prefix - prefix := "Thinking: " - if e.AgentName != "" && e.AgentName != "root" { - prefix = prefix + e.AgentName + ": " - } - fmt.Printf("\n%s", white(prefix)) - reasoningStarted = true - } - // Continue printing reasoning content - fmt.Printf("%s", white(e.Content)) - case *runtime.ToolCallConfirmationEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - result := printToolCallWithConfirmation(ctx, e.ToolCall, rd) - // If interrupted, skip resuming; the runtime will notice context cancellation and stop - if ctx.Err() != nil { - continue - } - lastConfirmedToolCallID = e.ToolCall.ID // Store the ID to avoid duplicate printing - switch result { - case ConfirmationApprove: - rt.Resume(ctx, string(runtime.ResumeTypeApprove)) - case ConfirmationApproveSession: - sess.ToolsApproved = true - rt.Resume(ctx, string(runtime.ResumeTypeApproveSession)) - case ConfirmationReject: - rt.Resume(ctx, string(runtime.ResumeTypeReject)) - lastConfirmedToolCallID = "" // Clear on reject since tool won't execute - case ConfirmationAbort: - // Stop the agent loop immediately - cancel() - continue - } - case *runtime.ToolCallEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - // Only print if this wasn't already shown during confirmation - if e.ToolCall.ID != lastConfirmedToolCallID { - printToolCall(e.ToolCall) - } - case *runtime.ToolCallResponseEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - printToolCallResponse(e.ToolCall, e.Response) - // Clear the confirmed ID after the tool completes - if e.ToolCall.ID == lastConfirmedToolCallID { - lastConfirmedToolCallID = "" - } - case *runtime.ErrorEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - lowerErr := strings.ToLower(e.Error) - if strings.Contains(lowerErr, "context cancel") && ctx.Err() != nil { // treat Ctrl+C cancellations as non-errors - lastErr = nil - } else { - lastErr = fmt.Errorf("%s", e.Error) - printError(lastErr) - } - case *runtime.MaxIterationsReachedEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - - result := promptMaxIterationsContinue(ctx, e.MaxIterations) - switch result { - case ConfirmationApprove: - rt.Resume(ctx, string(runtime.ResumeTypeApprove)) - case ConfirmationReject: - rt.Resume(ctx, string(runtime.ResumeTypeReject)) - return nil - case ConfirmationAbort: - rt.Resume(ctx, string(runtime.ResumeTypeReject)) - return nil - } - case *runtime.ElicitationRequestEvent: - if llmIsTyping { - fmt.Println() - llmIsTyping = false - } - - serverURL := e.Meta["cagent/server_url"].(string) - result := promptOAuthAuthorization(ctx, serverURL) - switch { - case ctx.Err() != nil: - return ctx.Err() - case result == ConfirmationApprove: - _ = rt.ResumeElicitation(ctx, "accept", nil) - case result == ConfirmationReject: - _ = rt.ResumeElicitation(ctx, "decline", nil) - return fmt.Errorf("OAuth authorization rejected by user") - } - } - } - - // If the loop ended due to Ctrl+C, inform the user succinctly - if ctx.Err() != nil { - fmt.Println(yellow("\n⚠️ agent stopped ⚠️")) - } - - // Wrap runtime errors to prevent duplicate error messages and usage display - if lastErr != nil { - return RuntimeError{Err: lastErr} - } - return nil - } - - if len(args) == 2 { - if args[1] == "-" { - buf, err := io.ReadAll(os.Stdin) - if err != nil { - return fmt.Errorf("failed to read from stdin: %w", err) - } - - if err := oneLoop(string(buf), os.Stdin); err != nil { - return err - } - } else { - if err := oneLoop(args[1], os.Stdin); err != nil { - return err - } - } - } else { - printWelcomeMessage() - firstQuestion := true - for { - if !firstQuestion { - fmt.Print("\n\n") - } - fmt.Print(blue("> ")) - firstQuestion = false - - line, err := input.ReadLine(ctx, os.Stdin) - if err != nil { - return err - } - - if err := oneLoop(line, os.Stdin); err != nil { - return err - } - } - } - - // Wrap runtime errors to prevent duplicate error messages and usage display - if lastErr != nil { - return RuntimeError{Err: lastErr} - } - return nil -} - -// TODO: This is a duplication of builtInSessionCommands() in pkg/tui/tui.go -func runUserCommand(userInput string, sess *session.Session, rt runtime.Runtime, ctx context.Context) (bool, error) { - yellow := color.New(color.FgYellow).SprintfFunc() - switch userInput { - case "/exit": - os.Exit(0) - case "/eval": - evalFile, err := evaluation.Save(sess) - if err == nil { - fmt.Printf("%s\n", yellow("Evaluation saved to file %s", evalFile)) - return true, err - } - return true, nil - case "/usage": - fmt.Printf("%s\n", yellow("Input tokens: %d", sess.InputTokens)) - fmt.Printf("%s\n", yellow("Output tokens: %d", sess.OutputTokens)) - return true, nil - case "/new": - // Reset session items - sess.Messages = []session.Item{} - return true, nil - case "/compact": - // Generate a summary of the session and compact the history - fmt.Printf("%s\n", yellow("Generating summary...")) - - // Create a channel to capture summary events - events := make(chan runtime.Event, 100) - - // Generate the summary - rt.Summarize(ctx, sess, events) - - // Process events and show the summary - close(events) - summaryGenerated := false - hasWarning := false - for event := range events { - switch e := event.(type) { - case *runtime.SessionSummaryEvent: - fmt.Printf("%s\n", yellow("Summary generated and added to session")) - fmt.Printf("Summary: %s\n", e.Summary) - summaryGenerated = true - case *runtime.WarningEvent: - fmt.Printf("%s\n", yellow("Warning: "+e.Message)) - hasWarning = true - } - } - - if !summaryGenerated && !hasWarning { - fmt.Printf("%s\n", yellow("No summary generated")) - } - - return true, nil - } - - return false, nil -} - -// parseAttachCommand parses user input for /attach commands -// Returns the message text (with /attach commands removed) and the attachment path -func parseAttachCommand(userInput string) (messageText, attachPath string) { - lines := strings.Split(userInput, "\n") - var messageLines []string - - for _, line := range lines { - // Look for /attach anywhere in the line - attachIndex := strings.Index(line, "/attach ") - if attachIndex != -1 { - // Extract the part before /attach - beforeAttach := line[:attachIndex] - - // Extract the part after /attach (starting after "/attach ") - afterAttachStart := attachIndex + 8 // Length of "/attach " - if afterAttachStart < len(line) { - afterAttach := line[afterAttachStart:] - - // Split on spaces to get the file path (first token) and any remaining text - tokens := strings.Fields(afterAttach) - if len(tokens) > 0 { - attachPath = tokens[0] - - // Reconstruct the line with /attach and file path removed - var remainingText string - if len(tokens) > 1 { - remainingText = strings.Join(tokens[1:], " ") - } - - // Combine the text before /attach and any text after the file path - var parts []string - if strings.TrimSpace(beforeAttach) != "" { - parts = append(parts, strings.TrimSpace(beforeAttach)) - } - if remainingText != "" { - parts = append(parts, remainingText) - } - reconstructedLine := strings.Join(parts, " ") - if reconstructedLine != "" { - messageLines = append(messageLines, reconstructedLine) - } - } - } - } else { - // Keep lines without /attach commands - messageLines = append(messageLines, line) - } - } - - // Join the message lines back together - messageText = strings.TrimSpace(strings.Join(messageLines, "\n")) - return messageText, attachPath -} - func fileExists(path string) bool { _, err := os.Stat(path) exists := err == nil @@ -697,82 +358,3 @@ func fromStore(reference string) (string, error) { return buf.String(), nil } - -// createUserMessageWithAttachment creates a user message with optional image attachment -func createUserMessageWithAttachment(agentFilename, userContent, attachmentPath string) *session.Message { - if attachmentPath == "" { - return session.UserMessage(agentFilename, userContent) - } - - // Convert file to data URL - dataURL, err := fileToDataURL(attachmentPath) - if err != nil { - fmt.Printf("Warning: Failed to attach file %s: %v\n", attachmentPath, err) - return session.UserMessage(agentFilename, userContent) - } - - // Ensure we have some text content when attaching a file - textContent := userContent - if strings.TrimSpace(textContent) == "" { - textContent = "Please analyze this attached file." - } - - // Create message with multi-content including text and image - multiContent := []chat.MessagePart{ - { - Type: chat.MessagePartTypeText, - Text: textContent, - }, - { - Type: chat.MessagePartTypeImageURL, - ImageURL: &chat.MessageImageURL{ - URL: dataURL, - Detail: chat.ImageURLDetailAuto, - }, - }, - } - - return session.UserMessage(agentFilename, "", multiContent...) -} - -// fileToDataURL converts a file to a data URL -func fileToDataURL(filePath string) (string, error) { - // Check if file exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - return "", fmt.Errorf("file does not exist: %s", filePath) - } - - // Read file content - fileBytes, err := os.ReadFile(filePath) - if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) - } - - // Determine MIME type based on file extension - ext := strings.ToLower(filepath.Ext(filePath)) - var mimeType string - switch ext { - case ".jpg", ".jpeg": - mimeType = "image/jpeg" - case ".png": - mimeType = "image/png" - case ".gif": - mimeType = "image/gif" - case ".webp": - mimeType = "image/webp" - case ".bmp": - mimeType = "image/bmp" - case ".svg": - mimeType = "image/svg+xml" - default: - return "", fmt.Errorf("unsupported image format: %s", ext) - } - - // Encode to base64 - encoded := base64.StdEncoding.EncodeToString(fileBytes) - - // Create data URL - dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, encoded) - - return dataURL, nil -} diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go new file mode 100644 index 000000000..1b72c9055 --- /dev/null +++ b/pkg/cli/runner.go @@ -0,0 +1,467 @@ +package cli + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/evaluation" + "github.com/docker/cagent/pkg/input" + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/telemetry" +) + +// RuntimeError wraps runtime errors to distinguish them from usage errors +type RuntimeError struct { + Err error +} + +func (e RuntimeError) Error() string { + return e.Err.Error() +} + +func (e RuntimeError) Unwrap() error { + return e.Err +} + +// Config holds configuration for running an agent in CLI mode +type Config struct { + AppName string + AttachmentPath string +} + +// Run executes an agent in non-TUI mode, handling user input and runtime events +func Run(ctx context.Context, cfg Config, agentFilename string, rt runtime.Runtime, sess *session.Session, args []string) error { + // Create a cancellable context for this agentic loop and wire Ctrl+C to cancel it + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Ensure telemetry is initialized and add to context so runtime can access it + telemetry.EnsureGlobalTelemetryInitialized() + if telemetryClient := telemetry.GetGlobalTelemetryClient(); telemetryClient != nil { + ctx = telemetry.WithClient(ctx, telemetryClient) + } + + sess.Title = "Running agent" + // If the last received event was an error, return it. That way the exit code + // will be non-zero if the agent failed. + var lastErr error + + oneLoop := func(text string, rd io.Reader) error { + userInput := strings.TrimSpace(text) + if userInput == "" { + return nil + } + + userInput = runtime.ResolveCommand(ctx, rt, userInput) + + handled, err := runUserCommand(userInput, sess, rt, ctx) + if err != nil { + return err + } + if handled { + return nil + } + + // Parse for /attach commands in the message + messageText, attachPath := parseAttachCommand(userInput) + + // Use either the per-message attachment or the global one + finalAttachPath := attachPath + if finalAttachPath == "" { + finalAttachPath = cfg.AttachmentPath + } + + sess.AddMessage(createUserMessageWithAttachment(agentFilename, messageText, finalAttachPath)) + + firstLoop := true + lastAgent := rt.CurrentAgentName() + llmIsTyping := false + reasoningStarted := false // Track if we've printed "Thinking:" prefix + var lastConfirmedToolCallID string + for event := range rt.RunStream(ctx, sess) { + agentName := event.GetAgentName() + if agentName != "" && (firstLoop || lastAgent != agentName) { + if !firstLoop { + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + fmt.Println() + } + PrintAgentName(agentName) + firstLoop = false + lastAgent = agentName + reasoningStarted = false // Reset reasoning state on agent change + } + switch e := event.(type) { + case *runtime.AgentChoiceEvent: + agentChanged := lastAgent != e.AgentName + if !llmIsTyping { + // Only add newline if we're not already typing + if !agentChanged { + fmt.Println() + } + llmIsTyping = true + } + // Add newline when transitioning from reasoning to regular content + if reasoningStarted { + fmt.Println() + } + reasoningStarted = false // Reset when regular content starts + fmt.Printf("%s", e.Content) + case *runtime.AgentChoiceReasoningEvent: + if !reasoningStarted { + // First reasoning chunk: print prefix + prefix := "Thinking: " + if e.AgentName != "" && e.AgentName != "root" { + prefix = prefix + e.AgentName + ": " + } + fmt.Printf("\n%s", White(prefix)) + reasoningStarted = true + } + // Continue printing reasoning content + fmt.Printf("%s", White(e.Content)) + case *runtime.ToolCallConfirmationEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + result := PrintToolCallWithConfirmation(ctx, e.ToolCall, rd) + // If interrupted, skip resuming; the runtime will notice context cancellation and stop + if ctx.Err() != nil { + continue + } + lastConfirmedToolCallID = e.ToolCall.ID // Store the ID to avoid duplicate printing + switch result { + case ConfirmationApprove: + rt.Resume(ctx, string(runtime.ResumeTypeApprove)) + case ConfirmationApproveSession: + sess.ToolsApproved = true + rt.Resume(ctx, string(runtime.ResumeTypeApproveSession)) + case ConfirmationReject: + rt.Resume(ctx, string(runtime.ResumeTypeReject)) + lastConfirmedToolCallID = "" // Clear on reject since tool won't execute + case ConfirmationAbort: + // Stop the agent loop immediately + cancel() + continue + } + case *runtime.ToolCallEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + // Only print if this wasn't already shown during confirmation + if e.ToolCall.ID != lastConfirmedToolCallID { + PrintToolCall(e.ToolCall) + } + case *runtime.ToolCallResponseEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + PrintToolCallResponse(e.ToolCall, e.Response) + // Clear the confirmed ID after the tool completes + if e.ToolCall.ID == lastConfirmedToolCallID { + lastConfirmedToolCallID = "" + } + case *runtime.ErrorEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + lowerErr := strings.ToLower(e.Error) + if strings.Contains(lowerErr, "context cancel") && ctx.Err() != nil { // treat Ctrl+C cancellations as non-errors + lastErr = nil + } else { + lastErr = fmt.Errorf("%s", e.Error) + PrintError(lastErr) + } + case *runtime.MaxIterationsReachedEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + + result := PromptMaxIterationsContinue(ctx, e.MaxIterations) + switch result { + case ConfirmationApprove: + rt.Resume(ctx, string(runtime.ResumeTypeApprove)) + case ConfirmationReject: + rt.Resume(ctx, string(runtime.ResumeTypeReject)) + return nil + case ConfirmationAbort: + rt.Resume(ctx, string(runtime.ResumeTypeReject)) + return nil + } + case *runtime.ElicitationRequestEvent: + if llmIsTyping { + fmt.Println() + llmIsTyping = false + } + + serverURL := e.Meta["cagent/server_url"].(string) + result := PromptOAuthAuthorization(ctx, serverURL) + switch { + case ctx.Err() != nil: + return ctx.Err() + case result == ConfirmationApprove: + _ = rt.ResumeElicitation(ctx, "accept", nil) + case result == ConfirmationReject: + _ = rt.ResumeElicitation(ctx, "decline", nil) + return fmt.Errorf("OAuth authorization rejected by user") + } + } + } + + // If the loop ended due to Ctrl+C, inform the user succinctly + if ctx.Err() != nil { + fmt.Println(Yellow("\n⚠️ agent stopped ⚠️")) + } + + // Wrap runtime errors to prevent duplicate error messages and usage display + if lastErr != nil { + return RuntimeError{Err: lastErr} + } + return nil + } + + if len(args) == 2 { + if args[1] == "-" { + buf, err := io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read from stdin: %w", err) + } + + if err := oneLoop(string(buf), os.Stdin); err != nil { + return err + } + } else { + if err := oneLoop(args[1], os.Stdin); err != nil { + return err + } + } + } else { + PrintWelcomeMessage(cfg.AppName) + firstQuestion := true + for { + if !firstQuestion { + fmt.Print("\n\n") + } + fmt.Print(Blue("> ")) + firstQuestion = false + + line, err := input.ReadLine(ctx, os.Stdin) + if err != nil { + return err + } + + if err := oneLoop(line, os.Stdin); err != nil { + return err + } + } + } + + // Wrap runtime errors to prevent duplicate error messages and usage display + if lastErr != nil { + return RuntimeError{Err: lastErr} + } + return nil +} + +// runUserCommand handles built-in session commands +// TODO: This is a duplication of builtInSessionCommands() in pkg/tui/tui.go +func runUserCommand(userInput string, sess *session.Session, rt runtime.Runtime, ctx context.Context) (bool, error) { + switch userInput { + case "/exit": + os.Exit(0) + case "/eval": + evalFile, err := evaluation.Save(sess) + if err == nil { + fmt.Printf("%s\n", Yellow("Evaluation saved to file %s", evalFile)) + return true, err + } + return true, nil + case "/usage": + fmt.Printf("%s\n", Yellow("Input tokens: %d", sess.InputTokens)) + fmt.Printf("%s\n", Yellow("Output tokens: %d", sess.OutputTokens)) + return true, nil + case "/new": + // Reset session items + sess.Messages = []session.Item{} + return true, nil + case "/compact": + // Generate a summary of the session and compact the history + fmt.Printf("%s\n", Yellow("Generating summary...")) + + // Create a channel to capture summary events + events := make(chan runtime.Event, 100) + + // Generate the summary + rt.Summarize(ctx, sess, events) + + // Process events and show the summary + close(events) + summaryGenerated := false + hasWarning := false + for event := range events { + switch e := event.(type) { + case *runtime.SessionSummaryEvent: + fmt.Printf("%s\n", Yellow("Summary generated and added to session")) + fmt.Printf("Summary: %s\n", e.Summary) + summaryGenerated = true + case *runtime.WarningEvent: + fmt.Printf("%s\n", Yellow("Warning: "+e.Message)) + hasWarning = true + } + } + + if !summaryGenerated && !hasWarning { + fmt.Printf("%s\n", Yellow("No summary generated")) + } + + return true, nil + } + + return false, nil +} + +// parseAttachCommand parses user input for /attach commands +// Returns the message text (with /attach commands removed) and the attachment path +func parseAttachCommand(userInput string) (messageText, attachPath string) { + lines := strings.Split(userInput, "\n") + var messageLines []string + + for _, line := range lines { + // Look for /attach anywhere in the line + attachIndex := strings.Index(line, "/attach ") + if attachIndex != -1 { + // Extract the part before /attach + beforeAttach := line[:attachIndex] + + // Extract the part after /attach (starting after "/attach ") + afterAttachStart := attachIndex + 8 // Length of "/attach " + if afterAttachStart < len(line) { + afterAttach := line[afterAttachStart:] + + // Split on spaces to get the file path (first token) and any remaining text + tokens := strings.Fields(afterAttach) + if len(tokens) > 0 { + attachPath = tokens[0] + + // Reconstruct the line with /attach and file path removed + var remainingText string + if len(tokens) > 1 { + remainingText = strings.Join(tokens[1:], " ") + } + + // Combine the text before /attach and any text after the file path + var parts []string + if strings.TrimSpace(beforeAttach) != "" { + parts = append(parts, strings.TrimSpace(beforeAttach)) + } + if remainingText != "" { + parts = append(parts, remainingText) + } + reconstructedLine := strings.Join(parts, " ") + if reconstructedLine != "" { + messageLines = append(messageLines, reconstructedLine) + } + } + } + } else { + // Keep lines without /attach commands + messageLines = append(messageLines, line) + } + } + + // Join the message lines back together + messageText = strings.TrimSpace(strings.Join(messageLines, "\n")) + return messageText, attachPath +} + +// createUserMessageWithAttachment creates a user message with optional image attachment +func createUserMessageWithAttachment(agentFilename, userContent, attachmentPath string) *session.Message { + if attachmentPath == "" { + return session.UserMessage(agentFilename, userContent) + } + + // Convert file to data URL + dataURL, err := fileToDataURL(attachmentPath) + if err != nil { + fmt.Printf("Warning: Failed to attach file %s: %v\n", attachmentPath, err) + return session.UserMessage(agentFilename, userContent) + } + + // Ensure we have some text content when attaching a file + textContent := userContent + if strings.TrimSpace(textContent) == "" { + textContent = "Please analyze this attached file." + } + + // Create message with multi-content including text and image + multiContent := []chat.MessagePart{ + { + Type: chat.MessagePartTypeText, + Text: textContent, + }, + { + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: dataURL, + Detail: chat.ImageURLDetailAuto, + }, + }, + } + + return session.UserMessage(agentFilename, "", multiContent...) +} + +// fileToDataURL converts a file to a data URL +func fileToDataURL(filePath string) (string, error) { + // Check if file exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return "", fmt.Errorf("file does not exist: %s", filePath) + } + + // Read file content + fileBytes, err := os.ReadFile(filePath) + if err != nil { + return "", fmt.Errorf("failed to read file: %w", err) + } + + // Determine MIME type based on file extension + ext := strings.ToLower(filepath.Ext(filePath)) + var mimeType string + switch ext { + case ".jpg", ".jpeg": + mimeType = "image/jpeg" + case ".png": + mimeType = "image/png" + case ".gif": + mimeType = "image/gif" + case ".webp": + mimeType = "image/webp" + case ".bmp": + mimeType = "image/bmp" + case ".svg": + mimeType = "image/svg+xml" + default: + return "", fmt.Errorf("unsupported image format: %s", ext) + } + + // Encode to base64 + encoded := base64.StdEncoding.EncodeToString(fileBytes) + + // Create data URL + dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, encoded) + + return dataURL, nil +} diff --git a/cmd/root/run_text_utils.go b/pkg/cli/text.go similarity index 86% rename from cmd/root/run_text_utils.go rename to pkg/cli/text.go index f679ead7d..8d6b425ab 100644 --- a/cmd/root/run_text_utils.go +++ b/pkg/cli/text.go @@ -1,4 +1,4 @@ -package root +package cli import ( "context" @@ -27,7 +27,7 @@ var ( bold = color.New(color.Bold).SprintfFunc() ) -// confirmation result types +// ConfirmationResult represents the result of a user confirmation prompt type ConfirmationResult string const ( @@ -37,21 +37,33 @@ const ( ConfirmationAbort ConfirmationResult = "abort" ) -// text utility functions +// Color formatting functions (exported for use by other packages) +var ( + Blue = blue + Yellow = yellow + Red = red + White = white + Green = green + Bold = bold +) -func printWelcomeMessage() { - fmt.Printf("\n%s\n%s\n\n", blue("------- Welcome to %s! -------", bold(AppName)), white("(Ctrl+C to stop the agent and exit)")) +// PrintWelcomeMessage prints the welcome message +func PrintWelcomeMessage(appName string) { + fmt.Printf("\n%s\n%s\n\n", blue("------- Welcome to %s! -------", bold(appName)), white("(Ctrl+C to stop the agent and exit)")) } -func printError(err error) { +// PrintError prints an error message +func PrintError(err error) { fmt.Println(red("❌ %s", err)) } -func printAgentName(agentName string) { +// PrintAgentName prints the agent name header +func PrintAgentName(agentName string) { fmt.Printf("\n%s\n", blue("--- Agent: %s ---", bold(agentName))) } -func printToolCall(toolCall tools.ToolCall, colorFunc ...func(format string, a ...any) string) { +// PrintToolCall prints a tool call +func PrintToolCall(toolCall tools.ToolCall, colorFunc ...func(format string, a ...any) string) { c := white if len(colorFunc) > 0 && colorFunc[0] != nil { c = colorFunc[0] @@ -59,9 +71,10 @@ func printToolCall(toolCall tools.ToolCall, colorFunc ...func(format string, a . fmt.Printf("\nCalling %s\n", c("%s%s", bold(toolCall.Function.Name), formatToolCallArguments(toolCall.Function.Arguments))) } -func printToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall, rd io.Reader) ConfirmationResult { +// PrintToolCallWithConfirmation prints a tool call and prompts for confirmation +func PrintToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall, rd io.Reader) ConfirmationResult { fmt.Printf("\n%s\n", bold(yellow("🛠️ Tool call requires confirmation 🛠️"))) - printToolCall(toolCall, color.New(color.FgWhite).SprintfFunc()) + PrintToolCall(toolCall, color.New(color.FgWhite).SprintfFunc()) fmt.Printf("\n%s", bold(yellow("Can I run this tool? ([y]es/[a]ll/[n]o): "))) // Try single-character input from stdin in raw mode (no Enter required) @@ -116,11 +129,13 @@ func printToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall, } } -func printToolCallResponse(toolCall tools.ToolCall, response string) { +// PrintToolCallResponse prints a tool call response +func PrintToolCallResponse(toolCall tools.ToolCall, response string) { fmt.Printf("\n%s\n", white("%s response%s", bold(toolCall.Function.Name), formatToolCallResponse(response))) } -func promptMaxIterationsContinue(ctx context.Context, maxIterations int) ConfirmationResult { +// PromptMaxIterationsContinue prompts the user to continue after max iterations +func PromptMaxIterationsContinue(ctx context.Context, maxIterations int) ConfirmationResult { fmt.Printf("\n%s\n", yellow("⚠️ Maximum iterations (%d) reached. The agent may be stuck in a loop.", maxIterations)) fmt.Printf("%s\n", white("This can happen with smaller or less capable models.")) fmt.Printf("\n%s (y/n): ", blue("Do you want to continue for 10 more iterations?")) @@ -141,7 +156,8 @@ func promptMaxIterationsContinue(ctx context.Context, maxIterations int) Confirm } } -func promptOAuthAuthorization(ctx context.Context, serverURL string) ConfirmationResult { +// PromptOAuthAuthorization prompts the user for OAuth authorization +func PromptOAuthAuthorization(ctx context.Context, serverURL string) ConfirmationResult { fmt.Printf("\n%s\n", yellow("🔐 OAuth Authorization Required")) fmt.Printf("%s %s (remote)\n", white("Server:"), blue(serverURL)) fmt.Printf("%s\n", white("This server requires OAuth authentication to access its tools."))