In [None]:
from itertools import chain

import pandas as pd
import sqlite3
from sklearn.cross_decomposition import PLSCanonical
from sklearn.mixture import BayesianGaussianMixture
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from matplotlib import animation
from IPython.display import HTML

## Constants

In [None]:
color_map = {'acarbose': 'goldenrod', 'control': 'darkblue',
             'UM': 'darkblue', 'UT': 'darkgreen',
             'male': 'blue', 'female': 'magenta',
             'C2013': 'blue', 'Glenn': 'red'}

## Load Data

In [None]:
# Select the data

con = sqlite3.connect('res/core.1.denorm.db')

# Relative Abundance
rrs_count = (pd.read_sql('SELECT * FROM rrs_taxon_count;',
                         con=con, index_col=['extraction_id', 'sequence_id'])
               .tally.unstack().fillna(0).astype(int))
rabund = rrs_count.apply(lambda x: x / x.sum(), axis=1)

# Coverage
bin_cvrg = (pd.read_sql("""
SELECT bin_id, extraction_id, SUM(coverage) AS coverage
FROM bin_coverage
JOIN library USING (library_id)
GROUP BY bin_id, extraction_id;""",
                        con=con, index_col=['extraction_id', 'bin_id'])
              .coverage.unstack().fillna(0).apply(lambda x: x / x.sum(), axis=1))

# Only keep shared extractions
extractions = set(rabund.index) & set(bin_cvrg.index)
rabund = rabund.loc[extractions]
bin_cvrg = bin_cvrg.loc[extractions]

# Phylotypes
phylotype = pd.read_sql('SELECT sequence_id, otu_id FROM rrs_taxon_count GROUP BY sequence_id;',
                       con=con, index_col='sequence_id')
name_map = {}
for otu, d in (pd.DataFrame({'mean_rabund': rabund.mean(),
                             'otu_id': phylotype.otu_id})
                 .sort_values('mean_rabund',
                              ascending=False)
                 .groupby('otu_id')):
    for i, sequence_id in enumerate(d.index, start=1):
        name_map[sequence_id] = '{}_{}'.format(otu, i)
phylotype['name'] = pd.Series(name_map)
phylotype['mean_rabund'] = rabund.mean()

contig_bin = pd.read_sql("SELECT * FROM contig_bin", con=con, index_col='contig_id')

## OTU Details

In [None]:
taxonomy = pd.read_sql('SELECT sequence_id, phylum_, class_, order_, family_, genus_ FROM taxonomy;',
                       con=con, index_col='sequence_id').rename(phylotype.name)

In [None]:
# Select abundant taxa and bins
# TODO: Set these threshold as parameters
major_taxa = phylotype.index[phylotype.mean_rabund > 0.0001]
major_bins = bin_cvrg.columns[bin_cvrg.mean() > 0.0001]
d_rabund = rabund[major_taxa].copy()
d_rabund['other'] = rabund.drop(columns=major_taxa).sum(1)
d_rabund.rename(columns=phylotype.name, inplace=True)
d_cvrg = bin_cvrg[major_bins].copy()
d_cvrg['other'] = bin_cvrg.drop(columns=major_bins).sum(1)

d_rabund.shape, d_cvrg.shape

In [None]:
d_rabund.mean().to_frame(name='mean_rabund').join(taxonomy).sort_values('mean_rabund', ascending=False).loc[['Otu0058_1', 'Otu0041_1']]

In [None]:
d_rabund.mean().to_frame(name='mean_rabund').join(taxonomy).sort_values('mean_rabund', ascending=False)

## Metabinning

Find combinations of bins that are likely to have contigs from the same genome.

PLSCanonical is a version of CanonicalCorrespondanceAnalysis in which two different matrices
are cross-decomposed.
We're modeling both the relative abundance of each OTU and the coverage of each bin as different
transformation of the same underlying latent space (actual genome density).
We then use the cross-product of the two loading vectors to get an aggregate estimate of the
relationship between bins and OTUs.
The key parameter in this machine-learning approach is **n_components**, in other words the
complexity of the underlying latent state.
With a small value of **n_components** similar OTUs will get more and more similar, potentially
introducing noise into the metabin.

In [None]:
pls_fit = PLSCanonical(scale=False, n_components=40).fit(d_cvrg.apply(np.sqrt), d_rabund.apply(np.sqrt))
bin_otu_contrib = pd.DataFrame((pls_fit.x_loadings_ @ pls_fit.y_loadings_.T),
                       index=d_cvrg.columns, columns=d_rabund.columns).rename(columns=phylotype.name)

Now we'll plot the most important results.
The bin hits for a given taxon are identified as the bin with the highest "contribution score" in that column,
along with all other bins with a score greater than some fraction, **factor** of that score.

In [None]:
tax_filter = lambda x: x.quantile(0.95) > 0.01

taxa_of_interest = sorted(d_rabund.loc[:, tax_filter].rename(columns=phylotype.name).columns)
if 'other' in taxa_of_interest:
    del taxa_of_interest[taxa_of_interest.index('other')]
print(len(taxa_of_interest))

factor = 1/3

_hits = {}
for tax in taxa_of_interest:
    top_score = bin_otu_contrib[tax].max()
#    print(tax, top_score)
    _hits[tax] = list((bin_otu_contrib[tax].sort_values(ascending=False) > top_score * factor)[lambda x: x].index)
    
#print()
#for tax in _hits:
#    print(tax, _hits[tax])

all_hits = set(chain(*_hits.values()))

for_plotting = bin_otu_contrib.loc[all_hits, taxa_of_interest].rename(columns=phylotype.name)
for_plotting[for_plotting < 0] = 0
a = sns.clustermap(for_plotting, robust=True,
                   figsize=(14, 18), col_cluster=False, cmap='Reds')

ax = a.fig.get_axes()[2]
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

print()

## Metabin Refinement

Now I need to do some filtering of the contigs in the metabins that were grouped together.
That's because the binning was inherently noisy and may include sequence from other genomes.
The challenge is figuring out how to identify contigs to be removed.
I can't just go on overall coverage because I may get false negatives, for instance sequences that are only present in some strains, sequences that are missassembled, or sequences at higher copy numbers (e.g. 16S).
I might also get some false positives,
sequences that coincidentally have mean coverage close to the expected level.
I also can't use raw coverage covariance across libraries
(although this would fix many of the above problems), because I'll get false negatives
for genome components that are not universally present (the "accessory genome").

The approach I'm actually going to take will look for clusters of contigs that behave similarly
across samples.  That way I'm not saying that everything will behave like the core genome,
but I'm also relying on coverage covariance as a powerful indicator of shared genome membership.

I will then filter these groups of contigs based on how much they look like some other organism.
Looking like some other organism means that they have large swings in coverage across samples
while still having consistent coverage across contigs.

In terms of a statistic to measure this, I'll use the ratio of standard-deviation of mean contig abundance
(higher when different samples look very different)
to the mean standard-deviation of contig abundance (lower when the group of contigs are all acting the same
in each sample).
Since this has a tendency to exclude contig groups that don't have a lot of sequence (and are therefore
noisier), I'll adjust this ratio by multiplying it by the log total contig length.

### OTU-1

Grab the set of bins that were associated with any of the top four Otu0001 phylotypes.

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0001_1', 'Otu0001_2', 'Otu0001_3', 'Otu0001_4']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

Extract the necessary data and metadata.

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.mapping_count, axis=0).loc[d_rabund.index]

I need to normalize the coverages since each sample has different amounts of
each genome.
My approach to normalization is to normalize to the coverage of sequence that is
ubiquitous and in a single copy in all genomes (i.e. "core" sequence).
The result is a good estimate of copies per core-genome (which could be
either greater than (multiple copies) or less than 1 (contig is missassembled),
but is frequently very close to 1).

I estimate this coverage by finding contigs that are indisputably
in the "core genome" and normalize everything to the mean coverage of these contigs.
The way I find these "core" contigs is by picking a "seed" contig and then looking for
all of the contigs with coverages that are SUPER closely correlated with the seed (like r > 0.99).
How I pick the seed is more of an art.
Right now I'm picking the seed by finding the contig with the highest correlation to the
16S relative abundance.

In [None]:
plt.scatter(cvrg['core-k161_641135']*1e6, d_rabund[otus].sum(1).loc[cvrg.index])

In [None]:
cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0]).sort_values().tail()

In [None]:
rrs_corr = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0]).sort_values()
rrs_corr.name = 'rrs_corr'
contig_meta.join(rrs_corr).sort_values('rrs_corr')[lambda x: x.length > 10000].tail()

In [None]:
seed = 'core-k161_375476'
assert seed in contig_ids
#assert compare in contig_ids
plt.scatter(cvrg[seed]*1e6, d_rabund[otus].sum(1).loc[cvrg.index])
#plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 2].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage**2   # * np.sqrt(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 20
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
HTML(anim.to_html5_video())

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.3, 0), (0, 0.7, 0), (0, 1, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
drop_groups = []  #[0, 12, 5]

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: (x.contamination_score < contam_threshold) &
                                           ~x.index.isin(drop_groups)].index])

vmin, vmax, cmap, norm = 0, 8, cm, SymLogNorm(linthresh=1, linscale=0.9)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)


In [None]:
#with open('res/core.a.mags.d/OTU-1-UM.contigs.list', 'w') as handle:
#    for contig_id in group_assign[lambda x: x.group.isin([1, 9, 14, 4, 6, 15])].index:
#        print(contig_id, file=handle)
        
with open('res/core.a.mags.d/OTU-1.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([1, 3, 7, 13, 14, 17, 2, 6, 8, 15, 17, 19])].index:
        print(contig_id, file=handle)
                
#with open('res/core.a.mags.d/OTU-1-UT.contigs.list', 'w') as handle:
#    for contig_id in group_assign[lambda x: x.group.isin([1, 9, 14, 3, 15, 16, 19])].index:
#        print(contig_id, file=handle)