## nb2 RF new function

Version of RF function for toytree distance module.

In [1]:
import toytree
import pandas as pd
import numpy as np
import itertools

### Reference: 

1. ETE3 robinson foulds function
   - Github repo link: ete3 > coretype > tree.py, https://github.com/etetoolkit/ete/blob/master/ete3/coretype/tree.py
   - Documentation ete3 robinson foulds function (treenode robinsonfoulds object in ete3) http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#robinson-foulds-distance



2. Definition of RF 
    - https://rdrr.io/cran/phangorn/man/treedist.html
    - Number of internal nodes in tree 1 + number of internal nodes in tree 2 - 2*(number of internal splits shared by the two trees)

### RF function v1: Visualizing partitions with toytree drop tips

Planning how to implement RF method on one tree

In [152]:
# original 5-tip tree
tree1 = toytree.core.rtree.unittree(5, seed=123)
tree1.draw(ts='p');
tree1.newick

'((r4:0.666667,(r3:0.333333,r2:0.333333)0:0.333333)0:0.333333,(r1:0.666667,r0:0.666667)0:0.333333);'

In [95]:
# Overview: 
# 1) identify as internal branch checking if node is NOT leaf
# 2) use get_leaf_names to get all the tip labels associated with this node
#    - each node has its own group of leaves
# 3) two types of partitions
#    - breaking internal branch = drop_tips for all tips in group
#    - breaking branch near tips = drop_tips for one individual tip in group (doesn't matter in RF)


# Visualizing breaking internal branch
# e.g. break internal branch, partition/split tree seed 123 at 6-8 branch
# 2 resulting "trees" after partition
# tree 1 = drop all tips that are below that node
tree2 = tree1.drop_tips(["r0", "r1"])
tree2.draw(ts = 'p');

# tree 2 = drop all tips that are NOT below that node
tree3 = tree1.drop_tips(["r2", "r3", "r4"])
tree3.draw(ts = 'p');

In [96]:
# mini implementation for 5 tip tree

# store in list
possible_partitions = set()
children_groupings = set()
non_children_groupings = set()

num_internal_nodes = 0

# find all possible partitions due to breaking internal branches
# for each node
for node in tree1.idx_dict.values():
    # check if node is an internal node = NOT a leaf (a node on the tip)
    if not node.is_leaf():
        num_internal_nodes += 1
        # find leaves under each of these internal nodes
        children = tuple(node.get_leaf_names())
        print("children", children)

        # eliminate children group with all of the leaves
        if len(children) == len(tree1.get_tip_labels()):
            # save it first here with variable as...
            pass
        else:
            non_children = tuple(set(children).symmetric_difference(set(tree1.get_tip_labels())))
            print("non-children", non_children) 
            
            children_groupings.add(children)
            non_children_groupings.add(non_children)
            
            # for checking
            partition1 = tree1.drop_tips(children)
            partition2 = tree1.drop_tips(non_children)
            mtree = toytree.MultiTree([partition1, partition2])
            mtree.draw(ts='p');
            
            # store as newick
            partition1 = tree1.drop_tips(children).newick
            partition2 = tree1.drop_tips(non_children).newick
            possible_partitions.add(tuple([partition1, partition2]))
print(children_groupings)
print(non_children_groupings)
print("number of internal nodes:", num_internal_nodes)
            

# find all possible partitions due to breaking leaves (NOT necessary)
#for node in tree1.idx_dict.values():
#    if node.is_leaf():
#        children = node.get_leaf_names()
#        print(children)
#        partition1 = tree1.drop_tips(children)
#        partition1.draw(ts='p');
        
#        partition1 = tree1.drop_tips(children).newick
#        possible_partitions.add(partition1)

print(possible_partitions)
max_partitions = len(possible_partitions)
print("max number of partitions:", max_partitions)

children ('r3', 'r2')
non-children ('r4', 'r0', 'r1')
children ('r1', 'r0')
non-children ('r2', 'r4', 'r3')
children ('r4', 'r3', 'r2')
non-children ('r0', 'r1')
children ('r4', 'r3', 'r2', 'r1', 'r0')
{('r4', 'r3', 'r2'), ('r3', 'r2'), ('r1', 'r0')}
{('r2', 'r4', 'r3'), ('r0', 'r1'), ('r4', 'r0', 'r1')}
number of internal nodes: 4
{('(r1:0.666667,r0:0.666667);', '(r4:0.666667,(r3:0.333333,r2:0.333333)0:0.333333);'), ('(r4:0.666667,(r3:0.333333,r2:0.333333)0:0.333333);', '(r1:0.666667,r0:0.666667);'), ('((r1:0.666667,r0:0.666667)0:0.333333,r4:1);', '(r3:0.333333,r2:0.333333);')}
max number of partitions: 3


### Add a second 5-tip tree for RF comparison

In [153]:
# add a second 5-tip tree for RF comparison
tree5 = toytree.core.rtree.unittree(5, seed=323)
tree5.draw(ts='p');
tree5.newick

'((r4:0.666667,r3:0.666667)0:0.333333,(r2:0.666667,(r1:0.333333,r0:0.333333)0:0.333333)0:0.333333);'

In [91]:
# compare partitions
# symmetric difference not working
rf = len(possible_partitions_5.symmetric_difference(possible_partitions))
print("rf:", rf)

max_rf = max_partitions + max_partitions_5
print("max_rf:", max_rf)

norm_rf = rf/max_rf
print("norm_rf:", norm_rf)


# steel and penny definition
# RF = # of internal edges from tree1 + internal edges from tree 2 - (2*shared internal splits) 
max_rf = (len(possible_partitions_5)) + (len(possible_partitions))
rf_steel_penny = num_internal_nodes + num_internal_nodes_5 - int(2*(len(possible_partitions_5.symmetric_difference(possible_partitions))))
norm_rf_steel_penny = rf_steel_penny/max_rf
print(rf_steel_penny)
print(norm_rf_steel_penny)

rf: 16
max_rf: 16
norm_rf: 1.0
-14
-0.875


### ---
### RF function v2: numpy + binary 

### Counting number of internal edges
Use get_edges function

In [205]:
names = tree1.get_tip_labels()
print(names)

# create dictionary mapping numbers to tip labels
namedict = dict(enumerate(names))
print(namedict)
num_of_internal_edges = 0

# get all edges in terms of their associated nodes
for edge in tree1.get_edges():
    # check if second value of edge (associated node that is further down the tree) is in dictionary keys
    if edge[1] not in list(namedict.keys()):
        # number of internal edges
        print(edge)
        num_of_internal_edges += 1
print("# of internal edges:", num_of_internal_edges)
tree1.draw(ts='p');

['r0', 'r1', 'r2', 'r3', 'r4']
{0: 'r0', 1: 'r1', 2: 'r2', 3: 'r3', 4: 'r4'}
[7 5]
[8 6]
[8 7]
# of internal edges: 3


In [206]:
tree1.get_edges()

array([[6, 0],
       [6, 1],
       [5, 2],
       [5, 3],
       [7, 4],
       [7, 5],
       [8, 6],
       [8, 7]])

### Counting number of shared partitions/splits

In [211]:
# dictionary to associate names to numbers
ndict = {j: i for i, j in enumerate(names)}

# store binary outputs in set
final = set()
for node in tree1.treenode.traverse('preorder'):
    print(node)
    bits = np.zeros(len(tree1), dtype=float)
    for child in node.iter_leaf_names():
        bits[ndict[child]] = True
        print(bits)
    # skip all True (whole tree) or just one true (refers to scenario to just partition a tip on the end)
    if sum(bits) == 1 or sum(bits) == tree1.ntips:
        print("bits skip", bits)
        pass
    else: final.add(tuple(bits))
print(final)


      /-r4
   /-|
  |  |   /-r3
  |   \-|
--|      \-r2
  |
  |   /-r1
   \-|
      \-r0
[0. 0. 0. 0. 1.]
[0. 0. 0. 1. 1.]
[0. 0. 1. 1. 1.]
[0. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
bits skip [1. 1. 1. 1. 1.]

   /-r4
--|
  |   /-r3
   \-|
      \-r2
[0. 0. 0. 0. 1.]
[0. 0. 0. 1. 1.]
[0. 0. 1. 1. 1.]

--r4
[0. 0. 0. 0. 1.]
bits skip [0. 0. 0. 0. 1.]

   /-r3
--|
   \-r2
[0. 0. 0. 1. 0.]
[0. 0. 1. 1. 0.]

--r3
[0. 0. 0. 1. 0.]
bits skip [0. 0. 0. 1. 0.]

--r2
[0. 0. 1. 0. 0.]
bits skip [0. 0. 1. 0. 0.]

   /-r1
--|
   \-r0
[0. 1. 0. 0. 0.]
[1. 1. 0. 0. 0.]

--r1
[0. 1. 0. 0. 0.]
bits skip [0. 1. 0. 0. 0.]

--r0
[1. 0. 0. 0. 0.]
bits skip [1. 0. 0. 0. 0.]
{(1.0, 1.0, 0.0, 0.0, 0.0), (0.0, 0.0, 1.0, 1.0, 0.0), (0.0, 0.0, 1.0, 1.0, 1.0)}


In [212]:
names5 = tree5.get_tip_labels()
# dictionary to associate names to numbers
ndict5 = {j: i for i, j in enumerate(names5)}

final5 = set()
for node in tree5.treenode.traverse('preorder'):
    print(node)
    bits = np.zeros(len(tree5), dtype=float)
    for child in node.iter_leaf_names():
        bits[ndict[child]] = True
        print(bits)
    if sum(bits) == 1 or sum(bits) == tree5.ntips:
        print("bits skip", bits)
        pass
    else: 
        final5.add(tuple(bits))
print(final5)


      /-r4
   /-|
  |   \-r3
--|
  |   /-r2
   \-|
     |   /-r1
      \-|
         \-r0
[0. 0. 0. 0. 1.]
[0. 0. 0. 1. 1.]
[0. 0. 1. 1. 1.]
[0. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
bits skip [1. 1. 1. 1. 1.]

   /-r4
--|
   \-r3
[0. 0. 0. 0. 1.]
[0. 0. 0. 1. 1.]

--r4
[0. 0. 0. 0. 1.]
bits skip [0. 0. 0. 0. 1.]

--r3
[0. 0. 0. 1. 0.]
bits skip [0. 0. 0. 1. 0.]

   /-r2
--|
  |   /-r1
   \-|
      \-r0
[0. 0. 1. 0. 0.]
[0. 1. 1. 0. 0.]
[1. 1. 1. 0. 0.]

--r2
[0. 0. 1. 0. 0.]
bits skip [0. 0. 1. 0. 0.]

   /-r1
--|
   \-r0
[0. 1. 0. 0. 0.]
[1. 1. 0. 0. 0.]

--r1
[0. 1. 0. 0. 0.]
bits skip [0. 1. 0. 0. 0.]

--r0
[1. 0. 0. 0. 0.]
bits skip [1. 0. 0. 0. 0.]
{(0.0, 0.0, 0.0, 1.0, 1.0), (1.0, 1.0, 0.0, 0.0, 0.0), (1.0, 1.0, 1.0, 0.0, 0.0)}


In [199]:
# number of shared internal partitions
len(final.intersection(final5))

1

### Implementation: Class object robinson_foulds in toytree

In [221]:
class robinson_foulds():
    """Returns the Robinson-Foulds distance between two trees.

    Faster cleaner version of RF...

    Parameters
    ----------
    tree1: toytree.ToyTree
        A first toytree instance to compare to another tree.
    tree2: toytree.ToyTree
        A second toytree instance to compare to tree1.
    *args: 
        Additional args TBD.

    Examples
    ---------
    >>> tree1 = toytree.rtree.unittree(10, seed=123)
    >>> tree2 = toytree.rtree.unittree(10, seed=321)
    >>> toytree.distance.treedist.robinson_foulds(tree1, tree2)
    """

    def __init__(self, trees, sampmethod, consensustree=None):
        # store inputs
        self.trees = toytree.core.multitree.MultiTree(trees)
        self.treelist = self.trees.treelist
        self.sampmethod = sampmethod

        # store consensus tree
        self.consensustree = consensustree
        if self.consensustree == None:
            self.consensustree = self.trees.get_consensus_tree() 
        # append consensus tree as last in tree list
        self.trees.treelist.append(self.consensustree)

        # store output
        self.getrfout = {}
        self.samporder = []
        self.data = pd.DataFrame(columns = ['trees', 'RF'])
        

    def get_rf(self):
        """
        Function to get RFs depending on user input (pairwise/random sampling of trees
        vs. compare all trees with consensus tree)
        Returns result in a dictionary, with key as tree # and value as RF value. 
        """
        for idx in range(len(self.trees)):
            ttre = self.treelist[idx]
            
            
            # PART 1: count number of internal edges
            names = ttre.get_tip_labels()
            
            # create dictionary mapping numbers to tip labels
            namedict = dict(enumerate(names))
            # store final number of internal edges
            num_of_internal_edges = 0

            # get all edges in terms of their associated nodes
            for edge in ttre.get_edges():
            # check if second value of edge (associated node that is further down the tree) is in dictionary keys
                if edge[1] not in list(namedict.keys()):
                # number of internal edges
                    num_of_internal_edges += 1
                    
                    
            # PART 2: count number of internal partition
            # create dictionary mapping tip labels to names
            ndict = {j: i for i, j in enumerate(names)}
            
            # save possible internal partitions in set
            internal_partitions = set()
            # use binary notation to record possible partitions
            for node in ttre.treenode.traverse('preorder'):
                bits = np.zeros(len(ttre), dtype=float)
                for child in node.iter_leaf_names():
                    bits[ndict[child]] = True
                # skip all True (whole tree)
                # or skip just one true (refers to scenario where only one tip is partitioned on the end)
                if sum(bits) == 1 or sum(bits) == ttre.ntips:
                    pass
                else: 
                    internal_partitions.add(tuple(bits))
                        
            # save RF data for each tree
            # if last tree, this means this is the RF set for the consensus tree
            if idx == len(self.trees)-1:
                self.getrfout['consensus'] = num_of_internal_edges, internal_partitions
                # remove consensus tree from tree list
                del self.trees.treelist[-1]
            # if not, treat RF set as set for a normal tree that will soon be used for comparisons
            else:
                self.getrfout[idx] = num_of_internal_edges, internal_partitions
            
    
    def compare_rf(self):
        """
        Function to compile tree # and associated RFs into a final data frame as output with self.data
        """
        # follow sampling order if user wants to calculate distances in pairwise/random fashion
        if self.sampmethod == "pairwise" or self.sampmethod == "random":
            # generate sampling order depending on pairwise or random user input
            length = len(self.trees)

            samporder = Sample(length, self.sampmethod)
            self.samporder = samporder.sampling()
        
            # iterate over each pair of trees depending on sampling order
            for idx in range(len(self.trees)-1):      
                t0_ninternaledges = self.getrfout[self.samporder[idx]][0]
                t1_ninternaledges = self.getrfout[self.samporder[idx+1]][0]
                t0_partitions = self.getrfout[self.samporder[idx]][1]
                t1_partitions = self.getrfout[self.samporder[idx+1]][1]
                t0_t1_shared_partitions = len(t0_partitions.intersection(t1_partitions))
                
                rf = t0_ninternaledges + t1_ninternaledges - 2*(t0_t1_shared_partitions)
                max_rf = t0_ninternaledges + t1_ninternaledges
            
                self.data = self.data.append({'trees' : str(self.samporder[idx])+ ", " + str(self.samporder[idx+1]), 
                                              'RF' : rf,
                                              'max_RF': max_rf,
                                              'normalized_rf': rf/max_rf},
                                              ignore_index = True)
        # compares each tree with consensus
        else:
            consensus_ninternaledges = self.getrfout['consensus'][0]
            consensus_partitions = self.getrfout['consensus'][1]
            
            for idx in range(len(self.trees)):
                t0_ninternaledges = self.getrfout[idx][0]
                t0_partitions = self.getrfout[idx][1]
                con_t0_shared_partitions = len(consensus_partitions.intersection(t0_partitions))
                
                rf = consensus_ninternaledges + t0_ninternaledges - 2*(con_t0_shared_partitions)
                max_rf = consensus_ninternaledges + t0_ninternaledges
                
                self.data = self.data.append({'trees' : str(idx) + ", consensus", 
                                              'RF' : rf,
                                              'max_RF': max_rf,
                                              'normalized_rf': rf/max_rf},
                                              ignore_index = True)
        # return data frame as output
        return self.data        
        
        
    def run(self):
        """
        Define run function
        """
        self.get_rf()
        self.compare_rf()