1KG
==
Run mushi on 3-SFS computed from 1000 Genome Project data

In [None]:
%matplotlib inline 
# %matplotlib notebook
from mushi import kSFS
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import seaborn as sns
import pickle
import composition as cmp

In [None]:
# set this to e.g. your Downloads folder path if you want plots saved to pdfs
plot_dir = '/Users/williamdewitt/Downloads/'

# plt.style.use('dark_background')

mpl.rc('text', usetex=True)
mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

Parse population manifest

In [None]:
pops = {}
with open('example_data/integrated_call_samples_v3.20130502.ALL.panel') as f:
    f.readline()
    for line in f:
        pop, super_pop = line.split('\t')[1:3]
        if super_pop not in pops:
            pops[super_pop] = []        
        if pop not in pops[super_pop]:
            pops[super_pop].append(pop)

### Load 1KG 3-SFSs

In [None]:
sorted_triplets = [f'{a5}{a}{a3}>{a5}{d}{a3}' for a in 'AC' for d in 'ACGT' if d != a for a5 in 'ACGT' for a3 in 'ACGT']


ksfs_dict = {}
plt.figure(figsize=(8, 8))
for super_pop in pops:
    for pop in pops[super_pop]:
        ksfs_df = pd.read_csv(f'example_data/{pop}/3-SFS.tsv', sep='\t', index_col=0)
        ksfs_dict[pop] = kSFS(X=ksfs_df.values, mutation_types=ksfs_df.columns)
        ksfs_dict[pop].plot_total(kwargs=dict(ls='', alpha=0.75, marker='o', ms=5, mfc='none', label=pop))

        foo, bar = ksfs_dict[pop].mutation_types.reindex(sorted_triplets)
        ksfs_dict[pop].mutation_types = foo
        ksfs_dict[pop].X = ksfs_dict[pop].X[:, bar]
plt.legend()
plt.show()

Number of segregating variants in each super population

In [None]:
for super_pop in pops:
    print(super_pop, sum(ksfs_dict[pop].X.sum() for pop in pops[super_pop]))

clip high frequencies due to ancestral state misidentification

In [None]:
clip_low = 0
clip_high = 10
# we need a different mask vector for each population becuase the number of haplotypes n
# (length of SFS vector) varies
freq_mask = {}
for super_pop in pops:
    for pop in pops[super_pop]:
        freq_mask[pop] = np.array([True if (clip_low <= i < ksfs_dict[pop].n - clip_high - 1) else False
                                   for i in range(ksfs_dict[pop].n - 1)])

time grid of epoch boundaries (measured in generations)

In [None]:
change_points = np.logspace(np.log10(1), np.log10(200000), 200)

masked genome size (excluding conserved sites, repeats, 1KG stict mask, and uncertain ancestral states)

In [None]:
with open('example_data/masked_size.tsv') as f:
    masked_genome_size = int(f.read())

mutation rate per site per generation

In [None]:
u = 1.3e-8

mutation rate per masked genome per generation

In [None]:
mu0 = u * masked_genome_size

generation time for time calibration

In [None]:
t_gen = 29

# Infer $\eta(t)$

regularization paramaters and convergence criteria

In [None]:
regularization_eta = dict(alpha_tv=1e2, alpha_spline=1e4, alpha_ridge=1e-3)
convergence = dict(tol=1e-10, max_iter=1000)

In [None]:
fig, axes = plt.subplots(1, 2, sharex='col', figsize=(10, 5))
for idx_super, super_pop in enumerate(pops):    
    for idx, pop in enumerate(pops[super_pop]):
        print(pop)
        # clear solutions, in case rerunning this cell
        ksfs_dict[pop].clear_eta()
        ksfs_dict[pop].clear_mu()
        ksfs_dict[pop].infer_history(change_points, mu0, infer_mu=False,
                                     loss='prf', **regularization_eta,
                                     **convergence, mask=freq_mask[pop])
        plt.sca(axes[0])
        ksfs_dict[pop].plot_total(kwargs=dict(ls='', alpha=0.5, marker='o', ms=5, mfc='none', c=f'C{idx_super}', label=super_pop if idx == 0 else None),
                                  line_kwargs=dict(c=f'C{idx_super}', ls=':', marker='.', ms=3, alpha=0.5, lw=1),
                                  fill_kwargs=dict(color=f'C{idx_super}', alpha=0))
        plt.legend(fontsize=6)
        if idx_super < len(pops) - 1:
            plt.xlabel(None)
        plt.sca(axes[1])
        ksfs_dict[pop].eta.plot(t_gen=t_gen, lw=2, label=super_pop if idx == 0 else None, alpha=0.5, c=f'C{idx_super}')
        plt.xlim([1e3, 1e6])
        plt.legend(fontsize=6)
        if idx_super < len(pops) - 1:
            plt.xlabel(None)
        break
plt.tight_layout()
if plot_dir:
    plt.savefig(f'{plot_dir}/1KG.eta.pdf')
plt.show()

In [None]:
regularization_mu = dict(beta_tv=7e1, beta_ridge=1e-10)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5.8, 2.5))

for idx, pop in enumerate(pops['EUR']):
    print(pop)
    ksfs_dict[pop].clear_mu()
    ksfs_dict[pop].infer_history(change_points, mu0, infer_eta=False,
                                 loss='prf', **regularization_mu,
                                 **convergence, mask=freq_mask[pop])
    plt.sca(axes[0])
    ksfs_dict[pop].plot(('TCC>TTC',), clr=True,
                        kwargs=dict(ls='', c=f'C{idx}', marker='o', ms=5, mfc='none', alpha=0.5, label=pop),
                        line_kwargs=dict(c=f'C{idx}', ls=':', marker='.', ms=3, alpha=0.5, lw=1))
    plt.ylabel('TCC$\\to$TTC component of\nvariant count composition')
    plt.legend(fontsize=6)

    plt.sca(axes[1])
    plt.gca().set_prop_cycle(None)
    ksfs_dict[pop].mu.plot(('TCC>TTC',), t_gen=t_gen, clr=False, c=f'C{idx}', alpha=0.75, lw=2, label=pop)
    plt.ylabel(r'TCC$\to$TTC mutation intensity')
    plt.xlim([1e3, 1e6])
    plt.legend(fontsize=6)

plt.tight_layout()
if plot_dir:
    plt.savefig(f'{plot_dir}/europulse.pdf', dpi=300)
plt.show()

# Infer $\boldsymbol\mu(t)$

In [None]:
regularization_mu = dict(hard=True, beta_rank=1e2, beta_tv=0, beta_spline=1e4, beta_ridge=1e-10)

Fit all populations and plot heatmaps of $k$-SFS and mush

In [None]:
with mpl.rc_context(rc={'text.usetex': False}):

    center = 1 / len(sorted_triplets) # for vmin/vmax

    for idx_super, super_pop in enumerate(pops):    
        for idx, pop in enumerate(pops[super_pop]):
            print(pop)

            ksfs_dict[pop].clear_mu()
            ksfs_dict[pop].infer_history(change_points, mu0, infer_eta=False,
                                         loss='prf', **regularization_mu,
                                         **convergence, mask=freq_mask[pop])

            plt.figure(figsize=(8, 10))
            plt.subplot(131)            
            ksfs_dict[pop].plot(clr=True, kwargs=dict(alpha=0.25, ls='', marker='o',
                                                      ms=3, mfc='none', rasterized=True),
                                line_kwargs=dict(ls=':', marker='.', ms=2, alpha=0.25,
                                                 lw=1, rasterized=True))            
            plt.subplot(132)
            ksfs_dict[pop].μ.plot(t_gen=t_gen, clr=True, alpha=0.5, lw=2)
            plt.xscale('log')

            plt.subplot(133)
            σ = np.linalg.svd(cmp.ilr(ksfs_dict[pop].mu.Z), compute_uv=False)
            x = np.arange(1, len(σ) + 1)
            plt.scatter(x, σ)
            plt.plot(x, σ, 'ko', ms=5, mfc='none', mew=.1)
            plt.xscale('log')
            plt.yscale('log')
            plt.xlabel('singular value rank')
            plt.ylabel('singular value')
            plt.tight_layout()
            plt.show()            
            
            singlets = ksfs_dict[pop].mutation_types.str[1].str.cat(ksfs_dict[pop].mutation_types.str[5], sep='>')
        #     a5 = ksfs_dict[pop].mutation_types.str[0]
        #     a3 = ksfs_dict[pop].mutation_types.str[2]

            col_map = {'A>C':'C0', 'A>G':'C1', 'A>T':'C2', 'C>A':'C3', 'C>G':'C4', 'C>T':'C5'}
            col_colors = [col_map[singlet] for singlet in singlets]

            g = ksfs_dict[pop].clustermap(figsize=(20, 7), col_cluster=False,
                                          xticklabels=True, rasterized=True,
                                          vmin=center / 1.7, vmax=1.7 * center,
                                          cmap='RdBu_r',
                                          col_colors=col_colors)
            g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 9, family='monospace')
        #     g.ax_col_dendrogram.set_visible(False)
            if plot_dir:
                g.savefig(f'{plot_dir}/heatmap.{pop}.X.pdf')
            plt.show()

            g = ksfs_dict[pop].μ.clustermap(t_gen=t_gen,
                                            figsize=(20, 7), col_cluster=False, xticklabels=True, rasterized=True,
                                            vmin=center / 1.5, vmax=1.5 * center,
                                            cmap='RdBu_r',
                                            col_colors=col_colors)
            g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 9, family='monospace')
        #     g.ax_col_dendrogram.set_visible(False) 
            if plot_dir:
                g.savefig(f'{plot_dir}/heatmap.{pop}.mu.pdf')
            plt.show()

### TMRCA CDF

In [None]:
fig, axes = plt.subplots(1, 5, sharey=True, figsize=(18, 4))
for idx, super_pop in enumerate(pops):
    plt.sca(axes[idx])
    plt.title(super_pop)
    for idx2, pop in enumerate(pops[super_pop]):
        plt.plot(ksfs_dict[pop].η.change_points, ksfs_dict[pop].tmrca_cdf(ksfs_dict[pop].eta), label=pop)
        plt.xlabel('$t$')
        plt.ylabel('TMRCA CDF')
        plt.ylim([0, 1])
        plt.xscale('log')
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# from sklearn.cluster import MeanShift

# time = t_gen * np.concatenate(([0], ksfs_dict[pop].μ.change_points))

# Z_dict = {}
# for pop in ksfs_dict:
#     Z_dict[pop] = cmp.ilr(ksfs_dict[pop].μ.Z)

# clusterer = MeanShift(bandwidth=.1)
# for row in range(len(time)):
#     Z = np.concatenate(tuple(Z_dict[pop][np.newaxis, row, :] for pop in ksfs_dict))
#     labels = clusterer.fit_predict(Z)
#     print(max(labels))

In [None]:
t_gen * np.concatenate(([0], ksfs_dict[pop].μ.change_points))[[58, -30]]

start = 58
end = -30

len(np.concatenate(([0], ksfs_dict[pop].μ.change_points))[58:-30])

In [None]:
from sklearn.decomposition import PCA, KernelPCA
from umap import UMAP

time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))[start:end]

Z_dict = {}
for pop in ksfs_dict:
    Z_dict[pop] = cmp.ilr(ksfs_dict[pop].μ.Z[start:end])
    
Z = np.concatenate(tuple(Z_dict[pop] for pop in ksfs_dict))

# from jax import jit
# @jit
# def kernel(x, y):
#     return cmp.inner(x, y)

np.random.seed(1)

# use one of these, or the one above learned on the expected segregating sites
embedding = UMAP(n_components=2,
                 n_neighbors=30,
                 min_dist=0,
#                  metric='cosine',
                 local_connectivity=.01,
                 n_epochs=1000
                )
# embedding = KernelPCA(n_components=2, kernel='sigmoid', gamma=1e-3)

# embedding = PCA(n_components=2)


embedding.fit(Z)

# each
Z_transform_dict = {pop:embedding.transform(Z_dict[pop]) for pop in ksfs_dict}



with mpl.rc_context(rc={'text.usetex': False}):

    plt.figure(figsize=(4, 4))
    # plt.subplot(311)
    for idx, super_pop in enumerate(pops):
        for idx2, pop in enumerate(pops[super_pop]):
            plt.plot(*Z_transform_dict[pop].T,
                     '-', lw=3, alpha=.5,
                     c=f'C{idx}',
                     label=super_pop if idx2 == 0 else None)
            plt.annotate(pop, Z_transform_dict[pop][0, :],
                         ha='center', va='center', c='w',
                         family='monospace' if pop is not 'Batwa' else None,
                         bbox=dict(boxstyle='circle', fc=f'C{idx}', ec=f'C{idx}', lw=2),
                         size=6)
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend()
    plt.tight_layout()
    if plot_dir:
        plt.savefig(f'{plot_dir}/1KG.umap.pdf')
    plt.show()

    fig = plt.figure(figsize=(5, 5))
    ax = fig.gca(projection='3d')
    ax.view_init(20, -105)
    # plt.subplot(311)
    for idx, super_pop in enumerate(pops):
        for idx2, pop in enumerate(pops[super_pop]):
            ax.plot(*Z_transform_dict[pop].T, np.log10(t_gen * time),
                    '-', lw=3, alpha=.5,
                    c=f'C{idx}',
                    label=super_pop if idx2 == 0 else None)
            ax.scatter(*Z_transform_dict[pop][None, 0, :].T, np.log10(t_gen * time[1]), s=50,
                       c=f'C{idx}', alpha=.5)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_zlabel('$\log_{10}(t)$')
#     ax.set_zlim([3, 6])
    # ax.zaxis._set_scale('log')
    # ax.legend()
    plt.tight_layout()
    if plot_dir:
        plt.savefig(f'{plot_dir}/1KG.umap.3D.pdf')
    plt.show()

    fig, axes = plt.subplots(2, 1, sharex='col', figsize=(5,5))
    for umap_comp in (0, 1):
        for idx, super_pop in enumerate(pops):
            for pop in pops[super_pop]:
                axes[umap_comp].plot(t_gen * time, Z_transform_dict[pop][:, umap_comp],
                         '-', lw=3, alpha=.75,
                         c=f'C{idx}',
                         label=pop)
                axes[umap_comp].annotate(pop, (time[0], Z_transform_dict[pop][0, umap_comp]),
                             ha='center', va='center', c='w',
                             family='monospace' if pop is not 'Batwa' else None,
                             bbox=dict(boxstyle='circle', fc=f'C{idx}', ec=f'C{idx}', lw=2),
                             size=6)            
        axes[umap_comp].set_ylabel(f'UMAP {umap_comp + 1}')
    plt.xlabel('time (years ago)')
    plt.xscale('symlog')
    # plt.legend(loc=(1.04, -1), fancybox=True, framealpha=0)
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(len(pops), sharex='col', figsize=(4, 12))
for idx, super_pop in enumerate(pops):
    plt.sca(axes[idx])
    for idx2, pop in enumerate(pops[super_pop]):
        t, Z = ksfs_dict[pop].mu.arrays()
        delta = np.array([cmp.perturb_inv(Z[i, :], Z[i + 1, :]) for i in range(Z.shape[0] - 1)])
        dmush = [np.sqrt(cmp.inner(delta[i], delta[i])) for i in range(len(delta))]
        dt = np.diff(t[:-1])
        dmushdt = dmush / (t_gen * dt)
        plt.plot(t_gen * t[1:-1], dmushdt, label=pop)
    plt.legend()
plt.xscale('log')
plt.xlabel(r'$t$ (years ago)')
plt.ylabel(r'$\frac{d\text{mush}}{dt}$')
plt.xlim([1e3, 1e6])
plt.tight_layout()
if plot_dir:
    plt.savefig(f'{plot_dir}/1KG.drift.pdf')
plt.show()     
#     break

In [None]:
time = np.concatenate(([0], ksfs_dict[pop].μ.change_points))

embedding = NMF(alpha=1.5e0, n_components=10)#, verbose=False, tol=1e-10, max_iter=10000)
# embedding = PCA(n_components=3)
embedding.fit(np.concatenate(tuple(ksfs_dict[pop].μ.Z.T for pop in ksfs_dict)))

Z_transform_dict = {pop:embedding.transform(ksfs_dict[pop].μ.Z.T) for pop in ksfs_dict}

plt.figure(figsize=(4, 2))
for i in range(embedding.n_components):
    plt.plot(time, embedding.components_[i], label=f'latent history {i + 1}')
plt.xlabel('$t$')
plt.xscale('symlog')
plt.legend(loc='center left', prop={'size': 7.5}, framealpha=.5)
plt.show()

In [None]:
for pop in ksfs_dict:
    print(pop)
    W = embedding.transform(ksfs_dict[pop].μ.Z.T).T

#     # norms = nmf.components_.T.mean(0, keepdims=True)
#     H = nmf.components_.T# / norms
#     # W = norms.T * W

    total_weight = W.sum(1, keepdims=True)
    W = W / total_weight

    df = pd.DataFrame(data=W.T,
                      index=ksfs_dict[pop].μ.mutation_types,
                      columns=range(1, n_components + 1))

    df_min = df.values.min()
    df_max = df.values.max()

    df["5'"] = df.index.str[0]
    df['mutation'] = df.index.str[1].str.cat(df.index.str[5], sep='→')
    df["3'"] = df.index.str[2]

    df = df.melt(id_vars=["5'", 'mutation', "3'"], var_name='latent history')

    g = sns.FacetGrid(df, row='mutation', col='latent history',
                      row_order=('C→A', 'C→G', 'C→T', 'A→G', 'A→C', 'A→T'),
                      margin_titles=True,
                      size=1.5
                      )

    def facet_heatmap(data, color, **kwargs):
        data = data.pivot(index="5'", columns="3'", values='value')
        sns.heatmap(data, **kwargs).invert_yaxis()n

    # colorbar axes
    cbar_ax = g.fig.add_axes([1.1, .3, .05, .4])

    g = g.map_dataframe(facet_heatmap,
                        cbar_ax=cbar_ax,
                        cmap='RdBu_r',
                        center=0,
                        vmin=df_min, vmax=df_max
                        )

    # so the colorbar doesn't overlap the plot
    # g.fig.subplots_adjust(right=.9)
#     plt.savefig('/Users/williamdewitt/Downloads/PC_heatmap.pdf')
    plt.show()

In [None]:
plt.figure()
for pop in ksfs_dict:
    plt.scatter(*Z_transform_dict[pop][:, 1:].T, alpha=0.8)
plt.show()

In [None]:
x = []
y = []
derived = []
context = []
pops = []

for i, triplet in enumerate(sorted_triplets):
    for pop in ksfs_dict:
        xy = Z_transform_dict[pop][i, 1:]
        x.append(xy[0])
        y.append(xy[1])
        pops.append(pop)
        context.append(triplet[:3])
        derived.append(triplet[5])
        
df = pd.DataFrame({'population':pops, 'context':context, 'derived':derived, 'latent history 2':x, 'latent history 3':y})

In [None]:
g = sns.relplot(x='latent history 2', y='latent history 3', row='derived', col='context', data=df, hue='population',
           height=2, aspect=1, alpha=.8).set_titles("{col_name}>{row_name}")
g.savefig('/Users/williamdewitt/Downloads/AFR.signatures.pdf')
plt.show()

In [None]:
sns.catplot(y='latent history 3', x='singlet', row='context', data=df, hue='population')
plt.show()