In [None]:
import mdsine2 as md2
from mdsine2.names import STRNAMES
import numpy as np 
from pathlib import Path 
import matplotlib.pyplot as plt 


In [None]:
unhealthy_dir = Path('../data/unhealthy-toy')

cv_output_dir = Path('../output/cv')
fig_dir = Path('../figs/')

cv_output_dir.mkdir(exist_ok=True, parents=True)
fig_dir.mkdir(exist_ok=True, parents=True)

# Make the data and validation Study objects
tsv_files = sorted(unhealthy_dir.glob('*.tsv'))
tsv_files = {f.stem : f for f in tsv_files}

# Read data into study object
holdout_study = md2.dataset.parse(
    name = unhealthy_dir.stem,
    metadata = tsv_files['metadata'],
    taxonomy = tsv_files['rdp_species'],
    reads = tsv_files['counts'],
    qpcr = tsv_files['qpcr'],
    perturbations = tsv_files['perturbations'],
)


In [None]:
# Remove subject 8
val = holdout_study.pop_subject('8')
val.name += '-validate'

print('Holdout study name:', holdout_study.name)
print('Subject IDs in holdout study:', [s.name for s in holdout_study])
print('Perturbations: \n', *holdout_study.perturbations)

print('Validation study name', val.name)
print('Subject IDs in holdout study:', [s.name for s in val])
print('Validation perturbations', *val.perturbations)

# We learned the negative binomial model in the previous tutorial, so
# we'll reuse those parameters here.
negbin_params = np.load('./negbin_params.npz')
a0 = negbin_params['a0'].item()
a1 = negbin_params['a1'].item()
print(f'Reusing a0 and a1 from tutorial 2: a0 = {a0}, a1 = {a1}')

# Learn the model
params = md2.config.MDSINE2ModelConfig(
    basepath=cv_output_dir / holdout_study.name, 
    seed=0, 
    burnin=50, n_samples=100, 
    negbin_a0=a0, negbin_a1=a1, 
    checkpoint=50
)
params.INITIALIZATION_KWARGS[STRNAMES.CLUSTERING]['value_option'] = 'no-clusters'
mcmc = md2.initialize_graph(params=params, graph_name=holdout_study.name, subjset=holdout_study)
mcmc = md2.run_graph(mcmc, crash_if_error=True)


In [None]:
# Forward simulate

# Get the initial conditions
subj = val['8']
M_truth = subj.matrix()['abs']
initial_conditions = M_truth[:,0]
initial_conditions[initial_conditions==0] = 1e5
times = subj.times

# Forward simulate for each gibb step
M = md2.model.gLVDynamicsSingleClustering.forward_sim_from_chain(
    mcmc, subj=subj, 
    initial_conditions=initial_conditions, 
    times=times, 
    simulation_dt=0.01
)


In [None]:
# Plot the forward sims, 5th-95th percentile
taxa = subj.taxa

low = np.percentile(M, q=5, axis=0)
high = np.percentile(M, q=95, axis=0)
med = np.percentile(M, q=50, axis=0)

oidx = 3 # OTU 4

fig = plt.figure()
ax = fig.add_subplot(111)
ax.fill_between(times, y1=low[oidx, :], y2=high[oidx, :], alpha=0.2)
ax.plot(times, med[oidx,:], label='Forward Sim')
ax.plot(times, M_truth[oidx, :], label='Data', marker='x', color='black',
       linestyle=':')
ax.set_yscale('log')

md2.visualization.shade_in_perturbations(ax, perturbations=subj.perturbations, subj=subj)
ax.set_ylim(bottom=1e5, top=1e12)

ax.legend()

fig.suptitle(md2.taxaname_for_paper(taxa[oidx], taxa))
plt.savefig(fig_dir / 'forward_sim.png')