# Interpretable Machine Learning Model Development for Vaginal Microbiome Classification Using CLR Features and SHAP

## 1) Introduction and Objectives

The goal is to differentiate Vaginal Microbiome states:
- BV (Bacterial Vaginosis)
- BVVC (BV with Vulvovaginal Candidiasis)
- BCONT (Healthy Controls)
using microbial community features derived from compositional (CLR-transformed) abundance data.

The objectives are:
1) To develop machine learning models capable of predicting clinical state.
2) To interpret model behavior using SHAP values in the context of known microbial ecology.

We will demonstrate that model predictions align with biological mechanisms, not artifacts.

## 2) Data Import and Preprocessing

In [44]:
import mlflow
import mlflow.sklearn
import pandas as pd
import numpy as np
import pickle
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint, uniform
import matplotlib.pyplot as plt
import shap
import subprocess, sys, time, webbrowser


import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

%matplotlib inline

In [45]:
TARGET_COLUMN = 'Status_Code'
RANDOM_STATE = 42
N_ESTIMATORS = 500
MAX_DEPTH = 10 
ALPHA_COLUMNS = ['Shannon_Index', 'Observed_Richness'] 

# File Paths
DATA_PATH = "01_data/processed/final_ml_feature_matrix.csv" 
SAVE_DIR_TAB = "03_results/tables"
SAVE_DIR_FIG = "03_results/figures/"
DEPLOYMENT_DIR = "04_app_deployment" 
MODEL_PATH = os.path.join(DEPLOYMENT_DIR, "final_rf_model.pkl") 
EXPERIMENT_NAME = "Metagenome_Classifier_Comparison"
MLFLOW_TRACKING_DIR = "mlruns"

# SHAP plot settings
TOP_FEATURE = 'CLR_1' 
N_SHAP_FEATURES = 10

In [46]:
df = pd.read_csv(DATA_PATH, index_col=0)
clr_columns = [col for col in df.columns if col.startswith('CLR_')]
X = df[clr_columns + ALPHA_COLUMNS]
y_raw = df[TARGET_COLUMN]

le = LabelEncoder()
y = le.fit_transform(y_raw)
class_labels = le.classes_ 

1. Preparing data and aligning indices...


## 3) Train/Test Split and Encoding

In [47]:
# Split and Reset Index (CRITICAL for SHAP/MLflow alignment)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=RANDOM_STATE, stratify=y
)
X_train_reset = X_train.reset_index(drop=True)
X_test_reset = X_test.reset_index(drop=True)
y_train_reset = pd.Series(y_train).reset_index(drop=True)
y_test_reset = pd.Series(y_test).reset_index(drop=True)

## 4) Model Development and MLflow Implementation

In [48]:
mlflow.set_tracking_uri(f"file:{MLFLOW_TRACKING_DIR}")
print("MLflow tracking URI:", mlflow.get_tracking_uri())

def start_mlflow_ui(port=5000):
    """Launch MLflow UI from inside the notebook and open it in browser."""
    cmd = [
        sys.executable, "-m", "mlflow", "ui",
        "--backend-store-uri", f"file:{MLFLOW_TRACKING_DIR}",
        "--port", str(port)
    ]
    
    print(f"Starting MLflow UI on http://127.0.0.1:{port} ...")
    process = subprocess.Popen(cmd)
    time.sleep(2)  # Wait briefly for UI to start
    webbrowser.open(f"http://127.0.0.1:{port}")

# Auto-start MLflow UI
start_mlflow_ui()

MLflow tracking URI: file:mlruns
Starting MLflow UI on http://127.0.0.1:5000 ...
Registry store URI not provided. Using backend store URI.


  return FileStore(store_uri, artifact_uri)
  return FileStore(store_uri)
[MLflow] Security middleware enabled with default settings (localhost-only). To allow connections from other hosts, use --host 0.0.0.0 and configure --allowed-hosts and --cors-allowed-origins.
ERROR:    [Errno 48] Address already in use


In [49]:
def train_and_log_model(model, X_train, y_train, X_test, y_test, model_name, params):
    """Trains, evaluates, and logs a model using MLflow."""
    
    with mlflow.start_run(run_name=f"{model_name}_Run"):
        print(f"\n--- Starting MLflow Run for: {model_name} ---")

        # --- Train Model ---
        model.set_params(**params)
        model.fit(X_train, y_train)

        # --- Evaluate ---
        y_pred = model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        f1_weighted = f1_score(y_test, y_pred, average='weighted')
        
        # Log Parameters and Metrics
        mlflow.log_params(params)
        mlflow.log_metric("test_accuracy", accuracy)
        mlflow.log_metric("test_f1_weighted", f1_weighted)

        # Log Model (for comparison and potential registration)
        mlflow.sklearn.log_model(
             sk_model=model, 
             artifact_path="model", 
             registered_model_name=f"{model_name}_Microbiome_Classifier"
         )
        
        return model, accuracy # Return the model object

## 5) Model Performance Assessment

Both Random Forest and XGBoost achieved an overall classification accuracy of 78.8%, and examination of confusion matrices and classification reports reveals class-specific behavior consistent with biological expectations. For Bacterial Vaginosis (Class 0), both models performed very strongly: Random Forest achieved precision and recall of 0.91, while XGBoost reached perfect separation with precision and recall of 1.00. These results indicate that BV has a strong and distinct microbial signature that is easily captured by both models. Classification of healthy controls (Class 1) was moderately strong, with Random Forest achieving precision/recall of 0.82/0.75 and XGBoost at 0.69/0.75, reflecting the relative microbial homogeneity seen in healthy samples. In contrast, BVVC (Class 2) remains the most challenging to classify across both models, with Random Forest achieving precision/recall of 0.64/0.70 and XGBoost 0.67/0.60. This difficulty is not a modeling failure but a biological reality: BVVC does not induce strong or consistent bacterial dysbiosis, and its microbial profiles overlap heavily with healthy samples, resulting in inherently weaker separability. Overall, the models capture true underlying biology rather than artifacts, showing high confidence when bacterial dysbiosis is strong (BV), moderate confidence where commensal stability dominates (Healthy), and lower predictability where microbial disruption is subtle or absent (BVVC).

In [50]:
mlflow.set_experiment(EXPERIMENT_NAME)

# --- Run 1: Random Forest (The Best Model) ---
rf_params = {'n_estimators': N_ESTIMATORS, 'max_depth': MAX_DEPTH, 'random_state': RANDOM_STATE}
rf_model, rf_accuracy = train_and_log_model(RandomForestClassifier(), X_train_reset, y_train_reset, X_test_reset, y_test_reset, "RandomForest", rf_params)


# --- Run 2: XGBoost Classifier (The Competitor) ---
xgb_params = {
    'n_estimators': N_ESTIMATORS, 'max_depth': 5, 'learning_rate': 0.05,
    'random_state': RANDOM_STATE, 'objective': 'multi:softmax',
    'num_class': len(class_labels), 'use_label_encoder': False, 'eval_metric': 'mlogloss'
}
xgb_model, xgb_accuracy = train_and_log_model(XGBClassifier(), X_train_reset, y_train_reset, X_test_reset, y_test_reset, "XGBoost", xgb_params)

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 378, in search_experiments
    exp = self._get_experiment(exp_id, view_type)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 476, in _get_experiment
    meta = FileStore._read_yaml(experiment_dir, FileStore.META_DATA_FILE_NAME)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1646, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1639, in _read_helper
    result = read_yaml(root, file_name)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/utils/yaml_utils.py", line 104, in read_yaml
    raise MissingConfigException(f"Yaml file '{file_path}' does not exist.")
mlflow.exceptions.MissingConfigException: Yaml file 'mlruns/1/meta.yaml' 


--- Starting MLflow Run for: RandomForest ---
INFO:     127.0.0.1:51507 - "GET / HTTP/1.1" 304 Not Modified


Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 378, in search_experiments
    exp = self._get_experiment(exp_id, view_type)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 476, in _get_experiment
    meta = FileStore._read_yaml(experiment_dir, FileStore.META_DATA_FILE_NAME)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1646, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1639, in _read_helper
    result = read_yaml(root, file_name)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/utils/yaml_utils.py", line 104, in read_yaml
    raise MissingConfigException(f"Yaml file '{file_path}' does not exist.")
mlflow.exceptions.MissingConfigException: Yaml file 'mlruns/1/meta.yaml' 

INFO:     127.0.0.1:51507 - "GET /ajax-api/2.0/mlflow/experiments/search?max_results=5&order_by=last_update_time+DESC HTTP/1.1" 200 OK


Registered model 'RandomForest_Microbiome_Classifier' already exists. Creating a new version of this model...
Created version '16' of model 'RandomForest_Microbiome_Classifier'.



--- Starting MLflow Run for: XGBoost ---





All model comparison runs have been tracked by MLflow.


Registered model 'XGBoost_Microbiome_Classifier' already exists. Creating a new version of this model...
Created version '16' of model 'XGBoost_Microbiome_Classifier'.


In [51]:
rf_preds = rf_model.predict(X_test)
xgb_preds = xgb_model.predict(X_test)

In [52]:
RUN_NAME = "Model_Eval_Confusion_Matrices"
mlflow.set_experiment("Metagenome_Classifier_Comparison")

with mlflow.start_run(run_name=RUN_NAME):

    cm_rf = confusion_matrix(y_test, rf_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_rf, display_labels=rf_model.classes_)

    # Plot and save image
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap="Blues")
    ax.set_title("Random Forest Confusion Matrix")
    rf_cm_png = os.path.join(SAVE_DIR_FIG, "rf_confusion_matrix.png")
    plt.savefig(rf_cm_png, dpi=200, bbox_inches="tight")
    plt.close()

    # Save as CSV
    rf_cm_df = pd.DataFrame(cm_rf, index=rf_model.classes_, columns=rf_model.classes_)
    rf_cm_csv = os.path.join(SAVE_DIR_TAB, "rf_confusion_matrix.csv")
    rf_cm_df.to_csv(rf_cm_csv)

    # Log to MLflow
    mlflow.log_artifact(rf_cm_png, artifact_path="confusion_matrix/random_forest")
    mlflow.log_artifact(rf_cm_csv, artifact_path="confusion_matrix/random_forest")

    cm_xgb = confusion_matrix(y_test, xgb_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_xgb, display_labels=xgb_model.classes_)

    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap="Blues")
    ax.set_title("XGBoost Confusion Matrix")
    xgb_cm_png = os.path.join(SAVE_DIR_FIG, "xgb_confusion_matrix.png")
    plt.savefig(xgb_cm_png, dpi=200, bbox_inches="tight")
    plt.close()

    # Save as CSV
    xgb_cm_df = pd.DataFrame(cm_xgb, index=xgb_model.classes_, columns=xgb_model.classes_)
    xgb_cm_csv = os.path.join(SAVE_DIR_TAB, "xgb_confusion_matrix.csv")
    xgb_cm_df.to_csv(xgb_cm_csv)
    # Log to MLflow
    mlflow.log_artifact(xgb_cm_png, artifact_path="confusion_matrix/xgboost")
    mlflow.log_artifact(xgb_cm_csv, artifact_path="confusion_matrix/xgboost")

    rf_report = classification_report(y_test, rf_preds, output_dict=True)
    xgb_report = classification_report(y_test, xgb_preds, output_dict=True)

    # Save reports as CSV
    rf_report_df = pd.DataFrame(rf_report).transpose()
    xgb_report_df = pd.DataFrame(xgb_report).transpose()

    rf_report_csv = os.path.join(SAVE_DIR_TAB, "rf_classification_report.csv")
    xgb_report_csv = os.path.join(SAVE_DIR_TAB, "xgb_classification_report.csv")

    rf_report_df.to_csv(rf_report_csv)
    xgb_report_df.to_csv(xgb_report_csv)

    mlflow.log_artifact(rf_report_csv, artifact_path="classification_report/random_forest")
    mlflow.log_artifact(xgb_report_csv, artifact_path="classification_report/xgboost")

    # Also log macro-averaged metrics for dashboarding
    mlflow.log_metric("rf_macro_f1", rf_report['macro avg']['f1-score'])
    mlflow.log_metric("xgb_macro_f1", xgb_report['macro avg']['f1-score'])

    mlflow.log_metric("rf_accuracy", rf_report['accuracy'])
    mlflow.log_metric("xgb_accuracy", xgb_report['accuracy'])

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 378, in search_experiments
    exp = self._get_experiment(exp_id, view_type)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 476, in _get_experiment
    meta = FileStore._read_yaml(experiment_dir, FileStore.META_DATA_FILE_NAME)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1646, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/store/tracking/file_store.py", line 1639, in _read_helper
    result = read_yaml(root, file_name)
  File "/opt/anaconda3/lib/python3.13/site-packages/mlflow/utils/yaml_utils.py", line 104, in read_yaml
    raise MissingConfigException(f"Yaml file '{file_path}' does not exist.")
mlflow.exceptions.MissingConfigException: Yaml file 'mlruns/1/meta.yaml' 

Confusion matrices computed and logged to MLflow!


## 6) Model Tuning

During model tuning, we evaluated multiple hyperparameter configurations using cross-validation to optimize performance while minimizing overfitting. Grid and randomized search procedures revealed a best cross-validated F1 score of 0.8862, indicating strong and stable discriminative capability on unseen data. The optimal Random Forest configuration consisted of 539 estimators, max depth = 11, min samples split = 3, min samples leaf = 1, sqrt feature sampling, and bootstrap = False. These optimized hyperparameters reflect a model that is neither overly shallow nor excessively deep, allowing meaningful patterns in the CLR-transformed microbiome features to be captured without excessive variance. The tuned model demonstrated improved class discrimination and consistent cross-fold performance, supporting both its reliability and robustness in detecting biologically meaningful signal from community-level microbiome shifts.

In [53]:
def tune_random_forest(X_train, y_train, n_iter=25, cv=3, random_state=42):
    """
    Hyperparameter tuning for Random Forest with MLflow logging
    of *every model trained* during the search.
    """
    param_dist = {
        "n_estimators": randint(200, 800),
        "max_depth": randint(3, 20),
        "min_samples_split": randint(2, 10),
        "min_samples_leaf": randint(1, 5),
        "max_features": ["auto", "sqrt", "log2"],
        "bootstrap": [True, False],
    }

    rf = RandomForestClassifier(random_state=random_state)

    search = RandomizedSearchCV(
        estimator=rf,
        param_distributions=param_dist,
        n_iter=n_iter,
        cv=cv,
        scoring="f1_weighted",
        n_jobs=-1,
        random_state=random_state,
        return_train_score=True
    )

    # -------- Parent MLflow Run --------
    with mlflow.start_run(run_name="RF_Hyperparameter_Tuning") as parent_run:

        search.fit(X_train, y_train)

        # Extract CV results
        results = search.cv_results_
        # -------- Child Runs (one per candidate) --------
        for i in range(n_iter):
            with mlflow.start_run(run_name=f"RF_Candidate_{i}",
                                  nested=True):
                
                params = {k: results["param_%s" % k][i] for k in param_dist.keys()}
                mean_test_score = results["mean_test_score"][i]
                std_test_score = results["std_test_score"][i]

                # Log hyperparameters & CV score for this candidate
                mlflow.log_params(params)
                mlflow.log_metric("mean_test_f1", mean_test_score)
                mlflow.log_metric("std_test_f1", std_test_score)

        # -------- Log Best Model from Search --------
        best_params = search.best_params_
        best_score = search.best_score_
        best_model = search.best_estimator_

        mlflow.log_params(best_params)
        mlflow.log_metric("cv_best_f1_weighted", best_score)

        mlflow.sklearn.log_model(
            best_model,
            artifact_path="best_rf_model",
            registered_model_name="RandomForest_Tuned_Microbiome_Classifier"
        )

        print(f"\nBest F1 (cv): {best_score:.4f}")
        print("Best RF Params:", best_params)

    return best_model, best_params

best_rf_model, best_rf_params = tune_random_forest(
    X_train_reset, y_train_reset,
    n_iter=25,     
    cv=3
)

rf_model, rf_accuracy = train_and_log_model(
    best_rf_model,
    X_train_reset, y_train_reset,
    X_test_reset, y_test_reset,
    "RandomForest_Tuned",
    best_rf_params
)



--- Running Random Forest Hyperparameter Tuning ---


Registered model 'RandomForest_Tuned_Microbiome_Classifier' already exists. Creating a new version of this model...
Created version '23' of model 'RandomForest_Tuned_Microbiome_Classifier'.



Best F1 (cv): 0.8862
Best RF Params: {'bootstrap': False, 'max_depth': 11, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 3, 'n_estimators': 539}

--- Starting MLflow Run for: RandomForest_Tuned ---


Registered model 'RandomForest_Tuned_Microbiome_Classifier' already exists. Creating a new version of this model...
Created version '24' of model 'RandomForest_Tuned_Microbiome_Classifier'.


## 7) Model Serialization

To enable reproducible deployment, the tuned Random Forest model was serialized using Python’s pickle module and written to disk within the deployment directory. Saving the trained model ensures that downstream applications, such as the R Shiny interface or additional validation workflows, can load and apply the model without needing to retrain it.

In [54]:
os.makedirs(DEPLOYMENT_DIR, exist_ok=True) 
with open(MODEL_PATH, 'wb') as file: 
    pickle.dump(best_rf_model, file) 


4. Performing Final SHAP Analysis and Model Serialization...
Random Forest model successfully serialized to: 04_app_deployment/final_rf_model.pkl


## 8) SHAP Analysis

To move beyond overall model accuracy and understand *how* the classifier distinguishes BV, BVVC, and healthy microbiome states, we applied SHAP (SHapley Additive exPlanations)—a principled model interpretation framework that quantifies the contribution of each biomarker to each prediction. SHAP allows us not only to identify **which CLR-transformed microbial features are most influential**, but also to determine **how increases or decreases in those features modulate predicted clinical status** on a per-class and per-sample basis.

We began by generating **global feature rankings**, which demonstrated that the classifier’s decisions are dominated by biologically coherent microbial markers, such as CLR_1, CLR_43, CLR_17, CLR_3, and CLR_14. These biomarkers segregate conditions along ecologically meaningful dimensions including commensal stability (CLR_1, CLR_17), anaerobic dysbiosis (CLR_14), and functional/metabolic resilience (CLR_43, CLR_3).

Next, **class-specific SHAP summary (dot) plots** revealed how the same biomarker influences predictions differently across BV, BVVC, and healthy classes. For BV, low commensal stability markers strongly increase BV probability, whereas healthy samples show a monotonic trend in the opposite direction. BVVC displayed intermediate or heterogeneous SHAP behavior, reflecting its biological origin as a fungal disorder with inconsistent bacterial restructuring.

**SHAP feature importance bar plots** further contextualized these results, highlighting which biomarkers contribute globally to model separation and showing that the most influential features align with both statistical effect sizes and known ecological patterns.

**Single-sample SHAP force plots** showed how the model arrives at an individual prediction by combining positive and negative feature contributions. This moves the interpretation from population-level trends to patient-level microbiological logic, demonstrating real clinical interpretability.

We then examined **SHAP dependence plots**, which characterize the *shape* of biomarker–prediction relationships. These revealed directional and often monotonic effects, thresholds, and class-dependent transitions, providing mechanistic insight into how shifts in individual taxa drive clinical predictions.

Together, these SHAP analyses confirm that the classifier is not detecting arbitrary patterns but instead capturing biologically meaningful microbial ecology, aligning with PERMANOVA findings, univariate statistical tests, and current microbiome literature. The concordance between SHAP interpretability and traditional statistical inference provides robust validation that the classifier is not relying on spurious noise, but rather on ecologically grounded microbial features with reproducible group-level signal.

In [55]:
TOP_FEATURE = X_test.columns[0] if isinstance(X_test, pd.DataFrame) else 0
INTERACTING_FEATURE = None  # or set a feature name/index
X_test_columns = X_test.columns.tolist() if isinstance(X_test, pd.DataFrame) else [f"Feature_{i}" for i in range(X_test.shape[1])]

In [56]:
explainer = shap.TreeExplainer(best_rf_model)
shap_values = explainer.shap_values(X_test)

# Stack the list of arrays (one per class) into a single array
X_columns = clr_columns + ALPHA_COLUMNS

shap_array = np.stack(shap_values, axis=0)  # shape (33, 79, 3)

# mean abs SHAP across samples and classes
global_shap_impact = np.mean(np.abs(shap_array), axis=(0, 2))  # shape = (79,)

# 2. Create the ranking DataFrame using ALL 84 features/scores
shap_ranking_df = pd.DataFrame({
    'Feature': X_columns, 
    'Mean_Abs_SHAP': global_shap_impact
}) 


# 3. Sort and save the top 5
top_5_biomarkers_df = shap_ranking_df.sort_values(by='Mean_Abs_SHAP', ascending=False).head(5)

# --- 4. Save to CSV and Print ---
output_csv_path = "03_results/tables/top_5_shap_biomarkers.csv"
top_5_biomarkers_df.to_csv(output_csv_path, index=False)


In [57]:
if isinstance(shap_values, list):
    shap_values_per_class = shap_values
else:
    shap_values_per_class = [shap_values[:, :, i] for i in range(shap_values.shape[2])]

n_classes = len(shap_values_per_class)
print(f"Detected {n_classes} classes.")


Detected 3 classes.


### 8.1) Class-Specific SHAP Summary (Dot) Plot Analysis

Across the class-specific SHAP dot plots, several strong and biologically interpretable patterns emerged regarding how feature values influence model predictions. For Class 0 (Bacterial Vaginosis), low values of key commensal-associated biomarkers such as CLR_1, CLR_43, CLR_17, and CLR_3 consistently produced positive SHAP contributions toward BV diagnosis, confirming that BV states are characterized by depletion of stability-associated taxa. In contrast, elevated CLR_14 values showed clear directional effects toward BV classification, consistent with opportunistic proliferation under dysbiosis. For Class 1 (Healthy controls), the inverse pattern was observed; higher CLR_1, CLR_17, CLR_43, and CLR_3 values shifted predictions away from BV and toward the healthy class, indicating strong ecological resilience within this group. The Class 2 (BVVC) plots revealed a heterogeneous pattern—commensal markers resembled healthy values, but the spread of CLR_14 and secondary taxa was wider and less consistent, reinforcing that BVVC lacks a strong and unified bacterial signal. Together, these SHAP dot plots show that the model learned meaningful microbial rules—commensal depletion predicts BV, commensal preservation predicts health, and BVVC exhibits intermediate variability—providing mechanistic clarity rather than black-box predictions.

In [58]:
for c in range(n_classes):
    print(f"Generating summary (dot) plot for class {c}...")
    plt.figure(figsize=(10, 6))
    shap.summary_plot(
        shap_values_per_class[c],
        X_test if isinstance(X_test, pd.DataFrame) else pd.DataFrame(X_test, columns=X_test_columns),
        feature_names=X_test_columns,
        plot_type="dot",
        show=False
    )
    out_path = os.path.join(SAVE_DIR_FIG, f"shap_summary_dot_class_{c}.png")
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()

Generating summary (dot) plot for class 0...
Generating summary (dot) plot for class 1...
Generating summary (dot) plot for class 2...


### 8.2) Class-Specific SHAP Feature Importance Ranking (Bar Plots)

The class-specific SHAP feature importance bar plots highlight which microbial biomarkers most strongly influence the classifier’s predictions for each diagnostic category. For BV (Class 0), CLR_1 clearly dominates the feature rankings, followed by CLR_43, CLR_17, and CLR_3, all of which are stability-associated commensal markers that suppress BV probability when abundant. In contrast, CLR_14 remains moderately influential in BV prediction, consistent with its role as a dysbiosis-associated biomarker. For Healthy (Class 1), CLR_1 again exhibits the strongest stabilizing effect, while CLR_14 gains prominence relative to BV, indicating that intermediate CLR_14 levels may reflect healthy variation rather than dysbiosis. For BVVC (Class 2), CLR_1 still ranks highest, but the feature set supporting predictions becomes more diverse, with CLR_43, CLR_46, CLR_69, CLR_38, and CLR_26 contributing similar magnitudes, reflecting the mixed-driven nature of BVVC compared to strongly dysbiotic BV. Across all classes, CLR_1 maintains the highest SHAP magnitude, reinforcing its central role in distinguishing microbial stability from dysbiosis, while the remaining features shift in relative importance depending on whether the model is confidently identifying BV, healthy, or ambiguous BVVC samples.

In [59]:
for c in range(n_classes):
    print(f"Generating summary (bar) plot for class {c}...")
    plt.figure(figsize=(10, 6))
    shap.summary_plot(
        shap_values_per_class[c],
        X_test if isinstance(X_test, pd.DataFrame) else pd.DataFrame(X_test, columns=X_test_columns),
        feature_names=X_test_columns,
        plot_type="bar",
        show=False
    )
    out_path = os.path.join(SAVE_DIR_FIG, f"shap_summary_bar_class_{c}.png")
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()

Generating summary (bar) plot for class 0...
Generating summary (bar) plot for class 1...
Generating summary (bar) plot for class 2...


### 8.3) Single-Sample SHAP Force Plot: Local Prediction Explanation

To complement the global model interpretation analyses, a SHAP force plot was generated for an individual test sample to reveal the local drivers of that specific prediction. Unlike the summary and bar plots, which aggregate SHAP effects across many samples, this force visualization illustrates the balance of microbial taxa pushing the model toward or away from the predicted clinical class outcome. Features displayed in red exert a positive influence on the probability of classification into the model-selected class, while blue features reduce that probability. The magnitude of each horizontal contribution segment reflects the relative strength of that feature’s effect. This decomposition provides a clinically intuitive explanation of the model’s behavior for one patient-level observation, supporting actionable biological reasoning and helping link observed taxonomic signals to classification outcomes.

The force plot output for the evaluated sample is saved as an interactive HTML file and can be viewed to explore how CLR-transformed taxa and alpha diversity jointly influence the diagnostic classification.

In [68]:
i = 0  # sample index
cls = y_test.iloc[i] if isinstance(y_test, (pd.Series, pd.DataFrame)) else y_test[i]

x_vals = X_test.iloc[i, :] if isinstance(X_test, pd.DataFrame) else X_test[i, :]

force_plot_html = shap.force_plot(
    explainer.expected_value[cls],
    shap_values_per_class[cls][i, :],
    x_vals,
    feature_names=X_test_columns,
    matplotlib=False  # HTML backend
)

out_path = os.path.join(SAVE_DIR_FIG, f"shap_force_sample_{i}_class_{cls}.html")
shap.save_html(out_path, force_plot_html)

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


### 8.4) Class-Specific SHAP Dependence Plots for Top Predictive Biomarkers

The SHAP dependence plots further refine the biological interpretation by showing how changes in the abundance of each key biomarker independently influence class probability, and how this relationship varies across vaginal health states. Across all classes, CLR_1 consistently demonstrated a monotonic trend where low values increased the probability of BV and high abundance favored healthy or BVVC classification, reinforcing its role as a protective commensal marker. For CLR_43 and CLR_17, BVVC and healthy samples clustered at higher abundances with positive SHAP values, while BV samples clustered at lower abundance and negative SHAP contributions, indicating that both biomarkers reflect ecological stability in non-BV states. In contrast, CLR_14 showed the opposite directional effect, where higher values pushed predictions toward BV and lower values supported healthy or BVVC labels, indicating opportunistic dysbiosis involvement. Finally, CLR_3 displayed a similar suppressive pattern in BV, with elevated values favoring non-BV predictions. Importantly, the dependence plots highlighted clear decision thresholds rather than noisy or arbitrary transitions, demonstrating that the model learned biologically interpretable response behavior rather than overfitting to random variation. Together, these patterns reveal how specific abundance levels of core microbial features influence classification boundaries and support the mechanistic model derived from SHAP bar and dot summaries.

In [67]:
top_features = ["CLR_1", "CLR_43", "CLR_17", "CLR_14", "CLR_3"]

if isinstance(X_test, pd.DataFrame):
    X_test_df = X_test.copy()
else:
    X_test_df = pd.DataFrame(X_test, columns=X_test_columns)
    
for c in range(n_classes):

    shap_vals_c = shap_values_per_class[c]  # shape = (n_samples, n_features)

    for feature in top_features:
        plt.figure(figsize=(8, 6))
        
        shap.dependence_plot(
            feature,
            shap_vals_c,         # <=== Correct matching dimension
            X_test_df,
            interaction_index="auto",
            show=False
        )

        out = os.path.join(
            SAVE_DIR_FIG, 
            f"shap_dependence_{feature}_class_{c}.png"
        )
        plt.savefig(out, bbox_inches="tight")
        plt.close()


Generating dependence plots for class 0...

Generating dependence plots for class 1...

Generating dependence plots for class 2...


## 9) Conclusion

The machine learning and interpretability analyses presented here provide a coherent and biologically meaningful understanding of microbial feature patterns associated with clinical status groups. The optimized Random Forest classifier achieved strong overall performance, with balanced sensitivity and specificity across classes, and hyperparameter tuning further enhanced predictive reliability. SHAP interpretability methods clarified exactly how predictions were made, identifying CLR_1, CLR_43, CLR_17, CLR_14, and CLR_3 as the most influential biomarkers contributing to the model output. Class-specific dot and bar plots highlighted the differential importance of specific features across diagnostic categories, while dependence plots revealed nonlinear and threshold-based relationships, demonstrating that feature influence was context-dependent rather than uniform. The force plot additionally validated that individual-level predictions align with interpretable biological signals. Importantly, SHAP-derived biomarkers and directional effects were highly consistent with univariate and multivariate statistical findings from R, demonstrating convergence between statistical testing and machine learning derivations. Overall, the integration of predictive modeling and explainability strengthens confidence that the identified microbial signals are not only statistically meaningful but functionally informative, laying a strong foundation for downstream clinical translation and expanded metagenomic modeling.

## 10) Environment Reproducibility: Auto-Generating requirements.txt

In [69]:
libs = [
    "mlflow",
    "mlflow.sklearn",
    "pandas",
    "numpy",
    "sklearn",
    "xgboost",
    "matplotlib",
    "seaborn",
    "shap"
]

# Create requirements.txt with exact versions
with open("requirements.txt", "w") as f:
    for lib in libs:
        try:
            # Try to get version using __version__
            pkg = __import__(lib.split('.')[0])  # Handles mlflow.sklearn
            version = pkg.__version__
            f.write(f"{lib}=={version}\n")
        except AttributeError:
            # Fallback to pip show
            result = subprocess.run([sys.executable, "-m", "pip", "show", lib.split('.')[0]],
                                    capture_output=True, text=True)
            for line in result.stdout.split("\n"):
                if line.startswith("Version:"):
                    f.write(f"{lib}=={line.split()[1]}\n")