In [25]:
from ete3 import Tree, TreeStyle
import os
import pickle

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
from collections import Counter

import sys

sys.path.append("/groups/itay_mayrose/halabikeren/tmp/ploidb/data_processing/")
from check_tree_monophyly import add_group_by_property

import pandas as pd
import numpy as np
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=False)
pallete = px.colors.qualitative.Vivid

import plotly.io as pio

pio.templates.default = "plotly_white"

import matplotlib.pyplot as plt

import sys

sys.path.append("/groups/itay_mayrose/halabikeren/tmp/ploidb/")
from services.pbs_service import PBSService

from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True, use_memory_fs=False)

INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [26]:
group_by_options = ["genus", "family"]
tree_name = "ALLMB"  # or "ALLOTB"
add_missing_names = False
resolve_ccdb = False
resolve_tree = False
classification_path = "../trees/wfo_classification_data.csv"
tree_path = f"../trees/resolved_{tree_name}_name_resolution_on_{'ccdb_and_tree' if resolve_ccdb and resolve_tree else ('only_ccdb' if resolve_ccdb else 'none')}_with_added_ccdb_{'and_wo_counts_' if add_missing_names else ''}names.nwk"
classification_data = pd.read_csv(classification_path)
time_points_to_partition_by = [5, 10, 20]
nodes_distances_path = f"./nodes_dist_{os.path.basename(tree_path).replace('nwk', 'csv')}"
time_points_to_partition_by_outpath = (
    f"../trees/time_points_to_internal_nodes_to_partition_by{'_with_missing_data' if add_missing_names else ''}.pkl"
)

In [27]:
tree = Tree(tree_path, format=1)

## Partition by times

In [15]:

def get_relevant_desendents(node: Tree, dist_from_root: float, node_to_dist: dict) -> list[str]:
    if node.is_leaf():
        return []

    if node.name != "" and node_to_dist[node.name] >= dist_from_root:
        return [node.name]

    desc = []
    for child in node.get_children():
        desc += get_relevant_desendents(node=child, dist_from_root=dist_from_root, node_to_dist=node_to_dist)
    return desc


def get_internal_nodes_to_partition_by(tree: Tree, node_to_dist: dict, time_point: int):
    dist_from_root = tree.get_distance(tree.get_leaf_names()[0]) - time_point
    print(f"dist_from_root={dist_from_root} for time_point={time_point}")
    internal_nodes_to_parition_by = set(
        get_relevant_desendents(node=tree, dist_from_root=dist_from_root, node_to_dist=node_to_dist)
    )
    for node_name in internal_nodes_to_parition_by:
        node = tree.search_nodes(name=node_name)[0]
        assert node.get_distance(node.get_leaves()[0]) <= dist_from_root
    return list(internal_nodes_to_parition_by)


In [13]:
if not os.path.exists(nodes_distances_path):
    nodes_distances = pd.DataFrame({"node": [node.name for node in tree.traverse()]})
    print(f"# nodes to compute distance for = {nodes_distances.shape[0]:,}")
    nodes_distances["distance_from_root"] = nodes_distances.node.parallel_apply(lambda node: tree.get_distance(node))
    nodes_distances.to_csv(f"./nodes_dist_{os.path.basename(tree_path).replace('nwk', 'csv')}", index=False)
else:
    nodes_distances = pd.read_csv(f"./nodes_dist_{os.path.basename(tree_path).replace('nwk', 'csv')}")

In [17]:
if os.path.exists(time_points_to_partition_by_outpath):
    with open(time_points_to_partition_by_outpath, "rb") as f:
        time_point_to_internal_nodes_to_parition_by = pickle.load(f)
else:
    node_to_dist = nodes_distances.set_index("node")["distance_from_root"].to_dict()
    time_point_to_internal_nodes_to_parition_by = pd.DataFrame({"time_point": time_points_to_partition_by})
    time_point_to_internal_nodes_to_parition_by[
        "nodes_to_partition_by"
    ] = time_point_to_internal_nodes_to_parition_by.time_point.parallel_apply(
        lambda time_point: get_internal_nodes_to_partition_by(
            tree=tree, node_to_dist=node_to_dist, time_point=time_point
        )
    )
    time_point_to_internal_nodes_to_parition_by = time_point_to_internal_nodes_to_parition_by.set_index("time_point")[
        "nodes_to_partition_by"
    ].to_dict()
    with open(time_points_to_partition_by_outpath, "wb") as outfile:
        pickle.dump(obj=time_point_to_internal_nodes_to_parition_by, file=outfile)

for time_point in time_points_to_partition_by:
    print(
        f"# clades when partitioning by {time_point}M years = {len(time_point_to_internal_nodes_to_parition_by[time_point])}"
    )

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1), Label(value='0 / 1'))), HBox(c…

dist_from_root=315.050124 for time_point=10
dist_from_root=320.050124 for time_point=5
dist_from_root=305.050124 for time_point=20
# clades when partitioning by 5M years = 5541
# clades when partitioning by 10M years = 4038
# clades when partitioning by 20M years = 2329


In [24]:
for time_point in time_points_to_partition_by:
    print(f"QA on time point {time_point}")
    for node_name in time_point_to_internal_nodes_to_parition_by[time_point]:
        node = tree.search_nodes(name=node_name)[0]
        node_age = node.get_distance(node.get_leaves()[0])
        assert node_age <= time_point

QA on time point 5
QA on time point 10
QA on time point 20
