Testing `mushi` with `msprime` simulations
==

In [None]:
%matplotlib notebook
import histories
import mushi
import msprime
%cd stdpopsim
from stdpopsim import homo_sapiens
%cd ../
from IPython.display import display, SVG
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import expit
import time

## `msprime` simulation

### Download human recombination maps

In [None]:
print(homo_sapiens.genome.default_genetic_map)
gmap = homo_sapiens.HapmapII_GRCh37()
gmap.download()

### Define true demographic history $\eta(t)$ based on the Tennessen et al. `TwoPopOutOfAfrica` model in `stdpopsim`

In [None]:
chrom = homo_sapiens.genome.chromosomes["chr22"]
model = homo_sapiens.TennessenTwoPopOutOfAfrica()
# model = homo_sapiens.TennessenOnePopAfrica()

Simulate a tree sequence for 200 haplotypes

In [None]:
n = 200
population_idx = 1 # 0 for AFR, 1 for EUR
samples = [msprime.Sample(population=population_idx, time=0) for j in range(n)]
seed = time.time()
tree_sequence = msprime.simulate(random_seed=seed,
                                 samples=samples,
                                 recombination_map=chrom.recombination_map(),
                                 **model.asdict())
print(f"{len(tree_sequence.breakpoints(as_array=True))} trees")

### First tree in the sequence

In [None]:
display(SVG(tree_sequence.first().draw(format='svg', width=500, height=200, node_labels={},
                                       mutation_labels={})))#, max_tree_height='ts')))

The `msprime.DemographyDebugger` lets us extract demographic history info

In [None]:
dd = msprime.DemographyDebugger(Ne=model.default_population_size,
                                population_configurations=model.population_configurations,
                                demographic_events=model.demographic_events,
                                migration_matrix=model.migration_matrix)

Define the time grid

In [None]:
t = np.logspace(0, 5, 200)

Extract effective population size history $\eta(t)$ for `mushi`

In [None]:
# note: the factor of 2 accounts for diploidy
y = 2 * dd.population_size_trajectory(np.concatenate(([0], t)))[:, population_idx]
η = histories.η(t, y)

plt.figure(figsize=(3, 3))
η.plot()
for tt in dd.epoch_times[1:]: plt.axvline(tt, c='k', ls=':')
plt.xlabel('$t$'); plt.ylabel('$η(t)$')
plt.xscale('symlog'); plt.yscale('log')
plt.tight_layout()
plt.show()

Define mutation intensity history $\mu(t)$

In [None]:
z = np.ones(len(t) + 1)
z[1:] += expit(.1 * (t - 80)) - 1.5 * expit(.1 * (t - 600))
z *= chrom.length * chrom.default_mutation_rate
μ = histories.μ(t, z)

plt.figure(figsize=(3, 3))
μ.plot()
plt.xlabel('$t$')
plt.ylabel('$μ(t)$')
plt.xscale('symlog')
plt.tight_layout()
plt.show()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3, 3))
plt.plot(η.change_points, mushi.kSFS(η, n=n).tmrca_cdf())
plt.xlabel('$t$'); plt.ylabel('TMRCA CDF'); plt.ylim([0, 1]); plt.xscale('symlog')
plt.tight_layout()
plt.show()

### place mutations on simulated tree sequence according to $\mu(t)$
We iterate over the epochs in our dense time grid

In [None]:
for start_time, end_time, mutation_rate in μ.epochs():
    print(f'epoch boundaries: ({start_time:.2f}, {end_time:.2f}), μ: {mutation_rate[0]:.2f}', flush=True, end='     \r')
    # note: the factor of 1 / chrom.length is needed to scale the mutation rate from sites to genomes
    tree_sequence = msprime.mutate(tree_sequence,
                                   rate=mutation_rate / chrom.length,                                   
                                   start_time=start_time,
                                   end_time=end_time,
                                   random_seed=seed,
                                   keep=True)

### compute and plot the SFS

In [None]:
X = np.array([tree_sequence.allele_frequency_spectrum(polarised=True, span_normalise=False)[1:-1]]).T
ksfs = mushi.kSFS(η, X=X)

plt.figure(figsize=(3, 3))
sfs.plot(0)
plt.tight_layout()
plt.show()

## Inferring $\mu(t)$ from the SFS with `mushi`

Proximal gradient descent

In [None]:
μ_inferred, f_trajectory = ksfs.infer_μ(# loss function parameters
                                        fit='prf',
                                        exclude_singletons=False,
                                        bins=None,
                                        # time derivative regularization parameters
                                        λ_tv=1e2,
                                        α_tv=0,
                                        # spectral regularization parameters
                                        λ_r=0,
                                        α_r=1,
                                        hard=True,                                        
                                        # convergence parameters
                                        max_iter=10000,
                                        tol=1e-10,
                                        γ=0.8)

In [None]:
plt.figure(figsize=(4, 2))
plt.plot(f_trajectory)
plt.xlabel('iterations')
plt.ylabel('loss')
plt.xscale('symlog')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 3))
ksfs.plot(μ=μ_inferred, prf_quantiles=True)
plt.show()