Testing with simulations
==

A notebook for testing `mushi`'s ability to invert data simulated under the forward model

In [None]:
import mushi

import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.special import expit
import stdpopsim

### Time grid

In [None]:
change_points = np.logspace(0, np.log10(100000), 200)
t = np.concatenate((np.array([0]), change_points))

### Define true demographic history

In [None]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Zigzag_1S14")
ddb = model.get_demography_debugger()
eta_true = mushi.eta(change_points,
                     1 / ddb.coalescence_rate_trajectory(steps=t,
                                                         num_samples=[2],
                                                         double_step_validation=False)[0])

plt.figure(figsize=(3.5, 3.5))
eta_true.plot(c='k');

### Mutation rate history $\mu(t)$
A 96 dimensional history with a mixture of two latent signature: constant and pulse.

In [None]:
flat = np.ones_like(t)
pulse = expit(.1 * (t - 100)) - expit(.01 * (t - 2000))
ramp = expit(-.01 * (t - 100)) 
cols = 3
Z = np.zeros((len(t), cols))
mu0 = 1
np.random.seed(0)

Z[:, 0] = mu0 * (1 * flat + .5 * pulse)
Z[:, 1] = mu0 * (.5 * flat + .4 * ramp)
Z[:, 2] = 94 * mu0 * flat

mutation_types = ['TCC>TTC', 'GAA>GGA', None]

mu_true = mushi.mu(change_points, Z, mutation_types)

plt.figure(figsize=(4, 4))
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=3, clr=False)
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=3, clr=False);

Estimate the total mutation rate using $t=0$

In [None]:
mu0 = mu_true.Z[0, :].sum()
print(mu0)

## Simulate a $k$-SFS
- We'll sample 200 haplotypes
- note that this simulation will have a slightly varying total mutation rate

In [None]:
n = 200
ksfs = mushi.kSFS(n=n)
ksfs.simulate(eta_true, mu_true, r=0.02, seed=1)

plt.figure(figsize=(4, 3))
ksfs.plot_total(kwargs=dict(ls='', marker='.'))
plt.xscale('log')
plt.yscale('log')

plt.figure(figsize=(4, 3))
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))

Number of segregating sites

In [None]:
ksfs.X.sum()

### TMRCA CDF

In [None]:
plt.figure(figsize=(3.5, 3.5))
plt.plot(change_points, ksfs.tmrca_cdf(eta_true))
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('log')
plt.tight_layout()

## Inference

### convergence parameters and time grid

In [None]:
convergence = dict(tol=0, max_iter=100, trend_kwargs=dict(max_iter=20))
pts = 100

### Infer $\eta(t)$

In [None]:
trend_penalties = ((0, 1e1), (1, 1e0))

folded = False

ksfs.clear_eta()
ksfs.clear_mu()
ksfs.r = None

ksfs.infer_eta(mu0,
               *trend_penalties,
               ridge_penalty=1e-4,
               pts=pts, **convergence, verbose=True, folded=folded)

if ksfs.r is not None:
    print(f'inferred ancestral misidentification rate: {ksfs.r:.3f}')

plt.figure(figsize=(8, 4))

plt.subplot(121)
ksfs.plot_total(kwargs=dict(ls='', marker='o', ms=5, c='k', alpha=0.75),
                line_kwargs=dict(c='C0', alpha=0.75, lw=3),
                fill_kwargs=dict(color='C0', alpha=0.1),
                folded=folded)
plt.xscale('log')
plt.yscale('log')

plt.subplot(122)
eta_true.plot(c='k', lw=2, label='true')
ksfs.eta.plot(lw=3, alpha=0.75, label='inferred')
plt.legend()
plt.tight_layout()
plt.show()

### Infer $\boldsymbol\mu(t)$

In [None]:
ksfs.clear_mu()

trend_penalties = ((0, 2e2), (3, 1e-1))

ksfs.infer_mush(*trend_penalties, ridge_penalty=1e-4,
                **convergence, verbose=True)

plt.figure(figsize=(8, 4))

plt.subplot(121)
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C0'),
          line_kwargs=dict(alpha=0.75, lw=2, c='C0'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C1'),
          line_kwargs=dict(alpha=0.75, lw=2, c='C1'))

plt.subplot(122)
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=2, c='C0')
ksfs.mu.plot(('TCC>TTC',), alpha=0.75, lw=3, ls='--', c='C0')
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=2, c='C1')
ksfs.mu.plot(('GAA>GGA',), alpha=0.75, lw=3, ls='--', c='C1')
plt.tight_layout()
plt.show()