Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/root/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra"

"github.com/docker/cagent/pkg/creator"
"github.com/docker/cagent/pkg/input"
"github.com/docker/cagent/pkg/runtime"
"github.com/docker/cagent/pkg/telemetry"
)
Expand Down Expand Up @@ -83,7 +84,7 @@ func NewNewCmd() *cobra.Command {
fmt.Print(blue("> "))

var err error
prompt, err = readLine(ctx, os.Stdin)
prompt, err = input.ReadLine(ctx, os.Stdin)
if err != nil {
return fmt.Errorf("failed to read purpose: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/content"
"github.com/docker/cagent/pkg/evaluation"
"github.com/docker/cagent/pkg/input"
"github.com/docker/cagent/pkg/remote"
"github.com/docker/cagent/pkg/runtime"
"github.com/docker/cagent/pkg/session"
Expand Down Expand Up @@ -594,7 +595,7 @@ func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime
fmt.Print(blue("> "))
firstQuestion = false

line, err := readLine(ctx, os.Stdin)
line, err := input.ReadLine(ctx, os.Stdin)
if err != nil {
return err
}
Expand Down Expand Up @@ -665,8 +666,8 @@ func runUserCommand(userInput string, sess *session.Session, rt runtime.Runtime,

// parseAttachCommand parses user input for /attach commands
// Returns the message text (with /attach commands removed) and the attachment path
func parseAttachCommand(input string) (messageText, attachPath string) {
lines := strings.Split(input, "\n")
func parseAttachCommand(userInput string) (messageText, attachPath string) {
lines := strings.Split(userInput, "\n")
var messageLines []string

for _, line := range lines {
Expand Down
35 changes: 4 additions & 31 deletions cmd/root/run_text_utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package root

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand All @@ -12,6 +11,7 @@ import (
"github.com/fatih/color"
"golang.org/x/term"

"github.com/docker/cagent/pkg/input"
"github.com/docker/cagent/pkg/tools"
)

Expand Down Expand Up @@ -98,7 +98,7 @@ func printToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall,
}

// Fallback: line-based scanner (requires Enter)
text, err := readLine(ctx, rd)
text, err := input.ReadLine(ctx, rd)
if err != nil {
return ConfirmationReject
}
Expand All @@ -125,7 +125,7 @@ func promptMaxIterationsContinue(ctx context.Context, maxIterations int) Confirm
fmt.Printf("%s\n", white("This can happen with smaller or less capable models."))
fmt.Printf("\n%s (y/n): ", blue("Do you want to continue for 10 more iterations?"))

response, err := readLine(ctx, os.Stdin)
response, err := input.ReadLine(ctx, os.Stdin)
if err != nil {
fmt.Printf("\n%s\n", red("Failed to read input, exiting..."))
return ConfirmationAbort
Expand All @@ -148,7 +148,7 @@ func promptOAuthAuthorization(ctx context.Context, serverURL string) Confirmatio
fmt.Printf("%s\n", white("Your browser will open automatically to complete the authorization."))
fmt.Printf("\n%s (y/n): ", blue("Do you want to authorize access?"))

response, err := readLine(ctx, os.Stdin)
response, err := input.ReadLine(ctx, os.Stdin)
if err != nil {
fmt.Printf("\n%s\n", red("Failed to read input, aborting authorization..."))
return ConfirmationAbort
Expand Down Expand Up @@ -291,30 +291,3 @@ func formatJSONValue(key string, value any) string {
return fmt.Sprintf("%s: %s", bold(key), string(jsonBytes))
}
}

func readLine(ctx context.Context, rd io.Reader) (string, error) {
lines := make(chan string)
errs := make(chan error)

go func() {
defer close(lines)
defer close(errs)

reader := bufio.NewReader(rd)
line, err := reader.ReadString('\n')
if err != nil {
errs <- err
} else {
lines <- line
}
}()

select {
case <-ctx.Done():
return "", ctx.Err()
case err := <-errs:
return "", err
case line := <-lines:
return line, nil
}
}
34 changes: 34 additions & 0 deletions pkg/input/readline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package input

import (
"bufio"
"context"
"io"
)

func ReadLine(ctx context.Context, rd io.Reader) (string, error) {
lines := make(chan string)
errs := make(chan error)

go func() {
defer close(lines)
defer close(errs)

reader := bufio.NewReader(rd)
line, err := reader.ReadString('\n')
if err != nil {
errs <- err
} else {
lines <- line
}
}()

select {
case <-ctx.Done():
return "", ctx.Err()
case err := <-errs:
return "", err
case line := <-lines:
return line, nil
}
}
64 changes: 64 additions & 0 deletions pkg/model/provider/dmr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
Expand All @@ -15,9 +16,11 @@ import (
"strings"

"github.com/sashabaranov/go-openai"
"golang.org/x/term"

"github.com/docker/cagent/pkg/chat"
latest "github.com/docker/cagent/pkg/config/v2"
"github.com/docker/cagent/pkg/input"
"github.com/docker/cagent/pkg/model/provider/base"
"github.com/docker/cagent/pkg/model/provider/options"
"github.com/docker/cagent/pkg/tools"
Expand Down Expand Up @@ -51,6 +54,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
endpoint, engine, err := getDockerModelEndpointAndEngine(ctx)
if err != nil {
slog.Debug("docker model status query failed", "error", err)
} else {
// Auto-pull the model if needed
if err := pullDockerModelIfNeeded(ctx, cfg.Model); err != nil {
slog.Debug("docker model pull failed", "error", err)
return nil, err
}
}

clientConfig := openai.DefaultConfig("")
Expand Down Expand Up @@ -430,6 +439,61 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
return contextSize, runtimeFlags
}

func pullDockerModelIfNeeded(ctx context.Context, model string) error {
// Check if running in interactive mode (stdin is a terminal)
interactive := term.IsTerminal(int(os.Stdin.Fd()))
if !interactive {
// In non-interactive mode (CI / Servers), do not attempt to pull the model
return nil
}

if modelExists(ctx, model) {
slog.Debug("Model already exists, skipping pull", "model", model)
return nil
}

// Prompt user for confirmation in interactive mode
fmt.Printf("\nModel %s not found locally.\n", model)
fmt.Printf("Do you want to pull it now? ([y]es/[n]o): ")

response, err := input.ReadLine(ctx, os.Stdin)
if err != nil {
return fmt.Errorf("failed to read user input: %w", err)
}

response = strings.TrimSpace(strings.ToLower(response))
if response != "y" && response != "yes" {
return fmt.Errorf("model pull declined by user")
}

// Pull the model
slog.Info("Pulling DMR model", "model", model)
fmt.Printf("Pulling model %s...\n", model)
cmd := exec.CommandContext(ctx, "docker", "model", "pull", model)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to pull model %s: %w", model, err)
}

slog.Info("Model pulled successfully", "model", model)
fmt.Printf("Model %s pulled successfully.\n", model)

return nil
}

func modelExists(ctx context.Context, model string) bool {
cmd := exec.CommandContext(ctx, "docker", "model", "inspect", model)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no api for this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to go that route... But I might have because it also needs to detect if the platform supports DMR at all.

var stderr bytes.Buffer
cmd.Stdout = io.Discard
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
slog.Debug("Model does not exist", "model", model, "error", strings.TrimSpace(stderr.String()))
return false
}
return true
}

func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string) error {
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags)

Expand Down
Loading