Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cmd/root/acp.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package root

import (
"path/filepath"

"github.com/spf13/cobra"

"github.com/docker/cagent/pkg/acp"
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/paths"
"github.com/docker/cagent/pkg/telemetry"
)

type acpFlags struct {
runConfig config.RuntimeConfig
sessionDB string
}

func newACPCmd() *cobra.Command {
Expand All @@ -28,6 +32,7 @@ func newACPCmd() *cobra.Command {
}

addRuntimeConfigFlags(cmd, &flags.runConfig)
cmd.Flags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database")

return cmd
}
Expand All @@ -38,5 +43,11 @@ func (f *acpFlags) runACPCommand(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
agentFilename := args[0]

return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig)
// Expand tilde in session database path
sessionDB, err := expandTilde(f.sessionDB)
if err != nil {
return err
}

return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig, sessionDB)
}
8 changes: 7 additions & 1 deletion cmd/root/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {

slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String())

sessionStore, err := session.NewSQLiteSessionStore(f.sessionDB)
// Expand tilde in session database path
sessionDB, err := expandTilde(f.sessionDB)
if err != nil {
return err
}

sessionStore, err := session.NewSQLiteSessionStore(sessionDB)
if err != nil {
return fmt.Errorf("creating session store: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,13 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes
return nil, nil, err
}

sessStore, err := session.NewSQLiteSessionStore(f.sessionDB)
// Expand tilde in session database path
sessionDB, err := expandTilde(f.sessionDB)
if err != nil {
return nil, nil, err
}

sessStore, err := session.NewSQLiteSessionStore(sessionDB)
if err != nil {
return nil, nil, fmt.Errorf("creating session store: %w", err)
}
Expand Down
92 changes: 92 additions & 0 deletions cmd/root/tilde_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package root

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/paths"
)

func TestExpandTilde(t *testing.T) {
t.Parallel()

homeDir := paths.GetHomeDir()
require.NotEmpty(t, homeDir, "Home directory should be available for tests")

tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "expands_tilde_prefix",
input: "~/session.db",
expected: filepath.Join(homeDir, "session.db"),
},
{
name: "expands_tilde_with_nested_path",
input: "~/.cagent/session.db",
expected: filepath.Join(homeDir, ".cagent", "session.db"),
},
{
name: "expands_tilde_with_deep_path",
input: "~/path/to/some/file.db",
expected: filepath.Join(homeDir, "path", "to", "some", "file.db"),
},
{
name: "absolute_path_unchanged",
input: "/absolute/path/session.db",
expected: "/absolute/path/session.db",
},
{
name: "relative_path_unchanged",
input: "relative/path/session.db",
expected: "relative/path/session.db",
},
{
name: "tilde_in_middle_unchanged",
input: "/some/~/path/session.db",
expected: "/some/~/path/session.db",
},
{
name: "tilde_without_slash_unchanged",
input: "~something",
expected: "~something",
},
{
name: "just_tilde_slash_expands",
input: "~/",
expected: homeDir,
},
{
name: "empty_string_unchanged",
input: "",
expected: "",
},
{
name: "dot_path_unchanged",
input: "./session.db",
expected: "./session.db",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

result, err := expandTilde(tt.input)

if tt.wantErr {
require.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
81 changes: 64 additions & 17 deletions pkg/acp/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"slices"
"strings"
"sync"

"github.com/coder/acp-go-sdk"
"github.com/google/uuid"

"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/runtime"
Expand All @@ -26,9 +26,10 @@ import (

// Agent implements the ACP Agent interface for cagent
type Agent struct {
agentSource config.Source
runConfig *config.RuntimeConfig
sessions map[string]*Session
agentSource config.Source
runConfig *config.RuntimeConfig
sessionStore session.Store
sessions map[string]*Session

conn *acp.AgentSideConnection
team *team.Team
Expand All @@ -47,11 +48,12 @@ type Session struct {
}

// NewAgent creates a new ACP agent
func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig) *Agent {
func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig, sessionStore session.Store) *Agent {
return &Agent{
agentSource: agentSource,
runConfig: runConfig,
sessions: make(map[string]*Session),
agentSource: agentSource,
runConfig: runConfig,
sessionStore: sessionStore,
sessions: make(map[string]*Session),
}
}

Expand Down Expand Up @@ -108,30 +110,75 @@ func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (a
}

// NewSession implements [acp.Agent]
func (a *Agent) NewSession(_ context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) {
sid := uuid.New().String()
slog.Debug("ACP NewSession called", "session_id", sid, "cwd", params.Cwd)
func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) {
slog.Debug("ACP NewSession called", "cwd", params.Cwd)

// Log warning if MCP servers are provided (not yet supported)
if len(params.McpServers) > 0 {
slog.Warn("MCP servers provided by client are not yet supported", "count", len(params.McpServers))
}

rt, err := runtime.New(a.team, runtime.WithCurrentAgent("root"))
// Validate and normalize working directory
var workingDir string
if wd := strings.TrimSpace(params.Cwd); wd != "" {
absWd, err := filepath.Abs(wd)
if err != nil {
return acp.NewSessionResponse{}, fmt.Errorf("invalid working directory: %w", err)
}
info, err := os.Stat(absWd)
if err != nil {
return acp.NewSessionResponse{}, fmt.Errorf("working directory does not exist: %w", err)
}
if !info.IsDir() {
return acp.NewSessionResponse{}, fmt.Errorf("working directory must be a directory")
}
workingDir = absWd
}

rt, err := runtime.New(a.team,
runtime.WithCurrentAgent("root"),
runtime.WithSessionStore(a.sessionStore),
)
if err != nil {
return acp.NewSessionResponse{}, fmt.Errorf("failed to create runtime: %w", err)
}

// Get root agent config for session settings
rootAgent, err := a.team.Agent("root")
if err != nil {
return acp.NewSessionResponse{}, fmt.Errorf("failed to get root agent: %w", err)
}

// Build session options (title will be set after we have the session ID)
sessOpts := []session.Opt{
session.WithMaxIterations(rootAgent.MaxIterations()),
session.WithThinking(rootAgent.ThinkingConfigured()),
}
if workingDir != "" {
sessOpts = append(sessOpts, session.WithWorkingDir(workingDir))
}

// Create session - use its auto-generated ID
sess := session.New(sessOpts...)
sess.Title = "ACP Session " + sess.ID

// Persist session to the store
if err := a.sessionStore.AddSession(ctx, sess); err != nil {
return acp.NewSessionResponse{}, fmt.Errorf("failed to persist session: %w", err)
}

slog.Debug("ACP session created", "session_id", sess.ID)

a.mu.Lock()
a.sessions[sid] = &Session{
id: sid,
sess: session.New(session.WithTitle("ACP Session " + sid)),
a.sessions[sess.ID] = &Session{
id: sess.ID,
sess: sess,
rt: rt,
workingDir: params.Cwd,
workingDir: workingDir,
}
a.mu.Unlock()

return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil
return acp.NewSessionResponse{SessionId: acp.SessionId(sess.ID)}, nil
}

// Authenticate implements [acp.Agent]
Expand Down
Loading
Loading