Mushi
==
_All that the rain promises and more..._

A notebook for testing `mushi`'s ability to invert data simulated under the forward model

In [None]:
%matplotlib inline 
import mushi
import histories
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import expit
import time
import msprime
import stdpopsim

In [None]:
# plt.style.use('dark_background')

### Time grid

In [None]:
change_points = np.logspace(0, np.log10(35000), 50)
t = np.concatenate((np.array([0]), change_points))

### Define true demographic history

In [None]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfricaArchaicAdmixture_5R19")
ddb = model.get_demography_debugger()
eta_true = histories.eta(change_points,
                       1 / ddb.coalescence_rate_trajectory(steps=t,
                                                           num_samples=[0, 2, 0, 0, 0],
                                                           double_step_validation=False)[0])

In [None]:
plt.figure(figsize=(3.5, 3.5))
eta_true.plot(c='k')
plt.show()

### Mutation rate history $\mu(t)$
A 96 dimensional history with a mixture of two latent signature: constant and pulse.

In [None]:
flat = np.ones_like(t)
pulse = expit(10 * (t - 100)) - expit(10 * (t - 1000))
cols = 96
Z = np.zeros((len(t), cols))
mu0 = 1
np.random.seed(2)
pulse_idxs = []
flat_idxs = []
for col in range(cols):
    scale = np.random.lognormal(0, 0.2)
    pulse_weight = np.random.lognormal(-.5, .05) if col == 0 else 0
    Z[:, col] = mu0 * (scale * (flat + pulse_weight * pulse))
    if pulse_weight:
        pulse_idxs.append(col)
    else:
        flat_idxs.append(col)
        
mu_true = histories.mu(change_points, Z)

In [None]:
plt.figure(figsize=(4, 4))
mu_true.plot(flat_idxs, alpha=0.1, lw=1, c='k')
mu_true.plot(pulse_idxs, alpha=0.5, lw=3)
plt.show()

Estimate the total mutation rate using $t=0$

In [None]:
mu0 = mu_true.Z[0, :].sum()

## Simulate a $k$-SFS
- We'll sample 200 haplotypes
- note that this simulation will have a slightly varying total mutation rate, due to the pulse

In [None]:
n = 1000
ksfs = mushi.kSFS(n=n)
ksfs.simulate(eta_true, mu_true, seed=1)

plt.figure(figsize=(4, 3))
# for idx in flat_idxs:
#     ksfs.plot(idx, alpha=0.05, lw=1, c='k', clr=True)
for idx in pulse_idxs:
    ksfs.plot(idx, lw=3, clr=True, ls='', marker='.')
plt.show()

In [None]:
ksfs.X.sum()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3.5, 3.5))
plt.plot(change_points, ksfs.tmrca_cdf(eta_true))
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('log')
plt.tight_layout()
plt.show()

### Infer $\eta(t)$ and $\boldsymbol\mu(t)$

Run inference

In [None]:
ksfs.clear_eta()
ksfs.clear_mu()

# define regularization parameters and convergence criteria
convergence = dict(tol=1e-10, max_iter=10000)
regularization_eta = dict(alpha_tv=1e1, alpha_spline=1e2, alpha_ridge=1e-10)

metadata = ksfs.infer_history(change_points, mu0, infer_mu=False,
                   **regularization_eta, **convergence)
plt.figure(figsize=(12, 4))
plt.subplot(131)
ksfs.plot_total()
plt.subplot(132)
eta_true.plot(c='k', lw=2, label='true')
ksfs.eta.plot(lw=3, alpha=0.75, label='inferred')
plt.legend()
plt.subplot(133)
plt.plot(metadata['y_convergence'])
plt.yscale('log')
plt.show()

In [None]:
ksfs.clear_mu()

# define regularization parameters and convergence criteria
regularization_mu = dict(beta_tv=5e1, beta_spline=1e-3, beta_ridge=1e-10)
convergence = dict(tol=1e-10, max_iter=10000)

metadata = ksfs.infer_history(change_points, mu0, infer_eta=False,
                   **regularization_mu, **convergence)
plt.figure(figsize=(12, 4))
plt.subplot(131)
ksfs.plot(alpha=0.1, lw=1, c='C0', clr=True)
for idx in pulse_idxs:
    ksfs.plot(idx, lw=3, c='C1', clr=True)
plt.subplot(132)
mu_true.plot(pulse_idxs, clr=False, c='k', lw=3, label='true')
ksfs.mu.plot(pulse_idxs, clr=False, lw=2, label='inferred')
plt.legend()
plt.subplot(133)
plt.plot(metadata['Z_convergence'])
plt.yscale('log')
plt.show()