Mushi
==
_All that the rain promises and more..._

A notebook for testing `mushi`'s ability to invert data simulated under the forward model

In [None]:
%matplotlib inline 
import mushi
import histories
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import expit
import time
import msprime
import stdpopsim

In [None]:
# plt.style.use('dark_background')

### Time grid

In [None]:
change_points = np.logspace(0, np.log10(200000), 50)
t = np.concatenate((np.array([0]), change_points))

### Demographic history $\eta(t)$ from the European population in Tennessen et al.

In [None]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfrica_2T12")
ddb = model.get_demography_debugger()
η_Tennessen = histories.η(change_points,
                          1 / ddb.coalescence_rate_trajectory(steps=t,
                                                              num_samples=[0, 2],
                                                              double_step_validation=False)[0])

In [None]:
plt.figure(figsize=(3, 3))
η_Tennessen.plot(label='Tennessen')
plt.show()

### Mutation rate history $\mu(t)$
A 96 dimensional history with a mixture of two latent signature: constant and pulse.

In [None]:
flat = np.ones_like(t)
pulse = expit(1 * (t - 100)) - expit(1 * (t - 1000))
cols = 96
Z = np.zeros((len(t), cols))
μ0 = 1
np.random.seed(2)
pulse_idxs = []
flat_idxs = []
for col in range(cols):
    scale = np.random.lognormal(0, 0.2)
    pulse_weight = np.random.lognormal(-1, .05) if col == 0 else 0
    Z[:, col] = μ0 * (scale * (flat + pulse_weight * pulse))
    if pulse_weight:
        pulse_idxs.append(col)
    else:
        flat_idxs.append(col)
        
μ = histories.μ(change_points, Z)

In [None]:
plt.figure(figsize=(3, 3))
μ.plot(flat_idxs, alpha=0.1, lw=1, c='k')
μ.plot(pulse_idxs, alpha=0.5, lw=3)
plt.show()

Estimate the total mutation rate using $t=0$

In [None]:
μ0 = μ.Z[0, :].sum()

## Simulate a $k$-SFS
- We'll sample 200 haplotypes
- note that this simulation will have a slightly varying total mutation rate, due to the pulse

In [None]:
n = 200
ksfs = mushi.kSFS(η=η_Tennessen, μ=μ, n=n)
ksfs.simulate(seed=1)

plt.figure(figsize=(4, 3))
for idx in flat_idxs:
    ksfs.plot(idx, alpha=0.05, lw=1, c='k', clr=True)
for idx in pulse_idxs:
    ksfs.plot(idx, lw=3, clr=True)
plt.show()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3, 3))
plt.plot(change_points, ksfs.tmrca_cdf())
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('symlog')
plt.tight_layout()
plt.show()

### Infer $\eta(t)$ and $\boldsymbol\mu(t)$
define regularization parameters and convergence criteria

In [None]:
η_regularization = dict(α_tv=1e1, α_spline=2e2, α_ridge=1e-10)
μ_regularization = dict(hard=True, β_rank=0, β_tv=1e1, β_β_ridge=1e-10)
convergence = dict(tol=1e-10, max_iter=10000)

Run inference

In [None]:
ksfs.infer_history(change_points, μ0, loss='prf',
                   **η_regularization, **μ_regularization, **convergence)

Plot the results

In [None]:
plt.figure(figsize=(7, 7))
plt.subplot(221)
ksfs.plot_total()
plt.subplot(222)
ksfs.plot(alpha=0.1, lw=1, c='C0', clr=True)
for idx in pulse_idxs:
    ksfs.plot(idx, lw=3, c='C1', clr=True)
plt.subplot(223)
η_Tennessen.plot(c='k', lw=3, label='true')
ksfs.η.plot(lw=2, label='inferred')
plt.legend()
plt.subplot(224)
μ.plot(pulse_idxs, clr=True, c='k', lw=3, label='true')
ksfs.μ.plot(pulse_idxs, clr=True, lw=2, label='inferred')
plt.legend()
plt.show()

Now let's try a parameter sweep of `α_tv` and `α_spline`, evaluating the L2 error of the demography at each point

In [None]:
def l2_error(η1: histories.η, η2: histories.η) -> np.float:
    '''L2 distance between two histories
    '''
    t1, y1 = η1.arrays()
    t2, y2 = η2.arrays()
    assert all(t1 == t2)
    Δt = np.diff(t1)
    return np.sqrt((Δt * (y1 - y2) ** 2)[:-1].sum())

In [None]:
convergence = dict(tol=1e-10, max_iter=1000, max_line_iter=300, γ=0.8)
α_tv_array = np.logspace(-3, 3, 10)
α_spline_array = np.logspace(-3, 3, 10)
X, Y = np.meshgrid(α_tv_array, α_spline_array)
Z = np.zeros_like(X)
for j, α_tv in enumerate(α_tv_array):
    for i, α_spline in enumerate(α_spline_array):
        η_regularization = dict(α_tv=α_tv, α_spline=α_spline, α_ridge=1e-10)
        # need this so initialization is the same for each iterate
        ksfs.η = None
        ksfs.infer_history(change_points, μ0, infer_μ=False, **η_regularization, **convergence)
        Z[i, j] = l2_error(η_Tennessen, ksfs.η)

In [None]:
plt.figure()
plt.pcolor(X, Y, Z)
plt.xscale('log')
plt.xlabel('$\\alpha_{\\mathrm{TV}}$')
plt.yscale('log')
plt.ylabel('$\\alpha_{\\mathrm{spline}}$')
plt.colorbar(label='RMS error')
plt.show()
    
# plt.figure(figsize=(3, 3))
# η_Tennessen.plot(c='k', lw=3, label='true')
# η.plot(lw=2, label='inferred')
# plt.legend()
# plt.show()

To do: study of non-convexity issues. Look for local suboptimal minima by randomizing intialization 