# Evaluate generalizability of the model across holdout plate(s)

## Import libraries

In [1]:
MODEL_ID = "ensemble"
ROLE = "generalizability"

import logging
from datetime import datetime
import pathlib

# ============================================
# 1) Choose a RUN_ID
# ============================================
RUN_ID = datetime.now().strftime("%m_%d_%H_%M")

RUN_ID = "12_08_08_12"

ANALYSIS_TYPE = "generalizability"


def setup_logger(
    run_id: str,
    model_id: str,
    role: str,
    log_dir: str = "logs",
    analysis_type: str = ANALYSIS_TYPE,
) -> logging.Logger:
    """
    Create a logger that writes to both stdout and a log file.

    - Logger name:  "<analysis_type>_<run_id>_<model_id>_<role>"
    - Log file:     "log_<analysis_type>_<run_id>_<model_id>.log" in `log_dir`
      (shared by all notebooks for the same model & run & analysis_type).
    """
    log_path = pathlib.Path(log_dir)
    log_path.mkdir(exist_ok=True)

    logger_name = f"{analysis_type}_{run_id}_{model_id}_{role}"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    logger.propagate = False  

    # Avoid adding handlers multiple times if the cell is re-run
    if not logger.handlers:
        # Common formatter for both handlers
        formatter = logging.Formatter(
            fmt="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%S",
        )

        # Stream handler (stdout)
        stream_handler = logging.StreamHandler()
        stream_handler.setLevel(logging.INFO)
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

        # File handler (one file per analysis_type + run_id + model_id)
        log_file = log_path / f"log_{analysis_type}_{run_id}_{model_id}.log"
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        print(log_path)

    return logger

logger = setup_logger(RUN_ID, MODEL_ID, ROLE)
logger.info("Initialized logger.")


2025-12-06T08:17:56 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Initialized logger.


logs


In [2]:
import pathlib
import logging

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score

import seaborn as sns
import matplotlib.pyplot as plt

logger.info("Imported libraries for generalizability comparison.")


2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Imported libraries for generalizability comparison.


In [3]:
# Base directory is the current generalizability folder
base_dir = pathlib.Path(".")

# Directories for individual model generalizability results
lr_results_dir = base_dir / "results"
rf_results_dir = base_dir / "results_randomforest"
xgb_results_dir = base_dir / "results_xgboost"

# Directory for comparison outputs
compare_results_dir = base_dir / "results_ensemble"
compare_results_dir.mkdir(exist_ok=True)

# Accuracy files (Plate 6, QC) for each model
acc_files = {
    "logistic_regression": lr_results_dir / "plate6_accuracy_final_model_qc.parquet",
    "random_forest": rf_results_dir / "plate6_accuracy_final_model_rf_qc.parquet",
    "xgboost": xgb_results_dir / "plate6_accuracy_final_model_xgb_qc.parquet",
}

# Precision-recall files (Plate 6, QC) for each model
pr_files = {
    "logistic_regression": lr_results_dir / "plate6_precision_recall_final_model_qc.parquet",
    "random_forest": rf_results_dir / "plate6_precision_recall_final_model_rf_qc.parquet",
    "xgboost": xgb_results_dir / "plate6_precision_recall_final_model_qc.parquet",
}

logger.info("Accuracy files: %s", acc_files)
logger.info("PR files: %s", pr_files)


2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Accuracy files: {'logistic_regression': PosixPath('results/plate6_accuracy_final_model_qc.parquet'), 'random_forest': PosixPath('results_randomforest/plate6_accuracy_final_model_rf_qc.parquet'), 'xgboost': PosixPath('results_xgboost/plate6_accuracy_final_model_xgb_qc.parquet')}
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: PR files: {'logistic_regression': PosixPath('results/plate6_precision_recall_final_model_qc.parquet'), 'random_forest': PosixPath('results_randomforest/plate6_precision_recall_final_model_rf_qc.parquet'), 'xgboost': PosixPath('results_xgboost/plate6_precision_recall_final_model_qc.parquet')}


# Metadata columns (we assume they start with "Metadata")
meta_cols = [c for c in df_lr.columns if c.startswith("Metadata")]
logger.info("Number of metadata columns: %d", len(meta_cols))

# Basic sanity checks
assert df_lr.shape[0] == df_rf.shape[0] == df_xgb.shape[0], "Row counts differ between models"

# Check that metadata matches row by row
same_lr_rf = (df_lr[meta_cols].values == df_rf[meta_cols].values).all()
same_lr_xgb = (df_lr[meta_cols].values == df_xgb[meta_cols].values).all()
logger.info("Metadata identical LR vs RF: %s", same_lr_rf)
logger.info("Metadata identical LR vs XGB: %s", same_lr_xgb)

# If one of these is False and you do not want to debug now, you can comment out these asserts
assert same_lr_rf, "Metadata mismatch between LR and RF"
assert same_lr_xgb, "Metadata mismatch between LR and XGB"


def get_wt_prob(df: pd.DataFrame) -> np.ndarray:
    """
    Extract WT probability for a model from its single cell df.
    We try probability_WT first, then fall back to 1 - probability_Null.
    """
    if "probability_WT" in df.columns:
        return df["probability_WT"].to_numpy()
    if "probability_Null" in df.columns:
        return 1.0 - df["probability_Null"].to_numpy()
    raise ValueError(
        "Could not find probability_WT or probability_Null in columns: "
        f"{df.columns.tolist()}"
    )


# Get WT probabilities from each model
p_lr = get_wt_prob(df_lr)
p_rf = get_wt_prob(df_rf)
p_xgb = get_wt_prob(df_xgb)

# Ensemble probability as simple average
p_ensemble = (p_lr + p_rf + p_xgb) / 3.0

logger.info("Ensemble probabilities computed.")

# Build ensemble dataframe by copying the logistic df structure
ensemble_df = df_lr.copy()

# Store per model probabilities if you want to inspect them
ensemble_df["probability_WT_logreg"] = p_lr
ensemble_df["probability_WT_rf"] = p_rf
ensemble_df["probability_WT_xgb"] = p_xgb

# Store ensemble probability
ensemble_df["probability_WT"] = p_ensemble

# Build a binary true label from Metadata_genotype: Null -> 0, WT -> 1, ignore HET
mapping = {"Null": 0, "WT": 1}
ensemble_df["true_genotype_binary"] = ensemble_df["Metadata_genotype"].map(mapping)

# Predicted binary label from ensemble probability
ensemble_df["predicted_genotype_binary"] = np.where(
    ensemble_df["probability_WT"] >= 0.5, 1, 0
)

logger.info("ensemble_df shape: %s", ensemble_df.shape)

# Save single cell ensemble probabilities
probs_out_file = ensemble_results_dir / "plate_6_single_cell_probabilities_ensemble_qc.parquet"
ensemble_df.to_parquet(probs_out_file)
logger.info("Saved ensemble single cell probabilities to %s", probs_out_file)

ensemble_df.head()


In [4]:
import json

# Load and combine accuracy tables
acc_dfs = []

for model_name, path in acc_files.items():
    logger.info("Loading accuracy for %s from %s", model_name, path)
    df = pd.read_parquet(path)

    # Add model identifier
    df["model"] = model_name

    acc_dfs.append(df)

accuracy_all_models_df = pd.concat(acc_dfs, ignore_index=True)
logger.info("Combined accuracy_all_models_df shape: %s", accuracy_all_models_df.shape)

# Quick numeric summary in logs: mean final accuracy per model
for model_name in accuracy_all_models_df["model"].unique():
    sub = accuracy_all_models_df[
        (accuracy_all_models_df["model"] == model_name)
        & (accuracy_all_models_df["data_type"] == "final")
    ]
    mean_acc = sub["accuracy"].mean()
    logger.info(
        "%s mean final accuracy across institutions: %.4f",
        model_name,
        mean_acc,
    )

# Save combined accuracy table
acc_out_file = compare_results_dir / "plate6_accuracy_final_all_models_qc.parquet"
accuracy_all_models_df.to_parquet(acc_out_file)
logger.info("Saved combined accuracy table to %s", acc_out_file)

# Log the full accuracy table as JSON for downstream text analysis
acc_records = accuracy_all_models_df.to_dict(orient="records")
logger.info("accuracy_all_models_df_json=%s", json.dumps(acc_records))

accuracy_all_models_df.head()


2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading accuracy for logistic_regression from results/plate6_accuracy_final_model_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading accuracy for random_forest from results_randomforest/plate6_accuracy_final_model_rf_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading accuracy for xgboost from results_xgboost/plate6_accuracy_final_model_xgb_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Combined accuracy_all_models_df shape: (12, 4)
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: logistic_regression mean final accuracy across institutions: 0.5918
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: random_forest mean final accuracy across institutions: 0.6043
2025-12-06T08:17:57 [generalizabilit

Unnamed: 0,Metadata_Institution,data_type,accuracy,model
0,MGH,final,0.496654,logistic_regression
1,MGH,shuffled,0.501585,logistic_regression
2,iNFixion,final,0.686848,logistic_regression
3,iNFixion,shuffled,0.501044,logistic_regression
4,MGH,final,0.592814,random_forest


In [5]:
import json

# Load and combine PR curve tables
pr_dfs = []

for model_name, path in pr_files.items():
    logger.info("Loading PR curves for %s from %s", model_name, path)
    df = pd.read_parquet(path)

    # Add model identifier
    df["model"] = model_name

    pr_dfs.append(df)

precision_recall_all_models_df = pd.concat(pr_dfs, ignore_index=True)
logger.info(
    "Combined precision_recall_all_models_df shape: %s",
    precision_recall_all_models_df.shape,
)

# Save combined PR table
pr_out_file = compare_results_dir / "plate6_precision_recall_final_all_models_qc.parquet"
precision_recall_all_models_df.to_parquet(pr_out_file)
logger.info("Saved combined PR table to %s", pr_out_file)

# Log column names so we know what we have
pr_cols = list(precision_recall_all_models_df.columns)
logger.info("precision_recall_all_models_df_columns=%s", pr_cols)

# Log PR curves as JSON for downstream text analysis.
# To keep the log size reasonable, we log only rows with data_type == "final".
pr_json_df = precision_recall_all_models_df[
    precision_recall_all_models_df["data_type"] == "final"
].copy()

pr_records = pr_json_df.to_dict(orient="records")
logger.info("precision_recall_all_models_df_final_json=%s", json.dumps(pr_records))

precision_recall_all_models_df.head()


2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading PR curves for logistic_regression from results/plate6_precision_recall_final_model_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading PR curves for random_forest from results_randomforest/plate6_precision_recall_final_model_rf_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Loading PR curves for xgboost from results_xgboost/plate6_precision_recall_final_model_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Combined precision_recall_all_models_df shape: (17926, 5)
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: Saved combined PR table to results_ensemble/plate6_precision_recall_final_all_models_qc.parquet
2025-12-06T08:17:57 [generalizability_12_08_08_12_ensemble_generalizability] INFO: precision_recall_all_models_df_colum

Unnamed: 0,Precision,Recall,Metadata_Institution,data_type,model
0,0.677801,1.0,iNFixion,final,logistic_regression
1,0.677577,0.998973,iNFixion,final,logistic_regression
2,0.677352,0.997947,iNFixion,final,logistic_regression
3,0.677127,0.99692,iNFixion,final,logistic_regression
4,0.676902,0.995893,iNFixion,final,logistic_regression


In [None]:
# Example PR plot: all models, final data
# If an "Institution" column exists, filter to all_institutions
if "Institution" in precision_recall_all_models_df.columns:
    plot_df = precision_recall_all_models_df[
        (precision_recall_all_models_df["data_type"] == "final")
        & (precision_recall_all_models_df["Institution"] == "all_institutions")
    ].copy()
else:
    # Fall back to all final rows if Institution is not present
    plot_df = precision_recall_all_models_df[
        precision_recall_all_models_df["data_type"] == "final"
    ].copy()

plt.figure(figsize=(6, 4))
sns.lineplot(
    data=plot_df,
    x="Recall",
    y="Precision",
    hue="model",
)
plt.title("Plate 6 PR curves (final)")
plt.legend(title="Model")
plt.tight_layout()
plt.show()
