In [None]:
from itertools import chain

import pandas as pd
import sqlite3
from sklearn.cross_decomposition import PLSCanonical
from sklearn.mixture import BayesianGaussianMixture
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from matplotlib import animation
from IPython.display import HTML

## Constants

In [None]:
color_map = {'acarbose': 'goldenrod', 'control': 'darkblue',
             'UM': 'darkblue', 'UT': 'darkgreen',
             'male': 'blue', 'female': 'magenta',
             'C2013': 'blue', 'Glenn': 'red'}

## Load Data

In [None]:
# Select the data

con = sqlite3.connect('res/core.1.denorm.db')

# Relative Abundance
rrs_count = (pd.read_sql('SELECT * FROM rrs_taxon_count;',
                         con=con, index_col=['extraction_id', 'sequence_id'])
               .tally.unstack().fillna(0).astype(int))
rabund = rrs_count.apply(lambda x: x / x.sum(), axis=1)

# Coverage
bin_cvrg = (pd.read_sql("""
SELECT bin_id, extraction_id, SUM(coverage) AS coverage
FROM bin_coverage
JOIN library USING (library_id)
GROUP BY bin_id, extraction_id;""",
                        con=con, index_col=['extraction_id', 'bin_id'])
              .coverage.unstack().fillna(0).apply(lambda x: x / x.sum(), axis=1))

# Only keep shared extractions
extractions = set(rabund.index) & set(bin_cvrg.index)
rabund = rabund.loc[extractions]
bin_cvrg = bin_cvrg.loc[extractions]

# Phylotypes
phylotype = pd.read_sql('SELECT sequence_id, otu_id FROM rrs_taxon_count GROUP BY sequence_id;',
                       con=con, index_col='sequence_id')
name_map = {}
for otu, d in (pd.DataFrame({'mean_rabund': rabund.mean(),
                             'otu_id': phylotype.otu_id})
                 .sort_values('mean_rabund',
                              ascending=False)
                 .groupby('otu_id')):
    for i, sequence_id in enumerate(d.index, start=1):
        name_map[sequence_id] = '{}_{}'.format(otu, i)
phylotype['name'] = pd.Series(name_map)
phylotype['mean_rabund'] = rabund.mean()

contig_bin = pd.read_sql("SELECT * FROM contig_bin", con=con, index_col='contig_id')

## OTU Details

In [None]:
taxonomy = pd.read_sql('SELECT sequence_id, phylum_, class_, order_, family_, genus_ FROM taxonomy;',
                       con=con, index_col='sequence_id').rename(phylotype.name)

In [None]:
# Select abundant taxa and bins
# TODO: Set these threshold as parameters
major_taxa = phylotype.index[phylotype.mean_rabund > 0.0001]
major_bins = bin_cvrg.columns[bin_cvrg.mean() > 0.0001]
d_rabund = rabund[major_taxa].copy()
d_rabund['other'] = rabund.drop(columns=major_taxa).sum(1)
d_rabund.rename(columns=phylotype.name, inplace=True)
d_cvrg = bin_cvrg[major_bins].copy()
d_cvrg['other'] = bin_cvrg.drop(columns=major_bins).sum(1)

d_rabund.shape, d_cvrg.shape

In [None]:
d_rabund.mean().to_frame(name='mean_rabund').join(taxonomy).sort_values('mean_rabund', ascending=False).loc[['Otu0058_1', 'Otu0041_1']]

In [None]:
d_rabund.mean().to_frame(name='mean_rabund').join(taxonomy).sort_values('mean_rabund', ascending=False)

## Metabinning

In [None]:
pls_fit = PLSCanonical(scale=False, n_components=40).fit(d_cvrg.apply(np.sqrt), d_rabund.apply(np.sqrt))
bin_otu_contrib = pd.DataFrame((pls_fit.x_loadings_ @ pls_fit.y_loadings_.T),
                       index=d_cvrg.columns, columns=d_rabund.columns).rename(columns=phylotype.name)

In [None]:
tax_filter = lambda x: x.quantile(0.95) > 0.01

taxa_of_interest = sorted(d_rabund.loc[:, tax_filter].rename(columns=phylotype.name).columns)
if 'other' in taxa_of_interest:
    del taxa_of_interest[taxa_of_interest.index('other')]
print(len(taxa_of_interest))



factor = 1/3

_hits = {}
for tax in taxa_of_interest:
    top_score = bin_otu_contrib[tax].max()
    print(tax, top_score)
    _hits[tax] = list((bin_otu_contrib[tax].sort_values(ascending=False) > top_score * factor)[lambda x: x].index)
    
print()
for tax in _hits:
    print(tax, _hits[tax])

all_hits = set(chain(*_hits.values()))

a = sns.clustermap(bin_otu_contrib.loc[all_hits, taxa_of_interest].rename(columns=phylotype.name), robust=True,
                   figsize=(14, 18), col_cluster=False, cmap='coolwarm', center=0)

ax = a.fig.get_axes()[2]
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

print()

## Low-diversity OTUs (sanity check)

### OTU-3

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0003_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
seed, compare = 'core-k161_1062998', 'core-k161_1089326',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(5,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 100
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-3.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 2, 3, 1, 4])].index:
        print(contig_id, file=handle)

### OTU-2

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0002_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
seed, compare = 'core-k161_578034', 'core-k161_2211409',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)



In [None]:
with open('res/core.a.mags.d/OTU-2.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0])].index:
        print(contig_id, file=handle)

## Lachnospiraceae

### OTU-15

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0015_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta.sort_values('length').tail()

In [None]:
seed, compare = 'core-k161_1048491', 'core-k161_571666', 
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e4, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))

In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

In [None]:
bgm = BayesianGaussianMixture(10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                              n_init=2,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)



In [None]:
with open('res/core.a.mags.d/OTU-15.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([5, 3, 0, 4])].index:
        print(contig_id, file=handle)

### OTU-25

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0025_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta[lambda x: x.bin_id=='bin00826'].sort_values('length').tail()

In [None]:
seed, compare = 'core-k161_461086', 'core-k161_449184', 
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e4, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))

In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(30,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)



In [None]:
with open('res/core.a.mags.d/OTU-25.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 6, 26, 15, 16, 28, 12])].index:
        print(contig_id, file=handle)

### OTU-32

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0032_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
seed, compare = 'core-k161_1443667', 'core-k161_886844', 
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e4, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.999].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))

In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40, 5)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(30,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                              n_init=2,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 3
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-32.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([16, 0, 17, 12, 3, 25, 6])].index:
        print(contig_id, file=handle)

## Other OTUs

### OTU-12 (Ruminiclostridium)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0012_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

taxonomy.loc[otus]

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta[lambda x: x.bin_id.isin(['bin00559'])].sort_values('length').tail(10)

In [None]:
seed, compare = 'core-k161_422699', 'core-k161_1821146',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2.5
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)



In [None]:
with open('res/core.a.mags.d/OTU-12.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 18, 6, 3, 7, 16, 17])].index:
        print(contig_id, file=handle)

### OTU-6 (Turicibacter)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0006_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta[lambda x: x.bin_id=='bin01024'].sort_values('length').tail(10)

In [None]:
seed, compare = 'core-k161_2698281', 'core-k161_948080'
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e4, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))

In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(30,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-6.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([3])].index:
        print(contig_id, file=handle)

### OTU-20 (Intestinimonas)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0020_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

taxonomy.loc[otus]

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
seed, compare = 'core-k161_288653', 'core-k161_535156',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage * np.log(group_cvrg.total_length)
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 30
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
group_cvrg.sort_values('contamination_score')

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
HTML(anim.to_html5_video())

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import SymLogNorm

colors = [(1, 1, 1), (0, 0.5, 0), (0, 0.8, 0)]  # R -> G -> B
n_bins = [3, 6, 10, 100]  # Discretizes the interpolation into bins
cmap_name = 'custom1'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

d = (cvrg_norm.groupby(group_assign.group, axis='columns').median()
              .loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index])

vmin, vmax, cmap, norm = 0, 8, cm, SymLogNorm(linthresh=1, linscale=0.9)

sns.clustermap(d[extraction_meta.site == 'UM'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)
sns.clustermap(d[extraction_meta.site == 'UT'], vmin=vmin, vmax=vmax,
               col_cluster=False, robust=True, cmap=cmap, norm=norm)


In [None]:
#with open('res/core.a.mags.d/OTU-20-UM.contigs.list', 'w') as handle:
#    for contig_id in group_assign[lambda x: x.group.isin([0, 7, 17, 15, 6])].index:
#        print(contig_id, file=handle)
#        
#with open('res/core.a.mags.d/OTU-20-UT.contigs.list', 'w') as handle:
#    for contig_id in group_assign[lambda x: x.group.isin([0, 7, 11, 15, 6])].index:
#        print(contig_id, file=handle)

with open('res/core.a.mags.d/OTU-20.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([0, 6, 7, 8, 11, 15, 17])].index:
        print(contig_id, file=handle)
        
# Groups 11 and 17 seem to differentiate the two sites pretty well.

### OTU-35	(Ruminococcaceae_UCG-014)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0035_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

taxonomy.loc[otus]

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta[lambda x: x.bin_id=='bin00425'].sort_values('length').tail()

In [None]:
seed, compare = 'core-k161_382136', 'core-k161_2295479',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(10,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-35.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([1])].index:
        print(contig_id, file=handle)

### OTU-58 (Mollicutes sp.)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0058_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)

In [None]:
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
seed, compare = 'core-k161_367789', 'core-k161_1175581',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < 1].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score > 1].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < 1].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-58.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([12, 11, 7])].index:
        print(contig_id, file=handle)

### OTU-41 (Bacteroides)

In [None]:
keep_thresh_factor = 1/3
otus = ['Otu0041_1']
bins = set()
for otu in otus:
    max_contrib = bin_otu_contrib[otu].max()
    bins |= set(bin_otu_contrib[otu][lambda x: x > max_contrib * keep_thresh_factor].index)
    
print(bins)
contig_ids = set(contig_bin[lambda x: x.bin_id.isin(bins)].index)

taxonomy.loc[otus]

In [None]:
contig_ids_sql = '"' + '", "'.join(contig_ids) + '"'

cvrg = pd.read_sql("""
SELECT extraction_id, contig_id, SUM(coverage) AS coverage
FROM contig_coverage
JOIN library USING (library_id)
WHERE contig_id IN ({})
GROUP BY extraction_id, contig_id
                   """.format(contig_ids_sql), con=con,
                   index_col=['extraction_id', 'contig_id']).coverage.unstack('contig_id', fill_value=0)

extraction_meta = pd.read_sql("""
SELECT *
FROM extraction
JOIN sample USING (sample_id)
JOIN mouse USING (mouse_id)
JOIN (SELECT extraction_id, SUM(coverage) AS coverage
      FROM library_total_coverage
      JOIN library USING (library_id)
      GROUP BY extraction_id) USING (extraction_id)
                               """, con=con, index_col='extraction_id')

contig_meta = pd.read_sql("""
SELECT *
FROM contig_bin
JOIN contig USING (contig_id)
WHERE contig_id IN ({})
                          """.format(contig_ids_sql),
                         con=con, index_col='contig_id')

cvrg = cvrg.div(extraction_meta.coverage, axis=0).loc[d_rabund.index]

In [None]:
contig_meta[lambda x: x.bin_id=='bin01631'].sort_values('length').tail()

In [None]:
seed, compare = 'core-k161_514932', 'core-k161_2549350',
assert seed in contig_ids
assert compare in contig_ids
plt.scatter(seed, compare, data=cvrg*1e5, c=d_rabund[otus].sum(1).loc[cvrg.index], cmap='coolwarm')
plt.plot([-1e3, 1e3], [-1e3, 1e3], c='k', lw=1, scalex=False, scaley=False)

trusted_contigs = cvrg.apply(lambda x: sp.stats.pearsonr(cvrg[seed], x)[0])[lambda x: x > 0.99].index
trusted_extractions = (cvrg[trusted_contigs].mean(1) / cvrg[trusted_contigs].std(1))[lambda x: x > 0.5].index

print('{} trusted contigs and {} trusted extractions identified'.format(len(trusted_contigs), len(trusted_extractions)))


In [None]:
cvrg_norm = cvrg.div(cvrg[trusted_contigs].mean(1), axis=0)
#_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='with_untrusted', alpha=0.8)
cvrg_norm = cvrg_norm.loc[trusted_extractions]
_ = plt.hist(np.log(cvrg_norm.mean()), bins=np.linspace(-4, 6), label='without_untrusted', alpha=0.8)
plt.legend()

In [None]:
cluster_data = np.sqrt(cvrg_norm)

nn = range(1, 40)
scores = []
for n in nn:
    score = BayesianGaussianMixture(n,
                                  covariance_type='diag',
    #                              weight_concentration_prior_type='dirichlet_distribution',
    #                              weight_concentration_prior=10,
                                  random_state=1,
                                 ).fit(cluster_data.T).score(cluster_data.T)
    scores.append(score)
plt.plot(nn, scores)

In [None]:
bgm = BayesianGaussianMixture(20,
                              covariance_type='diag',
#                              weight_concentration_prior_type='dirichlet_distribution',
#                              weight_concentration_prior=10,
                              random_state=1,
                             ).fit(cluster_data.T)
group_assign = pd.Series(bgm.predict(cluster_data.T), index=cvrg_norm.columns)
group_cvrg = cvrg_norm.groupby(group_assign, axis='columns').mean().mean().to_frame(name='group_mean_mean_coverage')
group_cvrg['group_mean_std_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').std().mean()
group_cvrg['group_std_mean_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').mean().std()
group_cvrg['group_max_coverage'] = cvrg_norm.groupby(group_assign, axis='columns').max().max()
group_cvrg['contamination_score'] = group_cvrg.group_std_mean_coverage / group_cvrg.group_mean_std_coverage
group_cvrg['total_length'] = contig_meta.groupby(group_assign).length.sum()
group_cvrg.index.name = 'group'
group_assign = group_assign.to_frame(name='group').join(group_cvrg, on='group')
group_assign['bin_id'] = contig_meta.bin_id
group_assign['length'] = contig_meta.length
group_assign.sort_values(['contamination_score', 'length'], ascending=[True, False], inplace=True)
# order = group_assign.index

fig, ax = plt.subplots(figsize=(15, 5))
ax.axhline(y=1, color='k', linestyle='--')

for des, d in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
#    color = None
    _ = ax.plot(d[group_assign.index].values.T, lw=1, alpha=0.25, color=color)
#_ = ax.plot(group_assign.group_mean_coverage.values, color='k')

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)

group_assign['contig_index'] = range(group_assign.shape[0])
group_order = \
    (group_assign.groupby('group').contig_index
                         .apply(lambda x: pd.Series({'middle': x.mean(),
                                                     'left': x.min(),
                                                     'right': x.max()}))).unstack().sort_values('left')
contam_threshold = 2
for inx, d in group_order.iterrows():
    if group_cvrg.loc[inx].contamination_score > contam_threshold:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    elif group_cvrg.loc[inx].isna().contamination_score:
        ax.axvline(d.left - 0.5, color='r', lw=0.5)
        continue
    else:
        ax.axvline(d.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d.middle, cvrg_norm.max().max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg.loc[inx].total_length),
                xy=(d.middle, cvrg_norm.max().max() * 0.5),
                ha='center', rotation=-90)

ax.set_yscale('symlog', linthreshy=1)

In [None]:
a = (group_assign
         [lambda x: x.group.isin(group_cvrg[lambda x: x.contamination_score < contam_threshold].index)]
         .groupby(['bin_id', 'group']).length.sum().unstack(fill_value=0))
b = (group_assign
                   [lambda x: x.group.isin(group_cvrg[lambda x: ( x.contamination_score > contam_threshold)
                                                                | x.contamination_score.isna()
                                                     ].index)]
                   .groupby('bin_id').length.sum())
b.name = 'contam'
a.join(b, how='outer').fillna(0).astype(int)

In [None]:
print(cvrg_norm.groupby([extraction_meta.site, extraction_meta.treatment]).count().iloc[:,0])

(cvrg_norm.groupby(group_assign.group, axis='columns').mean()
          .groupby([extraction_meta.site, extraction_meta.treatment]).mean()).loc[:, group_cvrg[lambda x: x.contamination_score < contam_threshold].index]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

ax.axhline(y=1, color='k', linestyle='--')
artists = []
plotting_order = []
for des, d0 in cvrg_norm.groupby(extraction_meta.site):
    color = color_map[des]
    des_artists = ax.plot(d0[group_assign[lambda x: x.contamination_score < contam_threshold].index].values.T,
                          lw=1, alpha=0.1, color=color)
    artists.extend(des_artists)
    plotting_order.extend(d0.index)
#original_colors = {a: a.get_color() for a in artists}
original_lw = {a: a.get_linewidth() for a in artists}
original_alpha = {a: a.get_alpha() for a in artists}
original_zorder = {a: a.get_zorder() for a in artists}
otu_rabund = [d_rabund.loc[extraction_id][otus].sum() for extraction_id in plotting_order]

group_cvrg_included = group_cvrg.loc[group_order.index][group_cvrg.contamination_score < contam_threshold]
group_order_included = group_order.loc[group_cvrg_included.index]

for inx, d1 in group_order_included.iterrows():
    ax.axvline(d1.left - 0.5, color='k', lw=1, linestyle='--')
    ax.annotate('({})'.format(inx), xy=(d1.middle, group_cvrg_included.group_max_coverage.max()), ha='center')
    ax.annotate('{:0.02}'.format(group_cvrg_included.loc[inx].total_length),
                xy=(d1.middle, group_cvrg_included.group_max_coverage.max() * 0.5),
                ha='center', rotation=-90)

annot = ax.annotate('', xy=(0.02, 0.8), xycoords="axes fraction", rotation=90)
ax.set_yscale('symlog', linthreshy=1)
fig.tight_layout()

def _init():
    return artists

def _animate(i):
    j = i - 1
    artists[i].set_linewidth(1)
    artists[i].set_alpha(0.9)
    artists[i].set_zorder(999)
    artists[j].set_linewidth(original_lw[artists[j]])
    artists[j].set_alpha(original_alpha[artists[j]])
    artists[j].set_zorder(original_zorder[artists[j]])
    annot.set_text('{} ({:0.1f}%)'.format(plotting_order[i], otu_rabund[i]*100))
    return [artists[i], artists[j], annot]

anim = animation.FuncAnimation(fig, _animate, init_func=_init,
                               frames=cvrg_norm.shape[0], interval=200, blit=True)

In [None]:
with open('res/core.a.mags.d/OTU-41.contigs.list', 'w') as handle:
    for contig_id in group_assign[lambda x: x.group.isin([7, 17])].index:
        print(contig_id, file=handle)

## Do any MAGs share contigs?

In [None]:
from glob import glob

alldata = []
for filepath in glob('res/core.a.mags.d/*.contigs.list'):
    otu = filepath.split('/')[-1].split('.')[0]
    df = pd.read_table(filepath, names=['contig_id'])
    df['otu_id'] = otu
    alldata.append(df)
all_contigs = pd.concat(alldata)
all_contigs['present'] = 1
all_contigs = all_contigs.set_index(['contig_id', 'otu_id'])

all_contigs.unstack(fill_value=0)[lambda x: x.sum(1) > 1].sum()
# Only OTU-1-UM and OTU-1-UT share any contigs

In [None]:
(all_contigs.present.unstack(fill_value=0)
            [lambda x: x.sum(1) > 1][lambda x: x['OTU-49'] > 0]
            [['OTU-1-UM', 'OTU-1-UT', 'OTU-49']]
)

In [None]:
contig_bin.loc[['core-k161_1003382', 'core-k161_90922']]