In [None]:
import os as _os
_os.chdir('..')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from operator import eq, sub
import scipy.stats
import scipy as sp
import seaborn as sns
from lib.pandas_util import idxwhere
from lib.plot import construct_ordered_palette

In [None]:
mgen = pd.read_table('meta/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/subject.tsv', index_col='subject_id')

mgen_meta = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(mgen_meta.subject_id.isna())

# meta.columns

In [None]:
subject_week = (
    visit
    .join(subject, on='subject_id')
    .reset_index()
    .dropna(subset=['subject_id', 'week_number'])
    .groupby(['subject_id', 'week_number'])
    .apply(lambda d: d.loc[d.notna().sum(1).sort_values().index[-1]])
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    .join(stool.groupby('visit_id').fecal_calprotectin.mean(), on='visit_id')
)

mgen_to_subject_week = mgen_meta.dropna(subset=['week_number']).apply(lambda x: x.subject_id + '_' + str(int(x.week_number)), axis=1).rename('subject_week_id')
mgen_to_subject_week

#.groupby(['subject_id', 'week_number']).visit_id.count().sort_values(ascending=False)

In [None]:
strain_depth_with_minor = pd.read_table(
    'data/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.fit-sfacts3-s75-g10000-seed0.collapse-10.strain_depth.tsv',
    # names=['library_id', 'species_strain_id', 'depth'],
    index_col=['sample', 'strain'],
).squeeze().unstack('strain', fill_value=0).groupby(mgen_to_subject_week).sum()

In [None]:
species_depth = (
    pd.read_table('data/hmp2.a.r.proc.gtpro.species_depth.tsv', index_col=['sample', 'species_id'])
    .squeeze()
    .unstack('species_id', fill_value=0)
    .groupby(mgen_to_subject_week)
    .sum()
)
species_depth.columns = species_depth.columns.astype(str)
plt.hist(strain_depth_with_minor.sum(1) - species_depth.sum(1))

In [None]:
thresh = 0.5

strain_collapse = strain_depth_with_minor.columns.to_series()
strain_other = strain_collapse.str.rsplit('-', 1).str[0] + '-other'
strain_collapse = strain_collapse.where(strain_depth_with_minor.max() > thresh, strain_other)

strain_collapse.value_counts().shape

In [None]:
strain_depth = strain_depth_with_minor.groupby(strain_collapse, axis='columns').sum()

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])
    
strain_taxonomy = strain_depth.columns.to_series().str.split('-').str[0].to_frame(name='species_id').join(species_taxonomy, on='species_id')

species_taxonomy = strain_taxonomy.drop_duplicates(subset=['species_id']).set_index('species_id')

In [None]:
species_taxonomy

In [None]:
plt.hist(strain_depth_with_minor.sum(1) - strain_depth.sum(1))

In [None]:
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
plt.hist(np.log10(strain_rabund.max()), bins=np.linspace(-5, 0, num=51))
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
plt.hist(np.log10(species_rabund.max()), bins=np.linspace(-5, 0, num=51))
None

In [None]:
_depth = species_depth
_n_taxa = len(_depth.columns)
_rabund = _depth.divide(_depth.sum(1), axis=0)

c = 'tab:blue'

_prevalence = (_rabund > 1e-5).mean()
_mean_rabund = _rabund.mean()
_decreasing_prevalence = (_rabund > 1e-5).mean().sort_values(ascending=False).index
_quantile_rabund = _rabund.loc[:, _decreasing_prevalence].cumsum(1).quantile([0.0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0]).T

xx = np.arange(_n_taxa)
plt.plot(xx, _quantile_rabund[0.5], c=c, lw=2, label='median_rabund')
plt.fill_between(xx, _quantile_rabund[0.0], _quantile_rabund[1.0], color=c, alpha=0.05, edgecolor=None)
plt.fill_between(xx, _quantile_rabund[0.05], _quantile_rabund[0.95], color=c, alpha=0.2, edgecolor=None)
plt.fill_between(xx, _quantile_rabund[0.25], _quantile_rabund[0.75], color=c, alpha=0.2, edgecolor=None)

plt.axvline((_quantile_rabund[0.0] < 0.99).sum(), linestyle='--', lw=1, color='grey')

In [None]:
_depth = strain_depth
_n_taxa = len(_depth.columns)
_rabund = _depth.divide(_depth.sum(1), axis=0)

c = 'tab:blue'

_prevalence = (_rabund > 1e-5).mean()
_mean_rabund = _rabund.mean()
_decreasing_prevalence = (_rabund > 1e-5).mean().sort_values(ascending=False).index
_quantile_rabund = _rabund.loc[:, _decreasing_prevalence].cumsum(1).quantile([0.0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0]).T

xx = np.arange(_n_taxa)
plt.plot(xx, _quantile_rabund[0.5], c=c, lw=2, label='median_rabund')
plt.fill_between(xx, _quantile_rabund[0.0], _quantile_rabund[1.0], color=c, alpha=0.05, edgecolor=None)
plt.fill_between(xx, _quantile_rabund[0.05], _quantile_rabund[0.95], color=c, alpha=0.2, edgecolor=None)
plt.fill_between(xx, _quantile_rabund[0.25], _quantile_rabund[0.75], color=c, alpha=0.2, edgecolor=None)

plt.axvline((_quantile_rabund[0.0] < 0.99).sum(), linestyle='--', lw=1, color='grey')

print('Unique species:', len(set(map(lambda s: s.split('-')[0], idxwhere(_quantile_rabund[0.0] < 0.99)))))

In [None]:
fig, ax = plt.subplots(figsize=(10, 20))
abx_status_matrix = subject_week.groupby(['subject_id', 'week_number']).status_antibiotics.any().astype(float)
cm = sns.heatmap(abx_status_matrix.unstack(), cmap='coolwarm', yticklabels=1, ax=ax, cbar=False)

In [None]:
d = abx_status_matrix.unstack().apply(lambda x: x.dropna().mean(), axis=1).to_frame(name='frac_libraries_during_abx').join(subject)
sns.stripplot(y='frac_libraries_during_abx', x='ibd_diagnosis', hue='sex', data=d, alpha=0.3, dodge=True)

In [None]:
fig, ax = plt.subplots(figsize=(10, 20))

v = subject_week.assign(has_mgen=lambda d: d.index.to_series().isin(species_depth.index))
has_mgen_matrix = v.groupby(['subject_id', 'week_number']).has_mgen.any().astype(float)

sns.heatmap(has_mgen_matrix.unstack(), cmap='coolwarm', yticklabels=1, ax=ax, cbar=False)

In [None]:
fig, ax = plt.subplots(figsize=(10, 20))

sns.heatmap(subject_week.set_index(['subject_id', 'week_number']).fecal_calprotectin.unstack(), yticklabels=1, ax=ax)

In [None]:
mgen_id_list_matrix = mgen_meta.reset_index().groupby(['subject_id', 'week_number']).library_id.apply(list)
mgen_id_list_matrix.sort_index()

In [None]:
d = (
    pd.DataFrame(dict(
        has_mgen=has_mgen_matrix,
        library_id_list=mgen_id_list_matrix,
        status_antibiotics=abx_status_matrix.astype(bool)
    ))
    .reset_index()
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    )


# For each subject, find a pair of visits with metagenomes where the first has no antibiotics (and not in the most-recent week, either),
# and the second sample does have antibiotics
# and they're in consecutive fortnights

def _find_antibiotic_comparison_pairs(data):
    d0 = data.sort_values(['subject_id', 'week_number']).copy()
    d0['has_mgen'] = d0['has_mgen'].astype(bool)
    d1 = (
        d0.assign(
            maybe_control=lambda x: ~(x.status_antibiotics | x.status_antibiotics.shift(1)),  # No abx this visit or last
            maybe_abx=lambda x: (x.status_antibiotics & (~x.status_antibiotics).shift(1)),  # Abx this visit but not last
        )
        [lambda x: x.has_mgen]
        .assign(
            last_mgen_maybe_control=lambda x: x.maybe_control.shift(1),
            next_mgen_maybe_abx=lambda x: x.maybe_abx.shift(-1),
            time_delta_last_mgen=lambda x: x.week_number - x.week_number.shift(1),
            time_delta_next_mgen=lambda x: x.week_number.shift(-1) - x.week_number,
        )
    )

    out = d1[lambda x: (
        (x.maybe_control & x.next_mgen_maybe_abx & (x.time_delta_next_mgen <= 2.0)) # First of a pair
        | (x.maybe_abx & x.last_mgen_maybe_control & (x.time_delta_last_mgen <= 2.0)) # Second of a pair
    )].head(2).index.values
    if len(out) == 2:
        return pd.Series(out, index=['pre', 'post'])
    else:
        return pd.Series(np.nan, index=['pre', 'post'])

abx_perturbation_pairs = (
    d.groupby('subject_id')
    .apply(_find_antibiotic_comparison_pairs)
    .dropna()
)
abx_perturbation_pairs

# For each subject, find a pair of visits with metagenomes where neither has antibiotics, (and not in the most-recent week, either),
# and they're in consecutive fortnights

def _find_dummy_comparison_pairs(data):
    d0 = data.sort_values(['subject_id', 'week_number']).copy()
    d0['has_mgen'] = d0['has_mgen'].astype(bool)
    d1 = (
        d0.assign(
            maybe_control=lambda x: ~(x.status_antibiotics | x.status_antibiotics.shift(1)),  # No abx this visit or last
            maybe_dummy=lambda x: ~(x.status_antibiotics | x.status_antibiotics.shift(1)),  # No abx this visit or last
        )
        [lambda x: x.has_mgen]
        .assign(
            last_mgen_maybe_control=lambda x: x.maybe_control.shift(1),
            next_mgen_maybe_dummy=lambda x: x.maybe_dummy.shift(-1),
            time_delta_last_mgen=lambda x: x.week_number - x.week_number.shift(1),
            time_delta_next_mgen=lambda x: x.week_number.shift(-1) - x.week_number,
        )
    )

    out = d1[lambda x: (
        (x.maybe_control & x.next_mgen_maybe_dummy & (x.time_delta_next_mgen <= 2.0)) # First of a pair
        | (x.maybe_dummy & x.last_mgen_maybe_control & (x.time_delta_last_mgen <= 2.0)) # Second of a pair
    )].head(2).index.values
    if len(out) == 2:
        return pd.Series(out, index=['pre', 'post'])
    else:
        return pd.Series(np.nan, index=['pre', 'post'])

dummy_perturbation_pairs = (
    d.groupby('subject_id')
    .apply(_find_dummy_comparison_pairs)
    .dropna()
)
dummy_perturbation_pairs

subjects_with_both_abx_and_dummy = list(set(abx_perturbation_pairs.index) & set(dummy_perturbation_pairs.index))

abx_perturbation_pairs = abx_perturbation_pairs.loc[subjects_with_both_abx_and_dummy]
dummy_perturbation_pairs = dummy_perturbation_pairs.loc[subjects_with_both_abx_and_dummy]

perturbation_pairs = abx_perturbation_pairs.join(dummy_perturbation_pairs, lsuffix='_abx', rsuffix='_dummy')
perturbation_pairs

In [None]:
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k]] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Bacteroidota;'))
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Proteobacteria;'))  # idxwhere(strain_taxonomy.p__ == 'd__Bacteria;p__Proteobacteria')
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes;'))
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;'))
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_C;'))
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Actinobacteriota;'))
diversity_func = lambda k: (strain_rabund.loc[perturbation_pairs[k], _tax_subset] > 1e-5).sum(1).values
d = (
    pd.DataFrame({k: diversity_func(k) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='diversity', data=d.stack().rename('diversity').reset_index())
plt.yscale('symlog')
plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='value', data=d.stack().rename('value').reset_index())
# plt.yscale('symlog')
# plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_C;'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='value', data=d.stack().rename('value').reset_index())
# plt.yscale('symlog')
# plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Proteobacteria;'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.stripplot(x='sample_type', y='value', hue='ibd_diagnosis', data=d.stack().to_frame(name='value').join(subject).reset_index())
# plt.yscale('symlog')
# plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
species_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(species_rabund, metric='braycurtis')), index=species_rabund.index, columns=species_rabund.index)

In [None]:
abx_perturbation = [species_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [species_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type')
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
strain_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(strain_rabund, metric='braycurtis')), index=strain_rabund.index, columns=strain_rabund.index)

In [None]:
abx_perturbation = [strain_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [strain_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type')
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
strain_jc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(
    strain_rabund > 1e-5,
    metric='jaccard',
)), index=strain_rabund.index, columns=strain_rabund.index)

In [None]:
abx_perturbation = [strain_jc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [strain_jc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type')
sns.stripplot('pair_type', 'jc', data=d.unstack().to_frame('jc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Bacteroidota;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Proteobacteria;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_C;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_rabund = strain_rabund.loc[:, idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Actinobacteriota;'))].apply(lambda x: x / x.sum(), axis=1)
_bc_dist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_rabund, metric='braycurtis')), index=_rabund.index, columns=_rabund.index)

abx_perturbation = [_bc_dist.loc[pair.pre_abx, pair.post_abx] for _, pair in perturbation_pairs.iterrows()]
dummy_perturbation = [_bc_dist.loc[pair.pre_dummy, pair.post_dummy] for _, pair in perturbation_pairs.iterrows()]

d = pd.DataFrame(dict(abx=abx_perturbation, dummy=dummy_perturbation), index=perturbation_pairs.index).rename_axis(columns='pair_type').dropna()
sns.stripplot('pair_type', 'bc', data=d.unstack().to_frame('bc').reset_index())

print(sp.stats.wilcoxon(d['abx'], d['dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)

sns.swarmplot(x='sample_type', y='value', data=d.stack().rename('value').reset_index())
# plt.yscale('symlog')
# plt.ylim(bottom=-1)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
(
    species_rabund
    .groupby(species_taxonomy[
        lambda x: x.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae')
    ].taxonomy_string, axis='columns')
    .sum()
    .groupby(subject_week.subject_id)
    .mean()
    .mean()
    .sort_values(ascending=False)
    .head(20)
)

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae;'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)
palette = construct_ordered_palette(perturbation_pairs.index, cm='tab20')
shift = {'abx': 0, 'dummy': 1.25}

for label, left, right in [('abx', 'pre_abx', 'post_abx'), ('dummy', 'pre_dummy', 'post_dummy')]:
    for subject_id in d.index:
        xy = [d.loc[subject_id, left]], [d.loc[subject_id, right]]
        plt.plot(np.array([0, 1]) + shift[label], [d.loc[subject_id, left], d.loc[subject_id, right]], alpha=0.75, color=palette[subject_id])
plt.yscale('symlog', linthresh=1e-5)
plt.ylim(bottom=-1e-6)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae;g__Roseburia;s__Roseburia intestinalis'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)
palette = construct_ordered_palette(perturbation_pairs.index, cm='tab20')
shift = {'abx': 0, 'dummy': 1.25}

for label, left, right in [('abx', 'pre_abx', 'post_abx'), ('dummy', 'pre_dummy', 'post_dummy')]:
    for subject_id in d.index:
        xy = [d.loc[subject_id, left]], [d.loc[subject_id, right]]
        plt.plot(np.array([0, 1]) + shift[label], [d.loc[subject_id, left], d.loc[subject_id, right]], alpha=0.75, color=palette[subject_id])
plt.yscale('symlog', linthresh=1e-5)
plt.ylim(bottom=-1e-6)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae;g__Agathobacter;s__Agathobacter faecis'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)
palette = construct_ordered_palette(perturbation_pairs.index, cm='tab20')
shift = {'abx': 0, 'dummy': 1.25}

for label, left, right in [('abx', 'pre_abx', 'post_abx'), ('dummy', 'pre_dummy', 'post_dummy')]:
    for subject_id in d.index:
        xy = [d.loc[subject_id, left]], [d.loc[subject_id, right]]
        plt.plot(np.array([0, 1]) + shift[label], [d.loc[subject_id, left], d.loc[subject_id, right]], alpha=0.75, color=palette[subject_id])
plt.yscale('symlog', linthresh=1e-5)
plt.ylim(bottom=-1e-6)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
_tax_subset = idxwhere(species_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae;g__Agathobacter;s__Agathobacter rectalis'))
_func = lambda ii: (species_rabund.loc[ii, _tax_subset]).sum(1).values
d = (
    pd.DataFrame({k: _func(perturbation_pairs[k]) for k in perturbation_pairs.columns}, index=perturbation_pairs.index)
    .rename_axis(columns='sample_type')

)
palette = construct_ordered_palette(perturbation_pairs.index, cm='tab20')
shift = {'abx': 0, 'dummy': 1.25}

for label, left, right in [('abx', 'pre_abx', 'post_abx'), ('dummy', 'pre_dummy', 'post_dummy')]:
    for subject_id in d.index:
        xy = [d.loc[subject_id, left]], [d.loc[subject_id, right]]
        plt.plot(np.array([0, 1]) + shift[label], [d.loc[subject_id, left], d.loc[subject_id, right]], alpha=0.75, color=palette[subject_id])
plt.yscale('symlog', linthresh=1e-5)
plt.ylim(bottom=-1e-6)
print(sp.stats.wilcoxon(d['pre_abx'], d['post_abx']))
print(sp.stats.wilcoxon(d['pre_dummy'], d['post_dummy']))
print(sp.stats.wilcoxon(d['pre_abx'], d['pre_dummy']))

In [None]:
d = (
    pd.DataFrame(dict(
        has_mgen=has_mgen_matrix,
        library_id_list=mgen_id_list_matrix,
        status_antibiotics=abx_status_matrix.astype(bool)
    ))
    .reset_index()
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    )

def _find_antibiotic_comparison_pairs(data):
    d0 = data.sort_values(['subject_id', 'week_number']).copy()
    d0['has_mgen'] = d0['has_mgen'].astype(bool)
    d1 = (
        d0.assign(
            maybe_control=lambda x: ~(x.status_antibiotics),  # No abx this visit
            maybe_abx=lambda x: (x.status_antibiotics),  # Abx this visit
            maybe_dummy=lambda x: ~(x.status_antibiotics),  # No abx this visit
        )
        [lambda x: x.has_mgen]
        .assign(
            last_mgen_maybe_control=lambda x: x.maybe_control.shift(1),
            next_mgen_maybe_abx=lambda x: x.maybe_abx.shift(-1),
            next_mgen_maybe_dummy=lambda x: x.maybe_dummy.shift(-1),
            time_delta_last_mgen=lambda x: x.week_number - x.week_number.shift(1),
            time_delta_next_mgen=lambda x: x.week_number.shift(-1) - x.week_number,
            next_subject_week_id=lambda x: x.index.to_series().shift(-1),
        )
    )
    out = []
    for pair_i, (this_subject_week_id, d2) in enumerate(d1[lambda x: x.maybe_control & x.next_mgen_maybe_abx & (x.time_delta_next_mgen <= 4.0)].iterrows()):
        out.append(['abx', pair_i, d2.time_delta_next_mgen, this_subject_week_id, d2.next_subject_week_id])
    for pair_i, (this_subject_week_id, d2) in enumerate(d1[lambda x: x.maybe_control & x.next_mgen_maybe_dummy & (x.time_delta_next_mgen <= 4.0)].iterrows()):
        out.append(['dummy', pair_i, d2.time_delta_next_mgen, this_subject_week_id, d2.next_subject_week_id])
    return pd.DataFrame(out, columns=['pair_type', 'pair_index', 'time_delta', 'left_subject_week_id', 'right_subject_week_id'])

perturbation_pairs = (
    d.groupby('subject_id')
    .apply(_find_antibiotic_comparison_pairs)
    .dropna()
    .reset_index()
    .drop(columns=['level_1'])
    .set_index('left_subject_week_id', drop=False)
    .rename_axis(index='subject_week_id')
)
perturbation_pairs

In [None]:
_tax = idxwhere(strain_taxonomy.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A;c__Clostridia;o__Lachnospirales;f__Lachnospiraceae;g__Agathobacter;s__Agathobacter rectalis'))
_func = lambda idx: strain_rabund.loc[idx, _tax].sum(1).values
d0 = perturbation_pairs.assign(
    left_value=lambda x: _func(x.left_subject_week_id),
    right_value=lambda x: _func(x.right_subject_week_id),
)
    


def _response_type(x, thresh):
    if (x.left_value > thresh) and (x.right_value < thresh):
        return 'decreasing'
    elif (x.left_value > thresh) and (x.right_value > thresh):
        return 'non_decreasing'
    elif (x.left_value < thresh) and (x.right_value < thresh):
        return 'unknown'
    elif (x.left_value < thresh) and (x.right_value > thresh):
        return 'increasing'
    else:
        assert False, "This shouldn't happen"
        

thresh = 1e-4
d0 = d0.assign(response_type=d0.apply(_response_type, thresh=thresh, axis=1))
    

palette = construct_ordered_palette(d0.response_type, cm='cool')
shift = {'dummy': 1.25, 'abx': 0}

# decreasing_pairs = idxwhere((d0.pair_type == 'abx') & (d0.left_value > thresh) & (d0.right_value < thresh))
# non_decreasing_pairs = idxwhere((d0.pair_type == 'abx') & (d0.left_value > thresh) & (d0.right_value > thresh))
# unknown_pairs = idxwhere((d0.pair_type == 'abx') & (d0.left_value < thresh) & (d0.right_value < thresh))
# increasing_pairs = idxwhere((d0.pair_type == 'abx') & (d0.left_value < thresh) & (d0.right_value > thresh))


fig, ax = plt.subplots()
for _, d1 in d0.iterrows():
    ax.plot(np.array([0, 1]) + shift[d1.pair_type], [d1.left_value, d1.right_value], c=palette[d1.response_type])
ax.set_yscale('symlog', linthresh=thresh, linscale=0.1)
ax.axhline(thresh, lw=1, linestyle='--', color='grey')

d0[lambda x: x.pair_type == 'abx']

In [None]:
sns.heatmap(d0[lambda x: x.pair_type == 'abx'].groupby(['subject_id', 'response_type']).apply(len).unstack('response_type', fill_value=0))