### Load in modules

In [1]:
from ete3 import Tree
import pandas as pd
import json
import copy
import copy


### Load in data

In [99]:
tree_path = './final.matOptimize.rerooted.pruned.nh'
label_path = './seq_pango_comparison.tsv'

## function to reduce PANGO lvl
def lin_reduce(lin, lvl=3):
    return '.'.join(lin.split('.')[:lvl])

## load in tree and annotations
tree = Tree(tree_path, format=1)
grping_df = pd.read_csv(label_path, sep='\t')
grping = dict(zip(grping_df.seqName.values, [lin_reduce(x) for x in grping_df.new.values]))


### Prep and label tree

In [101]:
tree_copy = tree.copy('newick-extended')
tree_grp_dist = {}

for node in tree_copy.traverse():
    depth = tree_copy.get_distance(node, topology_only=True)
    if node.is_leaf():
        leaf_grp = grping[node.name]
        node.add_features(grp=leaf_grp, depth=depth)
        if leaf_grp in tree_grp_dist:
            tree_grp_dist[leaf_grp]['count'] += 1
            if tree_grp_dist[leaf_grp]['depth'] > depth: tree_grp_dist[leaf_grp]['depth']  = depth
        else:
            tree_grp_dist[leaf_grp] = { 'count': 1, 'depth': depth }
    else:
        node.add_features(depth=depth)
    
print('Taxa count: %d' % len(tree_copy.get_leaves()))


Taxa count: 115622


In [102]:
## instead of traverse path from every foreign leaf to the MRCA, check for monophyletic groups
## for each monophyletic group (where exists), traverse from the top-most leaf in the group
def get_top_leaf(tree):
    if tree.is_leaf():
        return tree
    
    for node in tree.iter_descendants('levelorder'):
        if node.is_leaf():
            return node
        

In [103]:
DEFAULT_PARAMS = {
    'n_g_min': 10,
    'alpha': 0.6,
    'rho_f_max': 0.25,
    'rho_g_w': 0.6,
    'major_prop_min': 0.8,
    'minor_show_max': 3,
    'global_prop_indicate': {
        1.0: '*',
        0.9: '++',
        0.8: '+'
    }
}


In [104]:
####################     partition class     ####################

class Partition():
    
    def __init__(self, refs, root_node, params=DEFAULT_PARAMS):
        self.refs = { 'tree': refs['tree'], 'grp_dist': refs['grp_dist'] }
        self.actives = {'tree': root_node } ## need to create actives first to call on calc_grp_dist()
        self.actives['grp_dist'] = self.calc_grp_dist()
        self.count = len(root_node.get_leaves())
        self.is_candidate = True
        self.params = params
        self.set_name_partition()
        
        self.tree_up = root_node.up.name if root_node.up else None
        self.tree_up_dist = root_node.dist if root_node.up else 0
        
        
    def validate(self):
        ref_root_node = self.refs['tree'].search_nodes(name=self.actives['tree'].name)[0]
        ref_root_anc_node = ref_root_node.up if not ref_root_node.is_root() else None
        self.actives['tree'] = ref_root_node.detach()
        self.set_name_partition()
        self.update_ref_grp_dist()
        self.is_candidate = False

        ## strip remaining ref_tree of naked internal nodes
        if ref_root_anc_node:
            for leaf in ref_root_anc_node.iter_leaves():
                curr = leaf
                while curr.is_leaf() and not curr.is_root() and not hasattr(curr, 'grp'):
                    prev = curr
                    curr = curr.up
                    prev.detach()
                            
        
    def is_root(self):
        return self.actives['tree'].is_root()

        
    def get_leaves(self, grp_name=None, count_only=False, n_g_min_aware=False):        
        sf_leaves = sorted(list(filter(lambda leaf: (not grp_name or leaf.grp == grp_name) and \
                                       (not n_g_min_aware or self.refs['grp_dist'][leaf.grp]['count'] >= self.params['n_g_min']), \
                                       self.actives['tree'].get_leaves())), key=lambda leaf: leaf.depth, reverse=True)

        return (len(sf_leaves) if count_only else sf_leaves)
        
        
    def calc_grp_dist(self):
        grp_dist = {}
        for leaf in self.actives['tree'].iter_leaves():
            grp = leaf.grp
            depth = leaf.depth
            if grp in grp_dist:
                grp_dist[grp]['count'] += 1
                if depth > leaf.depth:
                    grp_dist[grp]['depth'] = leaf.depth
            else:
                grp_dist[grp] = { 'count': 1, 'depth': leaf.depth }
                    
        return grp_dist
    
    
    ## get majority grp (for naming)
    def get_majority_grp(self):
        sorted_grp_dist = self.sort_grp_dist(sort_by='size')
        return sorted_grp_dist[0]
    
    
    ## get sorted grp dist
    def sort_grp_dist(self, ref=False, sort_by='depth', n_g_min_aware=False):
        grp_dist = self.refs['grp_dist'] if ref else self.actives['grp_dist']
        pre_sort = [{ 'name': k, 'count': v['count'], 'depth': v['depth'] } for k, v in grp_dist.items()]
        
        if n_g_min_aware:
            pre_sort = list(filter(lambda grp: grp['count'] >= self.params['n_g_min'], pre_sort))
        
        sorted_by_size = sorted(pre_sort, key=lambda grp: grp['count'], reverse=sort_by == 'size')
        if sort_by == 'size':
            return sorted_by_size
        else:
            sorted_by_depth = sorted(sorted_by_size, key=lambda grp: grp['depth'], reverse=True)
            return sorted_by_depth
    
    
    ## get labelling name
    def set_name_partition(self):
        
        def modify_grp_name(grp):
            global_prop = grp['count']/tree_grp_dist[grp['name']]['count']
            for thres in self.params['global_prop_indicate'].keys():
                if global_prop >= thres:
                    return (grp['name'] + self.params['global_prop_indicate'][thres])

            return grp['name']
        
        sorted_grp_dist = self.sort_grp_dist(sort_by='size')
        
        is_single_majority = sorted_grp_dist[0]['count']/self.count >= self.params['major_prop_min']
        majority_grps = []
        minority_grps = []
        for grp in sorted_grp_dist:
            grp_prop = grp['count']/self.count
            if grp_prop >= self.params['major_prop_min']:
                majority_grps.append(grp)
            elif grp_prop < (1 - self.params['major_prop_min']):
                minority_grps.append(grp)
            else:
                if is_single_majority:
                    minority_grps.append(grp)
                else:
                    majority_grps.append(grp)
        
        majority_label = '/'.join([modify_grp_name(grp) for grp in majority_grps]) + \
                        ('/' if not is_single_majority and len(majority_grps) == 1 else '')
        minority_label = '(%s%s%s)' % ('/'.join([modify_grp_name(grp) for grp in minority_grps[:self.params['minor_show_max']]]), \
                        ('/...' if len(minority_grps) > self.params['minor_show_max'] else ''), \
                        ('[n=%d]' % len(minority_grps) if len(minority_grps) > self.params['minor_show_max'] else ''))
        
        self.name = (majority_label + (minority_label if len(minority_grps) else ''))
                
    
    def update_ref_grp_dist(self):
        for k, v in self.actives['grp_dist'].items():
            self.refs['grp_dist'][k]['count'] -= v['count']
            if self.refs['grp_dist'][k]['count'] == 0:
                del self.refs['grp_dist'][k]
            
            
    def calc_grp_obj(self, grp_name):
        n_g = self.actives['grp_dist'][grp_name]['count']
        n_c = self.get_leaves(count_only=True, n_g_min_aware=True)
        N_g = self.refs['grp_dist'][grp_name]['count']
        rho_c = n_g/n_c
        rho_g = n_g/N_g
                
        g_thres = n_g >= self.params['n_g_min']
        f_thres = (n_c - n_g) <= n_c*self.params['rho_f_max']*(1+rho_g*self.params['rho_g_w'])
        obj = self.params['alpha']*rho_c + (1 - self.params['alpha'])*rho_g
        
#         print(g_thres, n_g, self.params['n_g_min'])
#         print(f_thres, (n_c - n_g)/n_c, self.params['rho_f_max']*(1+rho_g*self.params['rho_g_w']), obj)
        
        return (g_thres and f_thres, obj)
            
            
    def find_partition(self, grp_name, node=None, starting_leaf=None):

        ## get farthest leaf in grp of interest
        root_node = node if node else self.refs['tree']
        grp_leaves = root_node.search_nodes(grp=grp_name)
        grp_mrca = root_node.get_common_ancestor(grp_leaves)
        depth_sorted_leaves = sorted(grp_leaves, key=lambda leaf: leaf.depth, reverse=True)
        
        ## start from deepest leaf, if no valid partition can be found then move onto to next deepest leaf
        iter_n = 10
        for i in range(min(iter_n, len(depth_sorted_leaves))):
            target_leaf = starting_leaf if starting_leaf else depth_sorted_leaves[i]

            ## evaluate obj as the path to the MRCA is traversed
            curr = target_leaf.up
            max_obj = -999
            max_obj_partition = None

            while True:
                curr_partition = Partition(self.refs, curr)
                pass_obj = curr_partition.calc_grp_obj(grp_name)

                if pass_obj[0] and pass_obj[1] >= max_obj:
                    max_obj_partition = curr_partition
                    max_obj_partition.obj = pass_obj[1]
                    max_obj = pass_obj[1]

                if curr.name == grp_mrca.name: break
                else: curr = curr.up
    #                 print(curr.name)

    #         print(max_obj)

            if starting_leaf or max_obj_partition:
                return max_obj_partition

        return None
    
    
    def check_foreign_partition(self, focus_grp):
        
        fgn_sorted_grps = filter(lambda grp: grp['name'] != focus_grp, 
                                 self.sort_grp_dist(sort_by='depth', n_g_min_aware=True))
        for grp in fgn_sorted_grps:
                        
            mono_leaves = map(lambda mono: get_top_leaf(mono),
                                 self.actives['tree'].get_monophyletic(values=[grp['name']], target_attr='grp'))            
            
            for leaf in mono_leaves:
                fgn_candidate = self.find_partition(grp['name'],
                                                    node=self.actives['tree'],
                                                    starting_leaf=leaf)

                if fgn_candidate:
                    return (fgn_candidate, grp['name'])    
            

In [105]:
refs = { 'tree': tree_copy.copy('newick-extended'), 'grp_dist': copy.deepcopy(tree_grp_dist) }
ref_partition = Partition(refs, tree_copy.copy('newick-extended'))


In [106]:
## main program starts here

partitions = []

start_time = time.time()
while True:
    
    print('Calculating most updated grp distribution...')
    depth_sorted_grps = ref_partition.sort_grp_dist(ref=True, sort_by='depth', n_g_min_aware=True)
    
    remaining_grp_exists = len(depth_sorted_grps)
    for grp in depth_sorted_grps:
                
        partition_stored = False
        focus_grp = grp['name']
        
        print(focus_grp)
        
        partition_candidate = ref_partition.find_partition(focus_grp)
        
        while partition_candidate:
            
            print('focus_grp: ', focus_grp)
            fgn_candidate = partition_candidate.check_foreign_partition(focus_grp)
            
            if fgn_candidate:
                partition_candidate, focus_grp = fgn_candidate
                
                continue
                
            else:
                partition_candidate.validate()
                partitions.append(partition_candidate)
                partition_stored = True

                print('\n%d' % (len(partitions) + 1))
                print(partition_candidate.name)
                print(partition_candidate.actives['tree'].name)
                print(partition_candidate.actives['grp_dist'])
                
                break
        
        if partition_stored:
            break ## break out to update global group distribution
    
    else:
        if remaining_grp_exists:
            continue
        else:
            break
            
print(time.time() - start_time)


Calculating most updated grp distribution...
BA.1.15
focus_grp:  BA.1.15

2
BA.1.15
node_11356
{'BA.1.15': {'count': 22, 'depth': '9.0'}}
Calculating most updated grp distribution...
BA.1.15
focus_grp:  BA.1.15

3
BA.1.15++(BA.1)
node_15136
{'BA.1.15': {'count': 12222, 'depth': '12.0'}, 'BA.1': {'count': 7, 'depth': '16.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.2.22
focus_grp:  BA.2.22

4
BA.2.22*(BA.2)
node_18103
{'BA.2.22': {'count': 39, 'depth': '7.0'}, 'BA.2': {'count': 1, 'depth': '9.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.8
focus_grp:  BA.1.8

5
BA.1.8*
node_14341
{'BA.1.8': {'count': 46, 'depth': '8.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.5
focus_grp:  BA.1.5

6
BA.1.5*
node_14305
{'BA.1.5': {'count': 81, 'depth': '7.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.12
focus_grp:  BA.1.12

7
BA.1.12++(BA.1)
node_497
{'BA.1.12': {'count': 292, 'depth': '8.0'}, 'BA.1': {'count': 4, 'depth': '10.0'}}
Calc

BA.1.16
BA.1.9
BA.2.13
BA.1.21
BA.1.19
BA.1.18
BA.2.23
BA.1.17
BA.1.14
BA.1.1
focus_grp:  BA.1.1
focus_grp:  BA.1

40
BA.1/(BA.1.17/BA.1.21)
node_6053
{'BA.1': {'count': 15, 'depth': '18.0'}, 'BA.1.17': {'count': 2, 'depth': '16.0'}, 'BA.1.21': {'count': 2, 'depth': '17.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.16
BA.1.9
BA.1.21
BA.2.13
BA.1.19
BA.1.18
BA.2.23
BA.1.17
BA.1.14
BA.1.1
focus_grp:  BA.1.1
focus_grp:  BA.1

41
BA.1(BA.1.1)
node_6079
{'BA.1': {'count': 38, 'depth': '10.0'}, 'BA.1.1': {'count': 4, 'depth': '10.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.16
BA.1.9
BA.1.21
BA.2.13
BA.1.19
BA.1.18
BA.2.23
BA.1.17
BA.1.14
BA.1.1
focus_grp:  BA.1.1
focus_grp:  BA.1

42
BA.1
node_8711
{'BA.1': {'count': 17, 'depth': '9.0'}}
Calculating most updated grp distribution...
BA.1.15
BA.1.16
BA.1.9
BA.1.21
BA.2.13
BA.1.19
BA.1.18
BA.2.23
BA.1.17
BA.1.14
BA.1.1
focus_grp:  BA.1.1
focus_grp:  BA.1

43
BA.1(BA.1.17/BA.1.16/BA.1.21/...[n=4])
node_11404
{

In [133]:
## rename partitions to ensure uniqueness
partition_names = []

count = 0
print('#partitions: %d\n' % len(partitions))
for i, p in enumerate(partitions):
    count += p.count
    name_count = partition_names.count(p.name)
    partition_names.append(p.name)
    if name_count:
        p.name = '%s_%d' % (p.name, name_count)
    print(i+1)
    print('%s\n%s' % (p.name, str(p.actives['grp_dist'])))
    print('\n')
    
print('#leaves: %d' % count)
    

#partitions: 51

1
BA.1.15
{'BA.1.15': {'count': 22, 'depth': '9.0'}}


2
BA.1.15++(BA.1)
{'BA.1.15': {'count': 12222, 'depth': '12.0'}, 'BA.1': {'count': 7, 'depth': '16.0'}}


3
BA.2.22*(BA.2)
{'BA.2.22': {'count': 39, 'depth': '7.0'}, 'BA.2': {'count': 1, 'depth': '9.0'}}


4
BA.1.8*
{'BA.1.8': {'count': 46, 'depth': '8.0'}}


5
BA.1.5*
{'BA.1.5': {'count': 81, 'depth': '7.0'}}


6
BA.1.12++(BA.1)
{'BA.1.12': {'count': 292, 'depth': '8.0'}, 'BA.1': {'count': 4, 'depth': '10.0'}}


7
BA.1.12
{'BA.1.12': {'count': 13, 'depth': '9.0'}}


8
BA.1.20++(BA.1/BA.1.15)
{'BA.1.20': {'count': 1167, 'depth': '8.0'}, 'BA.1': {'count': 22, 'depth': '8.0'}, 'BA.1.15': {'count': 1, 'depth': '9.0'}}


9
BA.1.16
{'BA.1.16': {'count': 29, 'depth': '9.0'}}


10
BA.1.16++(BA.1/BA.1.18/BA.1.21)
{'BA.1.16': {'count': 1740, 'depth': '9.0'}, 'BA.1': {'count': 10, 'depth': '9.0'}, 'BA.1.18': {'count': 1, 'depth': '11.0'}, 'BA.1.21': {'count': 1, 'depth': '11.0'}}


11
BA.2.12*
{'BA.2.12': {'count': 10, 'dept

## Construct partition-tree

In [148]:
## need empty partition as root partition
root_partition_node = Tree(name='root', format=1)

i_nodes = {}
p_nodes = {}
for p in partitions:
    temp_t = Tree(name=p.name, format=1)
    temp_t.dist = p.tree_up_dist
    temp_t.add_features(p=p)
    p_nodes[p] = temp_t
    
for p, n in p_nodes.items():
    curr_up = tree_copy.search_nodes(name=p.tree_up)[0]
    curr_p_tree = n
    anc_found = False
    while not anc_found:
        for pp, nn in p_nodes.items():
            if curr_up.is_root():
                root_partition_node.add_child(curr_p_tree)
                anc_found = True
                break
            elif curr_up.name in i_nodes:
                i_nodes[curr_up.name] = i_nodes[curr_up.name].add_child(curr_p_tree)
                anc_found = True
                break
            elif pp.actives['tree'].search_nodes(name=curr_up.name):
                nn.add_child(curr_p_tree)
                anc_found = True
                break

        if not anc_found:
            anc_part = Tree(name=curr_up.name, format=1)
            anc_part.add_child(curr_p_tree)
            curr_p_tree = anc_part
            i_nodes[curr_p_tree.name] = curr_p_tree
                
        if not curr_up.is_root():
            curr_up = curr_up.up


In [149]:
for node in root_partition_node.traverse():
    if node.name.startswith('BA'):
        count = node.p.count
        node.name = '%s|N=%d' % (node.name, count)
    

In [150]:
print(root_partition_node.get_ascii())


                                                       /-BA.2.22*(BA.2)|N=40
                                                      |
                                                      |--BA.2.12*|N=10
                                                      |
                                                      |--BA.2.2*|N=51
                                                      |
    /BA.2++(BA.2.13*/BA.2.23++/BA.2.5*/...[n=16])|N=4320-BA.2.3*(BA.2)|N=224
   |                                                  |
   |                                                  |--BA.2.10++(BA.2/BA.2.23)|N=530
   |                                                  |
   |                                                  |--BA.2.9++(BA.2)|N=1055
   |                                                  |
   |                                                   \-BA.2.1++|N=64
   |
   |                                                /-BA.1.15++(BA.1)|N=12229
   |                                            

In [98]:
root_partition_node.write(format=1, outfile='n115622.outliers_removed.partitionTree.tre')


### Label all nodes by partition

In [153]:
node_partition = {}
for p in partitions:
    for node in p.actives['tree'].traverse():
        node_partition[node.name] = p.name


In [158]:
with open('n115622.outliers_removed.node_partition.tsv', 'w+') as outfile:
    outfile.write('node\tpartition\n')
    outfile.write('\n'.join(['%s\t%s' % (k, v) for k, v in node_partition.items()]))
    

### Deconstruct partition-tree into nodes/links

In [210]:
part_nodes = [
    {
        'id': node.name,
        'num': node.p.count,
        'type': 3,
        'divs': [
            {
                'label': k,
                'num': v['count']
            } for k, v in node.p.actives['grp_dist'].items()
        ]
    } for node in part_tree.traverse()
]


In [211]:
part_links = [
    {
        'id': '%s_%s' % (node.up.name if node.up else '', node.name),
        'source': node.up.name if node.up else '',
        'target': node.name,
        'len': node.dist,
    } for node in part_tree.traverse()
]
part_links = list(filter(lambda link: link['source'], part_links))


### Get group distribution

In [170]:
with open('n115622.outliers_removed.partition_summary.tsv', 'w+') as outfile:
    for p in partitions:
        grp_dist = p.actives['grp_dist']
        total_count = sum([v['count'] for v in grp_dist.values()])
        outfile.write('%s (root: %s)\n' % (p.name, p.actives['tree'].name))
        for grp, v in grp_dist.items():
            outfile.write('%s (n=%d (%.2f%%))\n' % (grp, v['count'], (v['count']*100/total_count)))
        outfile.write('\n')


## Write to file

In [212]:
with open('test_n1307_nodes.json', 'w+') as outfile:
    outfile.write(json.dumps(part_nodes))
    

In [213]:
with open('test_n1307_links.json', 'w+') as outfile:
    outfile.write(json.dumps(part_links))
    

In [214]:
nodes_links = {
    'nodes': part_nodes,
    'links': part_links
}
with open('test_n1307_nodes_links.json', 'w+') as outfile:
    outfile.write(json.dumps(nodes_links))
    