diff --git a/cmd/root/acp.go b/cmd/root/acp.go index ba069da25..07c2b02fd 100644 --- a/cmd/root/acp.go +++ b/cmd/root/acp.go @@ -1,15 +1,19 @@ package root import ( + "path/filepath" + "github.com/spf13/cobra" "github.com/docker/cagent/pkg/acp" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/paths" "github.com/docker/cagent/pkg/telemetry" ) type acpFlags struct { runConfig config.RuntimeConfig + sessionDB string } func newACPCmd() *cobra.Command { @@ -28,6 +32,7 @@ func newACPCmd() *cobra.Command { } addRuntimeConfigFlags(cmd, &flags.runConfig) + cmd.Flags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database") return cmd } @@ -38,5 +43,11 @@ func (f *acpFlags) runACPCommand(cmd *cobra.Command, args []string) error { ctx := cmd.Context() agentFilename := args[0] - return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return err + } + + return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig, sessionDB) } diff --git a/cmd/root/api.go b/cmd/root/api.go index 7296fa70d..c50d971d8 100644 --- a/cmd/root/api.go +++ b/cmd/root/api.go @@ -136,7 +136,13 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error { slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String()) - sessionStore, err := session.NewSQLiteSessionStore(f.sessionDB) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return err + } + + sessionStore, err := session.NewSQLiteSessionStore(sessionDB) if err != nil { return fmt.Errorf("creating session store: %w", err) } diff --git a/cmd/root/run.go b/cmd/root/run.go index 604d34683..552d330ff 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -334,7 +334,13 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes return nil, nil, err } - sessStore, err := session.NewSQLiteSessionStore(f.sessionDB) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return nil, nil, err + } + + sessStore, err := session.NewSQLiteSessionStore(sessionDB) if err != nil { return nil, nil, fmt.Errorf("creating session store: %w", err) } diff --git a/cmd/root/tilde_test.go b/cmd/root/tilde_test.go new file mode 100644 index 000000000..494e3d57a --- /dev/null +++ b/cmd/root/tilde_test.go @@ -0,0 +1,92 @@ +package root + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/paths" +) + +func TestExpandTilde(t *testing.T) { + t.Parallel() + + homeDir := paths.GetHomeDir() + require.NotEmpty(t, homeDir, "Home directory should be available for tests") + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "expands_tilde_prefix", + input: "~/session.db", + expected: filepath.Join(homeDir, "session.db"), + }, + { + name: "expands_tilde_with_nested_path", + input: "~/.cagent/session.db", + expected: filepath.Join(homeDir, ".cagent", "session.db"), + }, + { + name: "expands_tilde_with_deep_path", + input: "~/path/to/some/file.db", + expected: filepath.Join(homeDir, "path", "to", "some", "file.db"), + }, + { + name: "absolute_path_unchanged", + input: "/absolute/path/session.db", + expected: "/absolute/path/session.db", + }, + { + name: "relative_path_unchanged", + input: "relative/path/session.db", + expected: "relative/path/session.db", + }, + { + name: "tilde_in_middle_unchanged", + input: "/some/~/path/session.db", + expected: "/some/~/path/session.db", + }, + { + name: "tilde_without_slash_unchanged", + input: "~something", + expected: "~something", + }, + { + name: "just_tilde_slash_expands", + input: "~/", + expected: homeDir, + }, + { + name: "empty_string_unchanged", + input: "", + expected: "", + }, + { + name: "dot_path_unchanged", + input: "./session.db", + expected: "./session.db", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := expandTilde(tt.input) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/acp/agent.go b/pkg/acp/agent.go index b516480ec..f747b21d3 100644 --- a/pkg/acp/agent.go +++ b/pkg/acp/agent.go @@ -6,13 +6,13 @@ import ( "encoding/json" "fmt" "log/slog" + "os" "path/filepath" "slices" "strings" "sync" "github.com/coder/acp-go-sdk" - "github.com/google/uuid" "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/runtime" @@ -26,9 +26,10 @@ import ( // Agent implements the ACP Agent interface for cagent type Agent struct { - agentSource config.Source - runConfig *config.RuntimeConfig - sessions map[string]*Session + agentSource config.Source + runConfig *config.RuntimeConfig + sessionStore session.Store + sessions map[string]*Session conn *acp.AgentSideConnection team *team.Team @@ -47,11 +48,12 @@ type Session struct { } // NewAgent creates a new ACP agent -func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig) *Agent { +func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig, sessionStore session.Store) *Agent { return &Agent{ - agentSource: agentSource, - runConfig: runConfig, - sessions: make(map[string]*Session), + agentSource: agentSource, + runConfig: runConfig, + sessionStore: sessionStore, + sessions: make(map[string]*Session), } } @@ -108,30 +110,75 @@ func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (a } // NewSession implements [acp.Agent] -func (a *Agent) NewSession(_ context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { - sid := uuid.New().String() - slog.Debug("ACP NewSession called", "session_id", sid, "cwd", params.Cwd) +func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { + slog.Debug("ACP NewSession called", "cwd", params.Cwd) // Log warning if MCP servers are provided (not yet supported) if len(params.McpServers) > 0 { slog.Warn("MCP servers provided by client are not yet supported", "count", len(params.McpServers)) } - rt, err := runtime.New(a.team, runtime.WithCurrentAgent("root")) + // Validate and normalize working directory + var workingDir string + if wd := strings.TrimSpace(params.Cwd); wd != "" { + absWd, err := filepath.Abs(wd) + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("invalid working directory: %w", err) + } + info, err := os.Stat(absWd) + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("working directory does not exist: %w", err) + } + if !info.IsDir() { + return acp.NewSessionResponse{}, fmt.Errorf("working directory must be a directory") + } + workingDir = absWd + } + + rt, err := runtime.New(a.team, + runtime.WithCurrentAgent("root"), + runtime.WithSessionStore(a.sessionStore), + ) if err != nil { return acp.NewSessionResponse{}, fmt.Errorf("failed to create runtime: %w", err) } + // Get root agent config for session settings + rootAgent, err := a.team.Agent("root") + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("failed to get root agent: %w", err) + } + + // Build session options (title will be set after we have the session ID) + sessOpts := []session.Opt{ + session.WithMaxIterations(rootAgent.MaxIterations()), + session.WithThinking(rootAgent.ThinkingConfigured()), + } + if workingDir != "" { + sessOpts = append(sessOpts, session.WithWorkingDir(workingDir)) + } + + // Create session - use its auto-generated ID + sess := session.New(sessOpts...) + sess.Title = "ACP Session " + sess.ID + + // Persist session to the store + if err := a.sessionStore.AddSession(ctx, sess); err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("failed to persist session: %w", err) + } + + slog.Debug("ACP session created", "session_id", sess.ID) + a.mu.Lock() - a.sessions[sid] = &Session{ - id: sid, - sess: session.New(session.WithTitle("ACP Session " + sid)), + a.sessions[sess.ID] = &Session{ + id: sess.ID, + sess: sess, rt: rt, - workingDir: params.Cwd, + workingDir: workingDir, } a.mu.Unlock() - return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil + return acp.NewSessionResponse{SessionId: acp.SessionId(sess.ID)}, nil } // Authenticate implements [acp.Agent] diff --git a/pkg/acp/agent_test.go b/pkg/acp/agent_test.go new file mode 100644 index 000000000..8d701d883 --- /dev/null +++ b/pkg/acp/agent_test.go @@ -0,0 +1,173 @@ +package acp + +import ( + "context" + "io" + "path/filepath" + "testing" + + acpsdk "github.com/coder/acp-go-sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/agent" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/model/provider/base" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/team" + "github.com/docker/cagent/pkg/tools" +) + +// mockStream simulates a chat completion stream for testing. +type mockStream struct { + responses []chat.MessageStreamResponse + idx int +} + +func (m *mockStream) Recv() (chat.MessageStreamResponse, error) { + if m.idx >= len(m.responses) { + return chat.MessageStreamResponse{}, io.EOF + } + resp := m.responses[m.idx] + m.idx++ + return resp, nil +} + +func (m *mockStream) Close() {} + +// mockProvider returns a predetermined stream for testing. +type mockProvider struct { + id string + stream chat.MessageStream +} + +func (m *mockProvider) ID() string { return m.id } + +func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return m.stream, nil +} + +func (m *mockProvider) BaseConfig() base.Config { return base.Config{} } + +func (m *mockProvider) MaxTokens() int { return 0 } + +// TestACPSessionPersistence verifies that ACP sessions are persisted to the SQLite store. +func TestACPSessionPersistence(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Create a temp SQLite session DB + dbPath := filepath.Join(t.TempDir(), "session.db") + sessStore, err := session.NewSQLiteSessionStore(dbPath) + require.NoError(t, err) + + // Close the store at the end + if closer, ok := sessStore.(io.Closer); ok { + defer closer.Close() + } + + // Create a mock provider that returns a simple assistant message + stream := &mockStream{ + responses: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{{ + Index: 0, + Delta: chat.MessageDelta{Content: "Hello from the agent!"}, + }}, + }, + { + Choices: []chat.MessageStreamChoice{{ + Index: 0, + FinishReason: chat.FinishReasonStop, + }}, + Usage: &chat.Usage{InputTokens: 10, OutputTokens: 5}, + }, + }, + } + prov := &mockProvider{id: "test/mock-model", stream: stream} + + // Create a minimal team with a root agent + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + // Create the ACP agent with the session store + // Note: we set team directly to avoid Initialize requiring full config loading + acpAgent := &Agent{ + agentSource: nil, // Not needed since team is pre-set + runConfig: &config.RuntimeConfig{}, + sessionStore: sessStore, + sessions: make(map[string]*Session), + team: tm, + } + + // Create a new session via ACP with a real temp directory + workingDir := t.TempDir() + newSessResp, err := acpAgent.NewSession(ctx, acpsdk.NewSessionRequest{ + Cwd: workingDir, + }) + require.NoError(t, err) + acpSessionID := string(newSessResp.SessionId) + require.NotEmpty(t, acpSessionID) + + // Get the session and add a user message + acpAgent.mu.Lock() + acpSess := acpAgent.sessions[acpSessionID] + acpAgent.mu.Unlock() + require.NotNil(t, acpSess) + + // Use the actual session ID for lookups (should match the ACP session ID after fix) + sessionID := acpSess.sess.ID + + // Add user message to the session + acpSess.sess.AddMessage(session.UserMessage("Hello, agent!")) + + // Run the runtime directly (bypasses ACP connection which we don't have in test) + // This tests that the session store is properly used by the runtime + eventsChan := acpSess.rt.RunStream(ctx, acpSess.sess) + + // Drain events + for range eventsChan { + // Just consume all events + } + + // Verify the session is persisted via GetSessionSummaries + summaries, err := sessStore.GetSessionSummaries(ctx) + require.NoError(t, err) + + // Find our session in the summaries + var found bool + for _, s := range summaries { + if s.ID == sessionID { + found = true + assert.Contains(t, s.Title, "ACP Session") + break + } + } + assert.True(t, found, "ACP session should appear in GetSessionSummaries") + + // Also verify full session retrieval + loadedSess, err := sessStore.GetSession(ctx, sessionID) + require.NoError(t, err) + assert.Equal(t, sessionID, loadedSess.ID) + assert.Contains(t, loadedSess.Title, "ACP Session") + assert.Equal(t, workingDir, loadedSess.WorkingDir) + + // Verify messages were persisted (user + assistant) + assert.GreaterOrEqual(t, len(loadedSess.Messages), 2, "Session should have at least user and assistant messages") + + // Find user message + var hasUserMsg, hasAssistantMsg bool + for _, item := range loadedSess.Messages { + if item.Message != nil { + if item.Message.Message.Role == chat.MessageRoleUser { + hasUserMsg = true + } + if item.Message.Message.Role == chat.MessageRoleAssistant { + hasAssistantMsg = true + } + } + } + assert.True(t, hasUserMsg, "Session should have a user message") + assert.True(t, hasAssistantMsg, "Session should have an assistant message") +} diff --git a/pkg/acp/run.go b/pkg/acp/run.go index 4c5a3563e..84f57b216 100644 --- a/pkg/acp/run.go +++ b/pkg/acp/run.go @@ -2,23 +2,35 @@ package acp import ( "context" + "fmt" "io" "log/slog" acpsdk "github.com/coder/acp-go-sdk" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/session" ) -func Run(ctx context.Context, agentFilename string, stdin io.Reader, stdout io.Writer, runConfig *config.RuntimeConfig) error { - slog.Debug("Starting ACP server", "agent", agentFilename) +func Run(ctx context.Context, agentFilename string, stdin io.Reader, stdout io.Writer, runConfig *config.RuntimeConfig, sessionDB string) error { + slog.Debug("Starting ACP server", "agent", agentFilename, "session_db", sessionDB) agentSource, err := config.Resolve(agentFilename, nil) if err != nil { return err } - acpAgent := NewAgent(agentSource, runConfig) + // Create SQLite session store for persistent sessions + sessStore, err := session.NewSQLiteSessionStore(sessionDB) + if err != nil { + return fmt.Errorf("creating session store: %w", err) + } + // Close the store on shutdown if it implements io.Closer + if closer, ok := sessStore.(io.Closer); ok { + defer closer.Close() + } + + acpAgent := NewAgent(agentSource, runConfig, sessStore) conn := acpsdk.NewAgentSideConnection(acpAgent, stdout, stdin) conn.SetLogger(slog.Default()) acpAgent.SetAgentConnection(conn)