In [None]:
pip install rpy2 pandas numpy matplotlib seaborn scipy statsmodels

In [None]:
import numpy as np
import pandas as pd
from microbiomeutil import MicrobiomeDataset
from skbio.diversity import beta_diversity
from skbio.diversity import tree as skbio_tree
from skbio.stats.ordination import pcoa
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from skbio import TreeNode
import pingouin as pg
from skbio.diversity.alpha import shannon, simpson, observed_otus
from statsmodels.robust.robust_linear_model import RLM
from microbiomeutil import MicrobiomeDataset
from sklearn.model_selection import train_test_split, GridSearchCV, RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_auc_score, roc_curve
)
from sklearn.preprocessing import label_binarize
from scipy.stats import zscore
from pathlib import Path
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
pandas2ri.activate()
ro.r('library(Maaslin2)')
ro.r('library(dplyr)')
ro.r('library(tibble)')

In [None]:
#Build a phyloseq object
otu_table = pd.read_csv("otu_table.tsv", sep="\t", index_col=0)   # taxa x samples
taxonomy = pd.read_csv("taxonomy.tsv", sep="\t", index_col=0)     # taxa x taxonomy ranks
metadata = pd.read_csv("sample_metadata.tsv", sep="\t", index_col=0)
tree = TreeNode.read("tree.nwk")  # Newick tree for UniFrac

#Check whether taxa and samples align between the 3 dataframes
otu_table = otu_table.loc[taxonomy.index.intersection(otu_table.index)]  # align taxa
metadata = metadata.loc[otu_table.columns.intersection(metadata.index)]  # align samples

# Create the MicrobiomeDataset object
ps = MicrobiomeDataset(
    feature_table=otu_table,
    taxonomy=taxonomy,
    sample_metadata=metadata,
    tree=tree
)

In [None]:
#QC

#Remove low abundance, low prevalence taxa
keep_taxa = (otu_table > 10).sum(axis=1) > 5
otu_table_filt = otu_table.loc[keep_taxa]

#Remove lowly covered samples
keep_samples = otu_table_filt.sum(axis=0) > 5000
otu_table_filt = otu_table_filt.loc[:, keep_samples]
metadata = metadata.loc[keep_samples]

#Remove unwanted taxa
mask = (
    (taxonomy["Kingdom"] != "Eukaryota") &
    ((taxonomy["Order"].isna()) | (taxonomy["Order"] != "Chloroplast")) &
    ((taxonomy["Phylum"].isna()) | (taxonomy["Phylum"] != "Fibrobacteres"))
)

taxonomy_filt = taxonomy.loc[mask]

# Ensure consistency with filtered OTU table
otu_table_filt = otu_table_filt.loc[otu_table_filt.index.intersection(taxonomy_filt.index)]

# Rebuild MicrobiomeDataset 
ps_filtered = MicrobiomeDataset(
    feature_table=otu_table_filt,
    taxonomy=taxonomy_filt,
    sample_metadata=metadata,
    tree=tree
)

In [None]:
#Principal Coordinate Analysis 

# Calculate weighted Unufrac distances
otu_table = ps_filtered.feature_table.T  
tree = ps_filtered.tree  
sample_ids = otu_table.index

# Calculate UniFrac
unifrac_dm = beta_diversity(
    metric='weighted_unifrac',
    counts=otu_table.values,
    ids=sample_ids,
    tree=tree
)

# Perform PCoA 
pcoa_res = pcoa(unifrac_dm)

# Eigenvalues and explained variance
evals = pcoa_res.eigvals
explained = evals / evals.sum()

print("Explained variance (first 6 axes):", explained.head(6))

#Add PCoA axes to metadata
pcoa_df = pcoa_res.samples.copy()
pcoa_df.columns = [f"PCoA{i+1}" for i in range(pcoa_df.shape[1])]
sampledf = ps_filtered.sample_metadata.join(pcoa_df)

# Scree plot
plt.figure(figsize=(6,4))
sns.barplot(x=np.arange(1, 21), y=explained.values[:20])
plt.xlabel("PCoA Axis")
plt.ylabel("Variance Explained")
plt.title("Weighted UniFrac PCoA Scree Plot")
plt.show()

# Ordination scatter plot
plt.figure(figsize=(8,6))
sns.scatterplot(
    data=sampledf,
    x='PCoA1', y='PCoA2',
    hue='Lifestyle',
    style='Sex',
    alpha=0.9,
    s=80
)
plt.axhline(0, color='gray', linestyle='--', lw=0.8)
plt.axvline(0, color='gray', linestyle='--', lw=0.8)
plt.title(f"Weighted UniFrac PCoA\n(PCo1: {explained.iloc[0]*100:.1f}%, PCo2: {explained.iloc[1]*100:.1f}%)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

#PERMANOVA
dist_matrix = pd.DataFrame(
    unifrac_dm.data, 
    index=unifrac_dm.ids, 
    columns=unifrac_dm.ids
)

sampledf = sampledf.loc[dist_matrix.index]

permanova_results = pg.permanova(
    distance=dist_matrix,
    metadata=sampledf[['Lifestyle', 'Latitude', 'Longtitude', 'Altitude', 'Age.C', 'Sex', 'Shelter']],
    n_perm=9999
)

print(permanova_results)

In [None]:
#Alpha diversity

# Compute alpha diversity metrics 
alpha_div = pd.DataFrame(index=otu_table.index)
alpha_div["Shannon"] = otu_table.apply(lambda x: shannon(x.values, base=2), axis=1)
alpha_div["Simpson"] = otu_table.apply(lambda x: simpson(x.values), axis=1)
alpha_div["Richness"] = otu_table.apply(lambda x: (x > 0).sum(), axis=1)

# Combine with metadata
alpha_df = metadata.join(alpha_div)

# Visualize Shannon diversity 
plt.figure(figsize=(8,5))
sns.boxplot(
    data=alpha_df,
    x="Lifestyle", y="Shannon",
    palette="viridis", fliersize=3
)
sns.stripplot(
    data=alpha_df,
    x="Lifestyle", y="Shannon",
    color="black", size=3, alpha=0.5
)
plt.title("Shannon Diversity by Lifestyle")
plt.ylabel("Shannon Index (bits)")
plt.xlabel("Lifestyle")
plt.tight_layout()
plt.show()

# Visualize richness similarly 
plt.figure(figsize=(8,5))
sns.boxplot(
    data=alpha_df,
    x="Lifestyle", y="Richness",
    palette="magma", fliersize=3
)
sns.stripplot(
    data=alpha_df,
    x="Lifestyle", y="Richness",
    color="black", size=3, alpha=0.5
)
plt.title("Observed Species Richness by Lifestyle")
plt.ylabel("Number of Taxa Observed")
plt.xlabel("Lifestyle")
plt.tight_layout()
plt.show()

# Export for downstream stats 
alpha_df.to_csv("alpha_diversity_metrics.tsv", sep="\t")

#Statistical Tests
pg.kruskal(data=alpha_df, dv='Shannon', between='Lifestyle')

model_rlm = smf.rlm(
    formula='Shannon ~ C(Lifestyle) + Age.C + Altitude + Latitude + C(Sex)',
    data=alpha_df
).fit()

print(model_rlm.summary())

In [None]:
#Visualize the top most taxa

mean_abundance = otu_table.mean(axis=1)
top20_taxa = mean_abundance.sort_values(ascending=False).head(20).index
otu_top20 = otu_table.loc[top20_taxa]

otu_long = (
    otu_top20
    .T
    .reset_index()
    .melt(id_vars='index', var_name='Taxon', value_name='Abundance')
    .rename(columns={'index': 'SampleID'})
)

otu_long = otu_long.merge(metadata, left_on='SampleID', right_index=True)

otu_long_grouped = (
    otu_long
    .groupby(['Lifestyle','Taxon'])['Abundance']
    .mean()
    .reset_index()
)

plt.figure(figsize=(10,5))
sns.barplot(
    data=otu_long_grouped,
    x='Lifestyle',
    y='Abundance',
    hue='Taxon'
)

plt.ylabel("Mean Relative Abundance")
plt.title("Top 20 Taxa by Lifestyle")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Random Forest

# Combine metadata + OTU table
data = otu_table.copy()
data["Lifestyle"] = metadata["Lifestyle"]

# Train-test split (80/20) 
X = data.drop(columns="Lifestyle")
y = data["Lifestyle"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=100
)

# Random Forest with repeated CV 
rf = RandomForestClassifier(random_state=100)
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=100)

param_grid = {"n_estimators": [100, 300, 500], "max_features": ["sqrt", "log2", None]}
grid = GridSearchCV(
    rf, param_grid, cv=cv, scoring="accuracy", n_jobs=-1, verbose=1
)
grid.fit(X_train, y_train)

rf_best = grid.best_estimator_
print("Best RF model:", grid.best_params_)

# Evaluate on test set 
y_pred = rf_best.predict(X_test)
print("\nConfusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))

#  ROC Curves (multiclass) 
y_test_bin = label_binarize(y_test, classes=rf_best.classes_)
y_score = rf_best.predict_proba(X_test)

plt.figure(figsize=(4,4))
colors = ["#F60239","#008607","#00DBC5","#003C86","#9400E6"]

for i, cls in enumerate(rf_best.classes_):
    fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_score[:, i])
    auc = roc_auc_score(y_test_bin[:, i], y_score[:, i])
    plt.plot(fpr, tpr, color=colors[i], label=f"{cls} (AUC={auc:.2f})")

plt.plot([0,1],[0,1],"k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(title="Lifestyle", bbox_to_anchor=(1.05, 1))
plt.tight_layout()
plt.savefig("AUC_Lifestyle.pdf")
plt.show()

# Variable Importance Factors
importances = pd.DataFrame({
    "ASV": X.columns,
    "Importance": rf_best.feature_importances_
}).sort_values("Importance", ascending=False)

plt.figure(figsize=(5,6))
sns.stripplot(data=importances.head(30), x="Importance", y="ASV", size=4)
plt.axvline(x=0.01, color="red", linestyle="--")
plt.title("Top Variable Importances")
plt.tight_layout()
plt.show()

# Filter features with importance > threshold
imp_features = importances.query("Importance > 0.01")["ASV"]
print("Top features:", imp_features.head())

# Construct taxonomic annotation table
tax_df = taxonomy.copy()
tax_df["TaxaName"] = tax_df.fillna("__").apply(lambda x: "-".join(x.astype(str)), axis=1)
tax_df["ASV"] = tax_df.index
tax_df = tax_df.reset_index(drop=True)
print(tax_df.head())

# Heatmap of important taxa 
otu_rel = otu.div(otu_table.sum(axis=1), axis=0)
otu_imp = otu_rel[imp_features]
otu_melt = otu_imp.melt(ignore_index=False, var_name="OTU", value_name="Abundance")
otu_melt = otu_melt.merge(metadata, left_index=True, right_index=True)

# Mean abundance by Lifestyle
otu_mean = (
    otu_melt.groupby(["Lifestyle", "OTU"])["Abundance"]
    .mean()
    .reset_index()
    .pivot(index="OTU", columns="Lifestyle", values="Abundance")
)

# z-score normalization
otu_z = otu_mean.apply(zscore, axis=1)

sns.clustermap(
    otu_z,
    cmap="YlOrRd",
    figsize=(8,6),
    row_cluster=True,
    col_cluster=False
)
plt.title("Top Variable Taxa (z-score normalized)")
plt.show()

In [None]:
#Maaslin

df_input_data = otu_table.T.copy()   
df_input_metadata = metadata.copy()

# Maaslin wants numeric features (no NA feature names)
# add 1 pseudocount as in your R code
df_input_data = df_input_data.apply(pd.to_numeric, errors='coerce').fillna(0) + 1

# Write temporary CSVs (MaAsLin2 reads files)
outdir = "maaslin2_output"
Path(outdir).mkdir(exist_ok=True)

data_csv = Path(outdir) / "maaslin_input.tsv"
meta_csv = Path(outdir) / "maaslin_meta.tsv"
df_input_data.to_csv(data_csv, sep="\t", index=True)
df_input_metadata.to_csv(meta_csv, sep="\t", index=True)

# Prepare R call 
r_data = str(data_csv)
r_meta = str(meta_csv)
r_out = str(Path(outdir) / "maaslin_res")

# Example: fixed_effects = c("Lifestyle"), reference = c("Lifestyle,1.Foragers")
r_cmd = f'''
fit_data2 <- Maaslin2(
  input_data = "{r_data}",
  input_metadata = "{r_meta}",
  min_prevalence = 0,
  analysis_method = "NEGBIN",
  normalization = "CSS",
  transform = 'NONE',
  output = "{r_out}",
  fixed_effects = c("Lifestyle"),
  reference = c("Lifestyle,1.Foragers")
)
'''
ro.r(r_cmd)

# After MaAsLin2 completes, read the all_results.tsv produced by it
res_file = Path(r_out) / "all_results.tsv"
df_res = pd.read_csv(res_file, sep="\t")
df_res['feature'] = df_res['feature'].astype(str).str.replace('-', '.', regex=False).str.replace('_', '.', regex=False)

# Join taxonomy in Python (make sure tax_df feature names match)
tax_join = tax_df.copy()
tax_join = tax_join.replace({np.nan: "__"})
tax_join['best_hit'] = tax_join['best_hit'].astype(str).str.replace('_', '.', regex=False).str.replace(':', '.', regex=False).str.replace('-', '.', regex=False)
tax_join['feature'] = tax_join['best_hit']

df_res = df_res.merge(tax_join.reset_index(), how='left', on='feature')

# Filter hits (example -log10(qval) > 40)
df_res['neglog10q'] = -np.log10(df_res['qval'] + 1e-300)
df_subset = df_res[df_res['neglog10q'] > 40]

# Plot volcano-like plot (coef vs -log10 qval)
plt.figure(figsize=(6,4))
palette = dict(zip(df_res['Phylum'].unique(), sns.color_palette("tab10", n_colors=df_res['Phylum'].nunique())))
sns.scatterplot(data=df_res, x='coef', y='neglog10q', hue='Phylum', palette=palette, s=40)
plt.axhline(-np.log10(0.05), linestyle='--', color='red')
plt.axvline(-1, linestyle='--', color='red')
plt.axvline(1, linestyle='--', color='red')
plt.title("MaAsLin2: coefficient vs -log10(qval)")
plt.tight_layout()
plt.savefig("Maaslin_lifestyle_volcano.png", dpi=300)
plt.show()

# Identify significant features (qval < 0.05 & abs(coef)>1)
df_sig = df_res[(df_res['qval'] < 0.05) & (df_res['coef'].abs() > 1)]
sig_features = df_sig['feature'].unique().tolist()

# Make heatmap: take mean abundance per Lifestyle for sig features
# Ensure otu_df columns match feature naming used by MaAsLin (transform if needed)
# Reformat OTU names similar to R (replace :, -, _ with .) if necessary
otu_for_merge = otu_table.copy()
otu_for_merge.index = otu_for_merge.index.astype(str).str.replace(':', '.', regex=False).str.replace('-', '.', regex=False).str.replace('_', '.', regex=False)

# Subset and compute mean abundance per Lifestyle
otu_sig = otu_for_merge.loc[otu_for_merge.index.intersection(sig_features)]
if otu_sig.shape[0] == 0:
    print("No significant features found to plot heatmap.")
else:
    otu_sig_T = otu_sig.T
    merged = otu_sig_T.merge(df_input_metadata[['Lifestyle']], left_index=True, right_index=True)
    mean_by_group = merged.groupby('Lifestyle').mean().T  # features x Lifestyle

    # X- score normalize rows (features)
    mean_by_group_z = mean_by_group.apply(zscore, axis=1).fillna(0)

    sns.clustermap(mean_by_group_z, cmap="YlOrRd", col_cluster=False, figsize=(8, max(4, 0.2*mean_by_group_z.shape[0])))
    plt.suptitle("MaAsLin2 significant features (z-scored means by Lifestyle)")
    plt.savefig("Maaslin_heatmap.png", dpi=300)
    plt.show()