In [None]:
from itertools import chain

import pandas as pd
import sqlite3
from sklearn.cross_decomposition import PLSCanonical
from sklearn.mixture import BayesianGaussianMixture
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import scipy as sp
from matplotlib import animation
from IPython.display import HTML

## Constants

In [None]:
color_map = {'acarbose': 'goldenrod', 'control': 'darkblue',
             'UM': 'darkblue', 'UT': 'darkgreen',
             'male': 'blue', 'female': 'magenta',
             'C2013': 'blue', 'Glenn': 'red'}

## Load Data

In [None]:
# Select the data

con = sqlite3.connect('data/core.1.denorm.db')

# Relative Abundance
rrs_count = (pd.read_sql('SELECT * FROM rrs_taxon_count;',
                         con=con, index_col=['extraction_id', 'sequence_id'])
               .tally.unstack().fillna(0).astype(int))
rabund = rrs_count.apply(lambda x: x / x.sum(), axis=1)

# Coverage
bin_cvrg = (pd.read_sql("""
SELECT bin_id, extraction_id, SUM(coverage) AS coverage
FROM bin_coverage
JOIN library USING (library_id)
GROUP BY bin_id, extraction_id;""",
                        con=con, index_col=['extraction_id', 'bin_id'])
              .coverage.unstack().fillna(0).apply(lambda x: x / x.sum(), axis=1))

# Only keep shared extractions
extractions = set(rabund.index) & set(bin_cvrg.index)
rabund = rabund.loc[extractions]
bin_cvrg = bin_cvrg.loc[extractions]

# Phylotypes
phylotype = pd.read_sql('SELECT sequence_id, otu_id FROM rrs_taxon_count GROUP BY sequence_id;',
                       con=con, index_col='sequence_id')
name_map = {}
for otu, d in (pd.DataFrame({'mean_rabund': rabund.mean(),
                             'otu_id': phylotype.otu_id})
                 .sort_values('mean_rabund',
                              ascending=False)
                 .groupby('otu_id')):
    for i, sequence_id in enumerate(d.index, start=1):
        name_map[sequence_id] = '{}_{}'.format(otu, i)
phylotype['name'] = pd.Series(name_map)
phylotype['mean_rabund'] = rabund.mean()

contig_bin = pd.read_sql("SELECT * FROM contig_bin", con=con, index_col='contig_id')

In [None]:
library = pd.read_table('meta/library.tsv', index_col='library_id')

## OTU Details

In [None]:
taxonomy = pd.read_sql('SELECT sequence_id, phylum_, class_, order_, family_, genus_ FROM rrs_taxonomy;',
                       con=con, index_col='sequence_id').rename(phylotype.name)

In [None]:
# Select abundant taxa and bins
# TODO: Set these threshold as parameters
major_taxa = phylotype.index[phylotype.mean_rabund > 0.0001]
major_bins = bin_cvrg.columns[bin_cvrg.mean() > 0.0001]
d_rabund = rabund[major_taxa].copy()
d_rabund['other'] = rabund.drop(columns=major_taxa).sum(1)
d_rabund.rename(columns=phylotype.name, inplace=True)
d_cvrg = bin_cvrg[major_bins].copy()
d_cvrg['other'] = bin_cvrg.drop(columns=major_bins).sum(1)

d_rabund.shape, d_cvrg.shape

In [None]:
d_rabund.mean().to_frame(name='mean_rabund').join(taxonomy).sort_values('mean_rabund', ascending=False)[lambda x: x.family_ == 'Muribaculaceae'].head(25)

## Metabinning

In [None]:
pls_fit = PLSCanonical(scale=False, n_components=40).fit(d_cvrg.apply(np.sqrt), d_rabund.apply(np.sqrt))
bin_otu_contrib = pd.DataFrame((pls_fit.x_loadings_ @ pls_fit.y_loadings_.T),
                       index=d_cvrg.columns, columns=d_rabund.columns).rename(columns=phylotype.name)

In [None]:
tax_filter = lambda x: x.quantile(0.95) > 0.01

taxa_of_interest = sorted(d_rabund.loc[:, tax_filter].rename(columns=phylotype.name).columns)
if 'other' in taxa_of_interest:
    del taxa_of_interest[taxa_of_interest.index('other')]
print(len(taxa_of_interest))



factor = 1/3

_hits = {}
for tax in taxa_of_interest:
    top_score = bin_otu_contrib[tax].max()
    print(tax, top_score)
    _hits[tax] = list((bin_otu_contrib[tax].sort_values(ascending=False) > top_score * factor)[lambda x: x].index)
    
print()
for tax in _hits:
    print(tax, _hits[tax])

all_hits = set(chain(*_hits.values()))

a = sns.clustermap(bin_otu_contrib.loc[all_hits, taxa_of_interest].rename(columns=phylotype.name), robust=True,
                   figsize=(14, 18), col_cluster=False, cmap='coolwarm', center=0)

ax = a.fig.get_axes()[2]
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

print()

## Metabin Refinement

### B1 / OTU-1 / Otu0001

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/3
otus = ['Otu0001_1', 'Otu0001_2', 'Otu0001_3',
        'Otu0001_4', 'Otu0001_5'
       ]
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00378'].sort_values('length', ascending=False).head()

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '536619', '1830468',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=40,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              max_iter=int(1e3),
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 35
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int).T

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.
# INSTRUCTION: Pick a linscale that best contrasts enriched and depleted groups.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[22, ]))

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)


#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

# "UM" Strain
with open('chkpt/core.a.mags/B1A.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 6, 9, 13, 15, 29])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B1A.g.library.list', 'w') as handle:
    extraction_ids = set(list(cvrg_norm[extraction_meta.site == 'UM'].index))
    extraction_ids -= set([])
    extraction_ids |= set(['EXT-0237', 'EXT-0226', 'EXT-0243', 'EXT-0045', 'EXT-0100', 'EXT-0094', 'EXT-0107'
                          ])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B1A.g.trusted_depth.tsv', sep='\t', header=False)

# UT Strain
with open('chkpt/core.a.mags/B1B.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([6, 9, 13, 21, 29])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B1B.g.library.list', 'w') as handle:
    extraction_ids = set(list(cvrg_norm[extraction_meta.site == 'UT'].index))
    extraction_ids -= set(['EXT-0237', 'EXT-0226', 'EXT-0243', 'EXT-0045', 'EXT-0100', 'EXT-0094', 'EXT-0107',
                           'EXT-0081','EXT-0246', 'EXT-0192'])
#    extraction_ids |= set(['EXT-0421'])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B1B.g.trusted_depth.tsv', sep='\t', header=False)

### B2 / OTU-4 / Otu0007

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/4
otus = ['Otu0007_1', 'Otu0007_2']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00225'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '35357', '1023797',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 50
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

# sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
#                col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d, vmin=vmin, vmax=vmax,
               col_cluster=True, robust=True, cmap=cmap, norm=norm, metric='cosine')

#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B2.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([12, 15, 0, 18])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B2.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B2.g.trusted_depth.tsv', sep='\t', header=False)

### B3 / OTU-? / Otu0009

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/3
otus = ['Otu0009_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00800'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '1621700', '2880790',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 50
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=0.9)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')


#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B3.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([6])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B3.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B3.g.trusted_depth.tsv', sep='\t', header=False)

### B4 / OTU-? / Otu0005

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/3
otus = ['Otu0005_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin01956'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '156806', '2140623',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.95
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 50
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')


#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B4.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 6, 7])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B4.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set(['EXT-0107', 'EXT-0210'])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B4.g.trusted_depth.tsv', sep='\t', header=False)

### B5 / OTU-? / Otu0004

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/3
otus = ['Otu0004_1', 'Otu0004_2']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin01354'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '3759879', '586388',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 50
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')


#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

# UT Strain
with open('chkpt/core.a.mags/B5.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 2, 7, 8, 13, 15])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B5.g.library.list', 'w') as handle:
    extraction_ids = set(list(cvrg_norm[extraction_meta.site == 'UT'].index))
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B5.g.trusted_depth.tsv', sep='\t', header=False)

### B6 / OTU-? / Otu0049

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/5
otus = ['Otu0049_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin01484'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '975319', '3176163',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=50,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 80
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

sns.clustermap(d, vmin=vmin, vmax=vmax,
               col_cluster=True, robust=True, cmap=cmap, norm=norm, metric='cosine')


#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B6.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([22, 4])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B6.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set(['EXT-0405', 'EXT-0238', 'EXT-0045'])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B6.g.trusted_depth.tsv', sep='\t', header=False)

### B7 / OTU-? / Otu0017

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/5
otus = ['Otu0017_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00280'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '1819450', '988141',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=50,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 20
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)


    ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[46, 41, 10])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

# sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
#                col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')
d[extraction_meta.site == 'UM']

#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

# UT Strain
with open('chkpt/core.a.mags/B7.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([36, 25])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B7.g.library.list', 'w') as handle:
    extraction_ids = set(list(cvrg_norm[extraction_meta.site == 'UT'].index))
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B7.g.trusted_depth.tsv', sep='\t', header=False)

## B8 / OTU-? / Otu0013

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/2
otus = ['Otu0013_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00599'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '766897', '2518556',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 20
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)


    ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

# sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
#                col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d, vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')

#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B8.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([1])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B8.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B8.g.trusted_depth.tsv', sep='\t', header=False)

## B9 / OTU-? / Otu0105

In [None]:
# INSTRUCTION: Select the closely related OTUs to pool together.
# INSTRUCTION: Pick a *keep_thresh_factor* to choose the bins associated with these OTUs that
# will be considered below.

keep_thresh_factor = 1/2
otus = ['Otu0105_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(mapping_count) AS mapping_count
      FROM library_total_nucleotides_mapping
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

#cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]
cvrg = cvrg.loc[d_rabund.index]

In [None]:
(rabund
     .rename(columns=phylotype.name)
     .groupby(extraction_meta.site)
     [otus]
     .mean())

In [None]:
# INSTRUCTION: Select a *bin_id* to help picking a seed contig.

contig_meta[lambda x: x.bin_id == 'bin00573'].sort_values('length', ascending=False).head(10)

#### Normalization

In [None]:
# INSTRUCTION: Select a *seed* contig whose coverage seems to best reflect the relative
# abundance of the OTUs of interest and find a second contig with matching coverage
# to *compare* as a double-check.
# INSTRUCTION: Select *trusted_contigs* threshold that selects the contigs used to normalize coverage.
# INSTRUCTION: Select *trusted_extractions* threshold that selects the extractions to be considered below.

seed, compare = '588300', '1679634',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, c='rabund',
            data=d_rabund[otus].sum(1).to_frame(name='rabund').join(cvrg).sort_values('rabund'),
            cmap='coolwarm')
plt.plot([0, 1e3], [0, 1e3], c='k', lw=1, scalex=False, scaley=False)
plt.yscale('symlog')
plt.xscale('symlog')

contig_thresh = 0.99
extract_thresh = 0.5

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > contig_thresh].index

trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > extract_thresh].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

#### Clustering

In [None]:
# INSTRUCTION: Select the *n_components* (usually between 10 and 100) to best resolve the contigs into groups.
# INSTRUCTION: Pick a *contam_threshold* that excludes as many groups as possible without excluding any that
# should be included.

cluster_data = np.sqrt(cvrg_norm)

bgm = BayesianGaussianMixture(n_components=30,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
_group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(_group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(_group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(_group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = _group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 50
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.min().min()-1e-1), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)


    ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index].T

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.min()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

In [None]:
# INSTRUCTION: Pick groups you want to drop from visualization.

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]
     .drop(columns=[])
    )

vmin, vmax, cmap, norm = 0, 20, cm, SymLogNorm(linthresh=1, linscale=2)

# sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
#                col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d, vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm, metric='cosine')

#### Curation

In [None]:
# INSTRUCTION: Pick curated contig groups and extractions for each MAG.

with open('chkpt/core.a.mags/B9.g.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([1, 4])].index:
        print(contig_id, file=handle)
with open('chkpt/core.a.mags/B9.g.library.list', 'w') as handle:
    extraction_ids = set(trusted_extractions)
    extraction_ids -= set([])
    for library_id in library[library.extraction_id.isin(extraction_ids)].index:
        print(library_id, file=handle)
library.join(cvrg, on='extraction_id')[trusted_contigs].sum(1).to_csv('chkpt/core.a.mags/B9.g.trusted_depth.tsv', sep='\t', header=False)