In [None]:
import matplotlib as mpl
from sklearn.decomposition import PCA
from sklearn.mixture import BayesianGaussianMixture
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans


#!/usr/bin/env python3
import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np
import theano.tensor as tt
import sys

tt_simplex_normalize = lambda x: x / x.sum(1).reshape((x.shape[0], 1))
tt_harmonic_mean = lambda x: 1 / tt.mean(1 / x)
tt_generalized_mean = lambda x, r=1: tt.mean(x ** r) ** (1 / r)

info = lambda s, *args: print(s % args, file=sys.stderr)

In [None]:
seq_cvrg_path = 'data/sim/gut1_ecoli_conspecific.n4e6.z32.s00.a.backmap.cvrg.tsv'
nlength_path = 'data/sim/gut1_ecoli_conspecific.n4e6.z32.s00.a.nlength.tsv'
tax_cvrg_path = 'data/sim/gut1_ecoli_conspecific.n4e6.z32.s00.a.tax_counts.tsv'

minlen = 200
mintaxrabund = 0
strain_reg_param1 = 1/3
strain_reg_param2 = 10
tax_fuzz_param = 1e-6
seq_fuzz_param = 1e-6
tax_uncertainty_param = 0.01

info("Starting latent strain analysis.")
info("seq_cvrg_path: %s", seq_cvrg_path)
info("nlength_path: %s", nlength_path)
info("tax_cvrg_path: %s", tax_cvrg_path)
info("minlen: %f", minlen)
info("mintaxrabund: %f", mintaxrabund)
info("strain_reg_param1: %f", strain_reg_param1)
info("strain_reg_param2: %f", strain_reg_param2)
info("tax_fuzz_param: %f", tax_fuzz_param)
info("seq_fuzz_param: %f", seq_fuzz_param)
info("tax_uncertainty_param: %f",tax_uncertainty_param)

# Load data
_seq_cvrg = pd.read_table(seq_cvrg_path,
                      names=['sample_id', 'sequence_id', 'tally'],
                      index_col=['sample_id', 'sequence_id'], squeeze=True
                     ).unstack('sequence_id', fill_value=0)
_seq_len = pd.read_table(nlength_path,
                        names=['contig_id', 'nlength'], index_col=['contig_id'], squeeze=True)
_tax_count = pd.read_table(tax_cvrg_path,
                      names=['sample_id', 'taxon_id', 'tally'],
                      index_col=['sample_id', 'taxon_id'], squeeze=True
                     ).unstack('taxon_id', fill_value=0)

# Align tables and drop low coverage dimensions.
seq_cvrg = _seq_cvrg.reindex(columns=_seq_len[_seq_len > minlen].index).dropna(axis='columns')
seq_len = _seq_len.loc[seq_cvrg.columns]
tax_count = _tax_count.loc[seq_cvrg.index, _tax_count.divide(_tax_count.sum(1), axis=0).max() > mintaxrabund]
# Scale to nucleotide counts
seq_count = seq_cvrg.multiply(seq_len).round().astype(int)

#
# assert (seq_count.index == tax_count.index).all(), "Sequence and taxon table indices must be aligned."
# assert (seq_len.index == seq_count.columns).all(), "Sequence and seq_len tables must be aligned."

n_samples, g_seqs = seq_count.shape
t_taxa = tax_count.shape[1]
s_strains = int(np.ceil(1.5 * t_taxa))
u_tax_counts = tax_count.sum(1).round().astype(int)
v_seq_counts = seq_count.sum(1).round().astype(int)

info("n_samples: %d", n_samples)
info("g_seqs: %d", g_seqs)
info("t_taxa: %d", t_taxa)
info("s_strains: %d", s_strains)
info("mean u_tax_counts: %f", u_tax_counts.mean())
info("mean v_seq_counts: %f", v_seq_counts.mean())

In [None]:
kmer_path = 'data/sim/gut1_ecoli_conspecific.n4e6.z32.s00.a.k4.tsv'
_seq_kmer = (pd.read_table(kmer_path, names=['contig_id', 'kmer', 'tally'],
                      index_col=['contig_id', 'kmer'],
                      squeeze=True)
               .unstack(fill_value=0))
seq_kmer = _seq_kmer.reindex(seq_len.index)

seq_kmer = seq_kmer + 1
seq_kmer = seq_kmer.divide(seq_kmer.sum(1), axis=0)
seq_kmer = np.log(seq_kmer).apply(lambda x: (x - x.mean()) / x.std())

In [None]:
thresh = 0.90

pca = PCA().fit(seq_kmer)
plt.figure()
plt.plot(np.cumsum(pca.explained_variance_ratio_))
ncomps = (np.cumsum(pca.explained_variance_ratio_) < thresh).sum() + 1
plt.axhline(thresh)
plt.axvline(ncomps)

kmer_coords = pca.transform(seq_kmer)[:,:ncomps]
plt.figure()
plt.scatter(kmer_coords[:,0], kmer_coords[:,1], s=0.5)

kmer_coords = pd.DataFrame(kmer_coords, index=seq_kmer.index, columns=[f'PC{i}' for i in range(ncomps)])

In [None]:
cvrg_coords = np.log((seq_cvrg + (100 / seq_len)).T).apply(lambda x: (x - x.mean()) / x.std())
plt.scatter(cvrg_coords['gut1_ecoli_conspecific.n4e6.s0000'], cvrg_coords['gut1_ecoli_conspecific.n4e6.s0001'], s=0.1)
#plt.xlim(-10, 200)
#plt.ylim(-10, 200)

In [None]:
raw_coords = cvrg_coords.join(kmer_coords)

In [None]:
km = KMeans(n_clusters=s_strains*10, verbose=2, n_init=1).fit(raw_coords)

In [None]:
seq_clust = pd.Series(km.predict(raw_coords), index=seq_len.index)
#plt.scatter(viz_coords[:,0], viz_coords[:,1], s=0.5, c=seq_clust, cmap=mpl.cm.gist_rainbow)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for ax, (i, j) in zip(axs.flatten(), [(1, 2), (3, 4), (5, 6), (7, 8)]):
    ax.scatter(raw_coords[f'gut1_ecoli_conspecific.n4e6.s000{i}'],
               raw_coords[f'gut1_ecoli_conspecific.n4e6.s000{j}'],
               s=0.1, c=seq_clust, cmap=mpl.cm.gist_rainbow)

In [None]:
clust_count = seq_count.groupby(seq_clust, axis=1).sum()
clust_len = seq_len.groupby(seq_clust).sum()

clust_cvrg = clust_count.divide(clust_len)
#plt.imshow(clust_cvrg, aspect='auto')
#plt.colorbar()

In [None]:
clust_len.sort_values(ascending=False).head()

In [None]:
plt.plot(clust_len.sort_values().values)
#plt.yscale('log')

In [None]:
n_samples, g_clusts = clust_count.shape
t_taxa = tax_count.shape[1]
s_strains = int(np.ceil(1.5 * t_taxa))
u_tax_counts = tax_count.sum(1).round().astype(int)
v_clust_counts = clust_count.sum(1).round().astype(int)

info("n_samples: %d", n_samples)
info("g_clusts: %d", g_clusts)
info("t_taxa: %d", t_taxa)
info("s_strains: %d", s_strains)
info("mean u_tax_counts: %f", u_tax_counts.mean())
info("mean v_clust_counts: %f", v_clust_counts.mean())

with pm.Model() as model:
    # Latent strain abundances
    pi = pm.Dirichlet('pi', a=np.ones(s_strains), shape=(n_samples, s_strains))
    pi_reg = pm.Potential('pi_reg', -strain_reg_param2 * tt_generalized_mean(pi, strain_reg_param1))


    # Strains to taxa
    theta = pm.Dirichlet('theta',
                        a=np.ones(t_taxa) * tax_uncertainty_param / t_taxa,
                        shape=(s_strains, t_taxa))
    expect_tax_frac_unnorm = pi.dot(theta) + tax_fuzz_param
    obs_tax = pm.Multinomial('obs_tax',
                            n=u_tax_counts,
                            p=expect_tax_frac_unnorm,
                            shape=(n_samples, t_taxa),
                            observed=tax_count.values)

    # Strains to clusts
    # phi = pm.Exponential('phi', lam=1, shape=(s_strains, g_clusts))  # Too much mass at intermediate values
    phi = pm.Weibull('phi', alpha=0.1, beta=0.1, shape=(s_strains, g_clusts))
    expect_clust_frac_unnorm = pi.dot(phi) * clust_len + seq_fuzz_param


    obs_clust = pm.Multinomial('obs_clust',
                            n=v_clust_counts,
                            p=expect_clust_frac_unnorm,
                            shape=(n_samples, g_clusts),
                            observed=clust_count.values)

    # Save intermediate values:
    pm.Deterministic('expect_tax_frac', tt_simplex_normalize(expect_tax_frac_unnorm))
    pm.Deterministic('expect_clust_frac', tt_simplex_normalize(expect_clust_frac_unnorm))


advi = pm.fit(int(1.5e5), model=model)
info("Finished fitting model with ADVI.")
advi_mean = advi.bij.rmap(advi.mean.eval())

strain_names = [f's{i:03}' for i in range(s_strains)]

theta_est = model.theta.eval({model.theta_stickbreaking__: advi_mean['theta_stickbreaking__']})
phi_est = model.phi.eval({model.phi_log__: advi_mean['phi_log__']})
pi_est = model.pi.eval({model.pi_stickbreaking__: advi_mean['pi_stickbreaking__']})

info("DONE")


In [None]:
plt.plot(advi.hist)

In [None]:
_theta_est = theta_est
_phi_est = phi_est
_pi_est = pi_est

theta_est = pd.DataFrame(_theta_est, index=strain_names, columns=tax_count.columns)
phi_est = pd.DataFrame(_phi_est, columns=clust_count.columns, index=strain_names)
pi_est = pd.DataFrame(_pi_est, columns=strain_names, index=clust_count.index)

In [None]:
_ = plt.hist(pi_est.mean(0), bins=100)

In [None]:
# How much sequence per taxonomic category was contributed by each strain?
# Given a single taxon, scale the gene content of each strain by how likely it is
# to be in that taxon, and then scale it by how common the taxon is.

strain_size = (phi_est * clust_len.values).sum(1)

per_tax_contrib = np.empty_like(theta_est)
for i, t in enumerate(tax_count.columns):
    per_tax_contrib[:,i] = (pi_est * theta_est.values[:,i] * strain_size).mean(0)

plt.figure(figsize=(10, 20))
plt.imshow(per_tax_contrib.T, aspect='auto')
plt.yticks(ticks=range(tax_count.shape[1]), labels=tax_count.columns)
plt.colorbar()

In [None]:
sns.clustermap(phi_est.loc[:,(clust_len > 5000)], metric='cosine', standard_scale=0, )

In [None]:
plt.figure(figsize=(10, 10))
sns.heatmap(theta_est)

In [None]:
d = pd.DataFrame(theta_est, columns=tax_count.columns)
d['Escherichia_coli_58110'].sort_values(ascending=False).head()

In [None]:
d0 = pd.DataFrame(phi_est, columns=clust_len.index).T.join(clust_len)

fig, ax = plt.subplots(figsize=(5, 4))
for strain in ['s118', 's064', 's073']:
    cumlen = d0.sort_values(strain, ascending=False).nlength.cumsum()
    d1 = d0[[strain]].copy()
    d1['cumlen'] = cumlen
    d1 = d1.sort_values('cumlen', ascending=False)
    ax.plot(strain, 'cumlen', data=d1, label=strain)

ax.set_xscale('log')
ax.set_xlim(1, 200)
ax.set_ylim(1, 1e7)
ax.legend()
#for x in [0.2, 0.5, 1]:
#    ax.axvline(x=x, color='k', lw=0.5, linestyle='--')

In [None]:
seq_phi_est = seq_clust.apply(lambda x: phi_est[x])

In [None]:
np.log(seq_phi_est).round(1).to_csv('test.tsv', sep='\t')

In [None]:
bins = np.array([0, 2, 5, 10])
seq_phi_est.apply(lambda x: np.digitize(x, bins)).to_csv('test.tsv', sep='\t')

In [None]:
plt.plot(np.log(seq_phi_est).loc['k101_5668'].sort_values().values)