In [None]:
import biom
import pandas as pd
import numpy as np
import xarray as xr
import glob

from birdman import NegativeBinomial

In [None]:
import cmdstanpy
#Install if needed:
cmdstanpy.install_cmdstan()

In [None]:
cmdstanpy.cmdstan_path()

In [None]:
fpath = glob.glob("templates/*.txt")[0]
table = biom.load_table("BIOM/44773/otu_table.biom")
metadata = pd.read_csv(
    fpath,
    sep="\t",
    index_col=0
)

metadata.head()

In [None]:
table

In [None]:
prevalence = table.to_dataframe().clip(upper=1).sum(axis=1)
features_to_keep = prevalence[prevalence >= 5].index.tolist()
table_filt = table.filter(features_to_keep, axis="observation")

In [None]:
nb = NegativeBinomial(
    table=table_filt,
    formula="diet",
    metadata=metadata,
)

In [None]:
nb.compile_model()
nb.fit_model(method="vi", num_draws=500)

In [None]:
inference = nb.to_inference()
inference

In [None]:
metadata['diet']

In [None]:
inference.posterior['covariate'][1].to_numpy()

In [None]:
inference.posterior

In [None]:
for i, clr in inference.posterior['beta_var'].groupby('chain'):
    print(clr.shape)

In [None]:
from birdman.transform import posterior_alr_to_clr

inference = nb.to_inference()
inference.posterior = posterior_alr_to_clr(
    inference.posterior,
    alr_params=["beta_var"],  # the default NB model has only beta in ALR coordinates
    dim_replacement={"feature_alr": "feature"},  # the default NB model assigns 'feature_alr' as the dimension name
    new_labels=nb.feature_names  # replace the old labels (all the feature names except the first) with all the feature names
)

In [None]:
import birdman.visualization as viz

ax = viz.plot_parameter_estimates(
    inference,
    parameter="beta_var",
    coords={"covariate": "diet[T.DIO]"},
)

# Sam's data

In [None]:
import os
os.getcwd()

In [None]:
fpath = "../birdman/metadata_2024Feb02.tsv"
# fpath = "/home/cys226/private/birdman/metadata_2024Feb02.tsv"
table = biom.load_table("../data/sam/genus-table-exported/feature-table.biom")
# table = biom.load_table("/home/cys226/private/birdman/feature-table.biom")
metadata = pd.read_csv(
    fpath,
    sep="\t",
    index_col=0
)

metadata

In [None]:
table.to_dataframe()

In [None]:
diabetes_cat = ['T2D', 'Obesity/T2D', 'Type II Diabetes', 'Type 2 diabetes', 'Type_II_Diabetes']
short_meta = metadata[(metadata['Disease'] == 'Healthy') | (metadata['Disease'].isin(diabetes_cat))]
short_meta['T2D'] = short_meta['Disease'].apply(lambda x: 'Healthy' if x == 'Healthy' else 'T2D')
short_meta = short_meta.loc[short_meta.index.intersection(table.to_dataframe().columns)]
short_table = table.filter(short_meta.index.tolist())
prevalence = short_table.to_dataframe().clip(upper=1).sum(axis=1)
features_to_keep = prevalence[prevalence >= 5].index.tolist()
short_table_filt = short_table.filter(features_to_keep, axis='observation')
short_table_filt.to_dataframe()

In [None]:
metadata['Notes'].value_counts()

In [None]:
metadata[metadata['Disease'] == 'T2D']

In [None]:
metadata[metadata['Disease'] == 'T2D']['Study'].value_counts()

In [None]:
metadata[metadata['Study'] == 'RadwanGilfillan_et_al_2020']

In [None]:
metadata[(metadata['Primer'] == 'V4') & (metadata['Age'] == 'Adult')]

In [None]:
metadata[metadata['Disease_Type'] == 'Metabolic']['Disease'].value_counts()

In [None]:
metadata['Disease_Type'].value_counts()

In [None]:
metadata['Disease'].value_counts()

In [None]:
diabetes_cat = ['T2D', 'Obesity/T2D', 'Type II Diabetes', 'Type 2 diabetes', 'Type_II_Diabetes']
metadata['Disease'].value_counts().loc[diabetes_cat].sum()

In [None]:
metadata['Age'].value_counts()

In [None]:
metadata['Age2'].value_counts()

In [None]:
# default: prevalence >= 5
# for the sake of running the model, prevalence >= 100

prevalence = table.to_dataframe().clip(upper=1).sum(axis=1)
features_to_keep = prevalence[prevalence >= 5].index.tolist()
table_filt = table.filter(features_to_keep, axis="observation")

In [None]:
table.to_dataframe().index

In [None]:
table_filt

In [None]:
# remove unnecessary metadata
short_meta = metadata.loc[table.to_dataframe().columns]
short_meta

In [None]:
# two cohorts: Disease_Type (Healthy vs Metabolic)
short_meta = short_meta[(short_meta['Disease_Type'] == 'Healthy') | (short_meta['Disease_Type'] == 'Metabolic')]
short_meta

In [None]:
# remove non-healthy, non-metabolic samples from otu table
table_filt = table_filt.filter(short_meta.index.tolist())

In [None]:
nb = NegativeBinomial(
    table=table_filt,
    formula="Disease_Type",
    metadata=short_meta
    # metadata=metadata,
)

In [None]:
nb.compile_model()

In [None]:
nb.fit_model(method="vi", num_draws=100)

In [None]:
inference = nb.to_inference()

In [None]:
inference.posterior['covariate'].to_numpy()[1]

In [None]:
# break down _beta_alr_to_clr

def _alr_to_clr(x: np.ndarray) -> np.ndarray:
    """Convert ALR coordinates to centered CLR coordinates.

    :param x: Matrix of ALR coordinates (features x draws)
    :type x: np.ndarray

    :returns: Matrix of centered CLR coordinates
    :rtype: np.ndarray
    """
    num_draws = x.shape[1]
    z = np.zeros((1, num_draws))
    x_clr = np.vstack((z, x))
    x_clr = x_clr - x_clr.mean(axis=0).reshape(1, -1)
    return x_clr


def _beta_alr_to_clr(beta: np.ndarray) -> np.ndarray:
    """Convert feature-covariate coefficients from ALR to CLR.

    :param beta: Matrix of beta ALR coordinates (n draws x p covariates x
        d features)
    :type beta: np.ndarray

    :returns: Matrix of beta CLR coordinates (n draws x p covariates x d+1
        features)
    :rtype: np.ndarray
    """
    num_draws, num_covariates, num_features = beta.shape
    beta_clr = np.zeros((num_draws, num_covariates, num_features+1))
    for i in range(num_covariates):  # TODO: vectorize
        beta_slice = beta[:, i, :].T  # features x draws
        beta_clr[:, i, :] = _alr_to_clr(beta_slice).T
    return beta_clr

In [None]:
# break down posterior_alr_to_clr
posterior = inference.posterior
alr_params=["beta_var"]
dim_replacement={"feature_alr": "feature"}
new_labels=nb.feature_names

new_posterior = posterior.copy()
for param in alr_params: 
    param_da = posterior[param]
    all_chain_alr_coords = param_da
    all_chain_clr_coords = []

    for i, chain_alr_coords in all_chain_alr_coords.groupby("chain"):
        chain_clr_coords = _beta_alr_to_clr(chain_alr_coords[0])
        all_chain_clr_coords.append(chain_clr_coords)

    all_chain_clr_coords = np.array(all_chain_clr_coords)

    new_dims = [
        dim_replacement[x]
        if x in dim_replacement else x
        for x in param_da.dims
    ]
    # Replace coords with updated labels

    new_coords = dict()
    for dim in param_da.dims:
        if dim in dim_replacement:
            new_name = dim_replacement[dim]
            new_coords[new_name] = new_labels
        else:
            new_coords[dim] = param_da.coords.get(dim).data

    new_param_da = xr.DataArray(
        all_chain_clr_coords,
        dims=new_dims,
        coords=new_coords
    )
    new_posterior[param] = new_param_da

new_posterior = new_posterior.drop_vars(dim_replacement.keys())
inference.posterior = new_posterior

In [None]:
import birdman.visualization as viz

ax = viz.plot_parameter_estimates(
    inference,
    parameter="beta_var",
    coords={"covariate": "Disease_Type[T.Metabolic]"},
)

In [None]:
param_means = inference.posterior['beta_var'].sel(
    **{'covariate': 'Disease_Type[T.Metabolic]'}
    ).mean(['chain', 'draw'])
sort_indices = param_means.argsort().data
param_means = param_means.data[sort_indices]
param_labels = inference.posterior['feature'].data[sort_indices]

In [None]:
param_means

In [None]:
taxonomy = pd.read_csv('/Users/candusshi/academics/qiime2-capstone/metadata/taxonomy.tsv', sep='\t', index_col=0)
taxonomy

In [None]:
pd.DataFrame({'feature': param_labels, 'mean clr': param_means})

In [None]:
table.to_dataframe().index

In [None]:
taxonomy.index.intersection(table.to_dataframe().index)

In [None]:
table.to_dataframe().index.intersection(taxonomy.index)

In [None]:
taxonomy.loc[param_labels[25]]