In [13]:
from ete3 import Tree
import pandas as pd
import numpy as np
import itertools
import os
from collections import defaultdict, Counter

from tqdm import tqdm
tqdm.pandas()

from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True, nb_workers=30, use_memory_fs=False)

INFO: Pandarallel will run on 30 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [146]:
unresolved_plant_names_path = "../data/name_resolution/unresolved_plant_names.csv"
resolved_plant_names_path = "../data/name_resolution/resolved_plant_names.csv"
tree_path = f"../data/trees/resolved_ALLMB_name_resolution_on_none_with_added_ccdb_names.nwk"
families_mrcas_path = "../data/trees/optimal_ALLMB_roots_for_family_partition.csv"

expended_tree_path = f"../data/trees/ALLMB_expanded_by_unresolved_names.tre"
mrca_backbone_path = f"../data/trees/families_mrca_backbone_tree.nwk"
ott_family_trees_dir = "../data/trees/OTT/ott_family_trees/"
ploidb_family_trees_dir = "../../PloiDB/chromevol/with_model_weighting/by_family_on_unresolved_ALLMB_and_unresolved_ccdb/"
mrca_based_tree_path = "../data/trees/families_mrca_based_tree.nwk"
sp_to_fam_path = "../data/trees/species_family_classification.csv"

In [15]:
def process_tree(tree_path: str) -> Tree:
    tree = Tree(tree_path, format=1)
    for leaf in tree.get_leaves():
        leaf.name = leaf.name.replace("_"," ").lower()
    return tree

tree = process_tree(tree_path)

In [17]:
unresolved_plant_names = pd.read_csv(unresolved_plant_names_path).Name.dropna().str.lower().unique().tolist()
resolved_plant_names = pd.read_csv(resolved_plant_names_path).resolved_name.dropna().str.lower().unique().tolist()
plant_names = resolved_plant_names if resolve_plant_names else unresolved_plant_names

print(f"# unresolved_plant_names = {len(unresolved_plant_names):,}")
print(f"# resolved_plant_names = {len(resolved_plant_names):,}")

# unresolved_plant_names = 5,297
# resolved_plant_names = 3,874


In [18]:
tree_names = set(tree.get_leaf_names())

unresolved_plant_names_in_tree = list(tree_names.intersection(set(unresolved_plant_names)))
resolved_plant_names_in_tree = list(tree_names.intersection(set(resolved_plant_names)))
plant_names_in_tree = resolved_plant_names_in_tree if resolve_plant_names else unresolved_plant_names_in_tree

unresolved_plant_names_not_in_tree = list(set(unresolved_plant_names)-set(unresolved_plant_names_in_tree))
resolved_plant_names_not_in_tree = list(set(resolved_plant_names)-set(resolved_plant_names_in_tree))
plant_names_not_in_tree = resolved_plant_names_not_in_tree if resolve_plant_names else unresolved_plant_names_not_in_tree

print(f"# unresolved plant names that are present in the tree = {len(unresolved_plant_names_in_tree):,}")
print(f"# resolved plant names that are present in the tree = {len(resolved_plant_names_in_tree):,}")

# unresolved plant names that are present in the tree = 1,887
# resolved plant names that are present in the tree = 1,970


In [19]:
# compute names that can be added to the tree ad direct children of their genus ancestor
tree_genera = set([name.split(" ")[0] for name in tree.get_leaf_names()])
names_genera = set([name.split(" ")[0] for name in plant_names])

unresolved_missing_names_that_can_be_added = [name for name in unresolved_plant_names_not_in_tree if name.split(" ")[0] in tree_genera]
resolved_missing_names_that_can_be_added = [name for name in resolved_plant_names_not_in_tree if name.split(" ")[0] in tree_genera]

print(f"# out of {len(unresolved_plant_names_not_in_tree):,} missing unresolved names in the tree, {len(unresolved_missing_names_that_can_be_added):,} can be added to the tree as direct children of their genus ancestor")
print(f"# out of {len(resolved_plant_names_not_in_tree):,} missing resolved names in the tree, {len(resolved_missing_names_that_can_be_added):,} can be added to the tree as direct children of their genus ancestor")

# out of 3,410 missing unresolved names in the tree, 2,463 can be added to the tree as direct children of their genus ancestor
# out of 1,904 missing resolved names in the tree, 1,652 can be added to the tree as direct children of their genus ancestor


In [22]:
tree_with_addition = tree.copy()
names_to_keep = [name for name in tree.get_leaf_names() if name.split(" ")[0] in names_genera]
tree_with_addition.prune(names_to_keep, preserve_branch_length=True)

names_to_add_to_tree = resolved_missing_names_that_can_be_added if resolve_plant_names else unresolved_missing_names_that_can_be_added 
print(f"# names that will be added to the tree = {len(names_to_add_to_tree):,}")

genus_to_names_to_add = defaultdict(list)
for name in names_to_add_to_tree:
    genus = name.split(" ")[0]
    if genus in tree_genera:
        genus_to_names_to_add[genus].append(name)
print(f"# genera to add direct children to {len(genus_to_names_to_add):,}")

genus_to_tree_names = defaultdict(list)
for leaf_name in tree_with_addition.get_leaf_names():
    genus = leaf_name.split(" ")[0]
    genus_to_tree_names[genus].append(leaf_name)

print(f"computing lca per genus across {len(genus_to_names_to_add):,} genera")
genus_to_ancestor = dict()
for genus in genus_to_names_to_add:
    if genus in tree_genera:
        genus_names = genus_to_tree_names[genus]
        assert(len(genus_names) > 0)
        if len(genus_names) == 1:
            genus_to_ancestor[genus] = tree_with_addition.search_nodes(name=genus_names[0])[0].up
            continue
        try:
            genus_to_ancestor[genus] = tree_with_addition.get_common_ancestor(genus_names)
        except Exception as e:
            print(f"could not find the ancestor of genus {genus} spannig species {','.join(genus_names)} due to error {e}")

print(f"adding missing species under lca per genus across {len(genus_to_ancestor):,} genera")
for genus in genus_to_ancestor:
    ancestor = genus_to_ancestor[genus]
    names = set(genus_to_names_to_add[genus]) - set(ancestor.get_leaf_names())
    time_to_leaf = ancestor.get_distance(ancestor.get_leaf_names()[0])
    for name in names:
        leaf = ancestor.add_child(name=name, dist=time_to_leaf)

plant_names_for_tree = plant_names_in_tree + names_to_add_to_tree
tree_with_addition.prune(plant_names_for_tree, preserve_branch_length=True)

print(f"# leafs in new tree = {len(tree_with_addition.get_leaf_names()):,}")
tree_with_addition.write(outfile=expended_tree_path)

# names that will be added to the tree = 2,463
# genera to add direct children to 1,187
computing lca per genus across 1,187 genera
adding missing species under lca per genus across 1,187 genera
# leafs in new tree = 4,350


In [23]:
tree_with_addition = Tree(expended_tree_path)
names_counter = Counter(tree_with_addition.get_leaf_names())
duplicated_names = [name for name in names_counter if names_counter[name] > 1]

In [25]:
name_to_leaves = {name: [l for l in tree_with_addition.get_leaves() if l.name == name] for name in duplicated_names}
name_to_anc = {}
for name in name_to_leaves:
    leaves = name_to_leaves[name]
    anc = tree_with_addition.get_common_ancestor(leaves)
    name_to_anc[name] = anc
    if len(anc.get_leaves()) > len(leaves):
        # remove the less resolved leaf appearances
        leaf_to_nsiblings = {leaf: len(leaf.up.get_children()) for leaf in leaves}
        leaves.sort(key = lambda n: leaf_to_nsiblings[n])
        best_leaf = leaves[0]
        for leaf in leaves:
            if leaf != best_leaf:
                leaf.detach()
            
    else:
        anc.name = name
        anc.dist = anc.dist + anc.get_leaves()[0].dist
        for leaf in anc.get_leaves():
            leaf.detach()

In [26]:
names_counter = Counter(tree_with_addition.get_leaf_names())
duplicated_names = [name for name in names_counter if names_counter[name] > 1]
assert(len(duplicated_names) == 0)

In [27]:
tree_with_addition.write(outfile=expended_tree_path)

# alternative - build a tree from the MRCAs trees

## construct the mrca backbone

In [127]:
tree = Tree(tree_path, format=1)
families_mrcas = pd.read_csv(families_mrcas_path)

In [128]:
mrca_node_to_family = families_mrcas.set_index("node")["family"].to_dict()

In [129]:
mrca_tree = tree.copy()
for name in mrca_node_to_family:
    try:
        node = mrca_tree.search_nodes(name=name)[0]
        for child in node.get_children():
            if child.name in mrca_node_to_family:
                raiseValueError(f"child node of {node.name} corresponding to family {mrca_node_to_family[name]} has a child node {child.name} that is the mrca of family {mrca_node_to_family[child.name]}")
            child.detach()
        node.name = mrca_node_to_family[name]
    except:
        print(f"mrca node {name} of family {mrca_node_to_family[name]} is absent from the tree")

mrca node N4116 of family polygalaceae is absent from the tree
mrca node N63 of family cabombaceae is absent from the tree


In [130]:
families = set(mrca_node_to_family.values())
leaves_to_remove =  [l for l in mrca_tree.get_leaves() if l.name not in families]
while len(leaves_to_remove) > 0:
    for leaf in leaves_to_remove:
        leaf.detach()
    leaves_to_remove = [l for l in mrca_tree.get_leaves() if " " in l.name]

In [131]:
for lname in set(mrca_tree.get_leaf_names())-families:
    l = tree.search_nodes(name=lname)[0]
    l.detach()

In [132]:
mrca_tree.write(outfile=mrca_backbone_path, format=1)

## attach to each mrca the family tree given from ott

In [133]:
missing_ott_families = []
missing_families = []
family_to_tree_path = {}
for leaf in mrca_tree.get_leaves():
    family = leaf.name.lower()
    family_tree_path = f"{ott_family_trees_dir}{family}/tree.nwk"
    if not os.path.exists(family_tree_path):
        missing_ott_families.append(family)
        family_tree_path = f"{ploidb_family_trees_dir}{family}/tree.nwk"
        if not os.path.exists(family_tree_path):
            missing_families.append(family)
            continue
    family_to_tree_path[family] = family_tree_path
        
print(f"# families with no ott trees = {len(missing_ott_families)} out of {len(families)}")
print(f"# for {len(missing_families)} families, ploidb trees weren't found either, suggesting that there are only few members of the family in the tree (less that 5) and so they will be excluded anyway")

missing_families = set(missing_families)
for family_node in mrca_tree.get_leaves():
    if family_node.name in missing_families:
        family_node.detach()

# families with no ott trees = 116 out of 317
# for 81 families, ploidb trees weren't found either, suggesting that there are only few members of the family in the tree (less that 5) and so they will be excluded anyway


In [134]:
mrca_tree.get_leaf_names()[:10]

['ochnaceae',
 'linaceae',
 'euphorbiaceae',
 'chrysobalanaceae',
 'dichapetalaceae',
 'rhizophoraceae',
 'erythroxylaceae',
 'phyllanthaceae',
 'malpighiaceae',
 'elatinaceae']

In [135]:
for family_node in mrca_tree.get_leaves():
    family = family_node.name
    try:
        family_tree = Tree(family_to_tree_path[family])
        family_node.add_child(family_tree, dist=0)
    except:
        family_node.detach()

In [137]:
mrca_tree.write(outfile=mrca_based_tree_path)

In [144]:
sp_to_family = {}
for family in families:
    try:
        family_mrca = mrca_tree.search_nodes(name=family)[0]
        for leaf in family_mrca.get_leaves():
            sp_to_family[leaf.name] = family
    except:
        continue

In [147]:
pd.Series(sp_to_family).reset_index().rename(columns={"index":"Plant", 0: "Family"}).to_csv(sp_to_fam_path, index=False)