1KG
==
Run mushi on 3-SFS computed from 1000 Genome Project data

In [None]:
# %matplotlib inline 
%matplotlib notebook
import histories
import mushi
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.special import expit
import time
import msprime
%cd stdpopsim
from stdpopsim import homo_sapiens
%cd ../
import dadi
from sklearn.decomposition import PCA

In [None]:
# plt.style.use('dark_background')

### Load 1KG 3-SFS

In [None]:
ksfs_CEU = pd.read_csv('1KG/scons_output/3-SFS.tsv', sep='\t', index_col=0)

mutation_types = ksfs_CEU.columns
ksfs_CEU

n = ksfs_CEU.shape[0] + 1

Rank plot of the number of SNPs of each triplet mutation type

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(ksfs_CEU.sum().sort_values(ascending=False).to_frame(), '.')
plt.xticks(rotation='vertical', family='monospace')
plt.ylabel('number of SNPs')
plt.yscale('symlog')
plt.tight_layout()
plt.savefig('/Users/williamdewitt/Downloads/foo.png')
plt.show()

In [None]:
sfs = np.concatenate(([-1], ksfs_CEU.values.sum(1), [-1]))
mask = [False if 1 < i < 180 else True for i in range(len(sfs))]
# mask=None
fs = dadi.Spectrum(sfs, mask=mask)
params = (0.0748317, 3.23003, 0.0890549)
# 0.221412   ,  1.0319     ,  0.102472
# 0.219817   ,  0.991569   ,  0.107134
lb = (0, 0, 0)
# ub = (1., 100, 1.)
pts_l = [100, 500, 1000]
# params = dadi.Misc.perturb_params(params, fold=2, lower_bound=lb, upper_bound=ub)
func = dadi.Demographics1D.bottlegrowth
func_ex = dadi.Numerics.make_extrap_log_func(func)

params = dadi.Inference.optimize(params, fs,
                                     func_ex, pts_l,
                                     epsilon=1e-6, gtol=1e-10, maxiter=np.inf,
                                     lower_bound=lb,
#                                      upper_bound=ub,
                                     verbose=True)[0]
model = func_ex(params, (n,), pts_l)

theta = dadi.Inference.optimal_sfs_scaling(model, fs)

plt.figure()
dadi.Plotting.plot_1d_comp_multinom(model, fs)
plt.show()

In [None]:
t = np.logspace(0, 7, 500)

In [None]:
nuB, nuF, T = params

# define mu = 1, so 4Na = theta
Na = theta / 4

y = []
for tt in np.concatenate(([0], t / (2 * Na))):
    if tt < T:
        y.append(2 * Na * nuF * np.exp(-np.log(nuF / nuB) * tt / T))
    else:
        y.append(2 * Na)

η = histories.η(t, np.array(y))

plt.figure(figsize=(3, 3))
η.plot()
# plt.axhline(2 * scaling, c='k', ls='--')
# plt.axhline(2 * scaling * nuB, c='k', ls='--')
# plt.axhline(2 * scaling * nuF, c='k', ls='--')
plt.show()

### Mushi $k$-SFS object conditioned on this demographic history

In [None]:
ksfs = mushi.kSFS(η, X=ksfs_CEU.values, mutation_types=mutation_types)

### TMRCA CDF

In [None]:
plt.figure(figsize=(3, 3))
plt.plot(η.change_points, ksfs.tmrca_cdf())
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('symlog')
plt.tight_layout()
plt.show()

### Mutation type enrichment as a heatmap with correlation clustering

In [None]:
ksfs.clustermap(figsize=(25, 10))
# plt.savefig('/Users/williamdewitt/Downloads/1KG_heatmap.pdf', transparent=True)
plt.show()

### Invert the $k$-SFS conditioned on $\eta(t)$ to get $\boldsymbol\mu(t)$
Accelerated proximal gradient descent

In [None]:
mask = np.array([False if (0 < i < 179) else True for i in range(n - 1)])

In [None]:
# first check if η fits the total SFS
sfs_total = mushi.kSFS(η, ksfs.X.sum(1, keepdims=True))
μ_constant = sfs_total.constant_μ_MLE(mask=mask)
plt.figure(figsize=(3, 3))
sfs_total.plot1(0, μ=μ_constant, prf_quantiles=True)
plt.tight_layout()
plt.show()

In [None]:
μ, f_trajectory = ksfs.infer_μ(# loss function parameters
                                        fit='prf',
                                        mask=mask,
                                        bins=None,
                                        # time derivative regularization parameters
                                        λ_tv=1e6,
                                        α_tv=0,
                                        # spectral regularization parameters
                                        λ_r=2e4,
                                        α_r=1-1/2e4,
                                        hard=True,                                        
                                        # convergence parameters
                                        max_iter=10000,
                                        tol=1e-10,
                                        γ=0.8)

Convergence

In [None]:
plt.figure(figsize=(4, 2))
plt.plot(f_trajectory)
plt.xlabel('iterations')
plt.ylabel('cost')
plt.xscale('symlog')
plt.tight_layout()
plt.show()

### Singular value spectrum of $Z$

In [None]:
plt.figure(figsize=(3, 3))
plt.bar(range(μ.Z.shape[1]), np.linalg.svd(μ.Z, compute_uv=False))
plt.yscale('log')
plt.tight_layout()
plt.show()

The inferred histories for each mutation type (raw mutation rate in units of mutations per genome per generation)

In [None]:
pulse_types = ('TCC>TTC', 'CCC>CTC', 'ACC>ATC', 'TCT>TTT')

In [None]:
plt.figure(figsize=(6, 2.5))
plt.subplot(121)

ksfs.plot(μ=μ, alpha=0.02, c='k', lw=2, normed=True)
for color_idx, mut_type in enumerate(pulse_types):
    ksfs.plot(mut_type, μ=μ, lw=2, c=f'C{color_idx}', normed=True, label=mut_type)
plt.legend(loc=1, prop={'size': 8}, framealpha=0, edgecolor='k')
plt.xlim([2, None])

plt.subplot(122)
μ.plot(alpha=0.1, lw=1, c='k', normed=True)
for mut_type in pulse_types:
    μ.plot(types=[mut_type], lw=2, normed=True, label=mut_type)
plt.legend().remove()#loc=2, prop={'size': 7.5})

plt.savefig('/Users/williamdewitt/Downloads/1KG.pdf', transparent=True)
plt.show()

Heatmap of the inferred mutation spectrum history, plotted as relative mutation intensity as in Harris and Pritchard

In [None]:
# μ.clustermap(figsize=(25, 10))
# # plt.savefig('/Users/williamdewitt/Downloads/cluster.png', transparent=False)
# plt.show()

plot $\chi^2$ goodness of fit for each $k$-SFS matrix element, and compute $\chi^2$ goodness of fit test for the $k$-SFS matrix as a whole

In [None]:
# ksfs.clustermap(μ, figsize=(25, 10), cmap='Reds')
# plt.show()

Plot SFS fit for the first 10 mutation types individually

In [None]:
# plt.figure(figsize=(3, 3))
# ksfs.plot1('TCC>TTC', μ=μ, prf_quantiles=True)
# plt.tight_layout()
# plt.show()

In [None]:
pca = PCA(n_components=3, whiten=False).fit(μ.Z.T)

plt.figure(figsize=(6, 2.5))
plt.subplot(121)

plt.plot(np.concatenate(([0], t)), pca.components_[0], label='principle vector 1')
plt.plot(np.concatenate(([0], t)), pca.components_[1], label='principle vector 2')
plt.plot(np.concatenate(([0], t)), pca.components_[2], label='principle vector 3')
# plt.plot(np.concatenate(([0], t)), pca.components_[3], label='principle vector 4')
# plt.plot(np.concatenate(([0], t)), pca.components_[4], label='principle vector 5')
# plt.plot(np.concatenate(([0], t)), pca.components_[5], label='principle vector 6')
# plt.plot(np.concatenate(([0], t)), pca.components_[6], label='principle vector 7')
# plt.plot(np.concatenate(([0], t)), pca.components_[7], label='principle vector 8')
# plt.plot(np.concatenate(([0], t)), pca.components_[8], label='principle vector 9')
# plt.plot(np.concatenate(([0], t)), pca.components_[9], label='principle vector 10')
plt.xlabel('$t$')
plt.xscale('log')
plt.legend(loc='lower left', prop={'size': 7.5}, framealpha=.5)

plt.subplot(122)
plt.scatter(*pca.transform(μ.Z.T).T[1:3, :], c='k', alpha=0.2, s=20)
for mut_type in pulse_types:
    plt.scatter(*pca.transform(μ.Z.T).T[1:3, μ.mutation_types.get_loc(mut_type)], label=mut_type, s=20)
plt.xlabel('PC 2')
plt.ylabel('PC 3')
plt.legend(loc='lower right', prop={'size': 8}, framealpha=.5)
plt.tight_layout()
plt.savefig('/Users/williamdewitt/Downloads/1KG_PC.pdf', transparent=True)
plt.show()

In [None]:
pca.explained_variance_

In [None]:
import seaborn as sns

col_colors = ['red' if mut_type in pulse_types else 'grey' for mut_type in μ.mutation_types]

Z_pca = pca.transform(μ.Z.T).T
df = pd.DataFrame(data=Z_pca[1:], index=range(2, Z_pca.shape[0] + 1),
                  columns=μ.mutation_types)
g = sns.clustermap(df, center=0, col_colors=col_colors, method='ward',
                   cbar_kws={'label': 'PC weight'})