# Custom Modelling with BIRDMAn - Longitudinal LME

One of the primary goals of BIRDMAn is to allow users to develop and implement their own statistical models based on their experimental design/question. For example, if you have paired or longitudinal data, you would likely want to specify random intercepts for individual subjects.

Here, we will do a fairly in-depth walkthrough of how you can use BIRDMAn to analyze longitudinal data. Note that for this workshop we will be using a custom Stan file, but in the future this file will be included as a default option.

## Data

We will be using data from the study "Linking the effects of helminth infection, diet and the gut microbiota with human whole-blood signatures (repeated measurements)" (Qiita ID: 11913). This study looks at the effect of de-worming on the gut microbiome. Importantly, they collected pre and post-deworming stool samples on several individuals.

For our purposes, we would like to consider a linear mixed effects model (LME) where de-worming timepoint is a fixed effect and subject ID is a random effect. We will first process and explore the raw data.

In [None]:
%matplotlib inline

import biom
import numpy as np
import pandas as pd
import birdman

In [None]:
raw_tbl = biom.load_table("../data/helminth/94270_reference-hit.biom")
raw_tbl

In [None]:
metadata = pd.read_csv("../data/helminth/11913_20191016-112545.txt", sep="\t", index_col=0)
metadata.head()

First, we want to determine which subjects have paired samples (taken at `time_point` 1 and 2).

In [None]:
metadata["host_subject_id"].unique()

In [None]:
subj_is_paired = metadata.groupby("host_subject_id").apply(lambda x: (x["time_point"].values == [1, 2]).all())
paired_subjs = subj_is_paired[subj_is_paired].index
paired_subjs

In [None]:
paired_samps = metadata[metadata["host_subject_id"].isin(paired_subjs)].index

## Covariates of Interest

We will then explore the metadata to see which columns may be important.

In [None]:
metadata.columns

For this example we will consider the following covariates in our model:

* `helminth`
* `time_point`
* `life_stage`
* `sex`

In [None]:
cols_to_keep = ["helminth", "time_point", "life_stage", "sex", "host_subject_id"]
metadata_model = metadata.loc[paired_samps, cols_to_keep].dropna()

Additionally, we are going to change the encoding of `time_point` to be more explicit. We are also going to prefix each subject ID with "S" so that they are read as strings instead of integers.

In [None]:
metadata_model["time_point"] = metadata_model["time_point"].map({1: "pre-deworm", 2: "post-deworm"})
metadata_model["host_subject_id"] = "S" + metadata["host_subject_id"].astype(str)

In [None]:
metadata_model.head()

## Filtering the feature table

We want to filter the original BIOM table to include only the samples we've specified. We are also going to reduce the number of microbes in the table to a more manageable number for this demonstration.

In [None]:
raw_tbl_df = raw_tbl.to_dataframe()
raw_tbl_df.iloc[:5, :5]

To filter out features, we calculate the prevalence of each feature and keep only features that appear in at least 20 samples out of 46.

In [None]:
filt_tbl_df = raw_tbl_df.loc[:, metadata_model.index]
prev = filt_tbl_df.clip(upper=1).sum(axis=1)
filt_tbl_df = filt_tbl_df.loc[prev[prev >= 20].index, :]
filt_tbl_df.shape

Now we save this DataFrame as a BIOM so we can pass it into BIRDMAn.

In [None]:
filt_tbl = biom.table.Table(
    filt_tbl_df.values,
    sample_ids=filt_tbl_df.columns,
    observation_ids=filt_tbl_df.index
)
filt_tbl

## LME Model

We have provided a Stan file with the code required to run this model. First, we'll print the contents to the screen.

In [None]:
with open("../models/negative_binomial_lme.stan", "r") as f:
    print(f.read())

### Description of the model

The basic Negative Binomial model is as follows:

$$ y_{ij} \sim \mathrm{NB}(\mu_{ij},\phi_j) $$

$$ \mu_{ij} = n_i \cdot p_{ij} $$

where $i$ is an arbitrary sample, $j$ is an arbitrary microbe, $n$ is the sampling depth, and $\phi$ is the dispersion parameter. We see that $\mu_{ij}$, the mean abundance of microbe $j$ in sample $i$, is the total sampling depth ($n_i$) multiplied by the probability of microbe $j$ in sample $i$.

For the default NB model, the probability for each microbe in a given sample is a function of only the fixed-effects. $x_i$ is the covariate matrix of sample $i$ and $\beta$ is the covariate coefficient vector (e.g. $[\beta_{\rm{intercept}}, \beta_{\rm{time}}, ...]^T$).

$$ \mathrm{alr}^{-1} (p_i) = x_i\beta $$

To include a random effect, we add a random intercept for each subject ID.

$$ \mathrm{alr}^{-1} (p_i) = x_i\beta + z_i u $$

where $z_i$ is the mapping of sample $i$ to subject and $u$ is the subject coefficient vector (e.g. $[u_{\textrm{S1}}, u_{\textrm{S2}}, ...]$)

In [None]:
nb_lme = birdman.Model(  # note that we are instantiating a base Model object
    table=filt_tbl,
    formula="C(time_point, Treatment('pre-deworm'))",
    metadata=metadata_model,
    model_path="../models/negative_binomial_lme.stan",
    num_iter=100,
    chains=4,
    seed=42
)

We then want to update our data dictionary with the new parameters.

By default BIRDMAn computes and includes:

* `y`: table data
* `x`: covariate design matrix
* `depth`: log sampling depths of samples
* `N`: number of samples
* `D`: number of features
* `p`: number of covariates (including Intercept)

We want to add the necessary variables to be passed to Stan:

* `S`: total number of groups (subjects)
* `B_p`: mapping of samples to subject
* `phi_s`: stdev prior for normally distributed covariate-feature coefficients
* `u_p`: scale prior for half-Cauchy distributed overdispersion coefficients
* `subj_ids`: stdev prior for normally distributed subject intercept shifts

We want to provide `subj_ids` with a mapping of which sample corresponds to which subject. Stan does not understand strings so we encode each unique subject as an integer (starting at 1 because Stan 1-indexes arrays).

In [None]:
group_var_series = metadata_model["host_subject_id"]
samp_subj_map = group_var_series.astype("category").cat.codes + 1

samp_subj_map.head(10)

In [None]:
groups = np.sort(group_var_series.unique())
groups

In [None]:
param_dict = {
    "S": len(groups),
    "subj_ids": samp_subj_map.values,
    "B_p": 3.0,
    "phi_s": 3.0,
    "u_p": 1.0
}

nb_lme.add_parameters(param_dict)

In [None]:
nb_lme.dat

In [None]:
nb_lme.compile_model()

In [None]:
%%time

nb_lme.fit_model()

In [None]:
nb_lme.diagnose();

In [None]:
inference = nb_lme.to_inference_object(
    params=["beta", "phi", "subj_int"],
    coords={
        "feature": nb_lme.feature_names,
        "covariate": nb_lme.colnames,
        "subject": groups
    },
    dims={
        "beta": ["covariate", "feature"],
        "phi": ["feature"],
        "subj_int": ["subject"]
    },
    alr_params=["beta"],
    posterior_predictive="y_predict",
    log_likelihood="log_lik",
    include_observed_data=True
)

In [None]:
import birdman.diagnostics as diag

diag.loo(inference, pointwise=True)

In [None]:
inference.posterior["subj_int"].dims

In [None]:
inference.posterior.coords["covariate"].values

In [None]:
from birdman.visualization import plot_posterior_predictive_checks

ax = plot_posterior_predictive_checks(inference)

For this notebook we'll look at differential taxa by post-deworm status.

In [None]:
deworm_diff = inference.posterior["beta"].sel(
    {"covariate": "C(time_point, Treatment('pre-deworm'))[T.post-deworm]"}
)
deworm_diff = deworm_diff.stack(mcmc_sample=("chain", "draw"))
deworm_diff_means = deworm_diff.mean(dim="mcmc_sample")
deworm_diff_stds = deworm_diff.std(dim="mcmc_sample")

In [None]:
sorted_deworm_diff_means = deworm_diff_means.sortby(deworm_diff_means)
sorted_deworm_diff_means

In [None]:
bottom_10 = sorted_deworm_diff_means[:10].coords["feature"].values
top_10 = sorted_deworm_diff_means[-10:].coords["feature"].values

In [None]:
import numpy as np

def log_ratio(table, top_feats, bot_feats):
    num_df = table.loc[:, top_feats].sum(axis=1).to_frame()# + 1
    num_df.columns = ["num"]
    num_df = num_df[num_df["num"] > 0]
    denom_df = table.loc[:, bot_feats].sum(axis=1).to_frame()# + 1
    denom_df.columns = ["denom"]
    denom_df = denom_df[denom_df["denom"] > 0]
    lr_df = num_df.join(denom_df, how="inner")
    lr_df["log_ratio"] = np.log(lr_df["num"]/lr_df["denom"])
    return lr_df

lr_df = log_ratio(filt_tbl_df.T, top_10, bottom_10).join(metadata_model, how="inner")
print(lr_df.shape)
lr_df.head()

Finally, we'll plot the log-ratio changes from pre-deworm to post-deworm and include `life_stage` as a factor in our visualization.

In [None]:
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import seaborn as sns

fig = plt.figure(figsize=(10, 8), facecolor="white")
gs = GridSpec(ncols=2, nrows=2, figure=fig)
ax1 = fig.add_subplot(gs[:, 0])
ax3 = fig.add_subplot(gs[1, 1])
ax2 = fig.add_subplot(gs[0, 1], sharex=ax3)
ax2.xaxis.set_visible(False)

for subj_id, subj_df in lr_df.groupby("host_subject_id"):
    _d = (
        subj_df.query("time_point == 'post-deworm'")["log_ratio"][0]
        - subj_df.query("time_point == 'pre-deworm'")["log_ratio"][0]
    )
    
    args = {
        "data": subj_df,
        "x": "time_point",
        "y": "log_ratio",
        "color": "black" if _d > 0 else "red"
    }
    sns.pointplot(**args, ax=ax1)
    sns.pointplot(**args, ax=ax2 if subj_df["life_stage"][0] == "juvenile" else ax3)

pre_deworm_lrs = lr_df.query("time_point == 'pre-deworm'")["log_ratio"].values
post_deworm_lrs = lr_df.query("time_point == 'post-deworm'")["log_ratio"].values

bplot1 = ax1.boxplot(
    [pre_deworm_lrs, post_deworm_lrs],
    positions=[-0.3, 1.3],
    patch_artist=True
)

bplot_dict = {"both": bplot1}
for life_stage, ax in zip(["juvenile", "adult"], [ax2, ax3]):
    _1 = lr_df.query(f"time_point == 'pre-deworm' & life_stage == '{life_stage}'")["log_ratio"].values
    _2 = lr_df.query(f"time_point == 'post-deworm' & life_stage == '{life_stage}'")["log_ratio"].values
    bplot = ax.boxplot([_1, _2], positions=[-0.3, 1.3], patch_artist=True)
    bplot_dict[life_stage] = bplot

for bplot in bplot_dict.values():
    for patch in bplot["boxes"]:
        patch.set_facecolor("lightgray")
        patch.set_linewidth(2)
    for med in bplot["medians"]:
        med.set_color("black")
    for flier in bplot["fliers"]:
        flier.set_markerfacecolor("lightgray")
        flier.set_linewidth(2)
    for cap in bplot["caps"]:
        cap.set_linewidth(2)
    for whisk in bplot["whiskers"]:
        whisk.set_linewidth(2)


for ax in [ax1, ax2, ax3]:
    ax.set_xticks([0, 1])
    ax.set_xlim([-0.5, 1.5])
    ax.set_xlabel("")
    ax.set_xticklabels(["Pre\nDeworm", "Post\nDeworm"], fontsize=16)
    ax.tick_params("y", labelsize=14)
    ax.set_ylim(ax1.get_ylim())
    ax.grid(axis="y")
    ax.set_axisbelow(True)
    
ax1.set_ylabel("Log-Ratio", fontsize=16)
ax1.set_title("All Subjects", fontsize=20)
ax2.set_title("Juvenile", fontsize=20)
ax3.set_title("Adult", fontsize=20)
for _ in [ax2, ax3]: _.set_ylabel("")
    
decrease_lr = Line2D([0], [0], color="red", lw=4)
increase_lr = Line2D([0], [0], color="black", lw=4)

ax1.legend(
    handles=[decrease_lr, increase_lr],
    labels=["Decreased", "Increased"],
    ncol=1,
    loc="lower right",
    fontsize=14
)

plt.show()