In [None]:
import os as _os

_os.chdir('..')

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sfacts as sf

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
import matplotlib as mpl
import scipy as sp
from operator import eq
from itertools import cycle
from tqdm import tqdm

In [None]:
def expected_sample_entropy(w, discretized=False):
    if discretized:
        gen = w.genotype.discretized().data
    else:
        gen = w.genotype.data
        
    com = w.community.data
    depth = w.metagenotype.total_counts()
        
    return ((sf.math.binary_entropy(com @ gen) * depth).sum("position") / depth.sum("position")).rename("entropy")

def max_strain_depth(w):
    return (w.community.data * w.metagenotype.mean_depth()).max('sample').rename('depth')

def total_strain_depth(w):
    return (w.community.data * w.metagenotype.mean_depth()).sum('sample').rename('depth')

In [None]:
mgen = pd.read_table('meta/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/subject.tsv', index_col='subject_id')

mgen_meta = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(mgen_meta.subject_id.isna())

# meta.columns

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

In [None]:
subject_week = (
    visit
    .join(subject, on='subject_id')
    .reset_index()
    .dropna(subset=['subject_id', 'week_number'])
    .groupby(['subject_id', 'week_number'])
    .apply(lambda d: d.loc[d.notna().sum(1).sort_values().index[-1]])
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    .join(stool.groupby('visit_id').fecal_calprotectin.mean(), on='visit_id')
)

mgen_to_subject_week = mgen_meta.dropna(subset=['week_number']).apply(lambda x: x.subject_id + '_' + str(int(x.week_number)), axis=1).rename('subject_week_id')
mgen_to_subject_week

#.groupby(['subject_id', 'week_number']).visit_id.count().sort_values(ascending=False)

In [None]:
species_id = '101493'
species_taxonomy.loc[species_id]

In [None]:
_all_species_depth = (
    pd.read_table('data/hmp2.a.r.proc.gtpro.species_depth.tsv', index_col=['sample', 'species_id'])
    .squeeze()
    .unstack('species_id', fill_value=0)
)
_all_species_depth.columns = _all_species_depth.columns.astype(str)
_all_species_rabund = _all_species_depth.divide(_all_species_depth.sum(1), axis=0)
_species_depth = _all_species_depth[species_id]

all_species_depth = _all_species_depth.groupby(mgen_to_subject_week).sum()

In [None]:
metagenotype_stem = f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05'
metagenotype = sf.Metagenotype.load(f'{metagenotype_stem}.mgen.nc')
world_path = f'{metagenotype_stem}.fit-sfacts10-s75-g10000-seed0.world.nc'
world = sf.World.load(world_path)
print(world_path)

meta = mgen_meta.loc[world.sample].sort_values(['subject_id', 'visit_date'])
metagenotype = metagenotype.sel(sample=meta.index)
world = world.sel(sample=meta.index)

same_subject = sp.spatial.distance.pdist(meta.subject_id.values.reshape((-1, 1)), metric=eq).astype(bool)

n_position_ss = min(world.sizes['position'], 1000)
w_ss = world.random_sample(position=n_position_ss).sel(strain=world.community.max("sample") > 0.01)

print(metagenotype.sizes['sample'])
print(metagenotype.sizes['position'], world.sizes['position'])
print(w_ss.sizes['strain'])

In [None]:
from itertools import cycle

subject_index_mod = {k: v for k, v in zip(meta.subject_id.unique(), cycle(range(20)))}
diagnosis_map = dict(zip(meta.ibd_diagnosis.unique(), range(100)))
site_map = dict(zip(meta.site.unique(), range(100)))



sf.plot.plot_community(
    w_ss.sel(sample=meta.index),
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.genotype.discretized().linkage(dim="strain"),
    col_linkage_func=lambda w: w.community.linkage("sample"),
    scalex=0.05,
    xticklabels=0,
    # col_cluster=False,
    # norm=mpl.colors.PowerNorm(1),
)

In [None]:
sf.plot.plot_metagenotype(
    w_ss.sel(sample=meta.index),
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    col_linkage_func=lambda w: w.community.linkage("sample"),
    row_col_annotation_cmap=mpl.cm.tab20,
    scalex=0.05,
    xticklabels=0,
    # col_cluster=False,
    # norm=mpl.colors.PowerNorm(1),
)

In [None]:
_strain_depth = (world.community.to_series().unstack().T * _species_depth).T.fillna(0)
_strain_depth_other = _species_depth - _strain_depth.sum(1)
_strain_depth = _strain_depth.assign(other=_strain_depth_other)

strain_depth = _strain_depth.groupby(mgen_to_subject_week).sum()

strain_rabund = strain_depth.divide(all_species_depth.sum(1), axis=0)

In [None]:
def _collect_pairs(df, status, keep):
    curr_subject_week_id = None
    curr_week_number = 0
    curr_perturbation_status = None
    intervening = False
    out = []
    for next_subject_week_id, x in df.sort_values(['subject_id', 'week_number']).iterrows():
        next_perturbation_status = status(x)
        if not keep(x):
            if next_perturbation_status == curr_perturbation_status:
                continue
            else:
                intervening = True
                continue
        else:
            out.append((
                curr_subject_week_id,
                curr_perturbation_status,
                next_subject_week_id,
                next_perturbation_status,
                x.week_number - curr_week_number,
                intervening,
            ))
            curr_subject_week_id = next_subject_week_id
            curr_perturbation_status = next_perturbation_status
            curr_week_number = x.week_number
            intervening = False
            continue
    return (
        pd.DataFrame(out, columns=['left_subject_week_id', 'left_status', 'right_subject_week_id', 'right_status', 'week_delta', 'intervening'])
        .set_index('left_subject_week_id', drop=False)
        .rename_axis(index='subject_week_id')
        .dropna(subset=['left_subject_week_id'])
    )
    
perturbation_pairs = (
    subject_week.assign(has_mgen=lambda x: x.index.isin(strain_rabund.index))
    .groupby('subject_id')
    .apply(_collect_pairs, status=lambda x: x.status_antibiotics, keep=lambda x: x.has_mgen)
    .reset_index('subject_id')
    .assign(transition=lambda x: x.left_status.astype(str).str[0] + x.right_status.astype(str).str[0])
)

pseudo = 1e-4
_rabund = strain_rabund
_ratio = {}
for _tax_id in tqdm(_rabund.columns):
    _ratio[_tax_id] = (
        perturbation_pairs
        .assign(
            left_value=lambda x: _rabund[_tax_id].loc[x.left_subject_week_id].values + pseudo,
            right_value=lambda x: _rabund[_tax_id].loc[x.right_subject_week_id].values + pseudo,
        )
        .assign(log_ratio=lambda x: np.log(x.right_value) - np.log(x.left_value))
        .log_ratio
    )
strain_log_ratio = pd.DataFrame(_ratio)

In [None]:
strain_log_ratio

In [None]:
d = strain_log_ratio.groupby(perturbation_pairs.transition).mean().T.sort_values('FT')

d1 = d

fig, ax = plt.subplots()
ax.scatter('FT', 'TF', data=d1, color='grey', s=1)
sns.regplot('FT', 'TF', data=d1, scatter=False, ax=ax)

fig, axs = plt.subplots(1, 2)
for ax, strain_id in zip(axs.flatten(), [d1.index[0], d1.index[-1]]):
    d2 = perturbation_pairs.assign(
                left_value=lambda x: strain_rabund[strain_id].loc[x.left_subject_week_id].values,
                right_value=lambda x: strain_rabund[strain_id].loc[x.right_subject_week_id].values,
                ratio=strain_log_ratio[strain_id],
            )
    ax.set_title(strain_id)
    sns.stripplot('transition', 'ratio', data=d2, ax=ax)

print(sp.stats.pearsonr(d1['FT'], d1['TF']))
d1

In [None]:
from patsy import dmatrix

thresh = 1e-5
_present = (_rabund > thresh)
_m = subject_week.loc[_rabund.index]

x_subject = dmatrix('subject_id - 1', data=_m, return_type='dataframe')
x_abx = _m.status_antibiotics.astype(int)
y = _present.loc[:, _present.sum(0) > 3]

m = len(y.columns)
n = len(y.index)
r = len(x_subject.columns)

print(n, m, r)

In [None]:
sns.clustermap(y)

In [None]:
plt.hist(_present.sum(1), bins=np.arange(_present.sum(1).max()))

In [None]:
import pymc3 as pm

In [None]:
with pm.Model() as model0:
    _y = pm.Data('_y', y)
    _x_subject = pm.Data('_x_subject', x_subject)
    _x_abx = pm.Data('_x_abx', x_abx)

    beta_subject = pm.Normal('beta_subject', sigma=10., shape=(r, m))
    beta_abx_strain = pm.Normal('beta_abx_strain', sigma=10., shape=(1, m))
    beta_abx_pooled = pm.Normal('beta_abx_pooled', sigma=10.)
    
    logit_prob = (
        (_x_subject @ beta_subject)
        + (_x_abx.reshape((n, 1)) @ beta_abx_strain)
        + (_x_abx.reshape((n, 1)) * beta_abx_pooled)
    )
    prob = pm.math.invlogit(logit_prob)
    alpha = pm.Lognormal('alpha', mu=2, sigma=1)  # FIXME: Increase sigma
    obs = pm.BetaBinomial('obs', alpha=prob * alpha, beta=(1 - prob) * alpha, n=1, observed=_y)

In [None]:
with model0:
    trace = pm.sample(return_inferencedata=True, chains=12, cores=12)

In [None]:
sns.kdeplot(np.log(trace.posterior.alpha.values.flatten()))

In [None]:
sns.kdeplot(trace.posterior.beta_abx_pooled.values.flatten())

In [None]:
median_beta = {}
for i in range(m):
    median_beta[y.columns[i]] = np.median(trace.posterior.beta_abx_strain[:,:,0,i].values.flatten())
    sns.kdeplot(trace.posterior.beta_abx_strain[:,:,0,i].values.flatten())
median_beta = pd.Series(median_beta)

In [None]:
d = strain_log_ratio.groupby(perturbation_pairs.transition).mean().T.sort_values('FT').assign(median_beta=median_beta)

plt.scatter('median_beta', 'FT', c='TF', data=d)

In [None]:
import arviz as az

az.summary(trace, var_names=['beta_abx_strain']).sort_values('hdi_3%')

In [None]:
plt.scatter(
    trace.posterior.beta_abx_strain[:,:,0,42].values.flatten(),
    trace.posterior.beta_abx_strain[:,:,0,44].values.flatten(),
    c=trace.posterior.beta_abx_pooled.values.flatten(),
    s=1,
)