In [None]:
from itertools import chain

import pandas as pd
import sqlite3
import scipy as sp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

In [None]:
con = sqlite3.connect('res/core.1.denorm.db')

In [None]:
# Select the data

# Relative Abundance
rabund = (pd.read_sql('SELECT * FROM rrs_taxon_count;', con=con)
               .groupby(['extraction_id', 'otu_id'])
               .tally.sum()
               .unstack(fill_value=0)
               .apply(lambda x: x / x.sum(), axis=1))

# Coverage
cvrg = (pd.read_sql("""
SELECT bin_id, extraction_id, SUM(coverage) AS coverage
FROM bin_coverage
JOIN library USING (library_id)
GROUP BY bin_id, extraction_id;""",
                        con=con, index_col=['extraction_id', 'bin_id'])
              .coverage.unstack().fillna(0).apply(lambda x: x / x.sum(), axis=1))

# Only keep shared extractions
extractions = set(rabund.index) & set(cvrg.index)
rabund = rabund.loc[extractions]
otus = rabund.mean()[lambda x: x > 0.001].index
#rabund['other'] = rabund[otus].sum(1)
#rabund.drop(set(rabund.columns) - set(otus), axis='columns', inplace=True)
cvrg = cvrg.loc[extractions]

In [None]:
cvrg[cvrg>0].min().min()

In [None]:
sp.stats.pearsonr(np.log(cvrg['bin01311']), np.log(rabund['Otu0001']))

In [None]:
np.log(rabund + rabund[rabund > 0].min()).values.shape

In [None]:
rabund[rabund > 0].min().min()

In [None]:
np.concatenate([np.log(cvrg + cvrg[cvrg>0].min().min()).values,
                np.log(rabund + rabund[rabund > 0].min().min()).values], axis=1)

In [None]:
results = np.corrcoef(np.concatenate([np.log(cvrg + cvrg[cvrg>0].min().min()).values,
                                      np.log(rabund + rabund[rabund > 0].min().min()).values], axis=1).T)
corr = pd.DataFrame(results[:cvrg.shape[1], -rabund.shape[1]:],
                    index=cvrg.columns, columns=rabund.columns)
#pval = pd.DataFrame(results.pvalue[:cvrg.shape[1], -rabund.shape[1]:],
#                    index=cvrg.columns, columns=rabund.columns)

In [None]:
tax = 'Otu0001'
for a in corr[((fdrcorrection(pval[tax])[1] < 0.01) & (corr[tax] > corr[tax].max() / 2))][tax].sort_values(ascending=False).index:
    print(a, end='|')

In [None]:
mbins = defaultdict(list)
for bin in corr.index:
    mbins[corr.loc[bin].idxmax()].append(bin)

In [None]:
corr.loc[mbins['Otu0002'], 'Otu0002'].sort_values(ascending=False)

In [None]:
for bin in corr['Otu0007'].sort_values(ascending=False).head(40).index:
    print(bin, end='|')

In [None]:
tax = 'Otu0007'
a = corr.loc[mbins[tax], tax].sort_values(ascending=False)
plt.plot(a.values)
a

In [None]:
for b in a.index:
    print(b, end='|')

In [None]:
sns.clustermap(corr2, col_cluster=False,
               cmap='coolwarm', center=0, robust=True,
               figsize=(14, 18))

In [None]:
# Select abundant taxa and bins
# TODO: Set these threshold as parameters
major_taxa = taxonomy.index[taxonomy.mean_rabund > 0.0001]
major_bins = bin_cvrg.columns[bin_cvrg.mean() > 0.0001]
d_rabund = rabund[major_taxa].copy()
d_rabund['other'] = rabund.drop(columns=major_taxa).sum(1)
d_rabund.rename(columns=taxonomy.name, inplace=True)
d_cvrg = bin_cvrg[major_bins].copy()
d_cvrg['other'] = bin_cvrg.drop(columns=major_bins).sum(1)

d_rabund.shape, d_cvrg.shape

In [None]:
def crossval(X, Y, model, k, random_state=None):
    n = len(X.index) // k
    assert n > 2
    order = list(X.sample(frac=1, random_state=random_state).index)
    scores = []
    for i in range(k):
        outgroup = order[n*i:n*i+n]
        fit = model.fit(X.drop(outgroup), Y.drop(outgroup))
        scores.append(fit.score(X.loc[outgroup], Y.loc[outgroup]))
    return scores

for n_components in [10, 15, 20, 25, 30, 35, 40, 100, 500]:
    model = PLSCanonical(scale=False, n_components=n_components)
    print(n_components,
          np.mean(crossval(d_cvrg.apply(np.sqrt),
                           d_rabund.apply(np.sqrt),
                           model, 4)
                 )
         )

In [None]:
# Pick n_components

n_components = 25

In [None]:
fit = PLSCanonical(scale=False, n_components=n_components).fit(d_cvrg.apply(np.sqrt), d_rabund.apply(np.sqrt))
contrib = pd.DataFrame((fit.x_loadings_ @ fit.y_loadings_.T),
                       index=d_cvrg.columns, columns=d_rabund.columns).rename(columns=taxonomy.name)

In [None]:
#min_mean_abund = 0.008
tax_filter = lambda x: x.quantile(0.90) > 0.02

taxa_of_interest = sorted(d_rabund.loc[:, tax_filter].rename(columns=taxonomy.name).columns)
if 'other' in taxa_of_interest:
    del taxa_of_interest[taxa_of_interest.index('other')]
len(taxa_of_interest)

In [None]:
factor = 0.33

hits = {}
for tax in taxa_of_interest:
    top_score = contrib[tax].max()
    print(tax, top_score)
    hits[tax] = list((contrib[tax].sort_values(ascending=False) > top_score * factor)[lambda x: x].index)
    
print()
for tax in hits:
    print(tax, hits[tax])

all_hits = set(chain(*hits.values()))

In [None]:
a = sns.clustermap(contrib.loc[all_hits, taxa_of_interest].rename(columns=taxonomy.name), robust=True,
                   figsize=(14, 18), col_cluster=False, cmap='coolwarm', center=0)

ax = a.fig.get_axes()[2]
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)