Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion internal/actions/flatten/flatten.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,18 @@ func Action(ctx *app.Context, opts Options, handler Handler) error {
// Execute the flatten
handler.OnStep(StepFlattening, basehandler.StatusStarted, "Moving branches...")

oldUpstreamByBranch := make(map[string]string, len(filteredPlan.RebaseSpecs))
for _, spec := range filteredPlan.RebaseSpecs {
oldUpstreamByBranch[spec.Branch] = spec.OldUpstream
}

// Update parent pointers for all planned moves
for _, move := range filteredPlan.Moves {
moveBranch := eng.GetBranch(move.Branch)
newParentBranch := eng.GetBranch(move.NewParent)
oldUpstream := oldUpstreamByBranch[move.Branch]

if err := eng.ReparentBranch(gctx, moveBranch, newParentBranch); err != nil {
if err := eng.ReparentBranchWithRevision(gctx, moveBranch, newParentBranch, oldUpstream); err != nil {
handler.OnStep(StepFlattening, basehandler.StatusFailed, err.Error())
return fmt.Errorf("failed to set parent for %s to %s: %w", move.Branch, move.NewParent, err)
}
Expand Down
57 changes: 54 additions & 3 deletions internal/engine/engine_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,48 @@ func TestSetParent(t *testing.T) {
})
}

func TestReparentBranch(t *testing.T) {
t.Parallel()

t.Run("falls back to merge-base when recorded divergence is missing and old parent has advanced", func(t *testing.T) {
t.Parallel()
s := scenario.NewScenario(t, testhelpers.BasicSceneSetup)

// Create main -> branch1 -> branch2, then advance branch1 without restacking branch2.
s.CreateBranch("branch1").
Commit("branch1 change").
CreateBranch("branch2").
Commit("branch2 change").
Checkout("branch1").
Commit("branch1 follow-up").
Checkout("main")

require.NoError(t, s.Engine.TrackBranch(context.Background(), "branch1", "main"))
require.NoError(t, s.Engine.TrackBranch(context.Background(), "branch2", "branch1"))

impl := s.Engine.(interface{ Git() git.Runner })
meta, err := impl.Git().ReadMetadata("branch2")
require.NoError(t, err)

// Simulate stale metadata by clearing the recorded divergence point.
require.NoError(t, impl.Git().WriteMetadata("branch2", meta.WithParentBranchRevision(nil)))

err = s.Engine.ReparentBranch(context.Background(), s.Engine.GetBranch("branch2"), s.Engine.GetBranch("main"))
require.NoError(t, err)

updatedMeta, err := impl.Git().ReadMetadata("branch2")
require.NoError(t, err)
require.NotNil(t, updatedMeta.GetParentBranchName())
require.Equal(t, "main", *updatedMeta.GetParentBranchName())
require.NotNil(t, updatedMeta.GetParentBranchRevision())

expectedRev, err := impl.Git().GetMergeBase("branch2", "main")
require.NoError(t, err)
require.Equal(t, expectedRev, *updatedMeta.GetParentBranchRevision())
require.Equal(t, "main", s.Engine.GetBranch("branch2").GetParent().GetName())
})
}

func TestDeleteBranch(t *testing.T) {
t.Parallel()
t.Run("deletes branch and updates children", func(t *testing.T) {
Expand Down Expand Up @@ -270,18 +312,22 @@ func TestDeleteBranch(t *testing.T) {
require.Contains(t, mainChildNames, "C3")
})

t.Run("returns error when child cannot be reparented", func(t *testing.T) {
t.Run("continues reparenting later children when one child cannot be reparented", func(t *testing.T) {
t.Parallel()
s := scenario.NewScenario(t, testhelpers.BasicSceneSetup)

// Create main -> branch1 -> branch2 and track both.
// Create main -> branch1 -> [branch2, branch3] and track all three.
s.CreateBranch("branch1").
Commit("branch1 change").
CreateBranch("branch2").
Commit("branch2 change").
Checkout("branch1").
CreateBranch("branch3").
Commit("branch3 change").
Checkout("main")
require.NoError(t, s.Engine.TrackBranch(context.Background(), "branch1", "main"))
require.NoError(t, s.Engine.TrackBranch(context.Background(), "branch2", "branch1"))
require.NoError(t, s.Engine.TrackBranch(context.Background(), "branch3", "branch1"))

// Force branch2 onto an unrelated orphan root so merge-base(branch2, main) fails.
s.Checkout("branch2")
Expand All @@ -296,8 +342,13 @@ func TestDeleteBranch(t *testing.T) {

err = s.Engine.DeleteBranch(context.Background(), s.Engine.GetBranch("branch1"))
require.Error(t, err)
require.Contains(t, err.Error(), "failed to reparent")
require.Contains(t, err.Error(), "failed to reparent child branch")
require.Contains(t, err.Error(), "branch2")

// branch3 should still be repaired even though branch2 failed first.
branch3Parent := s.Engine.GetBranch("branch3").GetParent()
require.NotNil(t, branch3Parent)
require.Equal(t, "main", branch3Parent.GetName())
})

t.Run("fails when trying to delete trunk", func(t *testing.T) {
Expand Down
127 changes: 86 additions & 41 deletions internal/engine/engine_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package engine

import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -367,12 +368,28 @@ func (e *engineImpl) SetParent(ctx context.Context, branch Branch, parentBranch
})
}

// setParentName updates only the parent branch name in metadata without
// modifying ParentBranchRevision.
func (e *engineImpl) setParentName(ctx context.Context, branch Branch, parentBranch Branch) error {
branchName := branch.GetName()
parentBranchName := parentBranch.GetName()
// resolveReparentRevision determines the revision to store when changing a
// branch's parent while keeping the branch's own commits intact.
//
// Prefer the recorded divergence point when it is still valid. If metadata is
// stale or missing, fall back to the actual merge-base with the new parent.
func (e *engineImpl) resolveReparentRevision(branchName, newParentName string, meta *git.Meta) (string, error) {
if rev := meta.GetParentBranchRevision(); rev != nil && *rev != "" {
if isAncestor, err := e.git.IsAncestor(*rev, branchName); err == nil && isAncestor {
return *rev, nil
}
}

parentRev, err := e.git.GetMergeBase(branchName, newParentName)
if err != nil {
return "", fmt.Errorf("failed to get merge base: %w", err)
}
return parentRev, nil
}

// reparentBranchWithRevision updates the parent name and parent revision in a
// single transaction so callers never observe a partially-applied reparent.
func (e *engineImpl) reparentBranchWithRevision(ctx context.Context, branchName, parentBranchName, parentRev string) error {
if branchName == parentBranchName {
return fmt.Errorf("branch %s cannot be its own parent", branchName)
}
Expand All @@ -384,71 +401,99 @@ func (e *engineImpl) setParentName(ctx context.Context, branch Branch, parentBra
}

meta = meta.WithParentBranchName(&parentBranchName)
meta = meta.WithParentBranchRevision(&parentRev)

tx := e.BeginTx(fmt.Sprintf("set parent name: %s -> %s", branchName, parentBranchName))
tx := e.BeginTx(fmt.Sprintf("reparent: %s -> %s", branchName, parentBranchName))
if err := tx.UpdateMeta(branchName, meta); err != nil {
return err
}
return tx.Commit(ctx)
})
}

// setParentPreservingDivergence updates a branch's parent while preserving
// the divergence point if it remains a valid ancestor. Uses setParentName
// (not SetParent) so the existing ParentBranchRevision is never overwritten
// with an incorrect merge-base value.
func (e *engineImpl) setParentPreservingDivergence(ctx context.Context, branch Branch, newParent Branch, oldDivergencePoint string) error {
if err := e.setParentName(ctx, branch, newParent); err != nil {
return err
// ReparentBranchWithRevision changes a branch's parent using an explicit
// divergence point chosen by the caller. This is used by operations that
// precompute and validate an old upstream before applying metadata changes.
func (e *engineImpl) ReparentBranchWithRevision(ctx context.Context, branch Branch, newParent Branch, parentRev string) error {
branchName := branch.GetName()
parentBranchName := newParent.GetName()
if branchName == parentBranchName {
return fmt.Errorf("branch %s cannot be its own parent", branchName)
}
if parentRev == "" {
return fmt.Errorf("parent revision cannot be empty when reparenting %s", branchName)
}

if oldDivergencePoint == "" {
return nil
return e.reparentBranchWithRevision(ctx, branchName, parentBranchName, parentRev)
}

// ReparentBranch changes a branch's parent while automatically preserving its
// divergence point. If the recorded divergence point is stale or invalid, it
// falls back to the merge-base with the new parent.
func (e *engineImpl) ReparentBranch(ctx context.Context, branch Branch, newParent Branch) error {
branchName := branch.GetName()
parentBranchName := newParent.GetName()
if branchName == parentBranchName {
return fmt.Errorf("branch %s cannot be its own parent", branchName)
}

// Set the correct divergence point so restacking replays only this
// branch's commits, not commits from the old parent.
isAncestor, err := e.git.IsAncestor(oldDivergencePoint, branch.GetName())
meta, err := e.readMetadata(branchName)
if err != nil {
return fmt.Errorf("failed to check ancestry of divergence point %s for %s: %w", oldDivergencePoint, branch.GetName(), err)
return fmt.Errorf("failed to read metadata: %w", err)
}
if !isAncestor {
return fmt.Errorf("divergence point %s is not an ancestor of %s: cannot preserve divergence point", oldDivergencePoint, branch.GetName())

parentRev, err := e.resolveReparentRevision(branchName, parentBranchName, meta)
if err != nil {
return fmt.Errorf("failed to determine reparent revision for %s: %w", branchName, err)
}

return e.updateParentRevision(ctx, branch.GetName(), oldDivergencePoint)
return e.reparentBranchWithRevision(ctx, branchName, parentBranchName, parentRev)
}

// ReparentBranch changes a branch's parent while automatically preserving its
// divergence point. This is the preferred way to reparent an existing branch
// when the branch's own commits should not change.
func (e *engineImpl) ReparentBranch(ctx context.Context, branch Branch, newParent Branch) error {
div, err := e.GetDivergencePoint(branch.GetName())
if err != nil {
return fmt.Errorf("failed to determine divergence point for %s: %w", branch.GetName(), err)
}
return e.setParentPreservingDivergence(ctx, branch, newParent, div)
type reparentUpdate struct {
branchName string
parentRev string
}

// ReparentBranches changes multiple branches to the same new parent while
// preserving each branch's divergence point. Divergence points are captured
// for all branches before any reparenting begins, ensuring correctness when
// branches in the list are related to each other.
// preserving divergence points. All target revisions are captured before any
// metadata writes begin so related branches keep the correct boundaries.
//
// The operation is best-effort: every branch is attempted and any failures are
// returned together so callers can repair as much of the stack as possible.
func (e *engineImpl) ReparentBranches(ctx context.Context, branchNames []string, newParent Branch) error {
divPoints := make(map[string]string, len(branchNames))
parentBranchName := newParent.GetName()
updates := make([]reparentUpdate, 0, len(branchNames))
var reparentErrs []error

for _, name := range branchNames {
div, err := e.GetDivergencePoint(name)
meta, err := e.readMetadata(name)
if err != nil {
return fmt.Errorf("failed to determine divergence point for %s: %w", name, err)
reparentErrs = append(reparentErrs, fmt.Errorf("failed to read metadata for %s: %w", name, err))
continue
}

parentRev, err := e.resolveReparentRevision(name, parentBranchName, meta)
if err != nil {
reparentErrs = append(reparentErrs, fmt.Errorf("failed to determine reparent revision for %s: %w", name, err))
continue
}
divPoints[name] = div

updates = append(updates, reparentUpdate{
branchName: name,
parentRev: parentRev,
})
}

for _, name := range branchNames {
if err := e.setParentPreservingDivergence(ctx, e.GetBranch(name), newParent, divPoints[name]); err != nil {
return fmt.Errorf("failed to reparent %s to %s: %w", name, newParent.GetName(), err)
for _, update := range updates {
if err := e.reparentBranchWithRevision(ctx, update.branchName, parentBranchName, update.parentRev); err != nil {
reparentErrs = append(reparentErrs, fmt.Errorf("failed to reparent %s to %s: %w", update.branchName, parentBranchName, err))
}
}

if len(reparentErrs) > 0 {
return errors.Join(reparentErrs...)
}
return nil
}

Expand Down
11 changes: 8 additions & 3 deletions internal/engine/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,17 @@ type BranchTracking interface {
TrackBranch(ctx context.Context, branchName string, parentBranchName string) error
UntrackBranch(branchName string) error
SetParent(ctx context.Context, branch Branch, parentBranch Branch) error
// ReparentBranchWithRevision changes a branch's parent using a caller-supplied
// divergence point. Use this when an action has already validated the exact
// old upstream that must be preserved.
ReparentBranchWithRevision(ctx context.Context, branch Branch, newParent Branch, parentRev string) error
// ReparentBranch changes a branch's parent while automatically preserving
// its divergence point. Preferred over SetParent for existing branches.
// its divergence point when still valid, otherwise falling back to the
// merge-base with the new parent. Preferred over SetParent for existing branches.
ReparentBranch(ctx context.Context, branch Branch, newParent Branch) error
// ReparentBranches changes multiple branches to the same new parent while
// preserving divergence points. All divergence points are captured before
// any reparenting begins.
// preserving divergence points. All target revisions are captured before any
// reparenting begins, and failures are aggregated after attempting every branch.
ReparentBranches(ctx context.Context, branchNames []string, newParent Branch) error
SetScope(ctx context.Context, branch Branch, scope Scope) error
SetBranchType(branch Branch, branchType git.BranchType) error
Expand Down
Loading