# Restore rooting, ordering and node labels of a tree

## Background

- There are such applications, for example, RAxML can estimate branch lengths for a given topology based on an alignment, but the resulting tree will lose the original rooting and node labels, as well as the ordering of nodes. The goal of this script is to restore those pieces of information.
- Inputs are the original tree (rooted and with node labels) and the reconstructed tree (unrooted and without node labels).
- Output is the resulting tree with rooting, ordering and node labels restored.

## Dependencies

In [1]:
from skbio import TreeNode

## Functions

### Prerequisite I: copying TreeNode directionally and recursively

 - Goal: at a given node, toward a given direction (to its parent / children), copy the entire tree structure.
 - scikit-bio's [`unrooted_copy`](https://github.com/biocore/scikit-bio/blob/master/skbio/tree/_tree.py#L584) function does similar things. However, it does not parse the tree root in the desired way.

Solution I: tested and working, logically more natural, but less efficient.

In [2]:
%%script false
def walk_copy(node, src):
    """Directionally and recursively copy a tree node and its neighbors.

    Parameters
    ----------
    node : skbio.TreeNode
        node and its neighbors to be copied
    src : skbio.TreeNode
        node in the original tree that will become parent of self node
        in the new tree

    Returns
    -------
    skbio.TreeNode
        copied node and its neighbors

    Notes
    -----
    Unlike scikit-bio's `unrooted_copy` function, this function has special
    treatment at root: After manipulation, the original root is gone, and all
    basal siblings of the self node become immediate children of it.

    The function determines whether a tree is rooted or unrooted in such way:
    rooted: root has two children; unrooted: root has 1 or 2+ children.

    Pseudocode:
    if node is root:
        if tree is rooted:
            raise error
        else:
            if src in node.children:
                append node.other_child
            else:
                raise error
    elif node is basal (i.e., child of root):
        if tree is rooted:
            if src in node.siblings:
                append node.children
            elif src in node.children:
                append node.sibling and node.other_children
            else:
                raise error
        else:
            if src is node.parent (i.e., root):
                append node.children
            elif src in node.children:
                append node.parent and node.other_children
            else:
                raise error
    else:
        if src is node.parent:
            append node.children
        elif src in node.children:
            append node.parent and node.other_children
        else:
            raise error
    """
    # create a new node
    res = TreeNode(node.name)
    
    parent = node.parent
    children = node.children
    other_children = [x for x in children if x is not src]
    siblings = [x for x in node.siblings()]
    other_siblings = [x for x in siblings if x is not src]

    # node is root
    if node.is_root():
        
        # rooted tree
        if len(node.children) == 2:
            raise ValueError('Cannot walk from root of an rooted tree.')

        # unrooted tree
        else:
            if src in node.children:
                res.length = src.length
                res.extend([walk_copy(x, node) for x in node.children
                            if x is not src])
            else:
                raise ValueError('Source and node are not neighbors.')

    # node is basal
    elif node.parent.is_root():
        
        # rooted tree
        if len(node.parent.children) == 2:
            if src in siblings:
                res.length = node.length + src.length
                res.extend([walk_copy(x, node) for x in node.children])
            elif src in node.children:
                res.length = src.length
                res.extend([walk_copy(x, node) for x in node.children
                            if x is not src])
                res.append(walk_copy(siblings[0], node))
            else:
                raise ValueError('Source and node are not neighbors.')

        # unrooted tree
        else:
            if src is node.parent:
                res.length = node.length
                res.extend([walk_copy(x, node) for x in node.children])
            elif src in node.children:
                res.length = src.length
                res.extend([walk_copy(x, node) for x in node.children
                            if x is not src])
                res.append(walk_copy(node.parent, node))
            else:
                raise ValueError('Source and node are not neighbors.')

    # node is derived
    else:
        if src is node.parent:
            res.length = node.length
            res.extend([walk_copy(x, node) for x in node.children])
        elif src in node.children:
            res.length = src.length
            res.extend([walk_copy(x, node) for x in node.children
                        if x is not src])
            res.append(walk_copy(node.parent, node))
        else:
            raise ValueError('Source and node are not neighbors.')

    return res

Solution II: simpler, more elegant, less straightforward, tested and working.

In [3]:
def walk_copy(node, src):
    """Directionally and recursively copy a tree node and its neighbors.

    Parameters
    ----------
    node : skbio.TreeNode
        node and its neighbors to be copied
    src : skbio.TreeNode
        an upstream node determining the direction of walking (src -> node)

    Returns
    -------
    skbio.TreeNode
        copied node and its neighbors

    Notes
    -----
    After manipulation, `src` will become the parent of `node`, and all other
    neighbors of `node` will become children of it.

    Unlike scikit-bio's `unrooted_copy` function, this function has special
    treatment at root: For an unrooted tree, its "root" will be retained as a
    regular node; for a rooted tree, its root will be deleted, and all basal
    nodes will become immediate children of the basal node where the source is
    located.

    The function determines whether a tree is rooted or unrooted in such way:
    rooted: root has two children; unrooted: root has one or more than two
    children.

    Logic (pseudocode):
    if node is root:
        if tree is rooted:
            raise error
        else:
            if src in node.children:
                append node.other_child
            else:
                raise error
    elif node is basal (i.e., child of root):
        if tree is rooted:
            if src in node.siblings:
                append node.children
            elif src in node.children:
                append node.sibling and node.other_children
            else:
                raise error
        else:
            if src is node.parent (i.e., root):
                append node.children
            elif src in node.children:
                append node.parent and node.other_children
            else:
                raise error
    else: (i.e., node is derived)
        if src is node.parent:
            append node.children
        elif src in node.children:
            append node.parent and node.other_children
        else:
            raise error
    """
    parent = node.parent
    children = node.children

    # position of node
    pos = ('root' if node.is_root() else 'basal' if parent.is_root()
           else 'derived')

    # whether tree is rooted
    rooted = ((True if len(children) == 2 else False) if pos == 'root'
              else (True if len(parent.children) == 2 else False)
              if pos == 'basal'
              else None)  # don't determine root status if node is derived
    if rooted:
        if pos == 'root':
            raise ValueError('Cannot walk from root of an rooted tree.')
        elif pos == 'basal':
            sibling = [x for x in node.siblings()][0]

    # direction of walking
    move = (('bottom' if src is sibling else 'top' if src in children
            else 'n/a') if rooted and pos == 'basal'
            else ('down' if src is parent else 'up' if src in children
            else 'n/a'))
    if move == 'n/a':
        raise ValueError('Source and node are not neighbors.')

    # create a new node
    res = TreeNode(node.name)

    # determine length of the new node
    res.length = (node.length if move == 'down'
                  else src.length + node.length if move == 'bottom'
                  else src.length)  # up or top

    # append children except for src (if applies)
    res.extend([walk_copy(c, node) for c in children if c is not src])

    # append parent if walking up (except at root)
    if move == 'up' and pos != 'root':
        res.append(walk_copy(parent, node))

    # append sibling if walking from one basal node to another
    if move == 'top':
        res.append(walk_copy(sibling, node))

    return res

### Prerequisite II: re-root in the middle of a branch

#### Background

scikit-bio has the [`root_at`](https://github.com/biocore/scikit-bio/blob/master/skbio/tree/_tree.py#L783) function (which calls the `unrooted_copy` function). However, it actually generates an unrooted tree, and it does not handle the original root in the desired way.

On a rooted tree, the root is gone (desired behavior) but there is a redundant node (`j`).

In [4]:
tree = TreeNode.read(['(((a:1.0,b:0.8)c:2.4,(d:0.8,e:0.6)f:1.2)g:0.4,(h:0.5,i:0.7)j:1.8)k;'])
print(tree.ascii_art())

                              /-a
                    /c-------|
                   |          \-b
          /g-------|
         |         |          /-d
         |          \f-------|
-k-------|                    \-e
         |
         |          /-h
          \j-------|
                    \-i


In [5]:
print(tree.root_at('g').ascii_art())

                    /-a
          /c-------|
         |          \-b
         |
         |          /-d
-root----|-f-------|
         |          \-e
         |
         |                    /-h
          \g------- /j-------|
                              \-i


On an unrooted tree, the root (`j`) is gone (NOT desired behavior).

In [6]:
tree = TreeNode.read(['(((a:0.6,b:0.5)g:0.3,c:0.8)h:0.4,(d:0.4,e:0.5)i:0.5,f:0.9)j;'])
print(tree.ascii_art())

                              /-a
                    /g-------|
          /h-------|          \-b
         |         |
         |          \-c
         |
-j-------|          /-d
         |-i-------|
         |          \-e
         |
          \-f


In [7]:
print(tree.root_at('g').ascii_art())

          /-a
         |
         |--b
-root----|
         |          /-c
         |         |
          \g-------|                    /-d
                   |          /i-------|
                    \h-------|          \-e
                             |
                              \-f


In either situation, the resulting tree is unrooted.

#### The function

In [8]:
def root_above(node, name=None):
    """Re-root a tree between a give node and its parent.

    Parameters
    ----------
    node : skbio.TreeNode
        node above which the new root will be placed
    name : str, optional
        name of the new root

    Returns
    -------
    skbio.TreeNode
        resulting rooted tree

    Notes
    -----
    Unlike scikit-bio's `root_at` function which actually generates an
    unrooted tree, this function generates a rooted tree (the root of
    which has exactly two children).
    """
    # walk down from self node
    left = walk_copy(node, node.parent)

    # walk up from parent node
    right = walk_copy(node.parent, node)

    # set basal branch lengths to be half of the original, i.e., midpoint
    left.length = right.length = node.length / 2

    # create new root
    return TreeNode(name, children=[left, right])


### Tests of the copying / rooting functions

#### Test on a rooted tree (two basal nodes)

In [9]:
tree = TreeNode.read(['(((a:1.0,b:0.8)c:2.4,(d:0.8,e:0.6)f:1.2)g:0.4,(h:0.5,i:0.7)j:1.8)k;'])
print(tree.ascii_art())

                              /-a
                    /c-------|
                   |          \-b
          /g-------|
         |         |          /-d
         |          \f-------|
-k-------|                    \-e
         |
         |          /-h
          \j-------|
                    \-i


Root between `c` and `g`

In [10]:
tree_cg = root_above(tree.find('c'))
print(str(tree_cg))
print(tree_cg.ascii_art())

((a:1.0,b:0.8)c:1.2,((d:0.8,e:0.6)f:1.2,(h:0.5,i:0.7)j:2.2)g:1.2);

                    /-a
          /c-------|
         |          \-b
         |
---------|                    /-d
         |          /f-------|
         |         |          \-e
          \g-------|
                   |          /-h
                    \j-------|
                              \-i


Verify that the new root is at the midpoint of the original branch

In [11]:
print('input:')
print('  g to c: %s' % tree.find('c').length)
print('output:')
print('  root to c: %s' % tree_cg.find('c').length)
print('  root to g: %s' % tree_cg.find('g').length)

input:
  g to c: 2.4
output:
  root to c: 1.2
  root to g: 1.2


Verify that the original root is eliminated and the basal branches are merged

In [12]:
print('input:')
print('  k to g: %s' % tree.find('g').length)
print('  k to j: %s' % tree.find('j').length)
print('output:')
print('  g to j: %s' % tree_cg.find('j').length)

input:
  k to g: 0.4
  k to j: 1.8
output:
  g to j: 2.2


Root between `i` and `j`

In [13]:
tree_ij = root_above(tree.find('i'))
print(str(tree_ij))
print(tree_ij.ascii_art())

(i:0.35,(h:0.5,((a:1.0,b:0.8)c:2.4,(d:0.8,e:0.6)f:1.2)g:2.2)j:0.35);

          /-i
         |
---------|          /-h
         |         |
          \j-------|                    /-a
                   |          /c-------|
                   |         |          \-b
                    \g-------|
                             |          /-d
                              \f-------|
                                        \-e


#### Test on a typical unrooted tree (three basal nodes)

In [14]:
tree = TreeNode.read(['(((a:0.6,b:0.5)g:0.3,c:0.8)h:0.4,(d:0.4,e:0.5)i:0.5,f:0.9)j;'])
print(tree.ascii_art())

                              /-a
                    /g-------|
          /h-------|          \-b
         |         |
         |          \-c
         |
-j-------|          /-d
         |-i-------|
         |          \-e
         |
          \-f


In [15]:
tree_ag = root_above(tree.find('a'))
str(tree_ag)

'(a:0.3,(b:0.5,(c:0.8,((d:0.4,e:0.5)i:0.5,f:0.9)j:0.4)h:0.3)g:0.3);\n'

In [16]:
print(tree_ag.ascii_art())

          /-a
         |
---------|          /-b
         |         |
          \g-------|          /-c
                   |         |
                    \h-------|                    /-d
                             |          /i-------|
                              \j-------|          \-e
                                       |
                                        \-f


In [17]:
tree_gh = root_above(tree.find('g'))
str(tree_gh)

'((a:0.6,b:0.5)g:0.15,(c:0.8,((d:0.4,e:0.5)i:0.5,f:0.9)j:0.4)h:0.15);\n'

In [18]:
print(tree_gh.ascii_art())

                    /-a
          /g-------|
         |          \-b
---------|
         |          /-c
         |         |
          \h-------|                    /-d
                   |          /i-------|
                    \j-------|          \-e
                             |
                              \-f


#### Test on a special unrooted tree (one basal node)

In [19]:
tree = TreeNode.read(['(((a:0.4,b:0.3)e:0.1,(c:0.4,d:0.1)f:0.2)g:0.6)h:0.2;'])
print(tree.ascii_art())

                              /-a
                    /e-------|
                   |          \-b
-h------- /g-------|
                   |          /-c
                    \f-------|
                              \-d


In [20]:
tree_ae = root_above(tree.find('a'))
print(str(tree_ae))
print(tree_ae.ascii_art())

(a:0.2,(b:0.3,((c:0.4,d:0.1)f:0.2,h:0.6)g:0.1)e:0.2);

          /-a
         |
---------|          /-b
         |         |
          \e-------|                    /-c
                   |          /f-------|
                    \g-------|          \-d
                             |
                              \-h


### Restore rooting

In [21]:
def restore_rooting(src, trg):
    """Restore rooting scenario in an unrooted tree based on a rooted tree.

    Parameters
    ----------
    src : skbio.TreeNode
        source tree from which rooting to be read
    trg : skbio.TreeNode
        target tree to which rooting to be set

    Returns
    -------
    skbio.TreeNode
        resulting tree with internal node labels
    """
    if len(src.children) != 2:
        raise ValueError('Source tree must be rooted.')
    if len(trg.children) == 2:
        raise ValueError('Target tree must be unrooted.')
    if set(x.name for x in src.tips()) != set(x.name for x in trg.tips()):
        raise ValueError('Taxa in source and target trees do not match.')

    # create new tree
    res = trg.copy()

    # find one of the two the basal clades in the source tree, which has less
    # descendants, hereby referred to as "outgroup" (the other being "ingroup")
    counts = {x: x.count(tips=True) for x in src.children}
    outgroup = set(x.name for x in min(counts, key=counts.get).tips())

    # locate the lowest common ancestor (LCA) of outgroup in the target tree
    lca = res.lca(outgroup)

    # if LCA is root rather than derived (i.e., outgroup is split across basal
    # clades), swap the tree and locate LCA again
    if lca is res:
        for tip in res.tips():
            if tip.name not in outgroup:
                # `root_at` is a scikit-bio function that generates an unrooted
                # tree in which the node, its parent and its sibling(s) become
                # basal nodes
                res = res.root_at(tip.parent)
                break
        lca = res.lca(outgroup)

    # re-root the target tree between LCA of outgroup and LCA of ingroup
    return root_above(lca)

### Restore node labels

In [22]:
def restore_node_labels(src, trg):
    """Restore internal node labels in one tree based on another tree.

    Parameters
    ----------
    src : skbio.TreeNode
        source tree from which internal node labels to be read
    trg : skbio.TreeNode
        target tree to which internal node labels to be written

    Returns
    -------
    skbio.TreeNode
        resulting tree with internal node labels added

    Notes
    -----
    Labels are assigned based on exact match of all descending taxa.
    Taxa in the source and target trees do not have to be identical.
    """
    # read descendants under each node label in source tree
    label2taxa = {}
    for node in src.non_tips(include_self=True):
        label = node.name
        if label is not None and label != '':
            if label in label2taxa:
                raise ValueError('Duplicated node label %s found.' % label)
            label2taxa[label] = set(x.name for x in node.tips())

    # identify and mark matching nodes per node label in target tree
    res = trg.copy()
    for node in res.non_tips(include_self=True):
        taxa = set(x.name for x in node.tips())
        for label in label2taxa:
            if label2taxa[label] == taxa:
                node.name = label
                break

    return res

### Restore node ordering

In [23]:
def restore_node_order(src, trg):
    """Restore ordering of nodes in one tree based on another tree.

    Parameters
    ----------
    src : skbio.TreeNode
        source tree from which node ordering to be read
    trg : skbio.TreeNode
        target tree to which nodes to be re-ordered accordingly
    label : bool
        corresponding nodes are identified by unique node label, which is
        more efficient than identifying by common set of descendants

    Returns
    -------
    skbio.TreeNode
        resulting tree with nodes re-ordered

    Notes
    -----
    All tips and internal nodes must have unique identifiers.
    """
    res = trg.copy()
    for nsrc in src.traverse():
        if not nsrc.is_tip():
            if nsrc.name is None or nsrc.name == '':
                raise ValueError('There are empty node label(s) in the source'
                                 'tree.')
            ntrg = res.find(nsrc.name)
            name2child = {}
            for child in ntrg.children:
                name2child[child.name] = child
            ntrg.children = []
            for child in nsrc.children:
                ntrg.append(name2child[child.name])
    return res

## Application on a 107-taxa tree

### Background

The 107 NCBI-defined reference genomes is a mini subset of the 10K "Web of Life". We already have a reference topology to describe the evolutionary relationships among them: `tree.nwk`. Now we want to infer meaningful branch lengths, using a genome-scale aligment, `align.fa`. [RAxML](https://sco.h-its.org/exelixis/web/software/raxml/index.html) can achieve this:
```
raxmlHPC -m PROTGAMMALG -f e -p 12345 -s align.fa -t tree.nwk -n test
```
The output, namely `raxml.nwk`, has the re-estimated branch lengths but is unrooted and unlabeled.

### Input files

Original tree with root and node labels.

In [24]:
tori = TreeNode.read('tree.nwk')
tori.count(tips=True)

107

Verify that it is rooted.

In [25]:
len(tori.children)

2

Verify that internal node labels are present and unique.

In [26]:
labels = set()
for node in tori.non_tips():
    label = node.name
    if label is None or label == '':
        raise ValueError('Some internal nodes do not have labels.')
    if label in labels:
        raise ValueError('Duplicated label found: %s.' % label)
    labels.add(label)

RAxML output tree, which is unrooted and unlabeled.

In [27]:
tres = TreeNode.read('raxml.nwk')
tres.count(tips=True)

107

Verify that it is unrooted.

In [28]:
len(tres.children)

3

Verify that the two tree has the same set of taxa.

In [29]:
set(x.name for x in tori.tips()) == set(x.name for x in tres.tips())

True

### Restore rooting

In [30]:
tout = restore_rooting(tori, tres)
tout.count(tips=True)

107

In [31]:
tout.write('output.rooted.nwk')

'output.rooted.nwk'

### Restore node labels

In [32]:
tout2 = restore_node_labels(tori, tout)

In [33]:
tout2.write('output.rooted.labeled.nwk')

'output.rooted.labeled.nwk'

### Restore node odering

In [34]:
tout3 = restore_node_order(tori, tout2)

In [35]:
tout3.write('output.rooted.labeled.ordered.nwk')

'output.rooted.labeled.ordered.nwk'