In [1]:
from random import seed, choice

import os
import heapq
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from skbio.tree import TreeNode

### Overview
**Goal**: Generate a smaller group of clades from the original 10,575 genomes in the phylogeny, whereby the divergence amongst the clades can be limited to a given threshold, and the divergence between the clades is maximized.

This approach is motivated by [WoL](https://github.com/biocore/wol/blob/master/code/notebooks/taxon_subsampling.ipynb) (Zhu et al.)

First, we limit the possible groupings of clades to be chosen from those that have a minimum relative evolutionary divergence (**RED**) ([Parks, et al., 2018](https://www.nature.com/articles/nbt.4229)) over a given threshold, `min_red`.
Then, for each candidate grouping rooted at `clade_root`, to maximize the divergence between clades, we can find the `n_clades` by finding the `n` descendants of `clade_root` that minimize $\sum_{i=1}^n \text{RED}(\texttt{node}_i)$. Specific genomes can then be sampled based on criteria such as:
1. Contains the most marker genes.
2. Contamination level is the lowest.
3. DNA quality score is the highest.
4. Random selection.
5. Sampling of all included genomes.

Qiyun has previously used the first three in unison to select a single. Those three criteria could potentially be used to filter to a smaller list, amongst which all genomes are sampled.

**Effect of paramters**:
* Increasing `min_red` will limit the number of genomes a given group of clades can contain.
* Increasing `n_clades` will increase the number of subgroups of genomes that are created.
Increasing either paramter will increase the resolution that downstream methods need to properly handle such groups of genomes.

In [2]:
seed(42)
%matplotlib inline

In [3]:
tree_fp = 'https://raw.githubusercontent.com/biocore/wol/master/data/trees/astral/branch_length/cons/astral.cons.nid.nwk'
# tree_fp = 'data/trees/astral.nid.nwk'
tree = TreeNode.read(tree_fp)
tree.count(tips=True)

10575

In [4]:
supports_fp = 'https://raw.githubusercontent.com/biocore/wol/master/data/trees/astral/astral.supports.tsv.bz2'
dfs = pd.read_table(supports_fp, index_col=0)

In [5]:
dfs.head()

Unnamed: 0_level_0,EN,LPP,QT
#node,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
N2,196.0,0.998406,0.450953
N3,196.0,0.998406,0.450953
N4,124.0,0.999993,0.53512
N5,195.0,0.914387,0.398648
N6,208.0,1.0,0.539983


In [6]:
genomes_fp = 'https://biocore.github.io/wol/data/genomes/metadata.tsv.bz2'
dfg = pd.read_table(genomes_fp, index_col=0)

In [7]:
dfg.head()

Unnamed: 0_level_0,asm_name,assembly_accession,bioproject,biosample,wgs_master,seq_rel_date,submitter,ftp_path,img_id,gtdb_id,...,coding_density,completeness,contamination,strain_heterogeneity,markers,5s_rrna,16s_rrna,23s_rrna,trnas,draft_quality
#genome,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
G000005825,ASM582v2,GCF_000005825.2,PRJNA224116,SAMN02603086,,2010/12/15,"Center for Genomic Sciences, Allegheny-Singer ...",ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000...,646311908,RS_GCF_000005825.2,...,85.144124,98.68,1.32,0.0,377,yes,yes,yes,20,high
G000006175,ASM617v2,GCF_000006175.1,PRJNA224116,SAMN00000040,,2010/06/03,US DOE Joint Genome Institute (JGI-PGF),ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000...,646564549,RS_GCF_000006175.1,...,80.167033,99.05,0.0,0.0,165,no,yes,yes,19,medium
G000006605,ASM660v1,GCF_000006605.1,PRJNA224116,SAMEA3283089,,2005/06/27,Bielefeld Univ,ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000...,637000085,RS_GCF_000006605.1,...,89.378688,100.0,0.68,0.0,319,yes,yes,yes,20,high
G000006725,ASM672v1,GCF_000006725.1,PRJNA224116,SAMN02603773,,2004/06/04,Sao Paulo state (Brazil) Consortium,ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000...,637000348,RS_GCF_000006725.1,...,82.59299,99.59,0.18,0.0,325,yes,yes,yes,20,high
G000006745,ASM674v1,GCF_000006745.1,PRJNA57623,SAMN02603969,,2001/01/09,TIGR,ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000...,637000333,RS_GCF_000006745.1,...,86.533164,99.86,0.03,0.0,360,yes,yes,yes,20,high


In [8]:
def calc_brlen_metrics(tree):
    """Calculate branch length-related metrics.
    
    Originally from Zhu et al.

    Parameters
    ----------
    tree : skbio.TreeNode

    Notes
    -----
    The following metrics are calculated:
    
        - height: Sum of branch lengths from the root to the node.

        - depths: Sums of branch lengths from all descendants to current node.

        - red: Relative evolutionary divergence (RED), introduced by Parks,
          et al., 2018, Nat Biotechnol.

              RED = p + (d / u) * (1 - p)

          where p = RED of parent, d = length, u = mean depth of parent
    
    Metrics will be appended to each node of tree in place.
    """
    # calculate depths
    for node in tree.postorder(include_self=True):
        if node.name is None:
            raise ValueError('Error: Found an unnamed node.')
        if node.length is None:
            node.length = 0.0
        if node.is_tip():
            node.depths = [0.0]
            node.taxa = [node.name]
        else:
            node.depths = [
                y + x.length for x in node.children for y in x.depths]
            node.taxa = sorted(set().union(*[x.taxa for x in node.children]))

    # calculate heights and REDs
    for node in tree.preorder(include_self=True):
        if node.is_root():
            node.height = 0.0
            node.red = 0.0
        else:
            node.height = node.parent.height + node.length
            if node.is_tip():
                node.red = 1.0
            else:
                node.red = node.parent.red + node.length \
                    / (node.length + sum(node.depths) / len(node.depths)) \
                    * (1 - node.parent.red)

In [9]:
calc_brlen_metrics(tree)

In [10]:
min_red = 0.1 # between 0 and 1
n_clades = 4
clade_roots = []
candidates = [tree]
while candidates:
    node = candidates.pop()
    if node.red >= min_red:
        clade_roots.append(node)
    else:
        candidates.extend(child for child in node)
    
print(clade_roots)

tree_dict = {node.name: node for node in tree.traverse()}

groupings = []
for clade_root in clade_roots:
    
    if len(list(clade_root.tips())) < n_clades:
        groupings.append([])
        continue
        
    chosen_clades = [(sum(node.red for node in clade_root),
                     clade_root.name)]

    while len(chosen_clades) < n_clades:
        # if chosen_clades[0][0] < 3:
        red_score, node_name = heapq.heappop(chosen_clades)
        node = tree_dict[node_name] 
        if node.children:
            for child in node:
                heapq.heappush(chosen_clades,
                               (sum(gc.red for gc in child),
                                child.name))
        else:
            heapq.heappush(chosen_clades, (3, node.name))
            
    groupings.append([tree_dict[node_name] for 
                      _, node_name in chosen_clades])

print(groupings)

[<TreeNode, name: N7, internal node count: 8450, tips count: 8452>, <TreeNode, name: N6, internal node count: 1452, tips count: 1454>, <TreeNode, name: N5, internal node count: 663, tips count: 665>, <TreeNode, name: N4, internal node count: 2, tips count: 4>]
[[<TreeNode, name: N39, internal node count: 3704, tips count: 3706>, <TreeNode, name: N13, internal node count: 41, tips count: 43>, <TreeNode, name: N38, internal node count: 0, tips count: 2>, <TreeNode, name: N24, internal node count: 4699, tips count: 4701>], [<TreeNode, name: N34, internal node count: 886, tips count: 888>, <TreeNode, name: N19, internal node count: 502, tips count: 504>, <TreeNode, name: N33, internal node count: 59, tips count: 61>, <TreeNode, name: G001873755, internal node count: 0, tips count: 0>], [<TreeNode, name: N28, internal node count: 240, tips count: 242>, <TreeNode, name: N11, internal node count: 366, tips count: 368>, <TreeNode, name: N27, internal node count: 46, tips count: 48>, <TreeNode,

In [11]:
value_counts = [dfg.loc[node.taxa]['lv2_group'].value_counts()
                for node in groupings[0]]

In [12]:
for item in value_counts:
    print(str(item.sort_index()) + '\n')

Actinobacteria    1096
Bacteria            85
Bacteroidetes        1
Chloroflexi        146
Cyanobacteria      295
Firmicutes        1944
PVC                  1
Proteobacteria       3
Spirochaetes         1
Terrabacteria      134
Name: lv2_group, dtype: int64

Bacteria      39
Firmicutes     4
Name: lv2_group, dtype: int64

Bacteria    2
Name: lv2_group, dtype: int64

Actinobacteria       1
Bacteria           350
Bacteroidetes      833
Chlamydiae         106
Chloroflexi          1
FCB                140
PVC                165
Proteobacteria    2971
Spirochaetes       134
Name: lv2_group, dtype: int64

