Mushi
==
_All that the rain promises and more..._

A notebook for testing `mushi`'s ability to invert data simulated under the forward model

API documentation can be viewed with
```python
help(History)
help(kSFS)
```

In [1]:
%matplotlib notebook
import mushi
from mushi import History, kSFS
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import expit
import time
import msprime
%cd stdpopsim
from stdpopsim import homo_sapiens
%cd ../

/Users/williamdewitt/Desktop/repos/dement/stdpopsim
/Users/williamdewitt/Desktop/repos/dement


### Time grid

In [2]:
t = np.logspace(0, np.log10(3e4), 300)
# t = np.linspace(1, 3e4, 500)

### Demographic history $\eta(t)$

In [3]:
model = homo_sapiens.TennessenTwoPopOutOfAfrica()
dd = msprime.DemographyDebugger(Ne=model.default_population_size,
                                population_configurations=model.population_configurations,
                                demographic_events=model.demographic_events,
                                migration_matrix=model.migration_matrix)
y = 2 * dd.population_size_trajectory(np.concatenate(([0], t)))[:, 1]
η = History(t, y)

plt.figure(figsize=(3, 3))
η.plot()
plt.xlabel('$t$')
plt.ylabel('$η(t)$')
plt.xscale('symlog')
plt.yscale('log')
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

### Mutation rate history $\mu(t)$

In [4]:
Z = np.zeros((len(t) + 1, 96))
tt = np.concatenate(([0], t))[:, np.newaxis]
μ0 = 1
flat = μ0 * (1.1 * np.ones_like(tt))
ramp = μ0 * (1 + expit(-.01 * (tt - 50)))
pulse = μ0 * (1 + expit(.01 * (tt - 80)) - 1.5 * expit(.01 * (tt - 600)))
Z[:, :10] += pulse #+ np.random.normal(0, .01, (1, 10))
Z[:, 10:20] += ramp #+ np.random.normal(0, .01, (1, 10))
Z[:, 20:] += flat #+ np.random.normal(0, .01, (1, 96 - 20))
# Z[:, 10:] += np.random.normal(0, .01, (1, 10))
1 * np.exp(- .1 * t)
μ = History(t, Z)

plt.figure(figsize=(3, 3))
μ.plot((0, 10, 20))
plt.xlabel('$t$')
plt.ylabel('$μ(t)$')
plt.xscale('symlog')
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

### Simulate a $k$-SFS under this history
We'll sample 200 haplotypes, plot the first SFS and the CDF of the TMRCA of the sample

In [5]:
n = 200
sfs = kSFS(η, n=n)
sfs.simulate(μ, seed=1)

plt.figure(figsize=(3, 3))
sfs.plot(0)
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

### TMRCA CDF

In [6]:
plt.figure(figsize=(3, 3))
plt.plot(η.change_points, sfs.tmrca_cdf())
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
# plt.xscale('symlog')
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

### Invert the $k$-SFS conditioned on $\eta(t)$ to get $\boldsymbol\mu(t)$

In [23]:
λ_tv = 1e3
α_tv = 0#0.99
λ_r = 1e1
α_r = .99
γ = 0.8
steps = 10000
tol = 1e-10
bins = None
# bins = np.logspace(0, np.log10(n), 5)
with np.errstate(all='raise'):
    μ_inferred = sfs.infer_μ(λ_tv=λ_tv, α_tv=α_tv, λ_r=λ_r, α_r=α_r, γ=γ, steps=steps, tol=tol, bins=bins)

relative change in loss function 4.3e-11 is within tolerance 1e-10 after 1557 steps


In [24]:
plt.figure(figsize=(4, 4))
plt.step(tt, pulse, c='C0', ls='--', lw=2)
plt.step(tt, ramp, c='C1', ls='--', lw=2)
plt.step(tt, flat, c='C2', ls='--', lw=2)
μ_inferred.plot(range(10), c='C0', alpha=0.2, lw=2)
μ_inferred.plot(range(10, 20), c='C1', alpha=0.2, lw=2)
μ_inferred.plot(range(20, 96), c='C2', alpha=0.1, lw=2)
plt.ylabel('$μ(t)$')
plt.xscale('log')
plt.ylim((0, None))
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

In [22]:
# plt.figure(figsize=(6, 2 * Z.shape[1]))
ax1 = ax2 = None
# for i in range(Z.shape[1]):
for i in (0, 10, 20):
# #     ax1 = plt.subplot(Z.shape[1], 2, 2 * i + 1, sharex=ax1)
#     ax1 = plt.subplot(2, 2, 2 * i + 1, sharex=ax1)
#     μ.plot(i, label='true')
#     μ_inferred.plot(i, label='inferred')    
#     plt.ylabel('$μ(t)$')
#     plt.xscale('log')
#     plt.ylim((0, None))

#     ax2 = plt.subplot(Z.shape[1], 2, 2 * i + 2, sharex=ax2)
#     ax2 = plt.subplot(2, 2, 2 * i + 2, sharex=ax2)
    plt.figure(figsize=(3, 3))
    sfs.plot(i, μ=μ_inferred, prf_quantiles=True)
#     if i < Z.shape[1] - 1:
#         plt.setp(ax1.get_xticklabels(), visible=False)
#         plt.setp(ax2.get_xticklabels(), visible=False)
#         ax1.set_xlabel(None)
#         ax2.set_xlabel(None)
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [14]:
plt.figure(figsize=(3, 3))
plt.bar(range(μ.vals.shape[1]), np.linalg.svd(μ_inferred.vals, compute_uv=False))
# plt.yscale('log')
plt.show()

<IPython.core.display.Javascript object>