# Notebook Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import statsmodels.formula.api as sm
import sqlite3
import seaborn as sns
import patsy
from sklearn.decomposition import PCA
from lifelines import KaplanMeierFitter
from matplotlib.ticker import StrMethodFormatter
from statsmodels.stats.multitest import fdrcorrection
import itertools

import matplotlib as mpl

import rpy2.ipython
%load_ext rpy2.ipython.rmagic

from scripts.lib.stats import raise_low, lrt_phreg, phreg_aic, mannwhitneyu
from scripts.lib.plotting import boxplot_with_points, load_style, residuals_plot
from skbio.diversity.alpha import chao1, simpson_e
from skbio.stats import subsample_counts
from skbio import DistanceMatrix
from skbio.stats.ordination import pcoa

concat = lambda list_of_lists: list(itertools.chain(*list_of_lists))
richness = lambda x: (x > 0).sum()

In [None]:
loaded_style = load_style('paper')

color_map = loaded_style['color_map']
mark_map = loaded_style['mark_map']
assign_significance_symbol = loaded_style['assign_significance_symbol']
savefig = loaded_style['savefig']
fullwidth = loaded_style['fullwidth']
halfwidth = loaded_style['halfwidth']

In [None]:
from scripts.lib.data import load_data
loaded_data = load_data('res/C2013.results.db')
gl = globals()
gl.update(loaded_data)

print(loaded_data.keys())

In [None]:
abund.shape

# Study population

## Longevity

### C2013 Cohort

## Dropped Samples

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(abund, how='left')
            .join(conc, how='left')
       )

print("rrs Data:")
print(data.dropna(subset=['Otu0001']).groupby(['site', 'sex', 'treatment']).cohort.count())
print()
print("Conc. Data:")
print(data.dropna(subset=['butyrate']).groupby(['site', 'sex', 'treatment']).cohort.count())

## Pellet Properties

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund, how='left')
            .dropna(subset=['Otu0001'])
       )

print(data.groupby(['treatment']).sample_weight.median())
print(mannwhitneyu('treatment', 'sample_weight', data))

fig, ax = plt.subplots()
boxplot_with_points('sex', 'hydration_factor', 'treatment', data=data, ax=ax)
fig, ax = plt.subplots()
boxplot_with_points('sex', 'sample_weight', 'treatment', data=data, ax=ax)
fig, ax = plt.subplots()
boxplot_with_points('sex', 'sample_hydrated_weight', 'treatment', data=data, ax=ax)

## Amplicon Properties

In [None]:
!cat res/C2013.rrs.procd.clust.reps.length.dist.tsv

# Community Survey

In [None]:
data = pd.read_sql("""
    SELECT mouse_id, SUM(tally) AS total_count
    FROM _rrs_library_taxon_count
    JOIN rrs_library_metadata USING (rrs_library_id)
    WHERE cohort = 'C2013'
      AND taxon_level = 'otu-0.03'
    GROUP BY mouse_id
                   """, con=con, index_col=['mouse_id'])

data.median(), data.min()

## PCoA

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund, how='left')
            .dropna(subset=['Otu0001'])
       )

taxa = list(rabund.loc[data.index].sum().sort_values(ascending=False).index)
tax_selector = lambda x: (x.mean() > 0.00000001) & ((x > 0).mean() > 0.00000005)


# Community PCOA
dmat = DistanceMatrix(sp.spatial.distance.pdist(data[taxa].loc[:,tax_selector],
                                                metric='braycurtis'),
                      ids=data.index)
p = pcoa(dmat)
coords = p.samples

fig, axs = plt.subplots(1, 3, figsize=(fullwidth, 2.1), sharey=True, sharex=True)


for ax, (factorA, factorB) in zip(axs, [('treatment', 'treatment'),
                                        ('site', 'treatment'),
                                        ('sex', 'treatment')]):
    dummy_artists = {}
    for groupA, d0 in coords.groupby(data[factorA]):
        color = color_map[groupA]
        for groupB, d1 in d0.groupby(data[factorB]):
            marker = mark_map[groupB]
            ax.scatter(d1.PC1, d1.PC2, c=color,
                       marker=marker, s=20, lw=0.3, edgecolor='k', label='_nolegend_')
        dummy_artists[groupA] = ax.scatter([], [], marker='s', color=color)
    ax.legend(dummy_artists, frameon=True, fontsize=6)

# Add panel letters   
for panel, ax in zip(['A', 'B', 'C'], axs):
    ax.set_aspect('equal', adjustable='datalim')
    ax.annotate(panel, xy=(0.02, 1.03), xycoords='axes fraction', fontweight='heavy')
    ax.set_xlabel('PCo1 ({:.0%})'.format(p.proportion_explained[0]))
    ax.tick_params(axis='both', labelsize=6)
    

axs[0].set_ylabel('PCo2 ({:.0%})'.format(p.proportion_explained[1]))
ax.set_yticks([-0.3, 0, 0.3])
ax.set_xticks([-0.3, 0, 0.3])

#fig.tight_layout()
fig.subplots_adjust(wspace=0.05)


savefig(fig, 'fig/comm_pcoa')

## PERMANOVA

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund, how='left')
            .dropna(subset=['Otu0001'])
       )

taxa = list(rabund.loc[data.index].sum().sort_values(ascending=False).index)
tax_selector = lambda x: (x.mean() > 0.0001) & ((x > 0).mean() > 0.05)

y = data[taxa].loc[:,tax_selector].apply(lambda x: x/x.sum(), axis=1)
meta = data[['cohort', 'sex', 'site', 'treatment']]

In [None]:
%%R -i y -i meta
library(vegan)

fit = adonis2(y ~ treatment * site * sex, data=meta)
print(fit)

print(fit$SumOfSqs / sum(fit$SumOfSqs))

In [None]:
%%R
d = vegdist(y, method='bray')
disp = betadisper(d, paste(meta$sex, meta$site, meta$treatment))
print(disp)
anova(disp)


In [None]:
%%R
disp = betadisper(d, paste(meta$site))
print(disp)
anova(disp)

In [None]:
%%R
disp = betadisper(d, paste(meta$sex))
print(disp)
anova(disp)

In [None]:
%%R
disp = betadisper(d, paste(meta$treatment))
print(disp)
anova(disp)

## Family Level Abundance

In [None]:
data = mouse[lambda x: (x.cohort == 'C2013')
                     & (x.treatment == 'control')
            ].join(rabund_family).dropna(subset=['Muribaculaceae'])

data[families].median().sort_values(ascending=False).head(20)

In [None]:
data = mouse[lambda x: (x.cohort == 'C2013')].join(rabund_family).dropna(subset=['Muribaculaceae'])
(data[families].sum() / data[families].sum().sum())['unclassified']

In [None]:
data = mouse[lambda x: x.cohort == 'C2013'].join(rabund_family, how='inner')

out = data.groupby(['treatment'])[families].quantile([0.25, 0.5, 0.75]).transpose()

pvalues = {}
for f in families:
    if data[f].sum() != 0:
        pvalues[f] = mannwhitneyu('treatment', f, data).pvalue
out['p'] = pd.Series(pvalues)
out['sig'] = out['p'].apply(assign_significance_symbol)

out = out.sort_values(('control', 0.5), ascending=False).head(5)

out[('control', 'pretty')] = out['control'].apply(lambda x: '{1:.1%} ({0:.1%}, {2:.1%})'.format(*x), axis=1)
out[('acarbose', 'pretty')] = out['acarbose'].apply(lambda x: '{1:.1%} ({0:.1%}, {2:.1%})'.format(*x), axis=1)

print(out[['p', 'sig']])
out[[('control', 'pretty'), ('acarbose', 'pretty')]]#[['fold_change']]

In [None]:
data = mouse[lambda x: x.cohort == 'C2013'].join(abund_family, how='inner')

out = data.groupby(['treatment'])[families].quantile([0.25, 0.5, 0.75]).transpose()

pvalues = {}
for f in families:
    if data[f].sum() != 0:
        pvalues[f] = mannwhitneyu('treatment', f, data).pvalue
out['p'] = pd.Series(pvalues)
out['sig'] = out['p'].apply(assign_significance_symbol)

#out.sort_values(('control', 0.5), ascending=False).head(5)

out = out.sort_values(('control', 0.5), ascending=False).head(5)

out['fold_change'] = out.apply(lambda x: '{:.2}{}'.format(x[('acarbose', 0.5)] / x[('control', 0.5)],
                                                          x.sig.values[0]), axis=1)

print(out['p'])
out[['fold_change']]

## OTU-1 / OTU-4

In [None]:
feats = ['Otu0001', 'Otu0004']
names = ['OTU-1', 'OTU-4']

data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund, how='left')
            .dropna(subset=feats)
       )
data['zorder'] = np.random.choice(np.arange(5, 10), size=len(data))


fig, axs = plt.subplots(1, 2, figsize=(halfwidth, 2.5), sharex=True, sharey=True)


for feat, name, ax in zip(feats, names, axs):
#    boxplot_with_points('site', feat, 'treatment', data=data, palette=color_map,
#                        points_kwargs={'jitter': True, 's': 3.5, 'lw': 0.5},
#                        dist_kwargs={'whis': 0, 'color': 'white', 'palette': None, 'linewidth': 1},
#                        ax=ax)
    
    sns.boxplot('site', y=feat, hue='treatment', data=data,
                hue_order=['control', 'acarbose'],
                palette={'control': 'white', 'acarbose': 'white'},
                whis=0, linewidth=1, showfliers=False, saturation=1,
                ax=ax)
    for (sex, zorder), d in data.groupby(['sex', 'zorder']):
        sns.stripplot('site', y=feat, hue='treatment', data=d,
                      hue_order=['control', 'acarbose'],
                      palette=color_map, marker=mark_map[sex],
                      jitter=True, split=True, linewidth=0.5,
                      s=3, alpha=0.7, zorder=zorder,
                      ax=ax)

    ax.set_ylim(-0.05, 0.85)
    ypos = 0.78
    for xpos, site in enumerate(data.site.unique()):
        print(feat, site)
        print(data[(data.site == site)].groupby('treatment')[feat].median())
        pvalue = mannwhitneyu('treatment', feat, data[data.site == site]).pvalue
        print('pvalue: ', pvalue)
        print()
        symbol = assign_significance_symbol(pvalue)
        if symbol == '†':
            continue
#        if symbol:
#            ax.hlines(y=ypos, xmin=xpos - 0.3, xmax=xpos + 0.3, lw=1)
        ax.annotate(symbol, xy=(xpos, ypos), ha='center', fontweight='bold')
    print()




    ax.set_xlabel('')
    ax.set_xticks(concat([x - 0.2, x + 0.2] for x in ax.xaxis.get_ticklocs()),
                  minor=True)
    ax.set_xticklabels(['−', '+'] * len(ax.xaxis.get_ticklocs()), minor=True,
                       position=(0, 0.01))

    ax.set_xticklabels(ax.get_xticklabels(), position=(0, -0.02))
    
    ax.legend([])
    ax.set_ylabel('')
#    ax.set_title(name)
    ax.set_yticklabels('{:2.0f}'.format(x * 100) for x in ax.get_yticks()[:-1])
#    ax.set_xlim(-1, 3)
    ax.tick_params(axis='x', which='both', bottom='off', top='off')
    ax.set_title(name)

    
axs[0].set_ylabel('Relative Abundance (%)')
#for panel, ax in zip(['A', 'B'], axs):
#    ax.annotate(panel, xy=(-0.02, 1.03), xycoords='axes fraction', fontweight='heavy')
    
#fig.tight_layout()
fig.subplots_adjust(wspace=0.05)

savefig(fig, 'fig/dominant_otus_box')

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(halfwidth, 2.5), sharex=True, sharey=True)


for feat, name, ax in zip(feats, names, axs):
#    boxplot_with_points('site', feat, 'treatment', data=data, palette=color_map,
#                        points_kwargs={'jitter': True, 's': 3.5, 'lw': 0.5},
#                        dist_kwargs={'whis': 0, 'color': 'white', 'palette': None, 'linewidth': 1},
#                        ax=ax)
    
    sns.boxplot('site', y=feat, hue='sex', data=data,
                hue_order=['male', 'female'],
                palette={'male': 'white', 'female': 'white'},
                whis=0, linewidth=1, showfliers=False, saturation=1,
                ax=ax)
    for (treatment, zorder), d in data.groupby(['treatment', 'zorder']):
        sns.stripplot('site', y=feat, hue='sex', data=d,
                      hue_order=['male', 'female'],
                      palette=color_map, marker=mark_map[sex],
                      jitter=True, split=True, linewidth=0.5,
                      s=3, alpha=0.7, zorder=zorder,
                      ax=ax)

    ax.set_ylim(-0.05, 0.85)
    ypos = 0.78
    for xpos, site in enumerate(data.site.unique()):
        print(feat, site)
        print(data[(data.site == site)].groupby('sex')[feat].median())
        pvalue = mannwhitneyu('sex', feat, data[data.site == site]).pvalue
        print('pvalue: ', pvalue)
        print()
        symbol = assign_significance_symbol(pvalue)
#        if symbol == '†':
#            continue
#        if symbol:
#            ax.hlines(y=ypos, xmin=xpos - 0.3, xmax=xpos + 0.3, lw=1)
        ax.annotate(symbol, xy=(xpos, ypos), ha='center', fontweight='bold')
    print()




    ax.set_xlabel('')
    ax.set_xticks(concat([x - 0.2, x + 0.2] for x in ax.xaxis.get_ticklocs()),
                  minor=True)
    ax.set_xticklabels(['−', '+'] * len(ax.xaxis.get_ticklocs()), minor=True,
                       position=(0, 0.01))

    ax.set_xticklabels(ax.get_xticklabels(), position=(0, -0.02))
    
    ax.legend([])
    ax.set_ylabel('')
#    ax.set_title(name)
    ax.set_yticklabels('{:2.0f}'.format(x * 100) for x in ax.get_yticks()[:-1])
#    ax.set_xlim(-1, 3)
    ax.tick_params(axis='x', which='both', bottom='off', top='off')
    ax.set_title(name)

    
axs[0].set_ylabel('Relative Abundance (%)')
#for panel, ax in zip(['A', 'B'], axs):
#    ax.annotate(panel, xy=(-0.02, 1.03), xycoords='axes fraction', fontweight='heavy')

In [None]:
feat = 'combined_otu1_otu4'

data[feat] = data.Otu0001 + data.Otu0004

boxplot_with_points('treatment', feat, hue='sex', data=data)

for site in data.site.unique():
    print(feat, site)
    print(data[(data.site == site)].groupby('sex')[feat].median())
    pvalue = mannwhitneyu('sex', feat, data[data.site == site]).pvalue
    print('pvalue: ', pvalue)
    print()

In [None]:
data.groupby(['site', 'treatment']).Otu0004.median()

In [None]:
feats = ['Otu0001', 'Otu0004']

data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund, how='left')
            .dropna(subset=feats)
       )

out = data.groupby(['site'])[feats].agg({'found_in': lambda x: (x > 0.001).sum(), 'out_of': lambda x: len(x)})
out.columns = out.columns.reorder_levels([1, 0])
out.sort_index(axis='columns')

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(abund)
            .join(abund_family)
            .dropna(subset=['Otu0001'])
       )

data['otus_1_4'] = data.Otu0001 + data.Otu0004
data['muribac_drop_1_4'] = data.Muribaculaceae - data.otus_1_4
data['non_muribac'] = data.dens - data.Muribaculaceae

print('Otus 1, 4 (abund)')
print(mannwhitneyu('treatment', 'otus_1_4', data))
median_density = data.groupby('treatment').otus_1_4.median()
print(median_density['acarbose'] / median_density['control'])
print()
print('Not 1, 4 (abund)')
print(mannwhitneyu('treatment', 'muribac_drop_1_4', data))
median_density = data.groupby('treatment').muribac_drop_1_4.median()
print(median_density['acarbose'] / median_density['control'])
print()
print('Non-Muri. OTUs (abund)')
print(mannwhitneyu('treatment', 'non_muribac', data))
median_density = data.groupby('treatment').non_muribac.median()
print(median_density['acarbose'] / median_density['control'])

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund)
            .join(rabund_family)
            .dropna(subset=['Otu0001'])
       )

data['otus_1_4'] = data.Otu0001 + data.Otu0004
data['muribac_drop_1_4'] = data.Muribaculaceae - data.otus_1_4

print('Not OTUs 1, 4 (rabund)')
print(mannwhitneyu('treatment', 'muribac_drop_1_4', data))
median_rabund= data.groupby('treatment').muribac_drop_1_4.median()
print(median_rabund)

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(rabund)
            .join(rabund_family)
            .dropna(subset=['Otu0001'])
       )

data['otus_1_4'] = data.Otu0001 + data.Otu0004
data['muribac_drop_1_4'] = data.Muribaculaceae - data.otus_1_4
data['otus_1_4_muribac_frac'] = data.otus_1_4 / data.Muribaculaceae


print('OTUs 1, 4 fraction')
print(mannwhitneyu('treatment', 'otus_1_4_muribac_frac', data))
median_rabund= data.groupby('treatment').otus_1_4_muribac_frac.median()
print(median_rabund)

## Alpha Diversity

In [None]:
data = mouse[lambda x: x.cohort == 'C2013'].join(count, how='inner').dropna(subset=['Otu0001'])
taxa = count.columns

min_n = data[taxa].sum(axis=1).min()
min_n_drop_otu1_4 = data[taxa].drop(['Otu0001', 'Otu0004'], axis='columns').sum(axis=1).min()

metrics = {'simpson_e': simpson_e, 'chao1': chao1, 'richness': richness}

for m in metrics:
    data[m] = (data[taxa]
                        .apply(lambda x: subsample_counts(x, min_n), axis='columns')
                        .apply(metrics[m], axis='columns'))
    data[m + '_drop'] = (data[taxa]
                                  .drop(['Otu0001', 'Otu0004'], axis='columns')
                                  .apply(lambda x: subsample_counts(x, min_n_drop_otu1_4), axis='columns')
                                  .apply(metrics[m], axis='columns'))
    
    print('{} (pooled):'.format(m), mannwhitneyu('treatment', m, data).pvalue)
    for site, d in data.groupby('site'):
        print('{} ({}):'.format(m, site), mannwhitneyu('treatment', m, d).pvalue)
    # Drop
    print('{} (pooled):'.format(m + '_drop'), mannwhitneyu('treatment', m + '_drop', data).pvalue)
    for site, d in data.groupby('site'):
        print('{} ({}):'.format(m + '_drop', site), mannwhitneyu('treatment', m + '_drop', d).pvalue)

data.groupby(['treatment'])[list(metrics) + [m + '_drop' for m in metrics]].agg([np.mean, np.std]).T

In [None]:
!grep '\(Otu0001\|Otu0004\)' res/C2013.rrs.procd.clust.reps.length.tsv

# HPLC
## Treatment Effects

In [None]:
fig.savefig?

In [None]:
data = (mouse[lambda x: (x.cohort == 'C2013')]
             .join(conc.unstack().to_frame(name='concentration')
                       .reset_index(level='molecule_id')
                  ).dropna(subset=['concentration']))
data['zorder'] = np.random.choice(np.arange(5, 10), size=len(data))

mols = ['glucose', 'lactate', 'succinate',
        'acetate', 'butyrate', 'propionate',
        'total_scfa'
       ]

fig, ax = plt.subplots(figsize = (halfwidth, 2.5))
sns.boxplot('molecule_id', y='concentration', hue='treatment', data=data,
            order=mols, hue_order=['control', 'acarbose'], palette={'control': 'white', 'acarbose': 'white'},
            whis=0, linewidth=1, ax=ax, showfliers=False, saturation=1)
for (sex, zorder), d in data.groupby(['sex', 'zorder']):
    sns.stripplot('molecule_id', y='concentration', hue='treatment', data=d,
                  order=mols, hue_order=['control', 'acarbose'],
                  palette=color_map,
                  marker='o',
                  jitter=True, split=True, linewidth=0.5,
                  s=3, alpha=0.5, zorder=zorder,
                  ax=ax)
plt.legend([])
ax.set_xlabel('')
ax.set_ylabel('conc. (mmols/kg)')

thresh = 0.1
ax.set_yscale('symlog', linthreshy=thresh, linscaley=0.2)
ax.set_ylim(bottom=-thresh, top=300)

minor_yticks = np.concatenate([np.arange(0, 1.1, step=0.1),
                               np.arange(0, 11, step=1)[1:],
                               np.arange(0, 110, step=10)[1:]
                              ])
ax.set_yticks(minor_yticks, minor=True)
ax.yaxis.set_tick_params(which='minor', size=1)
ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}'))

for xpos, m in enumerate(mols):
    ypos = 140
    pvalue = mannwhitneyu('treatment', 'concentration', data=data[data.molecule_id == m]).pvalue
    print(m, pvalue)
    
    symbol = assign_significance_symbol(pvalue)
    ss_annotation_kwrgs = dict(ha='center', fontweight='bold')
    if symbol == '†':
        ss_annotation_kwrgs['fontsize'] = 5
        ypos = 180
    ax.annotate(symbol, xy=(xpos, ypos), **ss_annotation_kwrgs)

        
ax.set_xticks(concat([x - 0.25, x + 0.25] for x in ax.xaxis.get_ticklocs()),
                  minor=True)
ax.set_xticklabels(['−', '+'] * len(ax.xaxis.get_ticklocs()), minor=True,
                   position=(0, 0.01))


xticklabels = [t.get_text() for t in ax.get_xticklabels()]
assert xticklabels[-1] == 'total_scfa'
xticklabels[-1] = 'total SCFA'
ax.set_xticklabels(xticklabels, position=(0, -0.02), rotation=30)

for ticklabel in ax.xaxis.get_majorticklabels():
    ticklabel.set_verticalalignment('top')
    
ax.axvspan(2.5, 5.5, color='grey', alpha=0.1)
ax.tick_params(axis='x', which='both', bottom='off', top='off')

#fig.tight_layout(pad=0.1)
savefig(fig, 'fig/scfa_box')

data[data.molecule_id.isin(mols)].groupby(['molecule_id', 'treatment']).concentration.median().unstack()[['control', 'acarbose']]

In [None]:
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(2, 4)

gs[:2,1]

In [None]:
data = (mouse[lambda x: (x.cohort == 'C2013')]
             .join(conc.unstack().to_frame(name='concentration')
                       .reset_index(level='molecule_id')
                  ).dropna(subset=['concentration']))
data['zorder'] = np.random.choice(np.arange(5, 10), size=len(data))

mols_sets = [
             ['acetate', 'butyrate', 'propionate', 'total_scfa'],
             ['glucose', 'lactate', 'succinate'],
            ]

fig = plt.figure(figsize=(halfwidth, 4))
axs = [fig.add_subplot(gs[0,:]), fig.add_subplot(gs[1,:-1])]

for mols, ax in zip(mols_sets, axs):
    sns.boxplot('molecule_id', y='concentration', hue='treatment', data=data,
                order=mols, hue_order=['control', 'acarbose'], palette={'control': 'white', 'acarbose': 'white'},
                whis=0, linewidth=1, ax=ax, showfliers=False, saturation=1)
    for (sex, zorder), d in data.groupby(['sex', 'zorder']):
        sns.stripplot('molecule_id', y='concentration', hue='treatment', data=d,
                      order=mols, hue_order=['control', 'acarbose'],
                      palette=color_map,
                      marker='o',
                      jitter=True, split=True, linewidth=0.5,
                      s=3, alpha=0.5, zorder=zorder,
                      ax=ax)
    ax.legend([])
    ax.set_xlabel('')
    ax.set_ylabel('conc. (mmols/kg)')

    thresh = 0.1
    ax.set_yscale('symlog', linthreshy=thresh, linscaley=0.2)
    ax.set_ylim(bottom=-thresh, top=300)

    minor_yticks = np.concatenate([np.arange(0, 1.1, step=0.1),
                                   np.arange(0, 11, step=1)[1:],
                                   np.arange(0, 110, step=10)[1:]
                                  ])
    ax.set_yticks(minor_yticks, minor=True)
    ax.yaxis.set_tick_params(which='minor', size=1)
    ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}'))

    for xpos, m in enumerate(mols):
        ypos = 110
        pvalue = mannwhitneyu('treatment', 'concentration', data=data[data.molecule_id == m]).pvalue
        print(m, pvalue)

        symbol = assign_significance_symbol(pvalue)
        ss_annotation_kwrgs = dict(ha='center', fontweight='bold')
        if symbol == '†':
            ss_annotation_kwrgs['fontsize'] = 5
            ypos = 160
        ax.annotate(symbol, xy=(xpos, ypos), **ss_annotation_kwrgs)


    ax.set_xticks(concat([x - 0.25, x + 0.25] for x in ax.xaxis.get_ticklocs()),
                      minor=True)
    ax.set_xticklabels(['−', '+'] * len(ax.xaxis.get_ticklocs()), minor=True,
                       position=(0, 0.01))


    xticklabels = [t.get_text() for t in ax.get_xticklabels()]
    if xticklabels[-1] == 'total_scfa':
        xticklabels[-1] = 'total\nSCFA'
#        ax.axvspan(2.5, 3.5, color='grey', alpha=0.1)

    ax.set_xticklabels(xticklabels, position=(0, -0.02))

    for ticklabel in ax.xaxis.get_majorticklabels():
        ticklabel.set_verticalalignment('top')

    ax.tick_params(axis='x', which='both', bottom='off', top='off')

#fig.tight_layout()
#savefig(fig, 'fig/scfa_box')

(data[data.molecule_id.isin(mols)].groupby(['molecule_id', 'treatment'])
     .concentration.median().unstack()[['control', 'acarbose']])

In [None]:
data = (mouse[lambda x: x.cohort=='C2013']
             .join(conc.unstack().to_frame(name='concentration')
                       .reset_index(level='molecule_id')
             .dropna(subset=['concentration'])
                  )
       )

mols = ['total_scfa', 'propionate_scfa_frac', 'acetate_scfa_frac', 'butyrate_scfa_frac']

for m in mols:
    pvalue = mannwhitneyu('treatment', 'concentration', data=data[data.molecule_id == m]).pvalue
    print(m, pvalue)


data[data.molecule_id.isin(mols)].groupby(['molecule_id', 'treatment']).concentration.median().unstack()

## Interaction Effects

In [None]:
data = (mouse
            [lambda x: x.cohort.isin(['C2013'])]
            .join(conc, how='left')
            .dropna(subset=['butyrate'])
       )

feat = 'propionate'

for mol in ['butyrate', 'acetate', 'propionate', 'lactate', 'glucose', 'succinate']:
    data[mol + '_hydrated'] = data[mol] * data.hydration_factor

print(data.groupby(['sex', 'treatment'])[feat].median())

boxplot_with_points('sex', 'propionate', 'treatment', data=data, points_kwargs={'jitter': True}, palette=color_map)
plt.yscale('log')


fit = sm.ols('np.log(raise_low({})) ~ C(treatment, Treatment("control")) * sex + site'.format(feat), data=data).fit()
fit.summary()

## Stratified by Sex / Site (Supplemental Figure)

In [None]:
data = (mouse[lambda x: (x.cohort == 'C2013')]
             .join(conc.unstack().to_frame(name='concentration')
                       .reset_index(level='molecule_id')
             .dropna(subset=['concentration'])
                  )
       )
data['zorder'] = np.random.choice(np.arange(5, 10), size=len(data))

mols = ['glucose', 'lactate', 'succinate',
        'acetate', 'butyrate', 'propionate',
        'total_scfa'
       ]

fig, axs = plt.subplots(3, 2, figsize=(fullwidth, 7), sharey=True, sharex=True)

for (sex, d1), ax in zip(data.groupby('sex'), axs.flatten()[-2:]):
    sns.boxplot('molecule_id', y='concentration', hue='treatment', data=d1,
                order=mols, hue_order=['control', 'acarbose'],
                palette={'control': 'white','acarbose': 'white'},
                whis=0, linewidth=1, ax=ax, showfliers=False, saturation=1)
    sns.stripplot('molecule_id', y='concentration', hue='treatment', data=d1,
                  order=mols, hue_order=['control', 'acarbose'],
                  palette=color_map, marker=mark_map[sex],
                  jitter=True, split=True, linewidth=0.5, s=3, alpha=0.7,
                  ax=ax)
    ax.legend([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_yscale('symlog', linthreshy=0.1, linscaley=0.2)
    ax.set_ylim(bottom=-1e-1, top=200)

    minor_yticks = np.concatenate([np.arange(0, 1.1, step=0.1),
                                   np.arange(0, 11, step=1)[1:],
                                   np.arange(0, 110, step=10)[1:]
                                  ])
    ax.set_yticks(minor_yticks, minor=True)
    ax.yaxis.set_tick_params(which='minor', size=1)
    ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}'))

    for xpos, m in enumerate(mols):
        pvalue = mannwhitneyu('treatment', 'concentration', data=d1[d1.molecule_id == m]).pvalue
        ypos = 110
        symbol = assign_significance_symbol(pvalue)
        ss_annotation_kwrgs = dict(ha='center', va='center', fontweight='bold')
        if symbol == '†':
            ss_annotation_kwrgs['fontsize'] = 5
            ypos = 140
        ax.annotate(symbol, xy=(xpos, ypos), **ss_annotation_kwrgs)


    xticklabels = [t.get_text() for t in ax.get_xticklabels()]
    assert xticklabels[-1] == 'total_scfa'
    xticklabels[-1] = 'total\nSCFA'
    ax.set_xticklabels(xticklabels, position=(0, -0.02), fontdict={'size': 8})

    for ticklabel in ax.xaxis.get_majorticklabels():
        ticklabel.set_verticalalignment('top')

    ax.axvspan(2.5, 5.5, color='grey', alpha=0.1)
    ax.tick_params(axis='x', which='both', bottom='off', top='off')
    ax.set_title(sex)

for (site, d1), ax in zip(data.groupby('site'), axs.flatten()[:-2]):
    sns.boxplot('molecule_id', y='concentration', hue='treatment', data=d1,
                order=mols, hue_order=['control', 'acarbose'], palette={'control': 'white', 'acarbose': 'white'},
                whis=0, linewidth=1, ax=ax, showfliers=False, saturation=1)
    for (sex, zorder), d2 in d1.groupby(['sex', 'zorder']):
        sns.stripplot('molecule_id', y='concentration', hue='treatment', data=d2,
                      order=mols, hue_order=['control', 'acarbose'],
                      palette=color_map, marker=mark_map[sex],
                      jitter=True, split=True, linewidth=0.5, s=3, alpha=0.7, zorder=zorder, ax=ax)
    ax.legend([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_yscale('symlog', linthreshy=0.1, linscaley=0.2)
    ax.set_ylim(bottom=-1e-1, top=200)

    minor_yticks = np.concatenate([np.arange(0, 1.1, step=0.1),
                                   np.arange(0, 11, step=1)[1:],
                                   np.arange(0, 110, step=10)[1:]
                                  ])
    ax.set_yticks(minor_yticks, minor=True)
    ax.yaxis.set_tick_params(which='minor', size=1)
    ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}'))

    for xpos, m in enumerate(mols):
        pvalue = mannwhitneyu('treatment', 'concentration', data=d1[d1.molecule_id == m]).pvalue
        ypos = 110
        symbol = assign_significance_symbol(pvalue)
        ss_annotation_kwrgs = dict(ha='center', va='center', fontweight='bold')
        if symbol == '†':
            ss_annotation_kwrgs['fontsize'] = 5
            ypos = 140
        ax.annotate(symbol, xy=(xpos, ypos), **ss_annotation_kwrgs)

    xticklabels = [t.get_text() for t in ax.get_xticklabels()]
    if xticklabels:
        if xticklabels[-1] == 'total_scfa':
            xticklabels[-1] = 'total SCFA'
    ax.set_xticklabels(xticklabels, position=(0, -0.02), fontdict={'size': 8})

    for ticklabel in ax.xaxis.get_majorticklabels():
        ticklabel.set_verticalalignment('top')

    ax.axvspan(2.5, 5.5, color='grey', alpha=0.1)
    ax.tick_params(axis='x', which='both', bottom='off', top='off')
    ax.set_title(site)
    
# Set the y-label only on left-most axes
for row in axs:
    row[0].set_ylabel('conc. (mmols/kg)')
    
ax = axs[1,1]
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_frame_on(False)

for ax in axs[-1]:
    ax.set_xticks(concat([x - 0.25, x + 0.25] for x in ax.xaxis.get_ticklocs()),
                      minor=True)
    ax.set_xticklabels(['−', '+'] * len(ax.xaxis.get_ticklocs()), minor=True,
                       position=(0, 0.01))
    ax.set_xticklabels(xticklabels, position=(0, -0.02), rotation=30)

for panel, ax in zip(['A', 'B', 'C'], axs.flatten()[:-2]):
    ax.annotate(panel, xy=(0.02, 1.03), xycoords='axes fraction', fontweight='heavy')
for panel, ax in zip(['D', 'E'], axs.flatten()[-2:]):
    ax.annotate(panel, xy=(0.02, 1.03), xycoords='axes fraction', fontweight='heavy')

fig.subplots_adjust(wspace=0.05, hspace=0.25)
savefig(fig, 'fig/scfa_box_supplement')

# Metabolite-Taxon Correlations

In [None]:
data = (mouse[lambda x: x.cohort.isin(['C2013']) &
                  x.treatment.isin(['control', 'acarbose']) &
                  x.site.isin(['UT', 'UM', 'TJL'])
             ]
             .join(abund_family)
             .join(conc)
             .dropna(subset=['butyrate', 'dens'])
       )
assert data.index.is_unique

tax_features = ['Muribaculaceae','Lachnospiraceae', 'Ruminococcaceae', 'Lactobacillaceae', 'Erysipelotrichaceae']
mol_features = ['butyrate', 'acetate', 'propionate', 'lactate', 'glucose']

def spearmanr(x, y, data):
    spear = sp.stats.spearmanr(data[x], data[y])
    return pd.Series(dict(rho=spear.correlation, pvalue=spear.pvalue))

out = {}
for tax in tax_features:
    for mol in mol_features:
        out[(tax, mol)] = spearmanr(tax, mol, data)

(pd.DataFrame(out)
   .T[['rho', 'pvalue']]
   # Create table entries as the correlation (rho) + *-for significance
   .apply(lambda x: '{:.02f}{}'.format(x.rho, assign_significance_symbol(x.pvalue)), axis=1)
   .unstack([1])
   .ix[['Muribaculaceae', 'Lachnospiraceae', 'Ruminococcaceae', 'Lactobacillaceae', 'Erysipelotrichaceae']]
   [['acetate', 'butyrate', 'propionate', 'lactate', 'glucose']]
)

In [None]:
data.groupby(['treatment']).apply(lambda df: sp.stats.spearmanr(df.butyrate,
                                                                df.Ruminococcaceae))

In [None]:
data = (mouse[lambda x: (x.cohort == 'C2013')]
             .join(abund_family)
             .join(conc)
             .dropna(subset=['butyrate', 'dens'])
       )
assert data.index.is_unique

#for fam in families:
#    data[fam] = data[fam] / data.ix[lambda x: x.treatment == 'control', fam].median()
    
fig, axs = plt.subplots(2, 2, figsize=(fullwidth, 0.75 * fullwidth))

cust = {'Muribaculaceae':
           {'label': 'Muribaculaceae (spike-adjust)',
            'lim': 2e4,
            'linthresh': 1e2
           },
        'Lachnospiraceae':
            {'label': 'Lachnospiraceae (spike-adjust)',
             'lim': 1e4,
             'linthresh': 1e2
            },
        'Lactobacillaceae':
            {'label': 'Lactobacillaceae (spike-adjust)',
             'lim': 5e3,
             'linthresh': 1e1
            },
        'propionate':
            {'label': 'propionate (mmols/kg)',
             'lim': 10,
             'linthresh': 1e-1
            },
        'butyrate':
            {'label': 'butyrate (mmols/kg)',
             'lim': 50,
             'linthresh': 1
            },
        'lactate':
            {'label': 'lactate (mmols/kg)',
             'lim': 50,
             'linthresh': 0.1
            },
        'acetate':
            {'label': 'acetate (mmols/kg)',
             'lim': 1e2,
             'linthresh': 1
            },
        }

for (x, y), ax in zip([('Muribaculaceae', 'propionate'),
                       ('Lachnospiraceae', 'butyrate'),
                       ('Lactobacillaceae', 'lactate'),
                       ('acetate', 'butyrate')],
                      axs.flatten()):
    print(x, y, sp.stats.spearmanr(data[x], data[y]))
    for treatment, d1 in data.groupby('treatment'):
        print(x, y, treatment, sp.stats.spearmanr(d1[x], d1[y]))
        for sex, d2 in d1.groupby('sex'):
            marker = mark_map[sex]
            addn_kwargs = dict(lw=0.5, edgecolor='black')
            if marker == 'X':
                addn_kwargs = dict(lw=4, edgecolor='none')
            ax.scatter(x, y, data=d2, color=color_map[treatment],
                       s=20, marker=marker,
                       alpha=0.7, **addn_kwargs, label='_nolegend_')
            print(x, y, sex, treatment, sp.stats.spearmanr(d2[x], d2[y]))

    ax.set_xlabel(cust[x]['label'], labelpad=0.2)
    ax.set_ylabel(cust[y]['label'], labelpad=0.2)
    for tick_label in ax.get_yticklabels():
        tick_label.set_position([0.02, 0])
    for tick_label in ax.get_xticklabels():
        tick_label.set_position([0, 0.02])

    ax.set_xscale('symlog', linthreshx=cust[x]['linthresh'], linscalex=0.25)
    ax.set_xlim(-cust[x]['linthresh'] / 2, cust[x]['lim'])
    ax.set_yscale('symlog', linthreshy=cust[y]['linthresh'], linscaley=0.25)
    ax.set_ylim(-cust[y]['linthresh'] / 2, cust[y]['lim'])
    #ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=4, integer=True))
    #    ax.yaxis.set_major_formatter(StrMethodFormatter('{x:}'))
    #    ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.02}'))
    #    ax.set_xscale('log')
    
#for sex in ['male', 'female']:
#    for treatment in ['control', 'acarbose']:
#        marker = mark_map[treatment]
#        addn_kwargs = dict(lw=0.5, edgecolor='black')
#        if marker == 'X':
#            addn_kwargs = dict(lw=4, edgecolor='none')
#        ax.scatter([], [], data=d2, color=color_map[sex],
#                   s=20, marker=marker,
#                   alpha=0.7, **addn_kwargs, label=(sex, treatment))
#ax.legend(bbox_to_anchor=(1.4,0.75), frameon=True, edgecolor='black')
    
#axs[-1,-1].remove()

for panel, ax in zip(['A', 'B', 'C', 'D'], axs.flatten()):
    ax.annotate(panel, xy=(0.02, 1.03), xycoords='axes fraction', fontweight='heavy')
    
    
# Add a legend
dummy_artists = {}
for treatment, _ in data.groupby('treatment'):
    dummy_artists[treatment] = ax.scatter([], [], color=color_map[treatment], marker='s', label=treatment)
for sex, _ in data.groupby('sex'):
    dummy_artists[sex] = ax.scatter([], [], color='grey', marker=mark_map[sex], label=sex)
ax.legend(dummy_artists, frameon=True, fontsize=6)
    
    
    
fig.subplots_adjust(hspace=0.35)
savefig(fig, 'fig/metab_corr')

In [None]:
data = (mouse[lambda x: (x.cohort == 'C2013')]
             .join(abund_family)
             .join(conc)
             .dropna(subset=['butyrate', 'dens'])
       )
assert data.index.is_unique

chem_feats = ['acetate', 'butyrate', 'propionate', 'lactate']
fam_feats = ['Muribaculaceae', 'Lachnospiraceae', 'Ruminococcaceae', 'Lactobacillaceae']
feats = chem_feats + fam_feats

corrs = np.zeros((len(feats), len(feats)))
pvals = np.ones_like(corrs)

for i, featA in enumerate(feats):
    for j, featB in enumerate(feats[i + 1:], i + 1):
        ctl_test = sp.stats.spearmanr(data.loc[(lambda x: x.treatment=='control'), featA],
                                      data.loc[(lambda x: x.treatment=='control'), featB]
                                     )
        
        corrs[i, j] = ctl_test.correlation
        pvals[i, j] = ctl_test.pvalue
        aca_test = sp.stats.spearmanr(data.loc[(lambda x: x.treatment=='acarbose'), featA],
                                      data.loc[(lambda x: x.treatment=='acarbose'), featB]
                                     )
        
        corrs[j, i] = aca_test.correlation
        pvals[j, i] = aca_test.pvalue
        
corrs = pd.DataFrame(corrs, index=feats, columns=feats)
corrs.rename(index={'Muribaculaceae': 'Muribaculaceae'}, columns={'Muribaculaceae': 'Muribaculaceae'}, inplace=True)
pvals = pd.DataFrame(pvals, index=corrs.index, columns=corrs.columns)

fig, ax = plt.subplots()

sns.heatmap(corrs,
            annot=pvals.applymap(assign_significance_symbol), fmt='',
#            annot_kws=dict(verticalalignment='center'),
            square=True, ax=ax, cbar_kws=dict(pad=0.08), cmap='RdBu_r')
ax.plot([0, len(feats)], [len(feats), 0], color='k')
#ax.xaxis.tick_top()
ax.xaxis.tick_top()
for tl in ax.get_xticklabels():
    tl.set_rotation(45)
    tl.set_horizontalalignment('left')
    
ax.set_ylabel('ACA')
ax.yaxis.set_label_position('right')
ax.set_xlabel('Control')

fig.axes[-1].set_xlabel('$\\rho$')
#data[chem_feats + fam_feats]

savefig(fig, 'fig/corr_matrix', bbox_inches='tight', pad_inches=0)

In [None]:
data.groupby(['treatment', 'site']).apply(lambda x: sp.stats.spearmanr(x.Muribaculaceae, x.propionate).pvalue)