From a46f4e02aff8b5e2355c2afc1dd904ccf4731099 Mon Sep 17 00:00:00 2001 From: Christopher Petito Date: Wed, 4 Feb 2026 23:57:56 +0100 Subject: [PATCH] Persist ACP sessions to default sqlite db unless specified with --session-db flag Generally makes agent and session setup follow more the patterns used for the api and run commands Also forces tilde expansion on the run and api commands when passing the session-db as --session-db=~/somepath or --session-db "~/somepath" Signed-off-by: Christopher Petito --- cmd/root/acp.go | 13 +++- cmd/root/api.go | 8 +- cmd/root/run.go | 8 +- cmd/root/tilde_test.go | 92 ++++++++++++++++++++++ pkg/acp/agent.go | 81 +++++++++++++++---- pkg/acp/agent_test.go | 173 +++++++++++++++++++++++++++++++++++++++++ pkg/acp/run.go | 18 ++++- 7 files changed, 370 insertions(+), 23 deletions(-) create mode 100644 cmd/root/tilde_test.go create mode 100644 pkg/acp/agent_test.go diff --git a/cmd/root/acp.go b/cmd/root/acp.go index ba069da25..07c2b02fd 100644 --- a/cmd/root/acp.go +++ b/cmd/root/acp.go @@ -1,15 +1,19 @@ package root import ( + "path/filepath" + "github.com/spf13/cobra" "github.com/docker/cagent/pkg/acp" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/paths" "github.com/docker/cagent/pkg/telemetry" ) type acpFlags struct { runConfig config.RuntimeConfig + sessionDB string } func newACPCmd() *cobra.Command { @@ -28,6 +32,7 @@ func newACPCmd() *cobra.Command { } addRuntimeConfigFlags(cmd, &flags.runConfig) + cmd.Flags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database") return cmd } @@ -38,5 +43,11 @@ func (f *acpFlags) runACPCommand(cmd *cobra.Command, args []string) error { ctx := cmd.Context() agentFilename := args[0] - return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return err + } + + return acp.Run(ctx, agentFilename, cmd.InOrStdin(), cmd.OutOrStdout(), &f.runConfig, sessionDB) } diff --git a/cmd/root/api.go b/cmd/root/api.go index 7296fa70d..c50d971d8 100644 --- a/cmd/root/api.go +++ b/cmd/root/api.go @@ -136,7 +136,13 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error { slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String()) - sessionStore, err := session.NewSQLiteSessionStore(f.sessionDB) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return err + } + + sessionStore, err := session.NewSQLiteSessionStore(sessionDB) if err != nil { return fmt.Errorf("creating session store: %w", err) } diff --git a/cmd/root/run.go b/cmd/root/run.go index 604d34683..552d330ff 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -334,7 +334,13 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes return nil, nil, err } - sessStore, err := session.NewSQLiteSessionStore(f.sessionDB) + // Expand tilde in session database path + sessionDB, err := expandTilde(f.sessionDB) + if err != nil { + return nil, nil, err + } + + sessStore, err := session.NewSQLiteSessionStore(sessionDB) if err != nil { return nil, nil, fmt.Errorf("creating session store: %w", err) } diff --git a/cmd/root/tilde_test.go b/cmd/root/tilde_test.go new file mode 100644 index 000000000..494e3d57a --- /dev/null +++ b/cmd/root/tilde_test.go @@ -0,0 +1,92 @@ +package root + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/paths" +) + +func TestExpandTilde(t *testing.T) { + t.Parallel() + + homeDir := paths.GetHomeDir() + require.NotEmpty(t, homeDir, "Home directory should be available for tests") + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "expands_tilde_prefix", + input: "~/session.db", + expected: filepath.Join(homeDir, "session.db"), + }, + { + name: "expands_tilde_with_nested_path", + input: "~/.cagent/session.db", + expected: filepath.Join(homeDir, ".cagent", "session.db"), + }, + { + name: "expands_tilde_with_deep_path", + input: "~/path/to/some/file.db", + expected: filepath.Join(homeDir, "path", "to", "some", "file.db"), + }, + { + name: "absolute_path_unchanged", + input: "/absolute/path/session.db", + expected: "/absolute/path/session.db", + }, + { + name: "relative_path_unchanged", + input: "relative/path/session.db", + expected: "relative/path/session.db", + }, + { + name: "tilde_in_middle_unchanged", + input: "/some/~/path/session.db", + expected: "/some/~/path/session.db", + }, + { + name: "tilde_without_slash_unchanged", + input: "~something", + expected: "~something", + }, + { + name: "just_tilde_slash_expands", + input: "~/", + expected: homeDir, + }, + { + name: "empty_string_unchanged", + input: "", + expected: "", + }, + { + name: "dot_path_unchanged", + input: "./session.db", + expected: "./session.db", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := expandTilde(tt.input) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/acp/agent.go b/pkg/acp/agent.go index b516480ec..f747b21d3 100644 --- a/pkg/acp/agent.go +++ b/pkg/acp/agent.go @@ -6,13 +6,13 @@ import ( "encoding/json" "fmt" "log/slog" + "os" "path/filepath" "slices" "strings" "sync" "github.com/coder/acp-go-sdk" - "github.com/google/uuid" "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/runtime" @@ -26,9 +26,10 @@ import ( // Agent implements the ACP Agent interface for cagent type Agent struct { - agentSource config.Source - runConfig *config.RuntimeConfig - sessions map[string]*Session + agentSource config.Source + runConfig *config.RuntimeConfig + sessionStore session.Store + sessions map[string]*Session conn *acp.AgentSideConnection team *team.Team @@ -47,11 +48,12 @@ type Session struct { } // NewAgent creates a new ACP agent -func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig) *Agent { +func NewAgent(agentSource config.Source, runConfig *config.RuntimeConfig, sessionStore session.Store) *Agent { return &Agent{ - agentSource: agentSource, - runConfig: runConfig, - sessions: make(map[string]*Session), + agentSource: agentSource, + runConfig: runConfig, + sessionStore: sessionStore, + sessions: make(map[string]*Session), } } @@ -108,30 +110,75 @@ func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (a } // NewSession implements [acp.Agent] -func (a *Agent) NewSession(_ context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { - sid := uuid.New().String() - slog.Debug("ACP NewSession called", "session_id", sid, "cwd", params.Cwd) +func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { + slog.Debug("ACP NewSession called", "cwd", params.Cwd) // Log warning if MCP servers are provided (not yet supported) if len(params.McpServers) > 0 { slog.Warn("MCP servers provided by client are not yet supported", "count", len(params.McpServers)) } - rt, err := runtime.New(a.team, runtime.WithCurrentAgent("root")) + // Validate and normalize working directory + var workingDir string + if wd := strings.TrimSpace(params.Cwd); wd != "" { + absWd, err := filepath.Abs(wd) + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("invalid working directory: %w", err) + } + info, err := os.Stat(absWd) + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("working directory does not exist: %w", err) + } + if !info.IsDir() { + return acp.NewSessionResponse{}, fmt.Errorf("working directory must be a directory") + } + workingDir = absWd + } + + rt, err := runtime.New(a.team, + runtime.WithCurrentAgent("root"), + runtime.WithSessionStore(a.sessionStore), + ) if err != nil { return acp.NewSessionResponse{}, fmt.Errorf("failed to create runtime: %w", err) } + // Get root agent config for session settings + rootAgent, err := a.team.Agent("root") + if err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("failed to get root agent: %w", err) + } + + // Build session options (title will be set after we have the session ID) + sessOpts := []session.Opt{ + session.WithMaxIterations(rootAgent.MaxIterations()), + session.WithThinking(rootAgent.ThinkingConfigured()), + } + if workingDir != "" { + sessOpts = append(sessOpts, session.WithWorkingDir(workingDir)) + } + + // Create session - use its auto-generated ID + sess := session.New(sessOpts...) + sess.Title = "ACP Session " + sess.ID + + // Persist session to the store + if err := a.sessionStore.AddSession(ctx, sess); err != nil { + return acp.NewSessionResponse{}, fmt.Errorf("failed to persist session: %w", err) + } + + slog.Debug("ACP session created", "session_id", sess.ID) + a.mu.Lock() - a.sessions[sid] = &Session{ - id: sid, - sess: session.New(session.WithTitle("ACP Session " + sid)), + a.sessions[sess.ID] = &Session{ + id: sess.ID, + sess: sess, rt: rt, - workingDir: params.Cwd, + workingDir: workingDir, } a.mu.Unlock() - return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil + return acp.NewSessionResponse{SessionId: acp.SessionId(sess.ID)}, nil } // Authenticate implements [acp.Agent] diff --git a/pkg/acp/agent_test.go b/pkg/acp/agent_test.go new file mode 100644 index 000000000..8d701d883 --- /dev/null +++ b/pkg/acp/agent_test.go @@ -0,0 +1,173 @@ +package acp + +import ( + "context" + "io" + "path/filepath" + "testing" + + acpsdk "github.com/coder/acp-go-sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/agent" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/model/provider/base" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/team" + "github.com/docker/cagent/pkg/tools" +) + +// mockStream simulates a chat completion stream for testing. +type mockStream struct { + responses []chat.MessageStreamResponse + idx int +} + +func (m *mockStream) Recv() (chat.MessageStreamResponse, error) { + if m.idx >= len(m.responses) { + return chat.MessageStreamResponse{}, io.EOF + } + resp := m.responses[m.idx] + m.idx++ + return resp, nil +} + +func (m *mockStream) Close() {} + +// mockProvider returns a predetermined stream for testing. +type mockProvider struct { + id string + stream chat.MessageStream +} + +func (m *mockProvider) ID() string { return m.id } + +func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return m.stream, nil +} + +func (m *mockProvider) BaseConfig() base.Config { return base.Config{} } + +func (m *mockProvider) MaxTokens() int { return 0 } + +// TestACPSessionPersistence verifies that ACP sessions are persisted to the SQLite store. +func TestACPSessionPersistence(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Create a temp SQLite session DB + dbPath := filepath.Join(t.TempDir(), "session.db") + sessStore, err := session.NewSQLiteSessionStore(dbPath) + require.NoError(t, err) + + // Close the store at the end + if closer, ok := sessStore.(io.Closer); ok { + defer closer.Close() + } + + // Create a mock provider that returns a simple assistant message + stream := &mockStream{ + responses: []chat.MessageStreamResponse{ + { + Choices: []chat.MessageStreamChoice{{ + Index: 0, + Delta: chat.MessageDelta{Content: "Hello from the agent!"}, + }}, + }, + { + Choices: []chat.MessageStreamChoice{{ + Index: 0, + FinishReason: chat.FinishReasonStop, + }}, + Usage: &chat.Usage{InputTokens: 10, OutputTokens: 5}, + }, + }, + } + prov := &mockProvider{id: "test/mock-model", stream: stream} + + // Create a minimal team with a root agent + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + // Create the ACP agent with the session store + // Note: we set team directly to avoid Initialize requiring full config loading + acpAgent := &Agent{ + agentSource: nil, // Not needed since team is pre-set + runConfig: &config.RuntimeConfig{}, + sessionStore: sessStore, + sessions: make(map[string]*Session), + team: tm, + } + + // Create a new session via ACP with a real temp directory + workingDir := t.TempDir() + newSessResp, err := acpAgent.NewSession(ctx, acpsdk.NewSessionRequest{ + Cwd: workingDir, + }) + require.NoError(t, err) + acpSessionID := string(newSessResp.SessionId) + require.NotEmpty(t, acpSessionID) + + // Get the session and add a user message + acpAgent.mu.Lock() + acpSess := acpAgent.sessions[acpSessionID] + acpAgent.mu.Unlock() + require.NotNil(t, acpSess) + + // Use the actual session ID for lookups (should match the ACP session ID after fix) + sessionID := acpSess.sess.ID + + // Add user message to the session + acpSess.sess.AddMessage(session.UserMessage("Hello, agent!")) + + // Run the runtime directly (bypasses ACP connection which we don't have in test) + // This tests that the session store is properly used by the runtime + eventsChan := acpSess.rt.RunStream(ctx, acpSess.sess) + + // Drain events + for range eventsChan { + // Just consume all events + } + + // Verify the session is persisted via GetSessionSummaries + summaries, err := sessStore.GetSessionSummaries(ctx) + require.NoError(t, err) + + // Find our session in the summaries + var found bool + for _, s := range summaries { + if s.ID == sessionID { + found = true + assert.Contains(t, s.Title, "ACP Session") + break + } + } + assert.True(t, found, "ACP session should appear in GetSessionSummaries") + + // Also verify full session retrieval + loadedSess, err := sessStore.GetSession(ctx, sessionID) + require.NoError(t, err) + assert.Equal(t, sessionID, loadedSess.ID) + assert.Contains(t, loadedSess.Title, "ACP Session") + assert.Equal(t, workingDir, loadedSess.WorkingDir) + + // Verify messages were persisted (user + assistant) + assert.GreaterOrEqual(t, len(loadedSess.Messages), 2, "Session should have at least user and assistant messages") + + // Find user message + var hasUserMsg, hasAssistantMsg bool + for _, item := range loadedSess.Messages { + if item.Message != nil { + if item.Message.Message.Role == chat.MessageRoleUser { + hasUserMsg = true + } + if item.Message.Message.Role == chat.MessageRoleAssistant { + hasAssistantMsg = true + } + } + } + assert.True(t, hasUserMsg, "Session should have a user message") + assert.True(t, hasAssistantMsg, "Session should have an assistant message") +} diff --git a/pkg/acp/run.go b/pkg/acp/run.go index 4c5a3563e..84f57b216 100644 --- a/pkg/acp/run.go +++ b/pkg/acp/run.go @@ -2,23 +2,35 @@ package acp import ( "context" + "fmt" "io" "log/slog" acpsdk "github.com/coder/acp-go-sdk" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/session" ) -func Run(ctx context.Context, agentFilename string, stdin io.Reader, stdout io.Writer, runConfig *config.RuntimeConfig) error { - slog.Debug("Starting ACP server", "agent", agentFilename) +func Run(ctx context.Context, agentFilename string, stdin io.Reader, stdout io.Writer, runConfig *config.RuntimeConfig, sessionDB string) error { + slog.Debug("Starting ACP server", "agent", agentFilename, "session_db", sessionDB) agentSource, err := config.Resolve(agentFilename, nil) if err != nil { return err } - acpAgent := NewAgent(agentSource, runConfig) + // Create SQLite session store for persistent sessions + sessStore, err := session.NewSQLiteSessionStore(sessionDB) + if err != nil { + return fmt.Errorf("creating session store: %w", err) + } + // Close the store on shutdown if it implements io.Closer + if closer, ok := sessStore.(io.Closer); ok { + defer closer.Close() + } + + acpAgent := NewAgent(agentSource, runConfig, sessStore) conn := acpsdk.NewAgentSideConnection(acpAgent, stdout, stdin) conn.SetLogger(slog.Default()) acpAgent.SetAgentConnection(conn)