# Introduction

**B**ayesian **I**nferential **R**egression for **D**ifferential **M**icrobiome **A**nalysis (BIRDMAn) is a WIP tool for flexible differential abundance analysis through Bayesian inference. BIRDMAn is unique in that it has been designed to support custom statistical modelling. Other tools implement specific models designed for general use cases. BIRDMAn, on the other hand, makes use of the [Stan](https://mc-stan.org/) probabilistic programming language for model specification. The overall goal of this software is to allow users to specify their own statistical models to address their individual experimental design/questions.

BIRDMAn also includes several default models for those who do not wish to tinker with custom Stan models. In this demo notebook, we'll walkthrough fitting the default Negative Binomial model to some example data. For more information see Jamie's [blogpost](https://mortonjt.github.io/probable-bug-bytes/probable-bug-bytes/differential-abundance/) that inspired this project.

**NOTE:** BIRDMAn is still in development and things are likely to change.

# Preprocessing feature table

We will be using data from the study "Responses of gut microbiota to diet composition and weight loss in lean and obese mice" (Qiita ID: 107). This study looks at the effect of weight loss and diet composition on the gut microbiome.

We will first process and explore the raw data.

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import biom

raw_tbl = biom.load_table("../data/lean_obese_mice/44773_otu_table.biom")
raw_tbl

In [None]:
raw_tbl_df = raw_tbl.to_dataframe()
raw_tbl_df.iloc[:5, :5]

For this demo we are going to use a small subset of the OTUs present in this table based on prevalence (present in at least 20 samples). This is done primarily to let the models run quickly.

In [None]:
prev = raw_tbl_df.clip(upper=1).sum(axis=1)
filt_tbl_df = raw_tbl_df.loc[prev[prev >= 20].index, :]
filt_tbl_df.shape

BIRDMAn expects tables to be input in BIOM format so we save our filtered table in this format.

In [None]:
filt_tbl = biom.table.Table(
    filt_tbl_df.values,
    sample_ids=filt_tbl_df.columns,
    observation_ids=filt_tbl_df.index
)
filt_tbl

# Exploring metadata

Given the title of this study, we are going to be focused on weight loss and diet as covariates in our model.

In [None]:
import pandas as pd

metadata = pd.read_csv(
    "../data/lean_obese_mice/107_20180101-113755.txt",
    sep="\t",
    index_col=0
)
metadata.index = metadata.index.astype(str)
metadata.head()

In [None]:
metadata.groupby(["diet", "treatment"]).size()

We will consider diet, treatment, and their interaction in our regression model. Importantly, we will specify that we want to keep `Control` and `Ad Lib` as the respective reference values (same way as in Songbird).

# Running BIRDMAn

Now it's time to run BIRDMAn! To save on computation time we will only specify 100 iterations per chain. For actual modelling you will likely want this value to be higher (defaults to 500).

Every BIRDMAn takes several parameters that I will outline here:

* `table`: BIOM table of features x samples
* `formula`: Formula with which to fit model (same as in Songbird)
* `metadata`: DataFrame with columns specified in `formula`
* `num_iter`: Number of iterations of MCMC sampling to run *per chain*
* `chains`: Number of chains to run

For the default Negative Binomial model we can also include the prior values for $\beta$ and $\phi$.

In [None]:
from birdman import NegativeBinomial

nb = NegativeBinomial(
    table=filt_tbl,
    formula="C(diet, Treatment('Control'))*C(treatment, Treatment('Ad Lib'))",
    metadata=metadata,
    num_iter=100,
    chains=4,
    beta_prior=5.0,
    cauchy_scale=5.0
)

The Stan model must be compiled before fitting. This should only have to be done once *per model type*. So the second time you use the default `NegativeBinomial` model it will not have to compile.

In [None]:
nb.compile_model()

Finally, we'll fit the model to our data. This should take ~1 minute.

In [None]:
%%time

nb.fit_model()

## Data structure for model

BIRDMAn makes heavy use of the `arviz` package for downstream analysis (see [documentation](https://arviz-devs.github.io/arviz/index.html)). We provide a wrapper function `to_inference_object` to make this relatively easy.

This function wraps much of the functionality of `xarray` for multi-dimensional arrays (see [documentation](http://xarray.pydata.org/en/stable/)). The primary things you have to learn are the `coords` and `dims` system of indexing data.

* `dims` corresponds to the dimensionality of the data. In our example, the $\beta$ parameter is of dimension (number of covariates x number of features), while the $\phi$ over-dispersion parameter is only of dimension (feature).
* `coords` provides the labels for each of the dimensions in `dims`. In this case, we want to specify the name of each covariate and each feature. BIRDMAn saves both of these values as `colnames` and `feature_names` respectively.

We also specify that $\beta$ parameters are currently in *ALR* coordinates. This means that for an input dataset of $N$ microbes, we only have $N-1$ microbes. To address this, we specify that we want the $\beta$ variable to be transformed into *CLR* coordinates to "get back" the microbe we lost.

Finally, the last three arguments to this function specify that we want to include several things in the inference object:

* `posterior_predictive`: If we tried to predict the feature table entries using the fitted parameters (we do this automatically in the default model) then we can use these values downstream for diagnosing the model. The value for this argument should be the name of the Stan variable containing this information (`y_predict` by default).
* `log_likelihood`: We also calculated the log-likelihood values to figure out the best parameter estimates. We can use these values for diagnosis as well. The value for this argument should be the name of the Stan variable containing this information (`log_lik` by default).
* `include_observed_data`: This argument specifies whether or not to include the original table values in the inference object. We can use these "truth" values to compare our model performance.

In [None]:
inference = nb.to_inference_object(
    params=["beta", "phi"],
    coords={
        "feature": nb.feature_names,
        "covariate": nb.colnames
    },
    dims={
        "beta": ["covariate", "feature"],
        "phi": ["feature"]
    },
    alr_params=["beta"],
    posterior_predictive="y_predict",
    log_likelihood="log_lik",
    include_observed_data=True
)

# Diagnosing our fitted model

As with Songbird (or any regression/ML procedure) we want to diagnose our model to make sure we are extracting useful signal and not simply overfitting predictive power. There are a number of ways to do this.

We include an easy to use function to give an *initial* diagnosis for model fit.

In [None]:
nb.diagnose();

## LOO

Next we look at the Pareto-smoothed Importance Sampling Leave-One-Out Cross-Validation (PSIS-LOO-CV). This is a mouthful but the core of it is that we want to cross-validate our model to make sure we are not overfitting. This can be fairly computationally expensive for Bayesian models so we use an estimation developed by Aki Vehtari and others. This way we can estimate model performance *from the existing sample draws*.

See these two papers for more information:

* https://arxiv.org/abs/1507.04544
* https://arxiv.org/abs/1507.02646

**Note:** This function requires that you calculated the `log likelihood` values in your Stan code (done by the default Negative Binomial model) and passed them to the inference object.

In [None]:
import birdman.diagnostics as diag

diag.loo(inference, pointwise=True)

Log likelihood is calculated for each entry in your input table - essentially figuring out how likely is the table value given the estimated parameters. We want to maximize the `elpd_loo` entry. In this demonstration we also calculate the *pointwise* predictive accuracy. We want to see the majority of the Pareto k diagnostic values < 0.7. If there are a lot of values above 0.7 this is indicative of a relatively poor model.

## $\hat{R}$ convergence

Another diagnostic tool is making sure your chains are converging. This can be done by checking the $\hat{R}$ values of your fitted parameters. In a nutshell, these values should be extremely close to 1 to ensure convergence.

See [this link](https://arviz-devs.github.io/arviz/api/generated/arviz.rhat.html) for more details. 

In [None]:
diag.rhat(inference).mean()

We see that across both our covariate coefficients and overdispersion parameters our chains have converged well.

## Posterior predictive check

We can do some more diagnosing by performing a *Posterior Predictive Check*. This procedure essentially uses the parameter distributions we've estimated and tries to predict our original values. We've provided an easy-to-use visualization function that performs this for you. This figure plots the individual table entries (samples x features) and how our model predictions fall. The black line represents the true value, the light gray vertical lines represent the middle 95% interval of values, and the dark gray dots represent the median values from all chains/iterations. We want to see the dark gray lines more-or-less follow the same shape as the black line. It is also preferable to have small credible intervals but predicting microbiome data is very difficult so it is expected that these intervals will be fairly large.

Note that this requires posterior predictive values to have been calculated in Stan and provided to `to_inference_object`.

In [None]:
import birdman.visualization as viz

viz.plot_posterior_predictive_checks(inference);

This model seems to do an okay job of predicting counts from metadata.

## Comparing to a null model

Similarly to Songbird we can compare our fitted regression model to a "null" model with only an intercept.

In [None]:
from birdman import NegativeBinomial

nb_null = NegativeBinomial(
    table=filt_tbl,
    formula="1",
    metadata=metadata,
    num_iter=100,
    chains=4
)
nb_null.compile_model()
nb_null.fit_model()

In [None]:
inference_null = nb_null.to_inference_object(
    params=["beta", "phi"],
    coords={
        "feature": nb_null.feature_names,
        "covariate": nb_null.colnames
    },
    dims={
        "beta": ["covariate", "feature"],
        "phi": ["feature"]
    },
    alr_params=["beta"],
    posterior_predictive="y_predict",
    log_likelihood="log_lik",
    include_observed_data=True
)

We can look at use `birdman.diagnostics.loo` again to see how our predictive power is with this null model.

In [None]:
diag.loo(inference_null, pointwise=True)

We see a lower value of `elpd_loo`, indicating that this null model has less predictive power than our regression model.

Another thing we can do is use the `arviz.compare` function to compare multiple models. This function takes a dictionary in the form of `{"model_1": InferenceObject, ...}` and outputs a table where the models are ranked from "best" at the top to "worst" at the bottom.

In [None]:
import arviz as az

az.compare({"null": inference_null, "model": inference})

We see that indeed our regression model is at the top. We can also compare the difference in `elpd` relative to the standard error to get a rough idea of how much better this model is.

In [None]:
(-8081.017922 - -8213.895754) / 105.563684

Looks like this model is about 1.25 SE above the null. Not bad!

# Analyzing differentials

We are now ready to use our fitted parameters for further analysis.

We can now plot our parameter estimates similarly to how we would do a rank-plot in Qurro. However, since in the Bayesian framework each parameter has a distribution, we also include the standard deviation of these parameter estimates.

In [None]:
nb.colnames

In [None]:
for col in nb.colnames[1:]:
    ax = viz.plot_parameter_estimates(
        inference,
        parameter="beta",
        coord={"covariate": col}
    )
    ax.set_title(":\n".join(col.split(":")))
    ax.axhline(y=0, color="gray")

We can use these differentials in the same ways as Songbird. Let's focus on the diet differentials - first we have to sort the features by their means across all chains and draws. This is fairly straightforward using `xarray`-style indexing.

In [None]:
diet_diffs = inference.posterior["beta"].sel({"covariate": "C(diet, Treatment('Control'))[T.DIO]"})
diet_diffs = diet_diffs.stack(mcmc_sample=("chain", "draw"))
diet_diffs_means = diet_diffs.mean(dim="mcmc_sample")
diet_diffs_means = diet_diffs_means.sortby(diet_diffs_means)
diet_diffs_means

## Plotting log-ratios

Finally, we'll calculate log-ratios using autoselected OTUs. We'll take the top and bottom 5 OTUs to use as our numerator and denominator respectively.

In [None]:
import numpy as np

def log_ratio(table, top_feats, bot_feats):
    num_df = table.loc[:, top_feats].sum(axis=1).to_frame()# + 1
    num_df.columns = ["num"]
    num_df = num_df[num_df["num"] > 0]
    denom_df = table.loc[:, bot_feats].sum(axis=1).to_frame()# + 1
    denom_df.columns = ["denom"]
    denom_df = denom_df[denom_df["denom"] > 0]
    lr_df = num_df.join(denom_df, how="inner")
    lr_df["log_ratio"] = np.log(lr_df["num"]/lr_df["denom"])
    return lr_df

In [None]:
bottom_otus = diet_diffs_means[:5].coords["feature"].values
top_otus = diet_diffs_means[-5:].coords["feature"].values

lr_df = log_ratio(filt_tbl_df.T, top_otus, bottom_otus).join(metadata, how="inner")
print(lr_df.shape)
lr_df.head()

We'll now plot these log-ratios and compare them by diet.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

args = {
    "data": lr_df,
    "x": "diet",
    "y": "log_ratio"
}

sns.boxplot(**args)
sns.stripplot(**args, linewidth=2, size=10)

plt.show()

In [None]:
import scipy.stats as ss

dio_samples = lr_df.query("diet == 'DIO'")["log_ratio"]
control_samples = lr_df.query("diet == 'Control'")["log_ratio"]

print(ss.mannwhitneyu(dio_samples, control_samples))

Indeed, this log-ratio separates the samples well by diet.

We can also use the Hotelling t-Test on individual covariates to see if they are centered around 0 (in ALR coordinates).

In [None]:
from birdman.stats import hotelling_ttest

hotelling_ttest(inference, {"covariate": "C(diet, Treatment('Control'))[T.DIO]"})