 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gibsonlab/mdsine2_tutorials/blob/main/notebooks/tut_02_inference.ipynb)
 # Running inference with the MDSINE2 model and exploring the posterior

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !curl -LJO https://github.com/gibsonlab/mdsine2_tutorials/raw/main/data/raw_tables.zip
    !mkdir -p ./data/ && unzip raw_tables.zip -d ./data/

    !git clone https://github.com/gerberlab/MDSINE2
    !cd MDSINE2 && git fetch && git checkout cv-wip-mirror-dem
    !pip install MDSINE2/.

In [None]:
import mdsine2 as md2
from mdsine2.names import STRNAMES
from pathlib import Path
import matplotlib.pyplot as plt 
import numpy as np 


 # Loading data
 Here we load a small preprocessed "toy" dataset created in the previous tutorial.

In [None]:
replicates_dir = Path('../data/replicates-toy')
healthy_dir = Path('../data/healthy-toy')
unhealthy_dir = Path('../data/unhealthy-toy')

output_dir = Path('../output/')
fig_dir = Path('../figs/')

output_dir.mkdir(exist_ok=True, parents=True)
fig_dir.mkdir(exist_ok=True, parents=True)


In [None]:
# # Read data
studies = []
for data_dir in [replicates_dir, healthy_dir, unhealthy_dir]:
    tsv_files = sorted(data_dir.glob('*.tsv'))
    tsv_files = {f.stem : f for f in tsv_files}

    # Read data into study object
    study = md2.dataset.parse(
        name = data_dir.stem,
        metadata = tsv_files['metadata'],
        taxonomy = tsv_files['rdp_species'],
        reads = tsv_files['counts'],
        qpcr = tsv_files['qpcr'],
        perturbations = tsv_files['perturbations'],
    )
    studies.append(study)

replicates, healthy, unhealthy = studies


 # Learning the Negative Binomial dispersion parameters
 Before running the main inference loop we learn the negative binomial parameters from our physical replicates.
 Then, we build the compute graph for learning the model that is used to learn negative binomial parameters.

In [None]:
# set negative binomial model inference parameters
params = md2.config.NegBinConfig(
    seed=0, burnin=100, n_samples=200,
    checkpoint=100, 
    basepath=str(output_dir / "negbin")
)

# Build the compute graph to learn negative binomial parameters
mcmc_negbin = md2.negbin.build_graph(
    params=params, 
    graph_name=replicates.name, 
    subjset=replicates
    )

# Run inference to learn the negative binomial parameters
mcmc_negbin = md2.negbin.run_graph(
    mcmc_negbin, 
    crash_if_error=True
)

# Print a summary of a0 and a1 posterior
print('a0', md2.summary(mcmc_negbin.graph[STRNAMES.NEGBIN_A0]))
print('a1', md2.summary(mcmc_negbin.graph[STRNAMES.NEGBIN_A1]))


 Here we visualize the fit of the learned negative binomial model. This is not representative of the real results because we only have the 15 of the most abundant taxa here.

In [None]:
fig = md2.negbin.visualize_learned_negative_binomial_model(mcmc_negbin)
fig.tight_layout()
plt.savefig(fig_dir / 'negbin_fit.png')


 # Run inference on the full model
 ### Initialize and set model hyperparameters
 Here we use the learned parameters from the negative binomial model to run inference with the full model. First, we'll create a routine that takes our mcmc object as input and plots posterior quantities.

In [None]:
def visualize(mcmc, study, seed):
    """ Viz posteriors.
    """
    growth = mcmc.graph[STRNAMES.GROWTH_VALUE]
    growth_rates_trace = growth.get_trace_from_disk(section='entire')

    growth_rates_mean = md2.summary(growth)['mean']
    print('Mean growth rates for taxa over posterior', growth_rates_mean)

    # Visualize trace for the first taxa
    md2.visualization.render_trace(growth_rates_trace[:,0], n_burnin=50, **{'title': 'OTU_1 Growth rate'})
    plt.savefig(fig_dir / 'posterior_growth_rates_{:04d}.png'.format(seed))

    # Process variance
    processvar = mcmc.graph[STRNAMES.PROCESSVAR]
    pv_rates_trace = processvar.get_trace_from_disk(section='entire')

    md2.visualization.render_trace(pv_rates_trace, n_burnin=50, **{'title': 'process variance'})
    plt.savefig(fig_dir / 'posterior_process_variance_{:04d}.png'.format(seed))

    # Taxa module assignments
    clustering = mcmc.graph[STRNAMES.CLUSTERING_OBJ]
    md2.generate_cluster_assignments_posthoc(clustering, set_as_value=True)
    taxa = mcmc.graph.data.taxa

    # Visualize co-cluster posterior probability
    coclusters = md2.summary(mcmc.graph[STRNAMES.CLUSTERING_OBJ].coclusters)['mean']
    md2.visualization.render_cocluster_probabilities(coclusters, taxa=study.taxa,
        yticklabels='%(paperformat)s | %(index)s')#, order=order)
    plt.savefig(fig_dir / 'posterior_cocluster_probs_{:04d}.png'.format(seed))

    # Visualize trace for number of modules
    md2.visualization.render_trace(clustering.n_clusters)
    plt.savefig(fig_dir / 'posterior_num_modules_trace_{:04d}.png'.format(seed))



 Here we run inference, iterate through different seed values.

In [None]:
# Get a0 and a1 from negbin (get the mean of the posterior) and fixes them for inference
a0 = md2.summary(mcmc_negbin.graph[STRNAMES.NEGBIN_A0])['mean']
a1 = md2.summary(mcmc_negbin.graph[STRNAMES.NEGBIN_A1])['mean']

# Save a0 and a1 for next tutorial
np.savez('./negbin_params.npz', a0=a0, a1=a1)

seeds = [0, 1]
chains = []

for seed in seeds:
    chain_basepath = output_dir / "mdsine2" / (unhealthy.name + "{:04d}".format(seed))
    chain_basepath.mkdir(exist_ok=True, parents=True)

    # Initialize parameters of the model 
    params = md2.config.MDSINE2ModelConfig(
        basepath=str(chain_basepath), 
        seed=seed,
        burnin=50, 
        n_samples=100, 
        negbin_a0=a0, negbin_a1=a1, 
        checkpoint=50
    )

    # The default number of modules is 30, which is larger than the number of taxa
    # that we have in this dataset, which would throw a flag.
    params.INITIALIZATION_KWARGS[STRNAMES.CLUSTERING]['value_option'] = 'no-clusters'

    # Initilize the graph
    mcmc = md2.initialize_graph(params=params, graph_name=unhealthy.name, subjset=unhealthy)

    # Perform inference
    mcmc = md2.run_graph(mcmc, crash_if_error=True)

    visualize(mcmc, unhealthy, seed)
    chains.append(mcmc)


In [None]:
start = 10
end = 20
vname = STRNAMES.GROWTH_VALUE

rhat = md2.pylab.inference.r_hat(chains, start=start, end=end, vname=vname)
print('Growth parameter r-hat:', rhat)

vname = STRNAMES.CONCENTRATION
rhat = md2.pylab.inference.r_hat(chains, start=start, end=end, vname=vname)
print('Concentration parameter r-hat:', rhat)

vname = STRNAMES.PROCESSVAR
rhat = md2.pylab.inference.r_hat(chains, start=start, end=end, vname=vname)
print('Process variance parameter r-hat:', rhat)


In [None]:
# Consennss clustering
basepath = output_dir / 'mdsine2-fixed-cluster'
basepath.mkdir(exist_ok=True)

# Use the clustering from a previous as the clustering assignment
params = md2.config.MDSINE2ModelConfig(
    basepath=basepath, seed=0, burnin=50, n_samples=100, 
    negbin_a0=a0, negbin_a1=a1, checkpoint=50)

# Do not learn the clustering parameters
params.LEARN[STRNAMES.CLUSTERING] = False
params.LEARN[STRNAMES.CONCENTRATION] = False

# Set the initialization option for clustering
params.INITIALIZATION_KWARGS[STRNAMES.CLUSTERING]['value_option'] = 'fixed-clustering'
params.INITIALIZATION_KWARGS[STRNAMES.CLUSTERING]['value'] = str(chain_basepath / "mcmc.pkl")

mcmc = md2.initialize_graph(params=params, graph_name=unhealthy.name, subjset=unhealthy)
mcmc = md2.run_graph(mcmc, crash_if_error=True)


In [None]:

# Plot bayes factors for module-module interactions
clustering = mcmc.graph[STRNAMES.CLUSTERING_OBJ]
bf_taxa = md2.generate_interation_bayes_factors_posthoc(mcmc)
bf_clustering = md2.condense_fixed_clustering_interaction_matrix(bf_taxa, clustering=clustering)

labels = ['Cluster {} | {}'.format(i+1, i+1) for i in range(len(clustering))]
md2.visualization.render_bayes_factors(bf_clustering, yticklabels=labels)
plt.savefig(fig_dir / 'posterior_fix_assignments_{:04d}.png'.format(seed))
