From 93eecd8ec697b7d3af4795b04de4ba27a4c3a9aa Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 11:40:06 +0200 Subject: [PATCH 1/3] add /snapshot command to list and reset snapshots --- docs/features/tui/index.md | 12 +- pkg/app/undo.go | 71 ++++++++++ pkg/hooks/builtins/builtins.go | 19 +++ pkg/hooks/builtins/snapshot.go | 87 +++++++++++++ pkg/hooks/builtins/snapshot_test.go | 116 +++++++++++++++++ pkg/runtime/snapshot.go | 43 +++++- pkg/tui/commands/commands.go | 46 +++++++ pkg/tui/commands/commands_test.go | 54 ++++++++ pkg/tui/dialog/snapshot.go | 195 ++++++++++++++++++++++++++++ pkg/tui/handlers.go | 27 ++++ pkg/tui/messages/session.go | 9 ++ pkg/tui/tui.go | 6 + 12 files changed, 681 insertions(+), 4 deletions(-) create mode 100644 pkg/tui/dialog/snapshot.go diff --git a/docs/features/tui/index.md b/docs/features/tui/index.md index a5d758308..196e18d00 100644 --- a/docs/features/tui/index.md +++ b/docs/features/tui/index.md @@ -40,7 +40,8 @@ Type `/` during a session to see available commands, or press Ctrl+Ctrl+↑/ (or j/k) to highlight an entry, then press Enter to reset the workspace to that point. Pick `` to revert every snapshot and bring the workspace back to its pre-agent state. Esc closes the dialog without changing anything. (`/snapshots` is accepted as an alias.) + +Neither command removes messages from the session transcript — they only touch files on disk. Both commands (and the matching command-palette entries) are hidden when snapshots are turned off. Omit `snapshot` or set it to `false` to leave automatic snapshots off; agents can still configure snapshot hooks manually. ## File Attachments diff --git a/pkg/app/undo.go b/pkg/app/undo.go index cb650e7d0..082c0d1f4 100644 --- a/pkg/app/undo.go +++ b/pkg/app/undo.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/session" ) @@ -14,10 +15,37 @@ type UndoSnapshotResult struct { RestoredFiles int } +// SnapshotInfo summarises one snapshot checkpoint for display. +type SnapshotInfo struct { + // Files is the number of files captured in the checkpoint. + Files int +} + type snapshotUndoer interface { UndoLastSnapshot(ctx context.Context, sess *session.Session) (files int, ok bool, err error) } +type snapshotLister interface { + ListSnapshots(sess *session.Session) []builtins.SnapshotInfo +} + +type snapshotResetter interface { + ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (files int, ok bool, err error) +} + +type snapshotsEnabledChecker interface { + SnapshotsEnabled() bool +} + +// SnapshotsEnabled reports whether automatic shadow-git snapshots are active +// for the current runtime. Returns false for runtimes that don't support +// snapshots at all (e.g. remote runtimes) or when the feature is turned off +// in the user config. +func (a *App) SnapshotsEnabled() bool { + checker, ok := a.runtime.(snapshotsEnabledChecker) + return ok && checker.SnapshotsEnabled() +} + func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) { if a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo @@ -35,3 +63,46 @@ func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) } return UndoSnapshotResult{RestoredFiles: files}, nil } + +// ListSnapshots returns the snapshot checkpoints recorded for the current +// session in chronological order (oldest first). Returns nil when the runtime +// does not support snapshots or none have been captured yet. +func (a *App) ListSnapshots() []SnapshotInfo { + if a.session == nil { + return nil + } + lister, ok := a.runtime.(snapshotLister) + if !ok { + return nil + } + raw := lister.ListSnapshots(a.session) + if len(raw) == 0 { + return nil + } + out := make([]SnapshotInfo, len(raw)) + for i, s := range raw { + out[i] = SnapshotInfo{Files: s.Files} + } + return out +} + +// ResetSnapshot reverts every checkpoint past index keep so the workspace +// returns to the state captured at that snapshot. keep == 0 resets to the +// original pre-agent state. Returns ErrNothingToUndo when nothing changes. +func (a *App) ResetSnapshot(ctx context.Context, keep int) (UndoSnapshotResult, error) { + if a.session == nil { + return UndoSnapshotResult{}, ErrNothingToUndo + } + resetter, ok := a.runtime.(snapshotResetter) + if !ok { + return UndoSnapshotResult{}, ErrNothingToUndo + } + files, ok, err := resetter.ResetSnapshot(ctx, a.session, keep) + if err != nil { + return UndoSnapshotResult{}, fmt.Errorf("restoring snapshot: %w", err) + } + if !ok { + return UndoSnapshotResult{}, ErrNothingToUndo + } + return UndoSnapshotResult{RestoredFiles: files}, nil +} diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index 694ecb856..8df4204fe 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -80,6 +80,25 @@ func (s *State) UndoLastSnapshot(ctx context.Context, sessionID, cwd string) (fi return s.snapshot.undoLast(ctx, sessionID, cwd) } +// ListSnapshots returns the completed snapshot checkpoints for a session in +// chronological order (oldest first). Returns nil when no snapshots exist. +func (s *State) ListSnapshots(sessionID string) []SnapshotInfo { + if s == nil || s.snapshot == nil || sessionID == "" { + return nil + } + return s.snapshot.listSnapshots(sessionID) +} + +// ResetSnapshot reverts every checkpoint past index keep so the workspace +// returns to the state captured at that snapshot. keep == 0 resets to the +// original (pre-agent) state. +func (s *State) ResetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { + if s == nil || s.snapshot == nil || sessionID == "" || cwd == "" { + return 0, false, nil + } + return s.snapshot.resetSnapshot(ctx, sessionID, cwd, keep) +} + // Register installs the stock builtin hooks on r and returns a [State] // handle the caller can use for stateful builtin operations. func Register(r *hooks.Registry) (*State, error) { diff --git a/pkg/hooks/builtins/snapshot.go b/pkg/hooks/builtins/snapshot.go index c9424a9ba..33b90fcc7 100644 --- a/pkg/hooks/builtins/snapshot.go +++ b/pkg/hooks/builtins/snapshot.go @@ -13,6 +13,12 @@ import ( // Snapshot is the registered name of the snapshot builtin. const Snapshot = "snapshot" +// SnapshotInfo summarises one completed snapshot checkpoint for display. +type SnapshotInfo struct { + // Files is the number of unique files captured in the checkpoint. + Files int +} + type snapshotBuiltin struct { manager *snapshot.Manager mu sync.Mutex @@ -204,6 +210,87 @@ func (b *snapshotBuiltin) undoLast(ctx context.Context, sessionID, cwd string) ( return len(checkpoint.files), true, nil } +// listSnapshots returns the completed checkpoints for a session in chronological +// order (oldest first). The returned slice may be empty. +func (b *snapshotBuiltin) listSnapshots(sessionID string) []SnapshotInfo { + b.mu.Lock() + defer b.mu.Unlock() + s := b.session[sessionID] + if s == nil || len(s.history) == 0 { + return nil + } + out := make([]SnapshotInfo, len(s.history)) + for i, c := range s.history { + out[i] = SnapshotInfo{Files: countUnique(c.files)} + } + return out +} + +// truncateAfter pops and returns checkpoints with index >= keep, leaving the +// surviving prefix in the session history. keep is clamped into [0, len]. +func (b *snapshotBuiltin) truncateAfter(sessionID string, keep int) []snapshotCheckpoint { + b.mu.Lock() + defer b.mu.Unlock() + s := b.session[sessionID] + if s == nil || len(s.history) == 0 { + return nil + } + if keep < 0 { + keep = 0 + } + if keep >= len(s.history) { + return nil + } + tail := append([]snapshotCheckpoint(nil), s.history[keep:]...) + for i := keep; i < len(s.history); i++ { + s.history[i] = snapshotCheckpoint{} + } + s.history = s.history[:keep] + return tail +} + +// resetSnapshot reverts every checkpoint with index >= keep so the workspace +// returns to the state captured at snapshot keep. keep == 0 means "reset to +// the original state". A keep value greater than or equal to the snapshot +// count is a no-op. Reverted checkpoints are dropped from the session history. +func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { + tail := b.truncateAfter(sessionID, keep) + if len(tail) == 0 { + return 0, false, nil + } + repo, err := b.manager.Open(ctx, cwd) + if err != nil { + return 0, true, err + } + patches := make([]snapshot.Patch, 0, len(tail)) + seen := map[string]bool{} + unique := 0 + for _, c := range tail { + patches = append(patches, snapshot.Patch{Hash: c.hash, Files: c.files}) + for _, f := range c.files { + if !seen[f] { + seen[f] = true + unique++ + } + } + } + if err := repo.Revert(ctx, patches); err != nil { + return 0, true, err + } + return unique, true, nil +} + +func countUnique(files []string) int { + if len(files) == 0 { + return 0 + } + seen := make(map[string]bool, len(files)) + for _, f := range files { + seen[f] = true + } + return len(seen) +} + func logPatch(ctx context.Context, scope, sessionID, label string, patch snapshot.Patch, after string) { if len(patch.Files) == 0 { return diff --git a/pkg/hooks/builtins/snapshot_test.go b/pkg/hooks/builtins/snapshot_test.go index 240acdc7d..d748e1172 100644 --- a/pkg/hooks/builtins/snapshot_test.go +++ b/pkg/hooks/builtins/snapshot_test.go @@ -82,6 +82,122 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { assert.NoFileExists(t, changedPath) } +func TestSnapshotBuiltinListAndReset(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not available") + } + paths.SetDataDir(t.TempDir()) + t.Cleanup(func() { paths.SetDataDir("") }) + + r := hooks.NewRegistry() + state, err := builtins.Register(r) + require.NoError(t, err) + fn, ok := r.LookupBuiltin(builtins.Snapshot) + require.True(t, ok) + + dir := snapshotBuiltinRepo(t) + + // Initially: no checkpoints. + assert.Empty(t, state.ListSnapshots("s")) + + // Capture three snapshots: each turn modifies one file. + recordTurn := func(t *testing.T, name, contents string) { + t.Helper() + _, err := fn(t.Context(), &hooks.Input{ + SessionID: "s", + Cwd: dir, + HookEventName: hooks.EventTurnStart, + }, nil) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(contents), 0o644)) + _, err = fn(t.Context(), &hooks.Input{ + SessionID: "s", + Cwd: dir, + HookEventName: hooks.EventTurnEnd, + Reason: "continue", + }, nil) + require.NoError(t, err) + } + + recordTurn(t, "a.txt", "a") + recordTurn(t, "b.txt", "b") + recordTurn(t, "c.txt", "c") + + snaps := state.ListSnapshots("s") + require.Len(t, snaps, 3) + assert.Equal(t, 1, snaps[0].Files) + assert.Equal(t, 1, snaps[1].Files) + assert.Equal(t, 1, snaps[2].Files) + + // Reset to snapshot 2: revert turn 3 only, leaving a.txt and b.txt intact. + files, restored, err := state.ResetSnapshot(t.Context(), "s", dir, 2) + require.NoError(t, err) + assert.True(t, restored) + assert.Equal(t, 1, files) + assert.FileExists(t, filepath.Join(dir, "a.txt")) + assert.FileExists(t, filepath.Join(dir, "b.txt")) + assert.NoFileExists(t, filepath.Join(dir, "c.txt")) + require.Len(t, state.ListSnapshots("s"), 2) + + // Reset to original: revert remaining checkpoints, deleting all three files. + files, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 0) + require.NoError(t, err) + assert.True(t, restored) + assert.Equal(t, 2, files) + assert.NoFileExists(t, filepath.Join(dir, "a.txt")) + assert.NoFileExists(t, filepath.Join(dir, "b.txt")) + assert.Empty(t, state.ListSnapshots("s")) + + // Subsequent reset is a no-op (nothing to revert). + _, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 0) + require.NoError(t, err) + assert.False(t, restored) +} + +func TestSnapshotBuiltinResetKeepBeyondHistoryIsNoop(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not available") + } + paths.SetDataDir(t.TempDir()) + t.Cleanup(func() { paths.SetDataDir("") }) + + r := hooks.NewRegistry() + state, err := builtins.Register(r) + require.NoError(t, err) + fn, ok := r.LookupBuiltin(builtins.Snapshot) + require.True(t, ok) + + dir := snapshotBuiltinRepo(t) + _, err = fn(t.Context(), &hooks.Input{ + SessionID: "s", + Cwd: dir, + HookEventName: hooks.EventTurnStart, + }, nil) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a"), 0o644)) + _, err = fn(t.Context(), &hooks.Input{ + SessionID: "s", + Cwd: dir, + HookEventName: hooks.EventTurnEnd, + Reason: "continue", + }, nil) + require.NoError(t, err) + + // keep == len(history) means "keep everything" — no checkpoints reverted. + files, restored, err := state.ResetSnapshot(t.Context(), "s", dir, 1) + require.NoError(t, err) + assert.False(t, restored) + assert.Equal(t, 0, files) + assert.FileExists(t, filepath.Join(dir, "a.txt")) + require.Len(t, state.ListSnapshots("s"), 1) + + // keep way past the end is also a no-op. + _, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 99) + require.NoError(t, err) + assert.False(t, restored) + require.Len(t, state.ListSnapshots("s"), 1) +} + func snapshotBuiltinRepo(t *testing.T) string { t.Helper() dir := t.TempDir() diff --git a/pkg/runtime/snapshot.go b/pkg/runtime/snapshot.go index aa552566c..47dc419b6 100644 --- a/pkg/runtime/snapshot.go +++ b/pkg/runtime/snapshot.go @@ -4,6 +4,7 @@ import ( "context" "os" + "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/session" ) @@ -14,11 +15,51 @@ func WithSnapshots(enabled bool) Opt { } } +// SnapshotsEnabled reports whether automatic snapshot hooks are active for +// this runtime. Used by the TUI to hide snapshot-related commands when the +// feature is off. +func (r *LocalRuntime) SnapshotsEnabled() bool { + return r != nil && r.snapshotsEnabled +} + // UndoLastSnapshot restores files recorded for the latest completed snapshot hook checkpoint. func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Session) (files int, ok bool, err error) { if r == nil || sess == nil { return 0, false, nil } + cwd := r.snapshotCwd(sess) + if cwd == "" { + return 0, false, nil + } + return r.builtinsState.UndoLastSnapshot(ctx, sess.ID, cwd) +} + +// ListSnapshots returns the completed snapshot checkpoints recorded for the +// session, oldest first. Returns nil when none exist. +func (r *LocalRuntime) ListSnapshots(sess *session.Session) []builtins.SnapshotInfo { + if r == nil || sess == nil { + return nil + } + return r.builtinsState.ListSnapshots(sess.ID) +} + +// ResetSnapshot reverts every checkpoint past index keep so the workspace +// returns to the state captured at that snapshot. keep == 0 resets to the +// original (pre-agent) state. +func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (files int, ok bool, err error) { + if r == nil || sess == nil { + return 0, false, nil + } + cwd := r.snapshotCwd(sess) + if cwd == "" { + return 0, false, nil + } + return r.builtinsState.ResetSnapshot(ctx, sess.ID, cwd, keep) +} + +// snapshotCwd resolves the working directory used to open the shadow +// repository for snapshot operations. +func (r *LocalRuntime) snapshotCwd(sess *session.Session) string { cwd := sess.WorkingDir if cwd == "" { cwd = r.workingDir @@ -26,5 +67,5 @@ func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Sessi if cwd == "" { cwd, _ = os.Getwd() } - return r.builtinsState.UndoLastSnapshot(ctx, sess.ID, cwd) + return cwd } diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 006e04c16..4aef75cfb 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -106,6 +106,29 @@ func builtInSessionCommands() []Item { return core.CmdHandler(messages.UndoSnapshotMsg{}) }, }, + { + ID: "session.snapshot", + Label: "Snapshot", + SlashCommand: "/snapshot", + Description: "List captured snapshots and reset to one", + Category: "Session", + Immediate: true, + Execute: func(string) tea.Cmd { + return core.CmdHandler(messages.ShowSnapshotsDialogMsg{}) + }, + }, + { + ID: "session.snapshots", + Label: "Snapshots", + SlashCommand: "/snapshots", + Hidden: true, + Description: "Alias for /snapshot", + Category: "Session", + Immediate: true, + Execute: func(string) tea.Cmd { + return core.CmdHandler(messages.ShowSnapshotsDialogMsg{}) + }, + }, { ID: "session.cost", Label: "Cost", @@ -392,10 +415,33 @@ func sortByLabel(items []Item) []Item { return items } +// snapshotCommandIDs is the set of IDs that depend on the snapshot feature. +// They are stripped from the palette and the slash-command parser when +// snapshots are turned off. +var snapshotCommandIDs = map[string]bool{ + "session.undo": true, + "session.snapshot": true, + "session.snapshots": true, +} + +// removeByIDs returns items whose IDs are not in ids. +func removeByIDs(items []Item, ids map[string]bool) []Item { + out := make([]Item, 0, len(items)) + for _, item := range items { + if !ids[item.ID] { + out = append(out, item) + } + } + return out +} + // BuildCommandCategories builds the list of command categories for the command palette func BuildCommandCategories(ctx context.Context, application *app.App) []Category { // Get session commands and filter based on model capabilities sessionCommands := builtInSessionCommands() + if !application.SnapshotsEnabled() { + sessionCommands = removeByIDs(sessionCommands, snapshotCommandIDs) + } categories := []Category{ { diff --git a/pkg/tui/commands/commands_test.go b/pkg/tui/commands/commands_test.go index b44928854..5627ef87a 100644 --- a/pkg/tui/commands/commands_test.go +++ b/pkg/tui/commands/commands_test.go @@ -107,6 +107,24 @@ func TestParseSlashCommand_OtherCommands(t *testing.T) { assert.True(t, ok) }) + t.Run("snapshot command", func(t *testing.T) { + t.Parallel() + cmd := parser.Parse("/snapshot") + require.NotNil(t, cmd) + msg := cmd() + _, ok := msg.(messages.ShowSnapshotsDialogMsg) + assert.True(t, ok) + }) + + t.Run("snapshots alias", func(t *testing.T) { + t.Parallel() + cmd := parser.Parse("/snapshots") + require.NotNil(t, cmd) + msg := cmd() + _, ok := msg.(messages.ShowSnapshotsDialogMsg) + assert.True(t, ok) + }) + t.Run("unknown command returns nil", func(t *testing.T) { t.Parallel() cmd := parser.Parse("/unknown") @@ -150,3 +168,39 @@ func TestParseSlashCommand_Compact(t *testing.T) { assert.Equal(t, "focus on the API design", compactMsg.AdditionalPrompt) }) } + +func TestRemoveByIDsDropsSnapshotCommands(t *testing.T) { + t.Parallel() + + items := builtInSessionCommands() + require.NotEmpty(t, items) + + hasID := func(items []Item, id string) bool { + for _, it := range items { + if it.ID == id { + return true + } + } + return false + } + + require.True(t, hasID(items, "session.undo")) + require.True(t, hasID(items, "session.snapshot")) + require.True(t, hasID(items, "session.snapshots")) + + filtered := removeByIDs(items, snapshotCommandIDs) + assert.False(t, hasID(filtered, "session.undo")) + assert.False(t, hasID(filtered, "session.snapshot")) + assert.False(t, hasID(filtered, "session.snapshots")) + // Other commands are untouched. + assert.True(t, hasID(filtered, "session.exit")) + assert.True(t, hasID(filtered, "session.new")) + + // Build a parser that mirrors the disabled-snapshots state and verify + // that the snapshot slash commands no longer resolve. + parser := NewParser(Category{Name: "Session", Commands: filtered}) + assert.Nil(t, parser.Parse("/undo")) + assert.Nil(t, parser.Parse("/snapshot")) + assert.Nil(t, parser.Parse("/snapshots")) + require.NotNil(t, parser.Parse("/exit")) +} diff --git a/pkg/tui/dialog/snapshot.go b/pkg/tui/dialog/snapshot.go new file mode 100644 index 000000000..9f95f1284 --- /dev/null +++ b/pkg/tui/dialog/snapshot.go @@ -0,0 +1,195 @@ +package dialog + +import ( + "fmt" + "strings" + + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + + "github.com/docker/docker-agent/pkg/app" + "github.com/docker/docker-agent/pkg/tui/core" + "github.com/docker/docker-agent/pkg/tui/core/layout" + "github.com/docker/docker-agent/pkg/tui/messages" + "github.com/docker/docker-agent/pkg/tui/styles" +) + +// Layout constants for the snapshots dialog. +const ( + snapshotsDialogWidthPercent = 60 + snapshotsDialogMinWidth = 40 + snapshotsDialogMaxWidth = 70 + snapshotsDialogHeightPercent = 70 + snapshotsDialogMaxHeight = 30 +) + +// snapshotsKeyMap defines the navigation keys for the snapshots dialog. +type snapshotsKeyMap struct { + Up key.Binding + Down key.Binding + Top key.Binding + Bottom key.Binding + Enter key.Binding + Escape key.Binding +} + +func defaultSnapshotsKeyMap() snapshotsKeyMap { + return snapshotsKeyMap{ + Up: key.NewBinding(key.WithKeys("up", "k"), key.WithHelp("↑/k", "up")), + Down: key.NewBinding(key.WithKeys("down", "j"), key.WithHelp("↓/j", "down")), + Top: key.NewBinding(key.WithKeys("home", "g"), key.WithHelp("g", "first")), + Bottom: key.NewBinding(key.WithKeys("end", "G"), key.WithHelp("G", "last")), + Enter: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "reset")), + Escape: key.NewBinding(key.WithKeys("esc", "q"), key.WithHelp("esc", "close")), + } +} + +// snapshotsDialog lists every captured snapshot and lets the user reset the +// workspace to the state of any of them (or to the original pre-agent state). +type snapshotsDialog struct { + BaseDialog + + snapshots []app.SnapshotInfo + keyMap snapshotsKeyMap + selected int // 0 = ; 1..len(snapshots) = snapshot N +} + +// NewSnapshotsDialog creates a snapshots dialog showing every captured +// checkpoint. Pass the snapshots in chronological order (oldest first). +func NewSnapshotsDialog(snapshots []app.SnapshotInfo) Dialog { + return &snapshotsDialog{ + snapshots: snapshots, + keyMap: defaultSnapshotsKeyMap(), + } +} + +func (d *snapshotsDialog) Init() tea.Cmd { return nil } + +func (d *snapshotsDialog) Update(msg tea.Msg) (layout.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + cmd := d.SetSize(msg.Width, msg.Height) + return d, cmd + + case tea.KeyPressMsg: + if cmd := HandleQuit(msg); cmd != nil { + return d, cmd + } + switch { + case key.Matches(msg, d.keyMap.Escape): + return d, core.CmdHandler(CloseDialogMsg{}) + case key.Matches(msg, d.keyMap.Up): + if d.selected > 0 { + d.selected-- + } + return d, nil + case key.Matches(msg, d.keyMap.Down): + if d.selected < d.maxIndex() { + d.selected++ + } + return d, nil + case key.Matches(msg, d.keyMap.Top): + d.selected = 0 + return d, nil + case key.Matches(msg, d.keyMap.Bottom): + d.selected = d.maxIndex() + return d, nil + case key.Matches(msg, d.keyMap.Enter): + if len(d.snapshots) == 0 { + return d, core.CmdHandler(CloseDialogMsg{}) + } + return d, tea.Sequence( + core.CmdHandler(CloseDialogMsg{}), + core.CmdHandler(messages.ResetSnapshotMsg{Keep: d.selected}), + ) + } + } + return d, nil +} + +// maxIndex is the index of the last selectable item. With N snapshots there +// are N+1 selectable items ( + each snapshot). +func (d *snapshotsDialog) maxIndex() int { + return len(d.snapshots) +} + +func (d *snapshotsDialog) Position() (row, col int) { + dialogWidth, maxHeight := d.dialogSize() + return CenterPosition(d.Width(), d.Height(), dialogWidth, maxHeight) +} + +func (d *snapshotsDialog) dialogSize() (dialogWidth, dialogHeight int) { + dialogWidth = d.ComputeDialogWidth(snapshotsDialogWidthPercent, snapshotsDialogMinWidth, snapshotsDialogMaxWidth) + dialogHeight = min(d.Height()*snapshotsDialogHeightPercent/100, snapshotsDialogMaxHeight) + return dialogWidth, dialogHeight +} + +func (d *snapshotsDialog) View() string { + dialogWidth, _ := d.dialogSize() + contentWidth := d.ContentWidth(dialogWidth, 2) + + builder := NewContent(contentWidth). + AddTitle("Snapshots"). + AddSeparator(). + AddSpace() + + if len(d.snapshots) == 0 { + empty := styles.DialogContentStyle.Italic(true). + Foreground(styles.TextMuted). + Width(contentWidth). + Align(lipgloss.Center). + Render("No snapshots taken yet.") + content := builder. + AddContent(empty). + AddSpace(). + AddHelpKeys("esc", "close"). + Build() + return styles.DialogStyle.Width(dialogWidth).Render(content) + } + + count := fmt.Sprintf("%d snapshot", len(d.snapshots)) + if len(d.snapshots) != 1 { + count += "s" + } + count += " captured" + + content := builder. + AddContent(styles.DialogOptionsStyle.Width(contentWidth).Render(count)). + AddSpace(). + AddContent(d.renderList(contentWidth)). + AddSpace(). + AddHelpKeys("↑/↓", "navigate", "enter", "reset", "esc", "close"). + Build() + + return styles.DialogStyle.Width(dialogWidth).Render(content) +} + +// renderList renders every selectable item ( + each snapshot). +func (d *snapshotsDialog) renderList(contentWidth int) string { + rows := make([]string, 0, len(d.snapshots)+1) + rows = append(rows, d.renderRow("", "restore the initial state", d.selected == 0, contentWidth)) + for i, snap := range d.snapshots { + filesLabel := fmt.Sprintf("%d file", snap.Files) + if snap.Files != 1 { + filesLabel += "s" + } + label := fmt.Sprintf("Snapshot %d", i+1) + rows = append(rows, d.renderRow(label, filesLabel, i+1 == d.selected, contentWidth)) + } + return lipgloss.JoinVertical(lipgloss.Left, rows...) +} + +// renderRow renders a single item line with name on the left and description +// on the right, highlighting the row when selected. +func (d *snapshotsDialog) renderRow(name, desc string, selected bool, width int) string { + nameStyle, descStyle := styles.PaletteUnselectedActionStyle, styles.PaletteUnselectedDescStyle + if selected { + nameStyle, descStyle = styles.PaletteSelectedActionStyle, styles.PaletteSelectedDescStyle + } + + left := nameStyle.Render(" " + name + " ") + right := descStyle.Render(" " + desc + " ") + gap := max(1, width-lipgloss.Width(left)-lipgloss.Width(right)) + return left + descStyle.Render(strings.Repeat(" ", gap)) + right +} diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index a3cc10ea9..1f5047ddc 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -263,6 +263,33 @@ func (m *appModel) handleUndoSnapshot() (tea.Model, tea.Cmd) { return m, notification.SuccessCmd(text) } +func (m *appModel) handleShowSnapshotsDialog() (tea.Model, tea.Cmd) { + snapshots := m.application.ListSnapshots() + return m, core.CmdHandler(dialog.OpenDialogMsg{ + Model: dialog.NewSnapshotsDialog(snapshots), + }) +} + +func (m *appModel) handleResetSnapshot(keep int) (tea.Model, tea.Cmd) { + if m.chatPage.IsWorking() { + return m, notification.WarningCmd("Wait for the current response to finish before resetting") + } + result, err := m.application.ResetSnapshot(context.Background(), keep) + if err != nil { + if errors.Is(err, app.ErrNothingToUndo) { + return m, notification.InfoCmd("Nothing to reset") + } + return m, notification.ErrorCmd(fmt.Sprintf("Failed to reset snapshot: %v", err)) + } + + target := "the original state" + if keep > 0 { + target = fmt.Sprintf("snapshot %d", keep) + } + text := fmt.Sprintf("Restored %d file%s to %s", result.RestoredFiles, plural(result.RestoredFiles), target) + return m, notification.SuccessCmd(text) +} + func plural(n int) string { if n == 1 { return "" diff --git a/pkg/tui/messages/session.go b/pkg/tui/messages/session.go index 430798001..616bf58fb 100644 --- a/pkg/tui/messages/session.go +++ b/pkg/tui/messages/session.go @@ -50,6 +50,15 @@ type ( // UndoSnapshotMsg restores files from the latest snapshot. UndoSnapshotMsg struct{} + // ShowSnapshotsDialogMsg requests opening the snapshots dialog. + ShowSnapshotsDialogMsg struct{} + + // ResetSnapshotMsg requests restoring the workspace to a snapshot. + // Keep is the number of snapshots to retain in chronological order: + // 0 reverts every snapshot (back to the original pre-agent state), + // N keeps snapshots 1..N and reverts any later ones. + ResetSnapshotMsg struct{ Keep int } + // ExportSessionMsg exports the session to the specified file. ExportSessionMsg struct{ Filename string } diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index ba62c0be7..08691c88f 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -823,6 +823,12 @@ func (m *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case messages.UndoSnapshotMsg: return m.handleUndoSnapshot() + case messages.ShowSnapshotsDialogMsg: + return m.handleShowSnapshotsDialog() + + case messages.ResetSnapshotMsg: + return m.handleResetSnapshot(msg.Keep) + case messages.EvalSessionMsg: return m.handleEvalSession(msg.Filename) From 6203b457908c94b968ece314983d96551f045877 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 11:46:14 +0200 Subject: [PATCH 2/3] rename /snapshot to /snapshots and use r to restore --- docs/features/tui/index.md | 6 +++--- pkg/tui/commands/commands.go | 15 +-------------- pkg/tui/commands/commands_test.go | 14 +------------- pkg/tui/dialog/snapshot.go | 30 +++++++++++++++--------------- 4 files changed, 20 insertions(+), 45 deletions(-) diff --git a/docs/features/tui/index.md b/docs/features/tui/index.md index 196e18d00..ad6a6b36f 100644 --- a/docs/features/tui/index.md +++ b/docs/features/tui/index.md @@ -41,7 +41,7 @@ Type `/` during a session to see available commands, or press Ctrl+Ctrl+↑/ (or j/k) to highlight an entry, then press Enter to reset the workspace to that point. Pick `` to revert every snapshot and bring the workspace back to its pre-agent state. Esc closes the dialog without changing anything. (`/snapshots` is accepted as an alias.) +- **`/snapshots`** opens a dialog showing how many snapshots have been captured and the number of files in each one. Use / (or j/k) to highlight an entry, then press r to reset the workspace to that point. Pick `` to revert every snapshot and bring the workspace back to its pre-agent state. Esc closes the dialog without changing anything. Neither command removes messages from the session transcript — they only touch files on disk. Both commands (and the matching command-palette entries) are hidden when snapshots are turned off. Omit `snapshot` or set it to `false` to leave automatic snapshots off; agents can still configure snapshot hooks manually. diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 4aef75cfb..25002009f 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -106,23 +106,11 @@ func builtInSessionCommands() []Item { return core.CmdHandler(messages.UndoSnapshotMsg{}) }, }, - { - ID: "session.snapshot", - Label: "Snapshot", - SlashCommand: "/snapshot", - Description: "List captured snapshots and reset to one", - Category: "Session", - Immediate: true, - Execute: func(string) tea.Cmd { - return core.CmdHandler(messages.ShowSnapshotsDialogMsg{}) - }, - }, { ID: "session.snapshots", Label: "Snapshots", SlashCommand: "/snapshots", - Hidden: true, - Description: "Alias for /snapshot", + Description: "List captured snapshots", Category: "Session", Immediate: true, Execute: func(string) tea.Cmd { @@ -420,7 +408,6 @@ func sortByLabel(items []Item) []Item { // snapshots are turned off. var snapshotCommandIDs = map[string]bool{ "session.undo": true, - "session.snapshot": true, "session.snapshots": true, } diff --git a/pkg/tui/commands/commands_test.go b/pkg/tui/commands/commands_test.go index 5627ef87a..74c751751 100644 --- a/pkg/tui/commands/commands_test.go +++ b/pkg/tui/commands/commands_test.go @@ -107,16 +107,7 @@ func TestParseSlashCommand_OtherCommands(t *testing.T) { assert.True(t, ok) }) - t.Run("snapshot command", func(t *testing.T) { - t.Parallel() - cmd := parser.Parse("/snapshot") - require.NotNil(t, cmd) - msg := cmd() - _, ok := msg.(messages.ShowSnapshotsDialogMsg) - assert.True(t, ok) - }) - - t.Run("snapshots alias", func(t *testing.T) { + t.Run("snapshots command", func(t *testing.T) { t.Parallel() cmd := parser.Parse("/snapshots") require.NotNil(t, cmd) @@ -185,12 +176,10 @@ func TestRemoveByIDsDropsSnapshotCommands(t *testing.T) { } require.True(t, hasID(items, "session.undo")) - require.True(t, hasID(items, "session.snapshot")) require.True(t, hasID(items, "session.snapshots")) filtered := removeByIDs(items, snapshotCommandIDs) assert.False(t, hasID(filtered, "session.undo")) - assert.False(t, hasID(filtered, "session.snapshot")) assert.False(t, hasID(filtered, "session.snapshots")) // Other commands are untouched. assert.True(t, hasID(filtered, "session.exit")) @@ -200,7 +189,6 @@ func TestRemoveByIDsDropsSnapshotCommands(t *testing.T) { // that the snapshot slash commands no longer resolve. parser := NewParser(Category{Name: "Session", Commands: filtered}) assert.Nil(t, parser.Parse("/undo")) - assert.Nil(t, parser.Parse("/snapshot")) assert.Nil(t, parser.Parse("/snapshots")) require.NotNil(t, parser.Parse("/exit")) } diff --git a/pkg/tui/dialog/snapshot.go b/pkg/tui/dialog/snapshot.go index 9f95f1284..713590ddb 100644 --- a/pkg/tui/dialog/snapshot.go +++ b/pkg/tui/dialog/snapshot.go @@ -26,22 +26,22 @@ const ( // snapshotsKeyMap defines the navigation keys for the snapshots dialog. type snapshotsKeyMap struct { - Up key.Binding - Down key.Binding - Top key.Binding - Bottom key.Binding - Enter key.Binding - Escape key.Binding + Up key.Binding + Down key.Binding + Top key.Binding + Bottom key.Binding + Restore key.Binding + Escape key.Binding } func defaultSnapshotsKeyMap() snapshotsKeyMap { return snapshotsKeyMap{ - Up: key.NewBinding(key.WithKeys("up", "k"), key.WithHelp("↑/k", "up")), - Down: key.NewBinding(key.WithKeys("down", "j"), key.WithHelp("↓/j", "down")), - Top: key.NewBinding(key.WithKeys("home", "g"), key.WithHelp("g", "first")), - Bottom: key.NewBinding(key.WithKeys("end", "G"), key.WithHelp("G", "last")), - Enter: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "reset")), - Escape: key.NewBinding(key.WithKeys("esc", "q"), key.WithHelp("esc", "close")), + Up: key.NewBinding(key.WithKeys("up", "k"), key.WithHelp("↑/k", "up")), + Down: key.NewBinding(key.WithKeys("down", "j"), key.WithHelp("↓/j", "down")), + Top: key.NewBinding(key.WithKeys("home", "g"), key.WithHelp("g", "first")), + Bottom: key.NewBinding(key.WithKeys("end", "G"), key.WithHelp("G", "last")), + Restore: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "restore")), + Escape: key.NewBinding(key.WithKeys("esc", "q"), key.WithHelp("esc", "close")), } } @@ -95,9 +95,9 @@ func (d *snapshotsDialog) Update(msg tea.Msg) (layout.Model, tea.Cmd) { case key.Matches(msg, d.keyMap.Bottom): d.selected = d.maxIndex() return d, nil - case key.Matches(msg, d.keyMap.Enter): + case key.Matches(msg, d.keyMap.Restore): if len(d.snapshots) == 0 { - return d, core.CmdHandler(CloseDialogMsg{}) + return d, nil } return d, tea.Sequence( core.CmdHandler(CloseDialogMsg{}), @@ -159,7 +159,7 @@ func (d *snapshotsDialog) View() string { AddSpace(). AddContent(d.renderList(contentWidth)). AddSpace(). - AddHelpKeys("↑/↓", "navigate", "enter", "reset", "esc", "close"). + AddHelpKeys("↑/↓", "navigate", "r", "restore", "esc", "close"). Build() return styles.DialogStyle.Width(dialogWidth).Render(content) From 12cb56156de2f7fef096e8c435844cd6d1b47186 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 12:12:54 +0200 Subject: [PATCH 3/3] simplify snapshot plumbing --- pkg/app/app_test.go | 15 +++ pkg/app/undo.go | 97 ++++++-------- pkg/hooks/builtins/snapshot.go | 69 ++++------ pkg/runtime/snapshot.go | 25 ++-- pkg/tui/dialog/snapshot.go | 222 ++++++++++++++------------------- 5 files changed, 188 insertions(+), 240 deletions(-) diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index 93f79d6e4..a1ded0fd9 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" @@ -80,6 +81,11 @@ func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil } func (m *mockRuntime) UndoLastSnapshot(context.Context, *session.Session) (int, bool, error) { return m.undoFiles, m.undoOK, m.undoErr } +func (m *mockRuntime) SnapshotsEnabled() bool { return true } +func (m *mockRuntime) ListSnapshots(*session.Session) []builtins.SnapshotInfo { return nil } +func (m *mockRuntime) ResetSnapshot(context.Context, *session.Session, int) (int, bool, error) { + return m.undoFiles, m.undoOK, m.undoErr +} // Verify mockRuntime implements runtime.Runtime var _ runtime.Runtime = (*mockRuntime)(nil) @@ -254,6 +260,15 @@ func TestApp_UndoLastSnapshot_NoSnapshot(t *testing.T) { assert.ErrorIs(t, err, ErrNothingToUndo) } +func TestApp_SnapshotsEnabled_DoesNotRequireSession(t *testing.T) { + t.Parallel() + + // SnapshotsEnabled answers a runtime-capability question; it must not + // silently return false just because no session is attached. + app := &App{runtime: &mockRuntime{}, session: nil} + assert.True(t, app.SnapshotsEnabled()) +} + func TestApp_RegenerateSessionTitle(t *testing.T) { t.Parallel() diff --git a/pkg/app/undo.go b/pkg/app/undo.go index 082c0d1f4..8c8b86cd4 100644 --- a/pkg/app/undo.go +++ b/pkg/app/undo.go @@ -15,89 +15,72 @@ type UndoSnapshotResult struct { RestoredFiles int } -// SnapshotInfo summarises one snapshot checkpoint for display. -type SnapshotInfo struct { - // Files is the number of files captured in the checkpoint. - Files int -} - -type snapshotUndoer interface { +// snapshotRuntime is the subset of the runtime API that the App needs to +// drive snapshot commands. Runtimes that don't capture snapshots (e.g. +// remote runtimes) simply don't implement this interface and the related +// commands are then disabled in the UI. +type snapshotRuntime interface { + SnapshotsEnabled() bool UndoLastSnapshot(ctx context.Context, sess *session.Session) (files int, ok bool, err error) -} - -type snapshotLister interface { ListSnapshots(sess *session.Session) []builtins.SnapshotInfo -} - -type snapshotResetter interface { ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (files int, ok bool, err error) } -type snapshotsEnabledChecker interface { - SnapshotsEnabled() bool +// snapshotRuntime returns the runtime's snapshot interface, or nil when the +// runtime doesn't support snapshots at all (e.g. remote runtimes). +func (a *App) snapshotRuntime() snapshotRuntime { + r, _ := a.runtime.(snapshotRuntime) + return r } // SnapshotsEnabled reports whether automatic shadow-git snapshots are active -// for the current runtime. Returns false for runtimes that don't support -// snapshots at all (e.g. remote runtimes) or when the feature is turned off -// in the user config. +// for the current runtime. The answer doesn't depend on having an active +// session: it's a runtime/configuration capability check. func (a *App) SnapshotsEnabled() bool { - checker, ok := a.runtime.(snapshotsEnabledChecker) - return ok && checker.SnapshotsEnabled() + r := a.snapshotRuntime() + return r != nil && r.SnapshotsEnabled() } +// UndoLastSnapshot restores the files captured in the most recent snapshot. func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) { - if a.session == nil { - return UndoSnapshotResult{}, ErrNothingToUndo - } - undoer, ok := a.runtime.(snapshotUndoer) - if !ok { - return UndoSnapshotResult{}, ErrNothingToUndo - } - files, ok, err := undoer.UndoLastSnapshot(ctx, a.session) - if err != nil { - return UndoSnapshotResult{}, fmt.Errorf("restoring snapshot: %w", err) - } - if !ok { + r := a.snapshotRuntime() + if r == nil || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } - return UndoSnapshotResult{RestoredFiles: files}, nil + return snapshotResult(r.UndoLastSnapshot(ctx, a.session)) } -// ListSnapshots returns the snapshot checkpoints recorded for the current -// session in chronological order (oldest first). Returns nil when the runtime -// does not support snapshots or none have been captured yet. -func (a *App) ListSnapshots() []SnapshotInfo { - if a.session == nil { - return nil - } - lister, ok := a.runtime.(snapshotLister) - if !ok { +// ListSnapshots returns the file count of every snapshot captured during the +// current session, oldest first. Returns nil when no snapshots exist or when +// the runtime doesn't support them. +func (a *App) ListSnapshots() []int { + r := a.snapshotRuntime() + if r == nil || a.session == nil { return nil } - raw := lister.ListSnapshots(a.session) - if len(raw) == 0 { - return nil - } - out := make([]SnapshotInfo, len(raw)) - for i, s := range raw { - out[i] = SnapshotInfo{Files: s.Files} + infos := r.ListSnapshots(a.session) + counts := make([]int, len(infos)) + for i, info := range infos { + counts[i] = info.Files } - return out + return counts } // ResetSnapshot reverts every checkpoint past index keep so the workspace // returns to the state captured at that snapshot. keep == 0 resets to the -// original pre-agent state. Returns ErrNothingToUndo when nothing changes. +// original pre-agent state. func (a *App) ResetSnapshot(ctx context.Context, keep int) (UndoSnapshotResult, error) { - if a.session == nil { + r := a.snapshotRuntime() + if r == nil || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } - resetter, ok := a.runtime.(snapshotResetter) - if !ok { - return UndoSnapshotResult{}, ErrNothingToUndo - } - files, ok, err := resetter.ResetSnapshot(ctx, a.session, keep) + return snapshotResult(r.ResetSnapshot(ctx, a.session, keep)) +} + +// snapshotResult adapts the (files, ok, err) tuple returned by snapshot +// operations into the UndoSnapshotResult / ErrNothingToUndo shape callers +// expect. +func snapshotResult(files int, ok bool, err error) (UndoSnapshotResult, error) { if err != nil { return UndoSnapshotResult{}, fmt.Errorf("restoring snapshot: %w", err) } diff --git a/pkg/hooks/builtins/snapshot.go b/pkg/hooks/builtins/snapshot.go index 33b90fcc7..fb437cbab 100644 --- a/pkg/hooks/builtins/snapshot.go +++ b/pkg/hooks/builtins/snapshot.go @@ -216,45 +216,22 @@ func (b *snapshotBuiltin) listSnapshots(sessionID string) []SnapshotInfo { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] - if s == nil || len(s.history) == 0 { + if s == nil { return nil } out := make([]SnapshotInfo, len(s.history)) for i, c := range s.history { - out[i] = SnapshotInfo{Files: countUnique(c.files)} + out[i] = SnapshotInfo{Files: len(c.files)} } return out } -// truncateAfter pops and returns checkpoints with index >= keep, leaving the -// surviving prefix in the session history. keep is clamped into [0, len]. -func (b *snapshotBuiltin) truncateAfter(sessionID string, keep int) []snapshotCheckpoint { - b.mu.Lock() - defer b.mu.Unlock() - s := b.session[sessionID] - if s == nil || len(s.history) == 0 { - return nil - } - if keep < 0 { - keep = 0 - } - if keep >= len(s.history) { - return nil - } - tail := append([]snapshotCheckpoint(nil), s.history[keep:]...) - for i := keep; i < len(s.history); i++ { - s.history[i] = snapshotCheckpoint{} - } - s.history = s.history[:keep] - return tail -} - // resetSnapshot reverts every checkpoint with index >= keep so the workspace // returns to the state captured at snapshot keep. keep == 0 means "reset to // the original state". A keep value greater than or equal to the snapshot // count is a no-op. Reverted checkpoints are dropped from the session history. func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { - tail := b.truncateAfter(sessionID, keep) + tail := b.popHistoryTail(sessionID, keep) if len(tail) == 0 { return 0, false, nil } @@ -262,33 +239,41 @@ func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd stri if err != nil { return 0, true, err } - patches := make([]snapshot.Patch, 0, len(tail)) + patches := make([]snapshot.Patch, len(tail)) seen := map[string]bool{} - unique := 0 - for _, c := range tail { - patches = append(patches, snapshot.Patch{Hash: c.hash, Files: c.files}) + for i, c := range tail { + patches[i] = snapshot.Patch{Hash: c.hash, Files: c.files} for _, f := range c.files { - if !seen[f] { - seen[f] = true - unique++ - } + seen[f] = true } } if err := repo.Revert(ctx, patches); err != nil { return 0, true, err } - return unique, true, nil + return len(seen), true, nil } -func countUnique(files []string) int { - if len(files) == 0 { - return 0 +// popHistoryTail removes and returns checkpoints with index >= keep, leaving +// the surviving prefix in the session history. keep is clamped to [0, len]. +// The popped slots in the backing array are zeroed so the dropped file lists +// can be garbage-collected before the slice grows past them again. +func (b *snapshotBuiltin) popHistoryTail(sessionID string, keep int) []snapshotCheckpoint { + b.mu.Lock() + defer b.mu.Unlock() + s := b.session[sessionID] + if s == nil { + return nil } - seen := make(map[string]bool, len(files)) - for _, f := range files { - seen[f] = true + if keep < 0 { + keep = 0 + } + if keep >= len(s.history) { + return nil } - return len(seen) + tail := append([]snapshotCheckpoint(nil), s.history[keep:]...) + clear(s.history[keep:]) + s.history = s.history[:keep] + return tail } func logPatch(ctx context.Context, scope, sessionID, label string, patch snapshot.Patch, after string) { diff --git a/pkg/runtime/snapshot.go b/pkg/runtime/snapshot.go index 47dc419b6..b80beafce 100644 --- a/pkg/runtime/snapshot.go +++ b/pkg/runtime/snapshot.go @@ -23,10 +23,7 @@ func (r *LocalRuntime) SnapshotsEnabled() bool { } // UndoLastSnapshot restores files recorded for the latest completed snapshot hook checkpoint. -func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Session) (files int, ok bool, err error) { - if r == nil || sess == nil { - return 0, false, nil - } +func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Session) (int, bool, error) { cwd := r.snapshotCwd(sess) if cwd == "" { return 0, false, nil @@ -46,10 +43,7 @@ func (r *LocalRuntime) ListSnapshots(sess *session.Session) []builtins.SnapshotI // ResetSnapshot reverts every checkpoint past index keep so the workspace // returns to the state captured at that snapshot. keep == 0 resets to the // original (pre-agent) state. -func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (files int, ok bool, err error) { - if r == nil || sess == nil { - return 0, false, nil - } +func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, keep int) (int, bool, error) { cwd := r.snapshotCwd(sess) if cwd == "" { return 0, false, nil @@ -58,14 +52,17 @@ func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, } // snapshotCwd resolves the working directory used to open the shadow -// repository for snapshot operations. +// repository for snapshot operations. Returns "" when no candidate is usable. func (r *LocalRuntime) snapshotCwd(sess *session.Session) string { - cwd := sess.WorkingDir - if cwd == "" { - cwd = r.workingDir + if r == nil || sess == nil { + return "" } - if cwd == "" { - cwd, _ = os.Getwd() + if sess.WorkingDir != "" { + return sess.WorkingDir + } + if r.workingDir != "" { + return r.workingDir } + cwd, _ := os.Getwd() return cwd } diff --git a/pkg/tui/dialog/snapshot.go b/pkg/tui/dialog/snapshot.go index 713590ddb..01c26f61c 100644 --- a/pkg/tui/dialog/snapshot.go +++ b/pkg/tui/dialog/snapshot.go @@ -2,66 +2,38 @@ package dialog import ( "fmt" - "strings" - "charm.land/bubbles/v2/key" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" - "github.com/docker/docker-agent/pkg/app" "github.com/docker/docker-agent/pkg/tui/core" "github.com/docker/docker-agent/pkg/tui/core/layout" "github.com/docker/docker-agent/pkg/tui/messages" "github.com/docker/docker-agent/pkg/tui/styles" ) -// Layout constants for the snapshots dialog. const ( - snapshotsDialogWidthPercent = 60 - snapshotsDialogMinWidth = 40 - snapshotsDialogMaxWidth = 70 - snapshotsDialogHeightPercent = 70 - snapshotsDialogMaxHeight = 30 + snapshotsDialogWidthPercent = 60 + snapshotsDialogMinWidth = 40 + snapshotsDialogMaxWidth = 70 ) -// snapshotsKeyMap defines the navigation keys for the snapshots dialog. -type snapshotsKeyMap struct { - Up key.Binding - Down key.Binding - Top key.Binding - Bottom key.Binding - Restore key.Binding - Escape key.Binding -} - -func defaultSnapshotsKeyMap() snapshotsKeyMap { - return snapshotsKeyMap{ - Up: key.NewBinding(key.WithKeys("up", "k"), key.WithHelp("↑/k", "up")), - Down: key.NewBinding(key.WithKeys("down", "j"), key.WithHelp("↓/j", "down")), - Top: key.NewBinding(key.WithKeys("home", "g"), key.WithHelp("g", "first")), - Bottom: key.NewBinding(key.WithKeys("end", "G"), key.WithHelp("G", "last")), - Restore: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "restore")), - Escape: key.NewBinding(key.WithKeys("esc", "q"), key.WithHelp("esc", "close")), - } -} - // snapshotsDialog lists every captured snapshot and lets the user reset the -// workspace to the state of any of them (or to the original pre-agent state). +// workspace to one of them (or to the original pre-agent state). type snapshotsDialog struct { BaseDialog - snapshots []app.SnapshotInfo - keyMap snapshotsKeyMap - selected int // 0 = ; 1..len(snapshots) = snapshot N + // fileCounts holds the number of files captured in each snapshot, oldest + // first. An empty slice puts the dialog in its empty state. + fileCounts []int + // selected is the highlighted entry. 0 = , N = snapshot N. + selected int } -// NewSnapshotsDialog creates a snapshots dialog showing every captured -// checkpoint. Pass the snapshots in chronological order (oldest first). -func NewSnapshotsDialog(snapshots []app.SnapshotInfo) Dialog { - return &snapshotsDialog{ - snapshots: snapshots, - keyMap: defaultSnapshotsKeyMap(), - } +// NewSnapshotsDialog creates a snapshots dialog. fileCounts must be in +// chronological order (oldest first). +func NewSnapshotsDialog(fileCounts []int) Dialog { + return &snapshotsDialog{fileCounts: fileCounts} } func (d *snapshotsDialog) Init() tea.Cmd { return nil } @@ -76,120 +48,116 @@ func (d *snapshotsDialog) Update(msg tea.Msg) (layout.Model, tea.Cmd) { if cmd := HandleQuit(msg); cmd != nil { return d, cmd } - switch { - case key.Matches(msg, d.keyMap.Escape): - return d, core.CmdHandler(CloseDialogMsg{}) - case key.Matches(msg, d.keyMap.Up): - if d.selected > 0 { - d.selected-- - } - return d, nil - case key.Matches(msg, d.keyMap.Down): - if d.selected < d.maxIndex() { - d.selected++ - } - return d, nil - case key.Matches(msg, d.keyMap.Top): - d.selected = 0 - return d, nil - case key.Matches(msg, d.keyMap.Bottom): - d.selected = d.maxIndex() - return d, nil - case key.Matches(msg, d.keyMap.Restore): - if len(d.snapshots) == 0 { - return d, nil - } - return d, tea.Sequence( - core.CmdHandler(CloseDialogMsg{}), - core.CmdHandler(messages.ResetSnapshotMsg{Keep: d.selected}), - ) - } + cmd := d.handleKey(msg) + return d, cmd } return d, nil } -// maxIndex is the index of the last selectable item. With N snapshots there -// are N+1 selectable items ( + each snapshot). -func (d *snapshotsDialog) maxIndex() int { - return len(d.snapshots) +func (d *snapshotsDialog) handleKey(msg tea.KeyPressMsg) tea.Cmd { + switch msg.String() { + case "esc", "q": + return core.CmdHandler(CloseDialogMsg{}) + case "up", "k": + if d.selected > 0 { + d.selected-- + } + case "down", "j": + if d.selected < len(d.fileCounts) { + d.selected++ + } + case "home", "g": + d.selected = 0 + case "end", "G": + d.selected = len(d.fileCounts) + case "r": + if len(d.fileCounts) == 0 { + return nil + } + return tea.Sequence( + core.CmdHandler(CloseDialogMsg{}), + core.CmdHandler(messages.ResetSnapshotMsg{Keep: d.selected}), + ) + } + return nil } func (d *snapshotsDialog) Position() (row, col int) { - dialogWidth, maxHeight := d.dialogSize() - return CenterPosition(d.Width(), d.Height(), dialogWidth, maxHeight) -} - -func (d *snapshotsDialog) dialogSize() (dialogWidth, dialogHeight int) { - dialogWidth = d.ComputeDialogWidth(snapshotsDialogWidthPercent, snapshotsDialogMinWidth, snapshotsDialogMaxWidth) - dialogHeight = min(d.Height()*snapshotsDialogHeightPercent/100, snapshotsDialogMaxHeight) - return dialogWidth, dialogHeight + return d.CenterDialog(d.View()) } func (d *snapshotsDialog) View() string { - dialogWidth, _ := d.dialogSize() - contentWidth := d.ContentWidth(dialogWidth, 2) + width := d.ComputeDialogWidth(snapshotsDialogWidthPercent, snapshotsDialogMinWidth, snapshotsDialogMaxWidth) + inner := d.ContentWidth(width, 2) - builder := NewContent(contentWidth). - AddTitle("Snapshots"). - AddSeparator(). - AddSpace() + content := NewContent(inner).AddTitle("Snapshots").AddSeparator().AddSpace() - if len(d.snapshots) == 0 { - empty := styles.DialogContentStyle.Italic(true). - Foreground(styles.TextMuted). - Width(contentWidth). - Align(lipgloss.Center). - Render("No snapshots taken yet.") - content := builder. - AddContent(empty). - AddSpace(). - AddHelpKeys("esc", "close"). - Build() - return styles.DialogStyle.Width(dialogWidth).Render(content) + if len(d.fileCounts) > 0 { + count := pluralize(len(d.fileCounts), "snapshot", "snapshots") + " captured" + content = content. + AddContent(styles.DialogOptionsStyle.Width(inner).Render(count)). + AddSpace() } - count := fmt.Sprintf("%d snapshot", len(d.snapshots)) - if len(d.snapshots) != 1 { - count += "s" - } - count += " captured" - - content := builder. - AddContent(styles.DialogOptionsStyle.Width(contentWidth).Render(count)). + body := content. + AddContent(d.bodyContent(inner)). AddSpace(). - AddContent(d.renderList(contentWidth)). - AddSpace(). - AddHelpKeys("↑/↓", "navigate", "r", "restore", "esc", "close"). + AddHelpKeys(d.helpKeys()...). Build() - return styles.DialogStyle.Width(dialogWidth).Render(content) + return styles.DialogStyle.Width(width).Render(body) } -// renderList renders every selectable item ( + each snapshot). -func (d *snapshotsDialog) renderList(contentWidth int) string { - rows := make([]string, 0, len(d.snapshots)+1) - rows = append(rows, d.renderRow("", "restore the initial state", d.selected == 0, contentWidth)) - for i, snap := range d.snapshots { - filesLabel := fmt.Sprintf("%d file", snap.Files) - if snap.Files != 1 { - filesLabel += "s" - } - label := fmt.Sprintf("Snapshot %d", i+1) - rows = append(rows, d.renderRow(label, filesLabel, i+1 == d.selected, contentWidth)) +// bodyContent returns either the empty-state line or the snapshot list, +// depending on whether any snapshots were captured. +func (d *snapshotsDialog) bodyContent(inner int) string { + if len(d.fileCounts) == 0 { + return styles.DialogContentStyle. + Italic(true). + Foreground(styles.TextMuted). + Width(inner). + Align(lipgloss.Center). + Render("No snapshots taken yet.") + } + + rows := make([]string, 0, len(d.fileCounts)+1) + rows = append(rows, d.renderRow("", "restore the initial state", d.selected == 0, inner)) + for i, count := range d.fileCounts { + rows = append(rows, d.renderRow( + fmt.Sprintf("Snapshot %d", i+1), + pluralize(count, "file", "files"), + d.selected == i+1, + inner, + )) } return lipgloss.JoinVertical(lipgloss.Left, rows...) } -// renderRow renders a single item line with name on the left and description -// on the right, highlighting the row when selected. +func (d *snapshotsDialog) helpKeys() []string { + if len(d.fileCounts) == 0 { + return []string{"esc", "close"} + } + return []string{"↑/↓", "navigate", "r", "restore", "esc", "close"} +} + +// renderRow draws a single list entry with the name on the left and a short +// description right-aligned within width. func (d *snapshotsDialog) renderRow(name, desc string, selected bool, width int) string { nameStyle, descStyle := styles.PaletteUnselectedActionStyle, styles.PaletteUnselectedDescStyle if selected { nameStyle, descStyle = styles.PaletteSelectedActionStyle, styles.PaletteSelectedDescStyle } - left := nameStyle.Render(" " + name + " ") right := descStyle.Render(" " + desc + " ") - gap := max(1, width-lipgloss.Width(left)-lipgloss.Width(right)) - return left + descStyle.Render(strings.Repeat(" ", gap)) + right + gap := max(0, width-lipgloss.Width(left)) + return left + lipgloss.PlaceHorizontal(gap, lipgloss.Right, right, + lipgloss.WithWhitespaceStyle(descStyle)) +} + +func pluralize(n int, singular, plural string) string { + word := plural + if n == 1 { + word = singular + } + return fmt.Sprintf("%d %s", n, word) }