In [1]:
import tskit
import pyslim
import dendropy
import glob
import re
import numpy as np
import pandas as pd
import json
import warnings
import functools
warnings.filterwarnings('ignore')

In [2]:
# variables
# rand_id and rep
rand_id = "V96218R2W5QAOLG"
rep = "0"
## newick string
tree_str = "((B,(D,E)C)A);"
## metadata paths
edges_path = "../../meta/mock_edges.tsv"
sims_sum_path = "../../output/rand_id_params.tsv"
sims_full_path = "../../output/sims_info.tsv"
sims_header_path = "../../output/header_sims_info.tsv"
trees_path = "../../output/"

In [3]:
def add_time(ts, dt):
    '''
    This function returns a tskit.TreeSequence in which `dt`
    has been added to the times in all nodes.
    '''
    #TODO: mutation time
    tables = ts.tables
    nodes_dict = tables.nodes.asdict()
    nodes_dict['time'] = nodes_dict['time'] + dt
    tables.nodes.set_columns(**nodes_dict)
    migrations_dict = tables.migrations.asdict()
    migrations_dict['time'] = migrations_dict['time'] + dt
    tables.migrations.set_columns(**migrations_dict)
    return pyslim.SlimTreeSequence.load_tables(tables)

def get_slim_gens(ts):
    return np.array([p.slim_generation for p in ts.slim_provenances])

def find_split_time(ts1, ts2):
    """
    Given two SLiM tree sequences with shared history, this
    function returns the  
    times (in time ago) for each tree.
    """
    slim_gens1 = get_slim_gens(ts1)
    slim_gens2 = get_slim_gens(ts2)
    # counting SLiM prov before chains diff
    j = 0
    # finding the first diff between provenance chains
    for p1, p2 in zip(ts1.provenances(), ts2.provenances()):
        if p1 != p2:
            break
        record = json.loads(p1.record)
        if record["software"]["name"] == "SLiM":
            j += 1
    if j == 0:
        raise ValueError("No shared SLiM provenance entries.")
    last_slim_gen = ts1.slim_provenances[j - 1].slim_generation
    T1 = abs(last_slim_gen - slim_gens1[-1])
    T2 = abs(last_slim_gen - slim_gens2[-1])
    return T1, T2


def match_nodes(tseqs, split_time):
    """
    Given two SLiM tree sequences, returns a dictionary relating
    the id in ts2 (key) to id in ts1 (item) for  node IDs in the
    two tree sequences that refer to the same node. If split time
    in ts2 (T2) is given, then only nodes before the split are
    considered. Note the only check of equivalency is the slim_id
    of the nodes.
    """
    node_mapping = np.full(tseqs[1].num_nodes, tskit.NULL)
    sids0 = np.array([n.metadata["slim_id"] for n in tseqs[0].nodes()])
    sids1 = np.array([n.metadata["slim_id"] for n in tseqs[1].nodes()])
    alive_at_split1 = tseqs[1].tables.nodes.time >= split_time
    sorted_ids0 = np.argsort(sids0)
    matches = np.searchsorted(
        sids0,
        sids1,
        side='left',
        sorter=sorted_ids0)
    is_1in0 = np.isin(sids1, sids0)
    both = np.logical_and(alive_at_split1, is_1in0)
    node_mapping[both] = sorted_ids0[matches[both]]
    return node_mapping

def sub_metadata(tseqs):
    """
    Work around current bug in `tskit.union`: subbing top-level metadata
    so they match.
    """
    tables0 = tseqs[0].tables
    tables0.metadata = tseqs[1].tables.metadata
    tseqs[0] = pyslim.SlimTreeSequence.load_tables(tables0)

In [4]:
def subtree(focal, edges, taxon_namespace, nodes = None):
    """
    Returns a dictionary of `dendropy.Node` objects from a `pandas.DataFrame`
    with two columns: `edge` and `parent`, which specifies 
    the edge-parent relationships. Only nodes below focal are returned.
    """
    if nodes == None:
        nodes = {}
    if not focal in nodes:
        nodes[focal] = dendropy.Node(taxon=taxon_namespace.get_taxon(focal))
    for i, row in edges.iterrows():
        if row.parent == focal:
            if not row.edge in nodes:
                nodes[row.edge] = dendropy.Node(taxon=taxon_namespace.get_taxon(row.edge))
            nodes[focal].add_child(nodes[row.edge])
            nodes = subtree(row.edge, edges, taxon_namespace, nodes)
    return nodes

def build_tree_from_df(edges):
    """
    Returns a `dendropy.Tree` from a `pandas.DataFrame` with edge-parent
    relationships.
    """
    root_name = edges.edge[edges.parent==""][0]
    taxon_namespace = dendropy.TaxonNamespace(edges.edge.values.tolist())
    nodes = subtree(root_name, edges, taxon_namespace)
    tree = dendropy.Tree(seed_node = nodes[root_name], taxon_namespace=taxon_namespace)
    return(tree)

def add_blen_from_meta(tree, meta, rand_id):
    """
    `meta` is a `pandas.DataFrame` with columns `edge`, `rand_id`, `gens` and
    `rescf`. This function adds branch lengths to the `dendropy.Tree` object 
    using the info in the `meta`.
    """
    rep = '0' # all reps should be the have the sam blens anyway!
    # traversing through tree -- annotating lengths
    for node in tree.postorder_node_iter():
        print(node.taxon.label)
        subset = meta[(meta.edge==node.taxon.label) & (meta.rand_id == rand_id)]
        assert subset.shape[0] == 1
        n_gens = np.floor(subset.gens.values[0]/subset.rescf.values[0])
        node.edge_length= n_gens
        print(node.edge_length)
        print(node.distance_from_tip())
    tree.calc_node_root_distances(return_leaf_distances_only=False)
    tree.calc_node_ages(ultrametricity_precision=False, is_force_max_age=True)
    return tree

In [5]:
## loading metadata
# edges contains all the edges and info about N and number of generations
edges = pd.read_csv(edges_path,sep="\t")
edges.parent= edges.parent.fillna("")
edges["edge"] = edges["edge"].str.replace('_','-')
edges["parent"] = edges["parent"].str.replace('_','-')
# sims_sum and sims_full relate rand_ids to simulation parameters
sims_sum = pd.read_csv(sims_sum_path,sep="\t")
sims_full= pd.read_csv(sims_full_path,sep="\t", header=None)
header = pd.read_csv("../../output/header_sims_info.tsv",sep="\t")
sims_full.columns = header.columns

In [6]:
# getting all output files and grouping by rand_id and rep
tree_files = glob.glob(trees_path+"*.trees")
pattern = f'{rand_id}_rep{rep}'
n_matches = sum(1 for file in tree_files if pattern in file)
# making sure we got all the files
assert n_matches == edges.shape[0]

In [7]:
# getting the phylo tree adn annotating with branch lengths,
tree = build_tree_from_df(edges)
tree = add_blen_from_meta(tree, sims_full, rand_id)

B
50.0
0.0
D
50.0
0.0
E
20.0
0.0
C
100.0
50.0
A
20.0
150.0


In [28]:
def union_tseqs(tree, rand_id, rep):
    """
    Given a `dendropy.tree` object with annotated `edge_lengths`, a `rand_id` 
    identifier and a replicate number `rep`, this performs the 
    `tskit.TableCollection.union` of all leaves in the phylogenetic tree.
    """
    tmp_tseqs = {}
    pops = []
    for node in tree.postorder_node_iter(filter_fn = lambda node: node.is_internal()):
        assert len(node.child_nodes()) == 2, "Polytomies are not supported."
        tseqs = []
        history_len = []
        print(node.taxon.label, "\t", node.age, sep="")
        for child in node.child_nodes():
            print("\t"+child.taxon.label+"\t"+str(child.root_distance)+"\t"+str(child.age))
            history_len.append(child.root_distance+child.age)
            if child.is_leaf():
                pops.append(child.taxon.label)
                tseqs.append(pyslim.load(trees_path+child.taxon.label+"_"+rand_id+"_rep"+rep+".trees"))
            else:
                tseqs.append(tmp_tseqs.pop(child.taxon.label))
        #check if times need be shifted
        print(f"Before shift\ttime 0: {tseqs[0].max_root_time}\ttime 1: {tseqs[1].max_root_time}")
        if history_len[1] > history_len[0]:
            tseqs[0] = add_time(tseqs[0], history_len[1]-history_len[0])
        elif history_len[0] > history_len[1]:
            tseqs[1] = add_time(tseqs[1], history_len[0]-history_len[1])
        print(f"After shift\ttime 0: {tseqs[0].max_root_time}\ttime 1: {tseqs[1].max_root_time}")
        node_mapping = match_nodes(tseqs, node.age)
        sub_metadata(tseqs)
        tmp_tseqs[node.taxon.label] = tseqs[0].union(tseqs[1], node_mapping)
    assert len(tmp_tseqs) == 1
    return tmp_tseqs[list(tmp_tseqs.keys())[0]], pops

In [29]:
tsu,  pops = union_tseqs(tree,rand_id,rep)

C	50.0
	D	150.0	0.0
	E	120.0	0.0
Before shift	time 0: 170.0	time 1: 140.0
After shift	time 0: 170.0	time 1: 170.0
A	150.0
	B	50.0	0.0
	C	100.0	50.0
Before shift	time 0: 70.0	time 1: 170.0
After shift	time 0: 170.0	time 1: 170.0


In [31]:
tsu

<tskit.trees.TreeSequence at 0x11a71fd90>

In [32]:
import sys

In [33]:
sys.getsizeof(tsu)

64

In [35]:
sys.getsizeof(tseqs[1])

64