In [1]:
import tskit
import pyslim
import msprime
import dendropy
import glob
import re
import numpy as np
import pandas as pd
import json
import warnings
import functools
warnings.filterwarnings('ignore')

In [53]:
def add_time(ts, dt):
    '''
    This function returns a tskit.TreeSequence in which `dt`
    has been added to the times in all nodes.
    '''
    tables = ts.tables
    nodes_dict = tables.nodes.asdict()
    nodes_dict['time'] = nodes_dict['time'] + dt
    tables.nodes.set_columns(**nodes_dict)
    migrations_dict = tables.migrations.asdict()
    migrations_dict['time'] = migrations_dict['time'] + dt
    tables.migrations.set_columns(**migrations_dict)
    mutations_dict = tables.mutations.asdict()
    if not np.any(np.isnan(mutations_dict['time'])):
        mutations_dict['time'] = mutations_dict['time'] + dt
        tables.mutations.set_columns(**mutations_dict)
    return pyslim.SlimTreeSequence.load_tables(tables)


def match_nodes(tseqs, split_time):
    """
    Given two SLiM tree sequences, returns a dictionary relating
    the id in ts2 (key) to id in ts1 (item) for  node IDs in the
    two tree sequences that refer to the same node. If split time
    in ts2 (T2) is given, then only nodes before the split are
    considered. Note the only check of equivalency is the slim_id
    of the nodes.
    """
    node_mapping = np.full(tseqs[1].num_nodes, tskit.NULL)
    sids0 = np.array([n.metadata["slim_id"] for n in tseqs[0].nodes()])
    sids1 = np.array([n.metadata["slim_id"] for n in tseqs[1].nodes()])
    alive_before_split1 = tseqs[1].tables.nodes.time >= split_time
    sorted_ids0 = np.argsort(sids0)
    matches = np.searchsorted(
        sids0,
        sids1,
        side='left',
        sorter=sorted_ids0)
    is_1in0 = np.isin(sids1, sids0)
    both = np.logical_and(alive_before_split1, is_1in0)
    node_mapping[both] = sorted_ids0[matches[both]]
    return node_mapping

def sub_metadata(tseqs):
    """
    Work around current bug in `tskit.union`: subbing top-level metadata
    so they match.
    """
    tables0 = tseqs[0].tables
    tables0.metadata = tseqs[1].tables.metadata
    tseqs[0] = pyslim.SlimTreeSequence.load_tables(tables0)

def msp_mutation_rate_map(intervals, total_rate, intervals_rate, length):
    """
    Takes a `pd.DataFrame` with three columns (?, start, end), with 0-indexed [start, end) intervals.
    Returns breaks and rates to use with `msprime.mutate`, in which the rate for `msprime` will be
    `total_rate-intervals_rate` within the intervals.
    """
    breaks = [0]
    rates = []
    for (i, c, start, end) in intervals.itertuples():
        if start not in breaks:
            breaks.append(start)
            rates.append(total_rate)
        breaks.append(end)
        rates.append(total_rate-intervals_rate)
    if not np.isclose(breaks[-1], length):
        breaks.append(length)
        rates.append(total_rate)
    return msprime.RateMap(breaks, rates)

In [4]:
def subtree(focal, edges, taxon_namespace, nodes = None):
    """
    Returns a dictionary of `dendropy.Node` objects from a `pandas.DataFrame`
    with two columns: `edge` and `parent`, which specifies 
    the edge-parent relationships. Only nodes below focal are returned.
    """
    if nodes == None:
        nodes = {}
    if not focal in nodes:
        nodes[focal] = dendropy.Node(taxon=taxon_namespace.get_taxon(focal))
    for i, row in edges.iterrows():
        if row.parent == focal:
            if not row.edge in nodes:
                nodes[row.edge] = dendropy.Node(taxon=taxon_namespace.get_taxon(row.edge))
            nodes[focal].add_child(nodes[row.edge])
            nodes = subtree(row.edge, edges, taxon_namespace, nodes)
    return nodes

def build_tree_from_df(edges):
    """
    Returns a `dendropy.Tree` from a `pandas.DataFrame` with edge-parent
    relationships.
    """
    root_name = edges.edge[edges.parent==""][0]
    taxon_namespace = dendropy.TaxonNamespace(edges.edge.values.tolist())
    nodes = subtree(root_name, edges, taxon_namespace)
    tree = dendropy.Tree(seed_node = nodes[root_name], taxon_namespace=taxon_namespace)
    return(tree)

def add_blen_from_meta(tree, meta, rand_id):
    """
    `meta` is a `pandas.DataFrame` with columns `edge`, `rand_id`, `gens` and
    `rescf`. This function adds branch lengths to the `dendropy.Tree` object 
    using the info in the `meta`.
    """
    rep = '0' # all reps should be the have the sam blens anyway!
    # traversing through tree -- annotating lengths
    for node in tree.postorder_node_iter():
        print(node.taxon.label)
        subset = meta[(meta.edge==node.taxon.label) & (meta.rand_id == rand_id)]
        assert subset.shape[0] == 1
        n_gens = np.floor(subset.gens.values[0]/subset.rescf.values[0])
        node.edge_length= n_gens
        #print(node.edge_length)
        #print(node.distance_from_tip())
    tree.calc_node_root_distances(return_leaf_distances_only=False)
    tree.calc_node_ages(ultrametricity_precision=False, is_force_max_age=True)
    return tree

In [5]:
def union_tseqs(tree, rand_id, rep):
    """
    Given a `dendropy.tree` object with annotated `edge_lengths`, a `rand_id` 
    identifier and a replicate number `rep`, this performs the 
    `tskit.TableCollection.union` of all leaves in the phylogenetic tree.
    """
    in_tseqs = {}
    for node in tree.postorder_node_iter(filter_fn = lambda node: node.is_internal()):
        assert len(node.child_nodes()) == 2, "Polytomies are not supported."
        tseqs = []
        pops = []
        history_len = []
        print(node.taxon.label, "\t", node.age, sep="")
        for child in node.child_nodes():
            print("\t"+child.taxon.label+"\t"+str(child.root_distance)+"\t"+str(child.age))
            history_len.append(child.root_distance+child.age)
            if child.is_leaf():
                pops.append(child.taxon.label)
                tseqs.append(pyslim.load(trees_path+child.taxon.label+"_"+rand_id+"_rep"+rep+".trees"))
            else:
                tseq, p = in_tseqs.pop(child.taxon.label)
                tseqs.append(tseq)
                pops += p
                del tseq
        #check if times need be shifted
        print(f"Before shift\ttime 0: {tseqs[0].max_root_time}\ttime 1: {tseqs[1].max_root_time}")
        if history_len[1] > history_len[0]:
            tseqs[0] = add_time(tseqs[0], history_len[1]-history_len[0])
        elif history_len[0] > history_len[1]:
            tseqs[1] = add_time(tseqs[1], history_len[0]-history_len[1])
        print(f"After shift\ttime 0: {tseqs[0].max_root_time}\ttime 1: {tseqs[1].max_root_time}")
        node_mapping = match_nodes(tseqs, node.age)
        sub_metadata(tseqs)
        in_tseqs[node.taxon.label] = (tseqs[0].union(tseqs[1], node_mapping), pops)
    assert len(in_tseqs) == 1
    return in_tseqs[list(in_tseqs.keys())[0]]

In [6]:
# variables

# rand_id and rep
rand_id = "TZPNGS0UY29NGB3"
rep = "0"
total_mut_rate = 1e-8
ex_mut_rate = 0
recapN = 10000

## metadata paths
edges_path = "../../meta/edges_meta.tsv"
sims_sum_path = "../../output/rand_id_params.tsv"
sims_full_path = "../../output/sims_info.tsv"
sims_header_path = "../../output/header_sims_info.tsv"
trees_path = "../../output/"
rec_hap_path = f"../../meta/maps/{rand_id}_recrate.hapmap"
ex_path = f"../../meta/maps/{rand_id}_exons.tsv"

In [7]:
## loading metadata
# edges contains all the edges and info about N and number of generations
edges = pd.read_csv(edges_path,sep="\t")
edges.parent= edges.parent.fillna("")
edges["edge"] = edges["edge"].str.replace('_','-')
edges["parent"] = edges["parent"].str.replace('_','-')
# sims_sum and sims_full relate rand_ids to simulation parameters
sims_sum = pd.read_csv(sims_sum_path,sep="\t")
sims_full= pd.read_csv(sims_full_path,sep="\t", header=None)
header = pd.read_csv("../../output/header_sims_info.tsv",sep="\t")
sims_full.columns = header.columns

In [8]:
# getting all output files and grouping by rand_id and rep
tree_files = glob.glob(trees_path+"*[0-9].trees")
pattern = f'{rand_id}_rep{rep}'
n_matches = sum(1 for file in tree_files if pattern in file)
# making sure we got all the files
assert n_matches == edges.shape[0]

In [9]:
# getting the phylo tree adn annotating with branch lengths,
tree = build_tree_from_df(edges)
tree = add_blen_from_meta(tree, sims_full, rand_id)

bornean-orangutan
sumatran-orangutan
orangutans
eastern-gorilla
western-gorila
gorilla
humans
bonobo
nigerian-chimp
western-chimp
nigerian-western
eastern-chimp
central-chimp
eastern-central
chimps
pan
human-pan
african-apes
great-apes


In [55]:
tsu,  pops = union_tseqs(tree,rand_id,rep)
tsu = pyslim.load_tables(tsu.tables)
print(tsu.slim_generation)
assert tsu.max_root_time.is_integer()
tsu = pyslim.annotate_defaults(tsu, tsu.metadata["SLiM"]["model_type"], int(tsu.max_root_time))
print(tsu.slim_generation)
slim_gen = tsu.slim_generation
# asserting within population coalescen
assert len(set([tsu.node(u).population for t in tsu.trees() for u in t.roots])) == 1
tsu.dump(f"../../output/{rand_id}_rep{rep}.union.trees")

orangutans	186.0
	bornean-orangutan	4283.0	0.0
	sumatran-orangutan	4283.0	0.0
Before shift	time 0: 6784.0	time 1: 6784.0
After shift	time 0: 6784.0	time 1: 6784.0
gorilla	78.0
	eastern-gorilla	5715.0	0.0
	western-gorila	5715.0	0.0
Before shift	time 0: 8216.0	time 1: 8216.0
After shift	time 0: 8216.0	time 1: 8216.0
nigerian-western	94.0
	nigerian-chimp	5191.0	0.0
	western-chimp	5191.0	0.0
Before shift	time 0: 7692.0	time 1: 7692.0
After shift	time 0: 7692.0	time 1: 7692.0
eastern-central	70.0
	eastern-chimp	5191.0	0.0
	central-chimp	5191.0	0.0
Before shift	time 0: 7692.0	time 1: 7692.0
After shift	time 0: 7692.0	time 1: 7692.0
chimps	171.0
	nigerian-western	5097.0	94.0
	eastern-central	5121.0	70.0
Before shift	time 0: 7692.0	time 1: 7692.0
After shift	time 0: 7692.0	time 1: 7692.0
pan	348.0
	bonobo	5191.0	0.0
	chimps	5020.0	171.0
Before shift	time 0: 7692.0	time 1: 7692.0
After shift	time 0: 7692.0	time 1: 7692.0
human-pan	1505.0
	humans	4984.0	0.0
	pan	4843.0	348.0
Before shift	time 0:

In [56]:
recomb_map = msprime.RecombinationMap.read_hapmap(rec_hap_path)
recap_tsu = tsu.recapitate(recombination_map=recomb_map, Ne=recapN)
del tsu # too much ram
print(slim_gen, recap_tsu.max_root_time, recap_tsu.num_mutations)

8216 117025.40151410378 0


In [149]:
mut_map = msp_mutation_rate_map(exons, total_mut_rate, ex_mut_rate, int(recap_tsu.sequence_length))
model_recap = msprime.SLiMMutationModel(type=3) # TODO: figure out the type number from the treeseq
model_slim = msprime.SLiMMutationModel(type=4) # TODO: figure out the type number from the treeseq
print("Before mutate:", recap_tsu.num_mutations)
recap_tsu = msprime.mutate(recap_tsu, end_time=slim_gen, model=model_recap, rate=total_mut_rate, keep=True)
print("Mutations added in the recapitation:", recap_tsu.num_mutations)
recap_tsu = msprime.mutate(recap_tsu, start_time=slim_gen, model=model_slim, rate=mut_map, keep=True)
print("Total mutations:", recap_tsu.num_mutations)

Before mutate: 232
Mutations added in the recapitation: 395
Total mutations: 441


In [143]:
sample_size = 10
win_size = 10**6
seed = 8297
rng = np.random.default_rng(seed)

In [144]:
# getting contemporary samples
# note the time of "contemporary" samples varies bc of differences in generation times
pop_samples = [recap_tsu.samples(population_id=i+1) for i in range(len(pops))] 
contemp_time = [np.min(recap_tsu.tables.nodes.time[samples]) for samples in pop_samples]
contemp_samples = [rng.choice(pop_samples[pid][recap_tsu.tables.nodes.time[pop_samples[pid]] == contemp_time[pid]], sample_size, replace=False)
                                                        for pid in range(len(pop_samples))]

In [145]:
# windowing
windows = np.arange(start=0,stop=recap_tsu.sequence_length, step=win_size)
if not np.isclose(recap_tsu.sequence_length, windows[-1], rtol=1e-12):
    windows = np.append(windows, [recap_tsu.sequence_length])

In [146]:
# obtaining indexes for all possible pairs (including diversity, i.e. i==j for (i,j))
indexes = [(x, y) for x in range(len(pops)) for y in range(len(pops)) if x >= y]

In [147]:
dxy = recap_tsu.divergence(sample_sets=contemp_samples, mode="site", windows=windows, indexes=indexes)

In [104]:
# half matrix + diagonal
assert dxy.shape[1] == ((len(pops)**2 - len(pops))/2) + len(pops)

In [105]:
labels = np.array([[pops[i],pops[j]] for i, j in indexes])

In [91]:
np.savez(f"rand-id_{rand_id}_rep_{rep}_win-size_{win_size}_sample-size_{sample_size}.npz", windows, dxy, labels)