In [2]:
import msprime
import tskit
import numpy as np
import random
from math import *

In [3]:
def remove_mutations(ts, starts, ends, prop):
    '''
    This function will return a new tree sequence the same as the input,
    but after removing each non-SLiM mutation within regions specified in 
    lists of start and end positions with probability `proportion`, independently. 
    So then, if we want to add neutral mutations with rate 1.0e-8 within the regions 
    and 0.7e-8 outside the regions, we could do
      ts = pyslim.load("my.trees")
      first_mut_ts = msprime.mutate(ts, rate=1e-8)
      mut_ts = remove_mutations(first_mut_ts, start, end, 0.3)
    :param float proportion: The proportion of mutations to remove.
    '''
    pos = ts.tables.sites.position #getting the positions of all sites
    is_msp = (np.diff(ts.tables.mutations.metadata_offset) == 0) #getting which mutations are from msprime
    #but we want to know which sites are from msprime
    is_msp_site = np.repeat(False, ts.num_sites)
    is_msp_site[ts.tables.mutations.site] = is_msp
    #finding which sites are inside the regions
    breaks=np.concatenate(([-1], starts, ends))
    breaks.sort()
    #np.search sorted is going to return even numbers if the pos is inside one of the regions
    in_regions = np.searchsorted(breaks,pos,"right")%2 == 0
    removable_sites = np.where(np.logical_and(in_regions, is_msp_site))[0]
    #find sites to remove with probability prop
    remove = np.where(np.random.binomial(1,prop,len(removable_sites))==1)[0]
    new_table = ts.tables
    new_table.delete_sites(remove)
    return(new_table.tree_sequence())

In [4]:
def remove_mutations_old(ts, start, end, proportion):
    '''
    This function will return a new tree sequence the same as the input,
    but after removing each non-SLiM mutation within regions specified in lists
    start and end with probability `proportion`, independently. So then, if we
    want to add neutral mutations with rate 1.0e-8 within the regions and 0.7e-8
    outside the regions, we could do
      ts = pyslim.load("my.trees")
      first_mut_ts = msprime.mutate(ts, rate=1e-8)
      mut_ts = remove_mutations(first_mut_ts, start, end, 0.3)
    :param float proportion: The proportion of mutations to remove.
    '''
    new_tables = ts.dump_tables()
    new_tables.mutations.clear()
    mutation_map = [-1 for _ in range(ts.num_mutations)]
    for j, mut in enumerate(ts.mutations()):
        keep_mutation = True
        for i in range(len(start)):
            left = start[i]
            right = end[i]
            assert(left < right)
            if i > 0:
                assert(end[i - 1] <= left)
            if mut.position >= left and mut.position < right and len(mut.metadata) == 0:
                keep_mutation = (random.uniform(0, 1) > proportion)
        if keep_mutation:
            mutation_map[j] = new_tables.mutations.num_rows
            if mut.parent < 0:
                new_parent = -1
            else:
                new_parent = mutation_map[mut.parent]
            new_tables.mutations.add_row(site = mut.site, node = mut.node,
                    derived_state = mut.derived_state,
                    parent = new_parent,
                    metadata = mut.metadata)
    return new_tables.tree_sequence()

In [39]:
import time
starts=[x for x in range(0,100000,20)]
ends=[x+10 for x in range(0,100000, 20)]
prop = 0.9
new = []
old = []
for i in range(1):
    ts = msprime.simulate(10, Ne=10000, length=50000000, recombination_rate=1e-8)
    print("simulated")
    ts_mut = msprime.mutate(ts, rate=1.4e-8)
    print("mutated")
    s1 = time.time()
    remove_mutations(ts_mut, starts, ends, prop)
    e1 = time.time()
    new.append(e1-s1)
    print(e1,s1,e1-s1)
    s2 = time.time()
    remove_mutations_old(ts_mut, starts, ends, prop)
    e2 = time.time()
    old.append(e2-s2)
    print(e2,s1,e2-s2)

simulated
mutated
1569906640.939546 1569906640.840832 0.09871411323547363
1569906787.605782 1569906640.840832 146.6660192012787


In [40]:
print("new func", "old func", sep="\t")
print(np.mean(new), np.mean(old), sep="\t")

new func	old func
0.09871411323547363	146.6660192012787


In [41]:
starts=[0]
ends=[10000]
prop = 0.5
ts = msprime.simulate(10, Ne=1000, mutation_rate=1e-5, length=10000)
new_ts = remove_mutations(ts, starts, ends,prop)
print(ts.diversity(),new_ts.diversity())

0.032039999999999576 0.014955555555555543


In [42]:
starts=[0]
ends=[5000]
prop = 1
ts = msprime.simulate(10, Ne=1000, mutation_rate=1e-5, length=10000)
n_before = ts.num_mutations
breaks = np.concatenate(([-1], starts, ends))
breaks.sort()
n_within_before = sum(np.searchsorted(breaks,ts.tables.sites.position,"right")%2 == 0)
new_ts = remove_mutations(ts, starts, ends,prop)
pos_after = []
for j, mut in enumerate(new_ts.mutations()):
    pos_after.append(mut.position)
n_within_after = sum(np.searchsorted(breaks,pos_after,"right")%2 == 0)
n_after = new_ts.num_mutations
#assert n_after == n_before - n_within_before
[n_before, n_within_before, n_after, n_within_after]


[731, 344, 387, 0]

In [43]:
starts=[0]
ends=[5000]
prop = 0
ts = msprime.simulate(10, Ne=1000, mutation_rate=1e-5, length=10000)
n_before = ts.num_mutations
breaks = np.concatenate(([-1], starts, ends))
breaks.sort()
n_within_before = sum(np.searchsorted(breaks,ts.tables.sites.position,"right")%2 == 0)
new_ts = remove_mutations(ts, starts, ends,prop)
pos_after = []
for j, mut in enumerate(new_ts.mutations()):
    pos_after.append(mut.position)
n_within_after = sum(np.searchsorted(breaks,pos_after,"right")%2 == 0)
n_after = new_ts.num_mutations
assert n_before == n_after
[n_before, n_within_before, n_after, n_within_after]

[2898, 1448, 2898, 1448]

In [44]:
starts=[0]
ends=[5000]
prop = 0.5
ts = msprime.simulate(10, Ne=1000, mutation_rate=1e-5, length=10000)
n_before = ts.num_mutations
breaks = np.concatenate(([-1], starts, ends))
breaks.sort()
n_within_before = sum(np.searchsorted(breaks,ts.tables.sites.position,"right")%2 == 0)
new_ts = remove_mutations(ts, starts, ends,prop)
pos_after = []
for j, mut in enumerate(new_ts.mutations()):
    pos_after.append(mut.position)
n_within_after = sum(np.searchsorted(breaks,pos_after,"right")%2 == 0)
n_after = new_ts.num_mutations
var = (prop*(1-prop)*n_before)
assert abs(n_within_after - (n_within_before * prop)) < sqrt(var)*3 #3 std devs
[n_before, n_within_before, n_after, n_within_after]

[1285, 657, 971, 343]