Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OpenMP
Ornithischia
Orthrozanclus
Osteichthyan
PCSA
PLATNICK
PLEIJEL
PUJADE
Expand Down Expand Up @@ -253,6 +254,8 @@ pscore
rRNA
rearranger
reconnections
reconverged
reconverges
regraft
regrafting
regrafts
Expand Down
97 changes: 84 additions & 13 deletions src/ts_constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,74 @@ bool violates_constraint_posthoc(const TreeState& tree,

namespace {

// Topology-only SPR: move `clip` to the edge (above, below).
// Unlike spr_clip/spr_regraft, this handles root-child clips correctly
// and doesn't save/restore state (caller must rebuild postorder and rescore).
void topology_spr(TreeState& tree, int clip, int above, int below) {
const int root = tree.n_tip;
int nx = tree.parent[clip];
int ns = (tree.left[nx - root] == clip)
? tree.right[nx - root]
: tree.left[nx - root];

if (nx != root) {
// --- Normal case: detach nx, connect ns to grandparent ---
int nz = tree.parent[nx];
tree.parent[ns] = nz;
if (nz >= tree.n_tip) {
int nzi = nz - tree.n_tip;
if (tree.left[nzi] == nx)
tree.left[nzi] = ns;
else
tree.right[nzi] = ns;
}

// Insert nx between above and below
if (above >= tree.n_tip) {
int ai = above - tree.n_tip;
if (tree.left[ai] == below)
tree.left[ai] = nx;
else
tree.right[ai] = nx;
}
tree.parent[nx] = above;
int nxi = nx - tree.n_tip;
tree.left[nxi] = clip;
tree.right[nxi] = below;
tree.parent[clip] = nx;
tree.parent[below] = nx;
} else {
// --- Root-child case: clip is a direct child of root ---
// Can't float root (identity is fixed at n_tip).
// Absorb ns into root and repurpose ns as the insertion node.
if (ns < tree.n_tip) return; // ns is a tip — degenerate, bail out

int nsi = ns - tree.n_tip;
int ns_left = tree.left[nsi];
int ns_right = tree.right[nsi];

// Root absorbs ns's children
tree.left[0] = ns_left;
tree.right[0] = ns_right;
tree.parent[ns_left] = root;
tree.parent[ns_right] = root;

// Insert ns between above and below, with clip as its other child
if (above >= tree.n_tip) {
int ai = above - tree.n_tip;
if (tree.left[ai] == below)
tree.left[ai] = ns;
else
tree.right[ai] = ns;
}
tree.parent[ns] = above;
tree.left[nsi] = clip;
tree.right[nsi] = below;
tree.parent[clip] = ns;
tree.parent[below] = ns;
}
}

// Collect (above, below) edge pairs within the subtree rooted at node.
// Iterative DFS; does NOT include the edge above `node` itself.
void collect_edges_in_subtree(const TreeState& tree, int sub_root,
Expand Down Expand Up @@ -601,36 +669,38 @@ static int impose_one_pass(TreeState& tree, ConstraintData& cd,
find_maximal_subtrees(tree, root, best_node, node_tips,
move_in_mask, n_words, move_in_roots);

// Bail out if too many moves needed
// Safety cap: abandon this pass if the repair is unexpectedly large.
int n_moves = static_cast<int>(
move_out_roots.size() + move_in_roots.size());
if (total_moves + n_moves > tree.n_tip / 4) break;
if (total_moves + n_moves > tree.n_tip / 4 + 2) {
return -1; // Distinguish "bailed out" from "no violations" (0)
}

// --- Execute SPR moves ---
// Enumerate target edges AFTER each clip (post-clip tree is valid).
// --- Execute topology moves ---
// Uses topology_spr() which handles root-child moves correctly
// (unlike spr_clip which can't detach root children).
// Rebuild postorder after each move so edge enumeration is valid.
for (int M : move_out_roots) {
if (tree.parent[M] == root) continue;
tree.spr_clip(M);
tree.build_postorder();
std::vector<std::pair<int,int>> targets;
collect_edges_outside_subtree(tree, best_node, targets);
if (targets.empty()) { tree.spr_unclip(); continue; }
if (targets.empty()) continue;
auto [above, below] =
targets[std::uniform_int_distribution<int>(
0, static_cast<int>(targets.size()) - 1)(rng)];
tree.spr_regraft(above, below);
topology_spr(tree, M, above, below);
++total_moves;
}

for (int M : move_in_roots) {
if (tree.parent[M] == root) continue;
tree.spr_clip(M);
tree.build_postorder();
std::vector<std::pair<int,int>> targets;
collect_edges_in_subtree(tree, best_node, targets);
if (targets.empty()) { tree.spr_unclip(); continue; }
if (targets.empty()) continue;
auto [above, below] =
targets[std::uniform_int_distribution<int>(
0, static_cast<int>(targets.size()) - 1)(rng)];
tree.spr_regraft(above, below);
topology_spr(tree, M, above, below);
++total_moves;
}
}
Expand All @@ -652,8 +722,9 @@ int impose_constraint(TreeState& tree, ConstraintData& cd)
// is bounded by n_splits. Cap at n_splits + 1 for safety.
for (int pass = 0; pass <= cd.n_splits; ++pass) {
int moves = impose_one_pass(tree, cd, rng);
if (moves < 0) break; // Bailed out — too many moves needed
if (moves == 0) break; // No violations found — done
total_moves += moves;
if (moves == 0) break; // No violations found — done
}

tree.build_postorder();
Expand Down
3 changes: 1 addition & 2 deletions src/ts_driven.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ ReplicateResult run_single_replicate(
}
case StartStrategy::RANDOM_TREE:
if (cd && cd->active) {
// Fall back to constraint-aware Wagner when constraints active
random_wagner_tree(result.tree, ds, cd);
random_constrained_tree(result.tree, ds, *cd);
} else {
random_topology_tree(result.tree, ds);
}
Expand Down
Loading
Loading