In [None]:
from itertools import chain

import pandas as pd
import sqlite3
from sklearn.cross_decomposition import PLSCanonical
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# TODO: The ultimate algorithm

1.  Select the seq_rabund and bin_cvrg data
1.  Normalize each matrix to sum to 1 across samples
1.  Cross-validation on PLSCanonical to pick n_components
1.  Fit PLSCanonical
1.  Calculate contribution scores (x_loadings @ y_loadings)
1.  Identify and drop Bins for which many seqs are correlated
1.  Re-normalize, Re-CV, and re-fit PLSCanonical with reduced data?
1.  For each phylotype of interest (top n seqs?), fit a t-distribution to the contribution scores
1.  Pick the bins so as to stay under a false-positive threshold

In [None]:
# Select the data

con = sqlite3.connect('res/core.1.denorm.db')

# Relative Abundance
rrs_count = (pd.read_sql('SELECT * FROM rrs_taxon_count;',
                         con=con, index_col=['extraction_id', 'sequence_id'])
               .tally.unstack().fillna(0).astype(int))
rabund = rrs_count.apply(lambda x: x / x.sum(), axis=1)

# Coverage
bin_cvrg = (pd.read_sql("""
SELECT bin_id, extraction_id, SUM(coverage) AS coverage
FROM bin_coverage
JOIN library USING (library_id)
GROUP BY bin_id, extraction_id;""",
                        con=con, index_col=['extraction_id', 'bin_id'])
              .coverage.unstack().fillna(0).apply(lambda x: x / x.sum(), axis=1))

# Only keep shared extractions
extractions = set(rabund.index) & set(bin_cvrg.index)
rabund = rabund.loc[extractions]
bin_cvrg = bin_cvrg.loc[extractions]

# Taxonomy
taxonomy = pd.read_sql('SELECT sequence_id, otu_id FROM rrs_taxon_count GROUP BY sequence_id;',
                       con=con, index_col='sequence_id')
name_map = {}
for otu, d in (pd.DataFrame({'mean_rabund': rabund.mean(),
                             'otu_id': taxonomy.otu_id})
                 .sort_values('mean_rabund',
                              ascending=False)
                 .groupby('otu_id')):
    for i, sequence_id in enumerate(d.index, start=1):
        name_map[sequence_id] = '{}_{}'.format(otu, i)
taxonomy['name'] = pd.Series(name_map)
taxonomy['mean_rabund'] = rabund.mean()


# Select abundant taxa and bins
# TODO: Set these threshold as parameters
major_taxa = taxonomy.index[taxonomy.mean_rabund > 0.0001]
major_bins = bin_cvrg.columns[bin_cvrg.mean() > 0.0001]
d_rabund = rabund[major_taxa].copy()
d_rabund['other'] = rabund.drop(columns=major_taxa).sum(1)
d_rabund.rename(columns=taxonomy.name, inplace=True)
d_cvrg = bin_cvrg[major_bins].copy()
d_cvrg['other'] = bin_cvrg.drop(columns=major_bins).sum(1)

d_rabund.shape, d_cvrg.shape

In [None]:
def crossval(X, Y, model, k, random_state=None):
    n = len(X.index) // k
    assert n > 2
    order = list(X.sample(frac=1, random_state=random_state).index)
    scores = []
    for i in range(k):
        outgroup = order[n*i:n*i+n]
        fit = model.fit(X.drop(outgroup), Y.drop(outgroup))
        scores.append(fit.score(X.loc[outgroup], Y.loc[outgroup]))
    return scores

for n_components in [10, 15, 20, 25, 30, 35, 40, 100, 500]:
    model = PLSCanonical(scale=False, n_components=n_components)
    print(n_components,
          np.mean(crossval(d_cvrg.apply(np.sqrt),
                           d_rabund.apply(np.sqrt),
                           model, 4)
                 )
         )

In [None]:
# Pick n_components

n_components = 25

In [None]:
fit = PLSCanonical(scale=False, n_components=n_components).fit(d_cvrg.apply(np.sqrt), d_rabund.apply(np.sqrt))
contrib = pd.DataFrame((fit.x_loadings_ @ fit.y_loadings_.T),
                       index=d_cvrg.columns, columns=d_rabund.columns).rename(columns=taxonomy.name)

In [None]:
#min_mean_abund = 0.008
tax_filter = lambda x: x.quantile(0.90) > 0.02

taxa_of_interest = sorted(d_rabund.loc[:, tax_filter].rename(columns=taxonomy.name).columns)
if 'other' in taxa_of_interest:
    del taxa_of_interest[taxa_of_interest.index('other')]
len(taxa_of_interest)

In [None]:
factor = 0.5

hits = {}
for tax in taxa_of_interest:
    top_score = contrib[tax].max()
    print(tax, top_score)
    hits[tax] = list((contrib[tax].sort_values(ascending=False) > top_score * factor)[lambda x: x].index)
    
print()
for tax in hits:
    print(tax, hits[tax])

all_hits = set(chain(*hits.values()))

In [None]:
a = sns.clustermap(contrib.loc[all_hits, taxa_of_interest].rename(columns=taxonomy.name), robust=True,
                   figsize=(14, 18), col_cluster=False, cmap='coolwarm', center=0)

ax = a.fig.get_axes()[2]
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

plt.savefig('heatmap.pdf')