# Heuristic search

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random

def heuristic_search(max_iterations=100, 
                     space=(0,1,20), 
                     xi=None, 
                     yi=None, 
                     plot_reds=False, 
                     plot_only_last=False, 
                     seed=None):
    
    rng = np.random.default_rng(seed)
    iteration = 0
    
    def plot():
        fig = plt.figure(figsize=(13, 13))
        ax = plt.axes(projection='3d')
        ax.scatter(a, b, c, c='black', marker='o', depthshade=False, s=80) 
        ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1, cmap = 'jet', alpha=0.5)
        ax.scatter(X, Y, Z, marker='.', depthshade=False, s=10, alpha=0.5, c="gray")  
        if iteration > 0 and plot_reds:
            ax.scatter(reds[0], reds[1], reds[2], marker='o', depthshade=False, s=40, alpha=1.0, c="red")
        ax.scatter(ca, cb, cc, c='black', marker='o', depthshade=False, s=200, alpha=1.0)  
              

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.grid(False)
        ax.view_init(30, 20)

    
    def get_z(u,w):
        return np.sin(np.pi*u)*np.sin(np.pi*w)
#         return abs(u-w)
#         return np.cos(u) * np.cos(w) * np.exp(-np.sqrt(u**2 + w**2)) / 4
#         return u * np.exp(-u**2 - w**2)
#         return u**2 - w**2
    
    def max_z():
        next_x, prev_x, next_y, prev_y = xi, xi, yi, yi
        
        if xi+1 < len(X[0]):
            next_x = xi+1
        if xi-1 >= 0:
            prev_x = xi-1
        if yi+1 < len(Y[0]):
            next_y = yi+1
        if yi-1 >= 0:
            prev_y = yi-1
        
        tz = -np.inf
        nxi = xi
        nyi = yi
        perm = itertools.product((next_x, prev_x), (next_y, prev_y))
        for i in perm:
            nz = get_z(X[0][i[0]],Y[i[1]][0])
            
            for k, n in enumerate([X[0][i[0]],Y[i[1]][0], nz]):
                    reds[k].append(n)
                        
            if nz >= tz:
                nxi = i[0]
                nyi = i[1]
                tz = nz
            

        
        return nxi, nyi
    
    x_space = np.linspace(*space)
    y_space = np.linspace(*space)
    X, Y = np.meshgrid(x_space, y_space)
    Z = get_z(X, Y)

    if xi == None:
        xi = rng.integers(0,space[2]-1)
    if yi == None:
        yi = rng.integers(0,space[2]-1)

    a = []
    b = []
    c = []
    
    ca = X[0][xi]
    cb = Y[yi][0]
    cc = get_z(ca, cb)
    
    reds = [[],[],[]]
    
    while iteration <= max_iterations:
        iteration += 1
        a.append(ca)
        b.append(cb)
        c.append(cc)
        
        if plot_only_last:
            if iteration == max_iterations:
                 plot()
        else:
            yield plot()
        

        xi, yi = max_z()
        
        ca = X[0][xi]
        cb = Y[yi][0]
        cc = get_z(ca, cb)
        
        



## Schematic representation of a generic heuristic search

In [None]:
hs = heuristic_search(space=(-1,0.4,20), xi=17, yi=9)

In [None]:
next(hs)

The explored space is significantly smaller in compairison with a exhaustive search

In [None]:
hs = heuristic_search(space=(-1,0.4,20), xi=17, yi=9, plot_only_last=True, plot_reds=True)
for i in hs: pass

Some other caveats

In [None]:
#random start limitations
hs = heuristic_search(space=(-1,0.4,20), xi=10, yi=18)

In [None]:
next(hs)

In [None]:
# big space and number of iterations
hs = heuristic_search(max_iterations=15, space=(-1,0.4,100), plot_only_last=True)
for i in hs: pass

# Nearest-neighbor interchanges (NNI)

In [441]:
def nni(tree, 
        seed=None, 
        highlight=False, 
        force=(-1, -1)
       ) -> ToyTree:
    
    tree = tree.copy()
    rng = np.random.default_rng(seed)
    
    
    # randomly select first subtree (any non-root Node)
    f_idx = rng.choice(tree.nnodes - 1)  
    if force[0] >= 0: f_idx = force[0] #overwrite random selection
    
    subtree_a = tree[f_idx]
    tips_a = subtree_a.get_leaf_names()
    
   
      
    # Check available nodes to select second subtree
    # It should follow the following statements
    available_nodes = (
        set(range(tree.nnodes - 1)) # set with all possible nodes but the root
        - set((i._idx for i in subtree_a._iter_descendants())) # remove descendants of subtree
        - set((i._idx for i in subtree_a._iter_sisters())) # remove sisters of subtree (to avoid uninformative interchange)
        - set((subtree_a._up._idx, )) # remove  parental of subtree (to avoid pick a subtree with subtree_a in there)
        - set((subtree_a._idx, )) # remove subtree node itself
    )
    
    
    if not available_nodes:
        print(f"No possible interchange if {subtree_a!r} is selected")
        return None
    
    # randomly select second subtree to interchange
    s_idx = rng.choice(list(available_nodes))
    if force[1] >= 0: s_idx = force[1] #overwrite random selection
    
    subtree_b = tree[s_idx]
    tips_b = subtree_b.get_leaf_names()

    
    print(f"Interchanging {subtree_a!r} with {subtree_b!r}")
    

# V. 1 (SEMIFUNCTIONAL)
    # Idea is not working completly, sometimes fails
    #remove subtrees from parentals (both subtrees)
    #add each subtree into the opposite parental
    #update moved subtrees up argument


    
    # Get children of parental of subtree a and remove subtree a from it
    a_up_children = list(subtree_a._up.children)
    a_up_children.remove(subtree_a)
    a_up_children.append(subtree_b)
    subtree_a._up._children = tuple(a_up_children)
    
    
    
#     # Get children of parental of subtree b and remove subtree b from it
    b_up_children = list(subtree_b._up.children)
    b_up_children.remove(subtree_b)
    b_up_children.append(subtree_a)
    subtree_b._up._children = tuple(b_up_children)
         

    # Put subtrees parentals in temporal vars
    old_a_up = subtree_a.up.copy()
    old_b_up = subtree_b.up.copy()
    
    # Break the connection of each subtree with their original parental
    # by replacing with the opposite
    subtree_a._up = old_b_up
    subtree_b._up = old_a_up
   
    
    
# V.2 (NOT WORKING)
# same approach but using built in methods 
# #     # Get children of parental of subtree a and remove subtree a from it
#     subtree_a.up._remove_child(subtree_a)
#     subtree_a.up._add_child(subtree_b)
    
# # #     # Get children of parental of subtree b and remove subtree b from it
#     subtree_b.up._remove_child(subtree_b)
#     subtree_b.up._add_child(subtree_a)
         
#     # Put subtrees parentals in temporal vars
#     old_a_up = subtree_a.up.copy()
#     old_b_up = subtree_b.up.copy()
    
#     # Break the connection of each subtree with their original parental
#     # by replacing with the opposite
#     subtree_a._up = old_b_up
#     subtree_b._up = old_a_up

    

# V.3 (NOT WORKING) 
#     ### new try
#     # create two empty nodes where we can put the children but the subtree removed
#     parental_new_a = toytree.Node("pna")
#     parental_new_b = toytree.Node("pnb")
    
#     # add children without considering the subtree interchanged
#     for child in subtree_a.up.children:
#         if child != subtree_a:
#             parental_new_a._add_child(child)
    
#     for child in subtree_b.up.children:
#         if child != subtree_b:
#             parental_new_b._add_child(child)
    
    
#     # add to each new parental node the swapping subtree
#     parental_new_a._add_child(subtree_b)
#     parental_new_b._add_child(subtree_a)
    
#     # update up argument for all children in each new parental node
#     for p in [parental_new_a, parental_new_b]:
#         for child in p.children:
#             child._up = p
    
    

# V.4 (NOT WORKING)
# This idea is not working  
#     # Create new nodes to replace the parantes of subtree
#     new_parental_a = toytree.Node("new_a")    # create a unlinked node
#     for child in subtree_a.up.children:       # add all nodes but the original swapping node
#         if child != subtree_a:
#             new_parental_a._add_child(child)
# #     new_parental_a._add_child(subtree_b)      # add the opposite swapping node
# #     new_parental_a._up = subtree_a._up._up    # copy parental of parental to new node
    
#     # same that before but this time for subtree b
#     new_parental_b = toytree.Node("new_b")   
#     for child in subtree_b.up.children:
#         if child != subtree_b:
#             new_parental_b._add_child(child)
# #     new_parental_b._add_child(subtree_a)
# #     new_parental_b._up = subtree_b._up._up
       
#     # replace with the new ones
# #     subtree_a._up = new_parental_a
# #     subtree_b._up = new_parental_b
    
    
    
    tree._update()
    
    
    tree.style.node_sizes = 15
    tree.style.node_labels = "idx"
    
    if highlight:
        tree.style.edge_colors = ['black'] * tree.nnodes
        tree.style.node_colors = ['white'] * tree.nnodes
        tree.style.node_style.stroke_width = 1.5
        tree.style.node_sizes = 8
        tree.style.node_labels = "idx"
        tree.style.node_labels_style.font_size = 12
        tree.style.node_labels_style._toyplot_anchor_shift = -12
        tree.style.node_labels_style.baseline_shift = 9
        tree.style.use_edge_lengths = False
        
        
        
        for it, subtree in enumerate([subtree_a, subtree_b]):
            for node in subtree._iter_descendants():
                tree.style.edge_colors[node.idx] = toytree.color.COLORS2[it]
                tree.style.node_colors[node.idx] = toytree.color.COLORS2[it]
                    
   
    tree.style.use_edge_lengths = False
    
    return tree

tree = toytree.rtree.unittree(5, seed=42)
tree.draw(node_labels="idx", node_sizes=15)
tresult = nni(tree, highlight=False, seed=42, force=(6,0))
if tresult: tresult.draw();

Interchanging Node(6) with Node(0)


In [None]:
tresult.get_node_data()

In [444]:
tips = 3
total_neighbors = 2 * (tips - 3) 



#rearrange
#if lazy mode -> pick the first best tree (fast but innacurate)
#if normal mode -> retain all trees
#if tie: pick every tie tree and rearrange (2nd level of neighbors)

# Subtree pruning and regrafting (SPR)

In [396]:
from typing import Optional, TypeVar
import numpy as np
import toytree


logger = logger.bind(name="toytree")
ToyTree = TypeVar("ToyTree")

def spr(
    tree: ToyTree,
    seed: Optional[int]=None,
    highlight: bool=False,
    ) -> ToyTree:
    """Return a rooted ToyTree one SPR move from the current tree.
    The returned tree will have a different topology from the starting
    tree, at an SPR distance of 1. It randomly samples a subtree to
    extract from the tree, and then reinserts the subtree at an edge
    that is not (1) one of its descendants; (2) its sister; (3) its
    parent; or (4) itself.
    Parameters
    ----------
    ...
    Examples
    --------
    >>> ...
    """
    tree = tree.copy()
    rng = np.random.default_rng(seed)

    # randomly select a subtree (any non-root Node)
    sidx = rng.choice(tree.nnodes - 1)
    subtree = tree[sidx]

    # get list of Nodes (edges) where subtree can be inserted. This
    # cannot be root, or a desc on the subtree Node, or the subtree itself.
    edges = (
        set(range(tree.nnodes)) 
        - set((i._idx for i in subtree._iter_descendants())) 
        - set((i._idx for i in subtree._iter_sisters())) 
        - set((subtree._up._idx, )) 
        - set((subtree._idx, ))
    )

    # sample an edge by its descendant Node
    new_sister = tree[rng.choice(list(edges))]
    print(f"Prunning {subtree!r} to regrafting in {new_sister!r}")

    # connect subtree to new sister by inserting a new Node
    new_node = toytree.Node("new")
    new_node_parent = new_sister._up
    new_node._up = new_node_parent
    new_node._children = (subtree, new_sister)
    
    old_node = subtree._up
    old_node_parent = old_node._up
    
    # remove 12 and connect 11 to 14
    old_node._remove_child(subtree)
    if old_node_parent:
        old_node_parent._remove_child(old_node)
        for child in old_node._children:
            old_node_parent._add_child(child)
#             child._dist += old_node._dist   ## why is this in here?
    del old_node

    # connect subtree to tree   
    subtree._up = new_node
    new_sister._up = new_node
    if new_node_parent:
        new_node_parent._remove_child(new_sister)
        new_node_parent._add_child(new_node)

    # if old root is now a singleton b/c new_sister is one of its children
    if len(tree.treenode.children) == 1:
        tree.treenode = tree.treenode.children[0]
        oldroot = tree.treenode._up
        tree.treenode._up = None
        del oldroot

    # if new_sister is now the root.
    elif new_sister == tree.treenode:
        tree.treenode = new_node
        
        
    tree._update()

    # optional: color edges of the subtree that was moved.
    if highlight:
        tree.style.edge_colors = ['black'] * tree.nnodes
        tree.style.node_colors = ['white'] * tree.nnodes
        tree.style.node_style.stroke_width = 1.5
        tree.style.node_sizes = 8
        tree.style.node_labels = "idx"
        tree.style.node_labels_style.font_size = 12
        tree.style.node_labels_style._toyplot_anchor_shift = -12
        tree.style.node_labels_style.baseline_shift = 9
        tree.style.use_edge_lengths = False

        # tree.get_mrca_node(*tips)
        for node in subtree._iter_descendants():
            tree.style.edge_colors[node.idx] = toytree.color.COLORS2[3]
            tree.style.node_colors[node.idx] = toytree.color.COLORS2[3]
        tree.style.node_colors[new_node.idx] = toytree.color.COLORS2[3]
    return tree

In [409]:
tree = toytree.rtree.unittree(7, seed=42)
to = tree.draw(node_labels="idx", node_sizes=15)
tm = spr(tree, highlight=True, seed=None).draw()

Prunning Node(8) to regrafting in Node(4)


# Heuristic search for the best tree

1. use one of the rearragement algorithm
2. do one rearragment
3. test some criterion
4. Use that tree as the new original tree
5. Go again, if find not improvement stop

option for multiple starts