Mushi
==
_All that the rain promises and more..._

A notebook for testing `mushi`'s ability to invert data simulated under the forward model

In [None]:
%matplotlib inline 
import mushi
import histories
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import expit
import time
import msprime
import stdpopsim

In [None]:
# set this to e.g. your Downloads folder path if you want plots saved to pdfs
plot_dir = '/Users/williamdewitt/Downloads/'

In [None]:
# plt.style.use('dark_background')

### Time grid

In [None]:
change_points = np.logspace(0, np.log10(100000), 50)
t = np.concatenate((np.array([0]), change_points))

### Define true demographic history

In [None]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfrica_2T12")
ddb = model.get_demography_debugger()
eta_true = histories.eta(change_points,
                       1 / ddb.coalescence_rate_trajectory(steps=t,
                                                           num_samples=[0, 2],
                                                           double_step_validation=False)[0])

In [None]:
plt.figure(figsize=(3.5, 3.5))
eta_true.plot(c='k')
plt.show()

### Mutation rate history $\mu(t)$
A 96 dimensional history with a mixture of two latent signature: constant and pulse.

In [None]:
flat = np.ones_like(t)
cols = 96
Z = np.zeros((len(t), cols))
mu0 = 1
np.random.seed(1)
for col in range(cols):
    pulse = expit(.1 * (t - np.random.normal(100, 100))) - expit(.01 * (t - np.random.normal(1000, 10)))
    scale = np.random.lognormal(-1, 1)
    pulse_weight = 5 if col == 0 else 0
    Z[:, col] = mu0 * scale * (flat + pulse_weight * pulse)

mu_true = histories.mu(change_points, Z)

In [None]:
plt.figure(figsize=(4, 4))
mu_true.plot(alpha=0.1, lw=2, c='C0', clr=True)
mu_true.plot((0,), alpha=0.75, lw=3, c='C1', clr=True)
plt.show()

Estimate the total mutation rate using $t=0$

In [None]:
mu0 = mu_true.Z[0, :].sum()

## Simulate a $k$-SFS
- We'll sample 200 haplotypes
- note that this simulation will have a slightly varying total mutation rate, due to the pulse

In [None]:
n = 200
ksfs = mushi.kSFS(n=n)
ksfs.simulate(eta_true, mu_true, seed=1)

plt.figure(figsize=(4, 3))
ksfs.plot(alpha=0.1, ls='', clr=True, marker='.', c='C0')
ksfs.plot(0, alpha=0.75, ls='', clr=True, marker='o', c='C1')
plt.show()

In [None]:
ksfs.X.sum()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3.5, 3.5))
plt.plot(change_points, ksfs.tmrca_cdf(eta_true))
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('log')
plt.tight_layout()
plt.show()

### Infer $\eta(t)$ and $\boldsymbol\mu(t)$

Run inference

In [None]:
ksfs.clear_eta()
ksfs.clear_mu()

# define regularization parameters and convergence criteria
convergence = dict(tol=1e-16, max_iter=10000)
regularization_eta = dict(alpha_tv=1e-3, alpha_spline=1e1, alpha_ridge=1e-10)

metadata = ksfs.infer_history(change_points, mu0, infer_mu=False,
                   **regularization_eta, **convergence)
plt.figure(figsize=(8, 4))
plt.subplot(121)
ksfs.plot_total(kwargs=dict(ls='', marker='o', ms=5, c='k', alpha=0.75),
                line_kwargs=dict(c='C0', alpha=0.75, lw=3),
                fill_kwargs=dict(color='C0', alpha=0.1))
plt.subplot(122)
eta_true.plot(c='k', lw=2, label='true')
ksfs.eta.plot(lw=3, alpha=0.75, label='inferred')
plt.legend()
plt.show()

In [None]:
ksfs.clear_mu()

# define regularization parameters and convergence criteria
regularization_mu = dict(beta_tv=1e1, beta_spline=1e1, beta_ridge=1e-10)
convergence = dict(tol=1e-16, max_iter=1000)

metadata = ksfs.infer_history(change_points, mu0, infer_eta=False,
                   **regularization_mu, **convergence)
plt.figure(figsize=(5, 3))
plt.subplot(121)
ksfs.plot(alpha=0.05, lw=1, c='C0', clr=True)
ksfs.plot(0, alpha=0.75, lw=2, c='C1', clr=True)
plt.subplot(122)
mu_true.plot(range(1, cols), alpha=0.1, lw=2, c='C0', clr=True)
mu_true.plot((0,), alpha=0.75, lw=2, c='C1', clr=True)
# ksfs.mu.plot(range(1, cols), alpha=0.1, lw=3, ls='--', c='C0', clr=True)
ksfs.mu.plot((0,), alpha=0.75, lw=3, ls='--', c='C1', clr=True)
plt.show()