In [1]:
import os
import sys
from collections import defaultdict
from functools import partial
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pandas.errors import SettingWithCopyWarning
import numpy as np
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('retina', 'png')

from tqdm.notebook import tqdm 
from math import factorial, log, log10, log1p, floor, ceil
from scipy.stats import chi2

import multiprocess

sns.set()
sns.set_style("ticks")

import tskit
import msprime

import geneinfo as gi
gi.email('kaspermunch@birc.au.dk')

# scale down size of default plots
sns.set_context("paper")
import matplotlib as mpl
scale = 0.8
d = dict([(k, v*scale) for (k, v) in sns.plotting_context('paper').items()])
d['figure.figsize'] = [5.4, 3.5]
mpl.rcParams.update(d)

def modpath(p, parent=None, base=None, suffix=None):
    par, name = os.path.split(p)
    name_no_suffix, suf = os.path.splitext(name)
    if type(suffix) is str:
        suf = suffix
    if parent is not None:
        par = parent
    if base is not None:
        name_no_suffix = base
    new_path = os.path.join(par, name_no_suffix + suf)
    if type(suffix) is tuple:
        assert len(suffix) == 2
        new_path, nsubs = re.subn(r'{}$'.format(suffix[0]), suffix[1], new_path)
        assert nsubs == 1, nsubs
    return new_path

In [2]:
import scipy

class Comb():

    cache = {}
    
    def __init__(self):
        pass

    def __call__(self, n, k, exact=True):
        if (n, k) not in self.cache:
            self.cache[(n, k)] = scipy.special.comb(n, k, exact=exact)
        return self.cache[(n, k)]

    @classmethod
    def clear(cls):
        cls.cache = {}

    def __enter__(self):

        def init_worker(data):
            # declare scope of a new global variable
            global comb
            # store argument in the global variable for this process
            comb = data
            
        self.pool = multiprocess.Pool(processes=8, initializer=init_worker, initargs=(self,))
        return self.pool

    def __exit__(self, type, value, tb):
        self.pool.close()

Comb().clear()
comb = Comb()

with Comb() as pool:
    pass

In [3]:
def get_coalescence_runs(all_times, clade_times):
    clade_times_set = set(clade_times)
    k, fn = 1, 1
    coalescence_runs = []
    first_derived_coal_found = False
    for t in all_times:
        is_derived = int( t in clade_times_set)
        if first_derived_coal_found:
            coalescence_runs.append(is_derived)
        if is_derived:
            # get all the coalescences *after* the first derived one. So that the first can be both 0 and 1
            first_derived_coal_found = True
    return np.array(coalescence_runs)

def get_runs_of_1s(bits):
    for bit, group in itertools.groupby(bits):
        if bit:
            yield sum(group)

def get_all_runs(bits):
    for bit, group in itertools.groupby(bits):
        if bit:
            yield sum(group)
    bits = np.absolute(bits - 1)
    for bit, group in itertools.groupby(bits):
        if bit:
            yield sum(group)

In [4]:
def prob_nr_of_runs(n, n1, n2):
    """
    Probability of the number runs of either zeros or ones
    n: number of runs
    n0: nr zeros
    n1: nr ones
    """
    if n % 2:
        # uneven
        k = (n - 1) //  2
        return (comb(n1-1, k)*comb(n2-1, k-1) + comb(n1-1, k-1)*comb(n2-1, k)) / comb(n1+n2, n1)
    else:
        # even
        k = n // 2
        return 2*comb(n1-1, k-1)*comb(n2-1, k-1) / comb(n1+n2, n1)    

In [5]:
%%time
import multiprocess
pool = multiprocess.Pool(processes=8)

n = 200
dim = 2*n+1
cache = np.ndarray(shape=(n+1, n+1, 2*n+1), dtype=float)
cache[:, :, :] = np.nan
for n0 in tqdm(range(1, n+1)):
    for n1 in range(1, n+1):
        for r in range(1, n0+n1):
            cache[n0, n1, r] = prob_nr_of_runs(r, n0, n1)

pool.close()

np.save('prob_nr_of_runs_cache.npy', cache)

  0%|          | 0/200 [00:00<?, ?it/s]

CPU times: user 17.3 s, sys: 121 ms, total: 17.4 s
Wall time: 17.9 s


In [None]:
prob_nr_of_runs_cache = np.load('prob_nr_of_runs_cache.npy')

for treeseq_file_name in [
    # africans
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/LWK/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/GWD/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/ESN/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/MSL/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/YRI/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    # europeans
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/GBR/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/FIN/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/IBS/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/TSI/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    # asians
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/CDX/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/CHB/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/CHS/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/JPT/relate/run_relate/1000g_ppl_phased_haplotypes.trees',
    '/home/ari/ari-intern/people/ari/ariadna-intern/steps/KHV/relate/run_relate/1000g_ppl_phased_haplotypes.trees'
                         ]:
    print(os.path.basename(treeseq_file_name))
    
    output_file_name = modpath(treeseq_file_name, parent='../results', suffix='_runstats.h5')
    
    tree_seq = tskit.load(treeseq_file_name)
    
    
    nr_samples = tree_seq.num_samples
    
    records = []
    nodes_time = tree_seq.nodes_time

    tree_idx = 0
    for tree in tqdm(tree_seq.trees()):
    
        N = tree.num_samples()
        all_times = [nodes_time[n] for n in tree.timedesc() if not tree.is_leaf(n)]
        for mut in tree.mutations():
            node = tree_seq.mutations_node[mut.id]
            clade_times = [nodes_time[n] for n in tree.timedesc(node) if not tree.is_leaf(n)]

            # nr-all-runs and max ones-run probabilities
            runs = get_coalescence_runs(all_times, clade_times)

            if len(runs) < nr_samples / 4:
                continue
            
            n1 = sum(runs)
            n0 = len(runs) - n1
            run_lengths = np.fromiter(get_all_runs(runs), int)
            runs_of_1s = list(get_runs_of_1s(runs))

            if len(runs_of_1s) == 0:
                # trippleton or smaller
                continue


            # # TODO: Filter on der and/or anc allele freq

            # if n1/(n0+n1) < 10:
            #     continue
                        
            max_ones_run_len = max(runs_of_1s)
            nr_runs = run_lengths.size

            
            # if sum(runs) < 2 or sum(runs) == len(runs) or max_ones_run_len == 1:
            #     pvalue_max_der_run = np.nan
            # else:
            #     try:
            #         pvalue_max_der_run = prob_longest_der_run_cache[n0, n1, max_ones_run_len:(n0+n1)].sum()
            #     except IndexError:
            #         pvalue_max_der_run = sum(prob_longest_1s_run(n0+n1, x, n1/(n0+n1)) for x in range(max_ones_run_len, n0+n1))

            
            if nr_runs == 1 or len(runs) <= 2 or nr_runs == len(runs):
                pvalue_nr_runs = np.nan
            else:
                try:
                    pvalue_nr_runs = prob_nr_of_runs_cache[n0, n1, 1:(nr_runs+1)].sum()
                except IndexError:
                    pvalue_nr_runs = sum(prob_nr_of_runs(x, n0, n1) for x in range(1, nr_runs+1))
            
            interval = tree.interval
            num_mutations = tree.num_mutations

            clade_left, clade_right = interval.left, interval.right

            assert clade_times
            
            records.append(['nr_runs',
                            mut.site, 
                            pvalue_nr_runs,
                            nr_runs,
                            len(runs),                            
                            clade_times[0], 
                            0,
                            interval.left,
                            interval.right,
                            clade_left, 
                            clade_right,
                            num_mutations])
    
    
        tree_idx += 1
    
    
    df = pd.DataFrame.from_records(records, columns=['stat_name', 'site', 'p', 'stat', 'nr_coal', 't1', 't2', 
                                                     'tree_left', 'tree_right', 'clade_left', 'clade_right', 
                                                     'nr_mut'])
    df['pos'] = tree_seq.sites_position[df.site]
    df.to_hdf(output_file_name, 'df', format='table')

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/13348 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/13024 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/12803 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/12515 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/13196 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/5050 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/5428 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/6379 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/5924 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/4680 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/4831 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/5070 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/4524 [00:00<?, ?it/s]

1000g_ppl_phased_haplotypes.trees


  0%|          | 0/5123 [00:00<?, ?it/s]