# Heuristic search

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random

def heuristic_search(max_iterations=100, 
                     space=(0,1,20), 
                     xi=None, 
                     yi=None, 
                     plot_reds=False, 
                     plot_only_last=False, 
                     seed=None):
    
    rng = np.random.default_rng(seed)
    iteration = 0
    
    def plot():
        fig = plt.figure(figsize=(13, 13))
        ax = plt.axes(projection='3d')
        ax.scatter(a, b, c, c='black', marker='o', depthshade=False, s=80) 
        ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1, cmap = 'jet', alpha=0.5)
        ax.scatter(X, Y, Z, marker='.', depthshade=False, s=10, alpha=0.5, c="gray")  
        if iteration > 0 and plot_reds:
            ax.scatter(reds[0], reds[1], reds[2], marker='o', depthshade=False, s=40, alpha=1.0, c="red")
        ax.scatter(ca, cb, cc, c='black', marker='o', depthshade=False, s=200, alpha=1.0)  
              

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.grid(False)
        ax.view_init(30, 20)

    
    def get_z(u,w):
        return np.sin(np.pi*u)*np.sin(np.pi*w)
#         return abs(u-w)
#         return np.cos(u) * np.cos(w) * np.exp(-np.sqrt(u**2 + w**2)) / 4
#         return u * np.exp(-u**2 - w**2)
#         return u**2 - w**2
    
    def max_z():
        next_x, prev_x, next_y, prev_y = xi, xi, yi, yi
        
        if xi+1 < len(X[0]):
            next_x = xi+1
        if xi-1 >= 0:
            prev_x = xi-1
        if yi+1 < len(Y[0]):
            next_y = yi+1
        if yi-1 >= 0:
            prev_y = yi-1
        
        tz = -np.inf
        nxi = xi
        nyi = yi
        perm = itertools.product((next_x, prev_x), (next_y, prev_y))
        for i in perm:
            nz = get_z(X[0][i[0]],Y[i[1]][0])
            
            for k, n in enumerate([X[0][i[0]],Y[i[1]][0], nz]):
                    reds[k].append(n)
                        
            if nz >= tz:
                nxi = i[0]
                nyi = i[1]
                tz = nz
            

        
        return nxi, nyi
    
    x_space = np.linspace(*space)
    y_space = np.linspace(*space)
    X, Y = np.meshgrid(x_space, y_space)
    Z = get_z(X, Y)

    if xi == None:
        xi = rng.integers(0,space[2]-1)
    if yi == None:
        yi = rng.integers(0,space[2]-1)

    a = []
    b = []
    c = []
    
    ca = X[0][xi]
    cb = Y[yi][0]
    cc = get_z(ca, cb)
    
    reds = [[],[],[]]
    
    while iteration <= max_iterations:
        iteration += 1
        a.append(ca)
        b.append(cb)
        c.append(cc)
        
        if plot_only_last:
            if iteration == max_iterations:
                 plot()
        else:
            yield plot()
        

        xi, yi = max_z()
        
        ca = X[0][xi]
        cb = Y[yi][0]
        cc = get_z(ca, cb)
        
        



## Schematic representation of a generic heuristic search

In [None]:
hs = heuristic_search(space=(-1,0.4,20), xi=17, yi=9)

In [None]:
next(hs)

The explored space is significantly smaller in compairison with a exhaustive search

In [None]:
hs = heuristic_search(space=(-1,0.4,20), xi=17, yi=9, plot_only_last=True, plot_reds=True)
for i in hs: pass

Some other caveats

In [None]:
#random start limitations
hs = heuristic_search(space=(-1,0.4,20), xi=10, yi=18)

In [None]:
next(hs)

In [None]:
# big space and number of iterations
hs = heuristic_search(max_iterations=15, space=(-1,0.4,100), plot_only_last=True)
for i in hs: pass

# Nearest-neighbor interchanges (NNI)

In [344]:
def nni(tree, seed=None, inplace=False, highlight=False, force=(-1, -1)):
    tree = tree if inplace else tree.copy()
    rng = np.random.default_rng(seed)
    
    
    # randomly select first subtree (any non-root Node)
    f_idx = rng.choice(tree.nnodes - 1)  
    if force[0] >= 0: f_idx = force[0] #overwrite random selection
    
    subtree_a = tree[f_idx]
    tips_a = subtree_a.get_leaf_names()
    
   
      
    # This cannot be a descendent of the first subtree
    # Get all possible nodes that satify previous statement
    available_nodes = (
        set(range(tree.nnodes - 1)) # set with all possible nodes but the root
        - set((i._idx for i in subtree_a._iter_descendants())) # remove descendants of subtree
        - set((i._idx for i in subtree_a._iter_sisters())) # remove sisters of subtree (to avoid uninformative interchange)
        - set((subtree_a._up._idx, )) # remove  parental of subtree (to avoid pick a subtree with subtree_a in there)
        - set((subtree_a._idx, )) # remove subtree node itself
    )
    
    
    if not available_nodes:
        print(f"No possible interchange if {subtree_a!r} is selected")
        return None
    
    # randomly select second subtree to interchange
    s_idx = rng.choice(list(available_nodes))
    if force[1] >= 0: s_idx = force[1] #overwrite random selection
    
    subtree_b = tree[s_idx]
    tips_b = subtree_b.get_leaf_names()

    
    print(f"Interchanging {subtree_a!r} with {subtree_b!r}")
    
    # update tree


    # Get children of parental of subtree a and remove subtree a from it
    a_up_children = list(subtree_a._up.children)
    a_up_children.remove(subtree_a)
    a_up_children.append(subtree_b)
    subtree_a._up._children = tuple(a_up_children)
    
    # Get children of parental of subtree b and remove subtree b from it
    b_up_children = list(subtree_b._up.children)
    b_up_children.remove(subtree_b)
    b_up_children.append(subtree_a)
    subtree_b._up._children = tuple(b_up_children)
    

    # Put those subtrees parentasl in temporal vars
    old_a_up = subtree_a.up.copy()
    old_b_up = subtree_b.up.copy()
    
    # Break the conextion of each subtree with their original parental
    # by replacing with the opposite
    subtree_a._up = old_b_up
    subtree_b._up = old_a_up
    
    
    # using built in method (just using this does not work)
#     subtree_a.up._remove_child(subtree_a)
#     subtree_b.up._remove_child(subtree_b)
#     subtree_a.up._add_child(subtree_b)
#     subtree_b.up._add_child(subtree_a)
    
        
#     print(f"{old_a_up=}")
#     print(f"{old_b_up=}")
    
  
    tree._update()
    
    

    
    if highlight:
        tree.style.edge_colors = ['black'] * tree.nnodes
        
        for it, tips in enumerate([tips_a, tips_b]):
            descs = tree.get_mrca_node(*tips).get_descendants()
            for node in tree:
                if node in descs:
                    tree.style.edge_colors[node.idx] = toytree.color.COLORS1[it]
                    
   
      
    
    return tree

tree = toytree.rtree.baltree(7)
tree.draw(node_labels="idx", node_sizes=15)
nni(tree, highlight=True).draw(node_labels="idx", node_sizes=15, use_edge_lengths=False);

Interchanging Node(9) with Node(0)


In [191]:
tips = 5
total_neighbors = 2 * (tips - 3) 



#rearrange
#if lazy mode -> pick the first best tree (fast but innacurate)
#if normal mode -> retain all trees
#if tie: pick every tie tree and rearrange (2nd level of neighbors)

# SPR

In [86]:
from typing import Optional, TypeVar
import numpy as np
import toytree.core.tree

ToyTree = TypeVar("ToyTree")


def spr(
    tree: ToyTree,
    seed: Optional[int]=None,
    inplace: bool=False,
    highlight: bool=False,
    ) -> ToyTree:
    """Return a rooted ToyTree one SPR move from the current tree.
    The returned tree will have a different topology from the starting
    tree, at an SPR distance of 1. It randomly samples a subtree to
    extract from the tree, and then reinserts the subtree at an edge
    that is not (1) one of its descendants; (2) its sister; (3) its
    parent; or (4) itself.
    Parameters
    ----------
    ...
    Examples
    --------
    >>> ...
    """
    tree = tree if inplace else tree.copy()
    rng = np.random.default_rng(seed)

    # randomly select a subtree (any non-root Node)
    sidx = rng.choice(tree.nnodes - 1)
    
    
    
    subtree = tree[sidx]
    tips = subtree.get_leaf_names()
    
    print(f"{subtree=}")

    # get list of Nodes (edges) where subtree can be inserted. This
    # cannot be the root, or a descendant on the subtree Node, or the
    # subtree itself.
    edges = (
        set(range(tree.nnodes - 1)) # set with all possible nodes but the root
        - set((i._idx for i in subtree._iter_descendants())) # remove descendants of subtree
        - set((i._idx for i in subtree._iter_sisters())) # remove sisters of subtree (to avoid re-inserted in the same place)
        - set((subtree._up._idx, )) # remove  parental of subtree (to avoid re-inserted in the same place if no-sister is present)
        - set((subtree._idx, )) # remove subtree node itself
    )
    
    
    # sample an edge by its descendant Node
    new_sister = tree[rng.choice(list(edges))]
    print(f"{new_sister=}")

    # connect subtree to new sister by inserting a new Node
    new_node = toytree.Node("new")
    new_node_parent = new_sister._up
    old_node = subtree._up
    old_node_parent = old_node._up
    new_node._up = new_node_parent
    new_node._children = (subtree, new_sister)

    # remove 12 and connect 11 to 14
    old_node._remove_child(subtree)
    if old_node_parent:
        old_node_parent._remove_child(old_node)
        for child in old_node._children:
            old_node_parent._add_child(child)
            child._dist += old_node._dist
    del old_node

    # connect subtree to tree
    subtree._up = new_node
    new_sister._up = new_node
    if new_node_parent:
        new_node_parent._remove_child(new_sister)
        new_node_parent._add_child(new_node)
    tree._update()

    # optional: color edges of the subtree that was moved.
    if highlight:
        tree.style.edge_colors = ['black'] * tree.nnodes
        descs = tree.get_mrca_node(*tips).get_descendants()
        for node in tree:
            if node in descs:
                tree.style.edge_colors[node.idx] = toytree.color.COLORS1[0]
    return tree

In [345]:
tree = toytree.rtree.baltree(7)
tree.draw(node_labels="idx", node_sizes=15)
spr(tree, highlight=True).draw(node_labels=True, node_sizes=15, use_edge_lengths=False)

subtree=Node(7)
new_sister=Node(9)


(<toyplot.canvas.Canvas at 0x7f2ae67d4280>,
 <toyplot.coordinates.Cartesian at 0x7f2ae66e20d0>,
 <toytree.core.drawing.toytree_mark.ToytreeMark at 0x7f2ae5a0ed30>)