In [None]:
import glob
import io
import os
import warnings
from typing import Any, Dict, List, Tuple, Union
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import zstandard as zstd
from scipy.stats import mannwhitneyu
from tqdm import tqdm
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    auc,
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


# external et al paper: https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1010204

In [None]:
concept_df = pd.read_csv('/path/to/omop/concept.csv')
visit_occurrence_df = pd.read_csv('/path/to/omop/visit_occurrence.csv')
drug_df_sub = pd.read_csv('/path/to/omop/drug_exposure.csv')

In [None]:
# get all concept ids for anti-tnf therapy
def filter_by_concept_name(
    df: pd.DataFrame,
    substrings: Union[str, List[str]],
    name_col: str = "concept_name",
    id_col: str = "concept_id"
) -> Tuple[pd.DataFrame, List[Any]]:

    if isinstance(substrings, str):
        substrings = [substrings]
    if not substrings:
        print("No substrings provided; returning empty result.")
        return df.iloc[0:0].copy(), []

    mask = pd.Series(False, index=df.index)
    for sub in substrings:
        mask |= df[name_col].str.contains(sub, case=False, na=False)

    # drop duplicate rows
    filtered_df = df.loc[mask].drop_duplicates().copy()

    ids = filtered_df[id_col].unique().tolist()

    print(f"Found {len(filtered_df)} unique rows matching {substrings!r} in '{name_col}'.")
    print(f"Unique {id_col} values: {ids}")

    return filtered_df, ids

drug_names = ['adalimumab', 'etanercept', 'infliximab', 'certolizumab pegol', 'golimumab']
filtered_concepts, drug_ids = filter_by_concept_name(concept_df, drug_names)
filtered_concepts

In [None]:
# subset drug dataframe to include only anti-tnf drugs
def filter_by_drug_concept_id(
    df: pd.DataFrame,
    concept_ids: List[int],
    column: str = "drug_concept_id"
) -> pd.DataFrame:
    filtered = df[df[column].isin(concept_ids)].copy()
    counts = filtered[column].value_counts().reindex(concept_ids, fill_value=0)
    
    # Print the counts
    print("Row counts by drug_concept_id:")
    for cid, cnt in counts.items():
        print(f"  {cid}: {cnt}")
    
    return filtered


filtered_drugs = filter_by_drug_concept_id(drug_df_sub, drug_ids)


In [None]:
# organize patient information of all patients who recieved anti-tnf
def create_person_drug_visit_dict(
    drugsub: pd.DataFrame,
    visdf: pd.DataFrame
) -> Dict[Any, Dict[str, List[pd.Timestamp]]]:

    drug = drugsub.copy()
    visits = visdf.copy()
    
    drug['drug_exposure_start_DATE'] = pd.to_datetime(
        drug['drug_exposure_start_DATE'], errors='coerce'
    )
    visits['visit_start_DATE'] = pd.to_datetime(
        visits['visit_start_DATE'], errors='coerce'
    )
    
    summary_dict: Dict[Any, Dict[str, List[pd.Timestamp]]] = {}
    
    for pid in drug['person_id'].unique():
        person_visits = visits.loc[
            visits['person_id'] == pid, 'visit_start_DATE'
        ].dropna().sort_values()
        visit_list: List[pd.Timestamp] = person_visits.tolist()
        
        person_drugs = drug.loc[
            drug['person_id'] == pid, 'drug_exposure_start_DATE'
        ].dropna().sort_values()
        drug_list: List[pd.Timestamp] = person_drugs.tolist()
        
        summary_dict[pid] = {
            'all_visits': visit_list,
            'drug_administrations': drug_list
        }
    
    return summary_dict


summary_dict = create_person_drug_visit_dict(filtered_drugs, visit_occurrence_df)
summary_dict


In [None]:
# convert dictionary to dataframe for easier downstream processing
def summary_dict_to_dataframe(
    summary_dict: Dict[Any, Dict[str, List[pd.Timestamp]]]
) -> pd.DataFrame:
    
    records = []
    for pid, data in summary_dict.items():
        visits = data.get("all_visits", [])
        drugs = data.get("drug_administrations", [])

        # Defaults if there are no drug administrations
        if drugs:
            first_drug = drugs[0]
            last_drug = drugs[-1]
            visits_before = sum(1 for v in visits if v < first_drug)
            visits_after = sum(1 for v in visits if v > last_drug)
        else:
            visits_before = 0
            visits_after = 0

        records.append({
            "person_id": pid,
            "visits_before_first_drug": visits_before,
            "visits_after_last_drug": visits_after,
            "total_drug_administrations": len(drugs)
        })

    return pd.DataFrame.from_records(records)


drug_admin_stat_df = summary_dict_to_dataframe(summary_dict)
drug_admin_stat_df_summary = pd.merge(drug_admin_stat_df, rabit_metadata, on='person_id', how='inner')
drug_admin_stat_df_summary

## Create label for ehr representation generation

In [None]:
# build a lookup of each patient’s earliest drug date (or NaT if none)
earliest_drug = {
    pid: drugs[0] if drugs else pd.NaT
    for pid, data in summary_dict.items()
    for drugs in [data.get('drug_administrations', [])]
}

rabit_labels = pd.DataFrame()

rabit_labels['patient_id'] = drug_admin_stat_df_summary['person_id']
rabit_labels['prediction_time'] = (
    rabit_labels['patient_id']
    .map(earliest_drug)                     # lookup earliest drug timestamp
    .dt.floor('T')                          # drop seconds
    .dt.strftime('%Y-%m-%d %H:%M:%S')       # format as string
)
rabit_labels['label_type'] = 'boolean'
rabit_labels['value'] = True

rabit_labels.to_csv(
    '/path/to/label/directory',
    index=False
)
rabit_labels


# Create responder/nonresponder labels

### if a timestamp for a visit (with no anti-tnf record) exists at least 12 months after latest record of anti-tnf record, patient is marked as a non-responder

In [None]:
def responder_summary_df(
    summary_dict: Dict[Any, Dict[str, List[pd.Timestamp]]], deltamon=6
) -> pd.DataFrame:

    records = []
    for pid, data in summary_dict.items():
        visits = data.get('all_visits', [])
        drugs = data.get('drug_administrations', [])
        
        responder = True
        if drugs and visits:
            latest_drug = drugs[-1]
            cutoff = latest_drug + pd.DateOffset(months=deltamon)
            # if any visit ≥ cutoff, mark as non-responder
            if any(v >= cutoff for v in visits):
                responder = False
        
        records.append({
            'person_id': pid,
            'responder': responder
        })
    
    df = pd.DataFrame.from_records(records)
    counts = df['responder'].value_counts()
    n_true = counts.get(True, 0)
    n_false = counts.get(False, 0)
    
    print(f"Responders (True): {n_true}")
    print(f"Non-responders (False): {n_false}")
    return df

responder_df = responder_summary_df(summary_dict, deltamon=12)
responder_df


In [None]:
# Subset of those who received anti-tnf for patients who had diagnosed RA
diag_labels = pd.read_csv('/path/to/omop/condition_occurrence.csv')
# extract patient_ids and time of diagnosis of RA (defined as ICD10 code: M05,M06)
## filter for patients who had RA diagnosis on or before anti-tnf treatment start date

## Generate RABIT proteomics for final cohort of patients at the time of anti-tnf treatment beginning

## Subset RABIT proteomics panel to include only the 344 proteins that overlap with external dataset for fair comparison

# Train RABIT-trained and External-trained model

In [None]:
def nested_enet_cv(
    df: pd.DataFrame,
    id_col: str,
    label_col: str,
    output_dir: str,
    outer_splits: int = 5,
    inner_splits: int = 10,
    random_state: int = 42,
    l1_ratio: float = 0.9,
):
   
    os.makedirs(output_dir, exist_ok=True)

    X  = df.drop(columns=[id_col, label_col])
    y  = df[label_col]
    ids = df[id_col]

    pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("clf", LogisticRegression(
            penalty="elasticnet",
            solver="saga",
            l1_ratio=l1_ratio,
            max_iter=20_000,
            random_state=random_state,
            n_jobs=-1,
        ))
    ])

    # hyperparameter searching
    param_grid = {
        "clf__C": 10 ** np.linspace(-4, 4, 15),
        "clf__l1_ratio": np.linspace(0.0, 1.0, 6),
    }

    outer_cv = StratifiedKFold(
        n_splits=outer_splits, shuffle=True, random_state=random_state
    )

    metrics_records, param_records, coef_records, oof_records = [], [], [], []

    fpr_grid     = np.linspace(0.0, 1.0, 101)
    recall_grid  = np.linspace(0.0, 1.0, 101)
    tpr_rows, prec_rows = [], []          # will become (folds × 101) arrays

    # outer fold cross validation
    for fold, (train_idx, test_idx) in enumerate(
        tqdm(list(outer_cv.split(X, y)), desc="Outer CV folds"), start=1
    ):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        inner_cv = StratifiedKFold(
            n_splits=inner_splits, shuffle=True, random_state=random_state
        )
        grid = GridSearchCV(
            estimator=pipe,
            param_grid=param_grid,
            scoring="roc_auc",
            cv=inner_cv,
            n_jobs=-1,
            verbose=0,
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            grid.fit(X_train, y_train)

        best_model  = grid.best_estimator_
        best_params = grid.best_params_

        y_pred  = best_model.predict(X_test)
        y_proba = best_model.predict_proba(X_test)[:, 1]

        metrics_records.append({
            "fold": fold,
            "n_test": len(test_idx),
            "auroc": roc_auc_score(y_test, y_proba),
            "auprc": average_precision_score(y_test, y_proba),
            "accuracy": accuracy_score(y_test, y_pred),
            "prevalence": y_test.mean(),
        })
        param_records.append({"fold": fold, "C": best_params["clf__C"]})

        coef = best_model.named_steps["clf"].coef_.flatten()
        coef_records.append(pd.DataFrame({
            "feature": X.columns,
            "coef": coef,
            "abs_coef": np.abs(coef),
            "fold": fold,
        }))

        oof_records.append(pd.DataFrame({
            id_col: ids.iloc[test_idx].values,
            "fold":   fold,
            "y_true": y_test.values,
            "y_proba": y_proba,
            "y_pred":  y_pred,
        }))

        fpr, tpr, _           = roc_curve(y_test, y_proba)
        rec, prec, _          = precision_recall_curve(y_test, y_proba)

        tpr_interp  = np.interp(fpr_grid, fpr, tpr)        # ↑ same grid for all folds
        prec_interp = np.interp(recall_grid, rec[::-1], prec[::-1])

        tpr_rows.append(tpr_interp)
        prec_rows.append(prec_interp)

        joblib.dump(
            best_model,
            os.path.join(output_dir, f"enet_nested_best_fold{fold}.joblib"),
        )

    metrics_df     = pd.DataFrame(metrics_records)
    best_params_df = pd.DataFrame(param_records)
    coef_df_all    = pd.concat(coef_records, ignore_index=True)
    oof_df_all     = pd.concat(oof_records, ignore_index=True)

    tpr_mat   = np.vstack(tpr_rows)    # shape: (n_folds, 101)
    prec_mat  = np.vstack(prec_rows)   # shape: (n_folds, 101)
    np.savez(
        os.path.join(output_dir, "nested_cv_curves.npz"),
        fpr_grid=fpr_grid,
        tpr_mat=tpr_mat,
        recall_grid=recall_grid,
        prec_mat=prec_mat,
    )

    metrics_df.to_csv(os.path.join(output_dir, "nested_cv_metrics.csv"), index=False)
    best_params_df.to_csv(os.path.join(output_dir, "nested_best_params.csv"), index=False)
    coef_df_all.to_csv(os.path.join(output_dir, "nested_coefficients.csv"), index=False)
    oof_df_all.to_csv(os.path.join(output_dir, "nested_oof_predictions.csv"), index=False)

    return (
        metrics_df,
        best_params_df,
        coef_df_all,
        oof_df_all,
        {
            "fpr_grid": fpr_grid,
            "tpr_mat":  tpr_mat,
            "recall_grid": recall_grid,
            "prec_mat": prec_mat,
        },
    )


In [None]:
# Train RABIT-trained model
metrics_df, best_params_df, coef_df_all, oof_df_all, curve_info = nested_enet_cv(
    rabit_proteomics,
    id_col="sample_id",
    label_col="responder",
    output_dir="/path/to/rabit-trained/model"
)

metrics_df["adjusted_auprc"] = metrics_df["auprc"] / metrics_df["prevalence"]

mean_row = {
    "fold":            "mean",
    "auroc":           metrics_df["auroc"].mean(),
    "auprc":           metrics_df["auprc"].mean(),
    "accuracy":        metrics_df["accuracy"].mean(),
    "prevalence":      metrics_df["prevalence"].mean(),
    "adjusted_auprc":  metrics_df["adjusted_auprc"].mean(),
}

metrics_df = pd.concat([metrics_df, pd.DataFrame([mean_row])], ignore_index=True)
metrics_df

In [None]:
# Train External-trained model
metrics_df, best_params_df, coef_df_all, oof_df_all, curve_info = nested_enet_cv(
    external_cohort_proteomics,
    id_col="sample_id",
    label_col="responder",
    output_dir="/path/to/external-trained/model"
)

metrics_df["adjusted_auprc"] = metrics_df["auprc"] / metrics_df["prevalence"]

mean_row = {
    "fold":            "mean",
    "auroc":           metrics_df["auroc"].mean(),
    "auprc":           metrics_df["auprc"].mean(),
    "accuracy":        metrics_df["accuracy"].mean(),
    "prevalence":      metrics_df["prevalence"].mean(),
    "adjusted_auprc":  metrics_df["adjusted_auprc"].mean(),
}

metrics_df = pd.concat([metrics_df, pd.DataFrame([mean_row])], ignore_index=True)
metrics_df

# Run inference: external measured proteomics as input into RABIT-trained model

In [None]:
# set paths
model_dir   = "/path/to/RABIT-trained/model"
model_paths = sorted(glob.glob(f"{model_dir}/enet_nested_best_fold*.joblib"))
id_col      = "sample_id"
label_col   = "responder"          # must be 0/1

df_new = external_df_cleaned_final
assert label_col in df_new.columns, "New data must contain ground-truth labels"

fpr_grid    = np.linspace(0.0, 1.0, 101)
recall_grid = np.linspace(0.0, 1.0, 101)
tpr_rows, prec_rows = [], []        

# predict with every saved model
proba_matrix   = []
per_model_long = []

for path in model_paths:
    model = joblib.load(path)
    features = model.feature_names_in_
    X_raw    = df_new[features]

    if isinstance(model, Pipeline):     # pipeline has imputer/scaler
        proba = model.predict_proba(X_raw)[:, 1]
    else:                             
        scaler = StandardScaler()
        proba = model.predict_proba(scaler.fit_transform(X_raw))[:, 1]

    proba_matrix.append(proba)

    fpr,  tpr,  _ = roc_curve(df_new[label_col], proba)
    rec,  prec, _ = precision_recall_curve(df_new[label_col], proba)

    # interpolate onto common grids
    tpr_rows.append(np.interp(fpr_grid, fpr, tpr))
    prec_rows.append(np.interp(recall_grid, rec[::-1], prec[::-1]))

    # long-format rows for CSV
    per_model_long.append(pd.DataFrame({
        id_col:    df_new[id_col],
        "y_true":  df_new[label_col],
        "y_proba": proba,
        "model_id": os.path.basename(path),
    }))

proba_matrix = np.vstack(proba_matrix)    
proba_avg    = proba_matrix.mean(axis=0)
pred_avg     = (proba_avg >= 0.5).astype(int)

# aggregate metrics
y_true     = df_new[label_col]
auroc      = roc_auc_score(y_true, proba_avg)
auprc      = average_precision_score(y_true, proba_avg)
accuracy   = accuracy_score(y_true, pred_avg)
prevalence = y_true.mean()
adj_auprc  = auprc / prevalence if prevalence > 0 else np.nan

metrics_df = pd.DataFrame([{
    "auroc":          auroc,
    "auprc":          auprc,
    "prevalence":     prevalence,
    "adjusted_auprc": adj_auprc,
    "accuracy":       accuracy,
    "n_models":       len(model_paths),
    "n_samples":      len(df_new),
}])

print(metrics_df)

tpr_mat  = np.vstack(tpr_rows)    
prec_mat = np.vstack(prec_rows)

np.savez(
    f"{model_dir}/per_model_curves.npz",
    fpr_grid=fpr_grid,
    tpr_mat=tpr_mat,
    recall_grid=recall_grid,
    prec_mat=prec_mat,
)

fpr_ens,  tpr_ens,  _ = roc_curve(y_true, proba_avg)
rec_ens,  prec_ens, _ = precision_recall_curve(y_true, proba_avg)
np.savez(
    f"{model_dir}/ensemble_curve.npz",
    fpr=fpr_ens, tpr=tpr_ens, recall=rec_ens, precision=prec_ens
)

pred_path_ensemble = f"{model_dir}/ensemble_predictions.csv"
pred_path_permdl   = f"{model_dir}/per_model_predictions.csv"
metrics_path       = f"{model_dir}/ensemble_metrics.csv"

pd.DataFrame({
    id_col:    df_new[id_col],
    "y_true":  y_true,
    "y_proba": proba_avg,
}).to_csv(pred_path_ensemble, index=False)

pd.concat(per_model_long, ignore_index=True).to_csv(pred_path_permdl, index=False)
metrics_df.to_csv(metrics_path, index=False)

print(f"Saved curves           → {model_dir}/per_model_curves.npz")
print(f"Saved ensemble curve   → {model_dir}/ensemble_curve.npz")
print(f"Saved ensemble preds   → {pred_path_ensemble}")
print(f"Saved per-model preds  → {pred_path_permdl}")
print(f"Saved metrics          → {metrics_path}")



# Train clinical (age+sex) model 

In [None]:
def nested_enet_cv(
    df: pd.DataFrame,
    id_col: str,
    label_col: str,
    output_dir: str,
    outer_splits: int = 5,
    inner_splits: int = 10,
    random_state: int = 42,
    l1_ratio: float = 0.9,
):

    os.makedirs(output_dir, exist_ok=True)

    if not {"age", "sex"}.issubset(df.columns):
        raise KeyError("Input df must contain 'age' and 'sex'.")
    X = (
        df[["age", "sex"]].copy()
          .assign(sex=lambda d: d["sex"]
                              .str.lower()
                              .map({"female": 0, "male": 1}))
    )
    y   = df[label_col]
    ids = df[id_col]

    pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("clf", LogisticRegression(
            penalty="elasticnet", solver="saga",
            l1_ratio=l1_ratio, max_iter=20_000,
            random_state=random_state, n_jobs=-1,
        ))
    ])
    param_grid = {"clf__C": 10 ** np.linspace(-4, 4, 15)}

    outer_cv = StratifiedKFold(
        n_splits=outer_splits, shuffle=True, random_state=random_state
    )

    metrics_records, param_records, coef_records, oof_records = [], [], [], []

    fpr_grid    = np.linspace(0.0, 1.0, 101)
    recall_grid = np.linspace(0.0, 1.0, 101)
    tpr_rows, prec_rows = [], []        # one row per outer fold

    # outer fold cross validation
    for fold, (train_idx, test_idx) in enumerate(
        tqdm(list(outer_cv.split(X, y)), desc="Outer CV folds"), start=1
    ):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        inner_cv = StratifiedKFold(
            n_splits=inner_splits, shuffle=True, random_state=random_state
        )
        grid = GridSearchCV(
            estimator=pipe, param_grid=param_grid,
            scoring="roc_auc", cv=inner_cv,
            n_jobs=-1, verbose=0,
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            grid.fit(X_train, y_train)

        best_model  = grid.best_estimator_
        best_params = grid.best_params_

        y_pred  = best_model.predict(X_test)
        y_proba = best_model.predict_proba(X_test)[:, 1]

        metrics_records.append({
            "fold":       fold,
            "n_test":     len(test_idx),        # DD. store size
            "auroc":      roc_auc_score(y_test, y_proba),
            "auprc":      average_precision_score(y_test, y_proba),
            "accuracy":   accuracy_score(y_test, y_pred),
            "prevalence": y_test.mean(),
        })
        param_records.append({"fold": fold, "C": best_params["clf__C"]})

        coef = best_model.named_steps["clf"].coef_.flatten()
        coef_records.append(pd.DataFrame({
            "feature":  X.columns,
            "coef":     coef,
            "abs_coef": np.abs(coef),
            "fold":     fold,
        }))

        oof_records.append(pd.DataFrame({
            id_col:   ids.iloc[test_idx].values,
            "fold":   fold,
            "y_true": y_test.values,
            "y_proba": y_proba,
            "y_pred":  y_pred,
        }))

        fpr, tpr, _        = roc_curve(y_test, y_proba)
        rec, prec, _       = precision_recall_curve(y_test, y_proba)
        tpr_rows.append(np.interp(fpr_grid, fpr, tpr))
        prec_rows.append(np.interp(recall_grid, rec[::-1], prec[::-1]))

        # Save model
        joblib.dump(
            best_model,
            os.path.join(output_dir, f"enet_nested_best_fold{fold}.joblib"),
        )

    metrics_df     = pd.DataFrame(metrics_records)
    best_params_df = pd.DataFrame(param_records)
    coef_df_all    = pd.concat(coef_records, ignore_index=True)
    oof_df_all     = pd.concat(oof_records, ignore_index=True)

    tpr_mat  = np.vstack(tpr_rows)      # (outer_splits × 101)
    prec_mat = np.vstack(prec_rows)
    np.savez(
        os.path.join(output_dir, "nested_cv_curves.npz"),
        fpr_grid=fpr_grid,
        tpr_mat=tpr_mat,
        recall_grid=recall_grid,
        prec_mat=prec_mat,
    )

    metrics_df.to_csv(os.path.join(output_dir, "nested_cv_metrics.csv"), index=False)
    best_params_df.to_csv(os.path.join(output_dir, "nested_best_params.csv"), index=False)
    coef_df_all.to_csv(os.path.join(output_dir, "nested_coefficients.csv"), index=False)
    oof_df_all.to_csv(os.path.join(output_dir, "nested_oof_predictions.csv"), index=False)

    curves = dict(fpr_grid=fpr_grid, tpr_mat=tpr_mat,
                  recall_grid=recall_grid, prec_mat=prec_mat)

    return metrics_df, best_params_df, coef_df_all, oof_df_all, curves




metrics_df, best_params_df, coef_df_all, oof_df_all, curves_nested = nested_enet_cv(
    clinical_model_input,
    id_col="sample_id",
    label_col="responder",
    output_dir="/path/to/clinical/model"
)


metrics_df["adjusted_auprc"] = metrics_df["auprc"] / metrics_df["prevalence"]

mean_row = {
    "fold":            "mean",
    "auroc":           metrics_df["auroc"].mean(),
    "auprc":           metrics_df["auprc"].mean(),
    "accuracy":        metrics_df["accuracy"].mean(),
    "prevalence":      metrics_df["prevalence"].mean(),
    "adjusted_auprc":  metrics_df["adjusted_auprc"].mean(),
}

metrics_df = pd.concat([metrics_df, pd.DataFrame([mean_row])], ignore_index=True)
metrics_df

In [None]:
# paths
model_dir   = "/path/to/clinical/model"
model_paths = sorted(glob.glob(f"{model_dir}/enet_nested_best_fold*.joblib"))
id_col      = "sample_id"
label_col   = "responder"        
df_new = external_df_cleaned_clinical
assert label_col in df_new.columns, "New data must contain ground-truth labels"

fpr_grid    = np.linspace(0.0, 1.0, 101)
recall_grid = np.linspace(0.0, 1.0, 101)
tpr_rows, prec_rows = [], []         

proba_matrix   = []
per_model_long = []

for path in model_paths:
    model    = joblib.load(path)          
    features = model.feature_names_in_   
    X_new    = df_new[features]

    proba = model.predict_proba(X_new)[:, 1]     
    proba_matrix.append(proba)

    fpr,  tpr,  _ = roc_curve(df_new[label_col], proba)
    rec,  prec, _ = precision_recall_curve(df_new[label_col], proba)

    tpr_rows.append(np.interp(fpr_grid, fpr, tpr))
    prec_rows.append(np.interp(recall_grid, rec[::-1], prec[::-1]))

    # long-format rows
    per_model_long.append(pd.DataFrame({
        id_col:   df_new[id_col],
        "y_true": df_new[label_col],
        "y_proba": proba,
        "model_id": os.path.basename(path),
    }))

proba_matrix = np.vstack(proba_matrix)         
proba_avg    = proba_matrix.mean(axis=0)
pred_avg     = (proba_avg >= 0.5).astype(int)

y_true     = df_new[label_col]
auroc      = roc_auc_score(y_true, proba_avg)
auprc      = average_precision_score(y_true, proba_avg)
accuracy   = accuracy_score(y_true, pred_avg)
prevalence = y_true.mean()
adj_auprc  = auprc / prevalence if prevalence > 0 else np.nan

metrics_df = pd.DataFrame([{
    "auroc":          auroc,
    "auprc":          auprc,
    "prevalence":     prevalence,
    "adjusted_auprc": adj_auprc,
    "accuracy":       accuracy,
    "n_models":       len(model_paths),
    "n_samples":      len(df_new),
}])

print(metrics_df)

tpr_mat  = np.vstack(tpr_rows)      # shape: (n_models × 101)
prec_mat = np.vstack(prec_rows)

np.savez(
    f"{model_dir}/per_model_curves_external_clinical.npz",
    fpr_grid=fpr_grid, tpr_mat=tpr_mat,
    recall_grid=recall_grid, prec_mat=prec_mat,
)

fpr_ens,  tpr_ens,  _ = roc_curve(y_true, proba_avg)
rec_ens,  prec_ens, _ = precision_recall_curve(y_true, proba_avg)
np.savez(
    f"{model_dir}/ensemble_curve_external_clinical.npz",
    fpr=fpr_ens, tpr=tpr_ens, recall=rec_ens, precision=prec_ens
)

pred_path_ensemble = f"{model_dir}/ensemble_predictions.csv"
pred_path_permdl   = f"{model_dir}/per_model_predictions.csv"
metrics_path       = f"{model_dir}/ensemble_metrics.csv"

pd.DataFrame({
    id_col:   df_new[id_col],
    "y_true": y_true,
    "y_proba": proba_avg,
}).to_csv(pred_path_ensemble, index=False)

pd.concat(per_model_long, ignore_index=True).to_csv(pred_path_permdl, index=False)
metrics_df.to_csv(metrics_path, index=False)

print(f"Saved per-model curves   → {model_dir}/per_model_curves.npz")
print(f"Saved ensemble curve     → {model_dir}/ensemble_curve.npz")
print(f"Saved ensemble preds     → {pred_path_ensemble}")
print(f"Saved per-model preds    → {pred_path_permdl}")
print(f"Saved metrics            → {metrics_path}")


In [None]:
# for visualization
def _load(csv_or_buf):
    df = pd.read_csv(csv_or_buf)
    return df["y_true"].values, df["y_proba"].values


def plot_curves(
    csv_files,
    *,
    kind: str = "roc",
    title: str | None = None,
    ax=None,
    grid_kwargs: dict | None = None,
    legend_kwargs: dict | None = None,
    save: bool = True,
    out_dir: str | os.PathLike = (
        "/path/for/figures"
    ),
    filename: str | None = None,
):

    grid_kwargs = grid_kwargs or {
        "which": "both",
        "linestyle": "--",
        "linewidth": 0.5,
        "alpha": 0.5,
    }
    legend_kwargs = legend_kwargs or {"loc": "lower right", "frameon": False}

    created_fig = ax is None
    if created_fig:
        fig, ax = plt.subplots(figsize=(5, 5))
    else:
        fig = ax.figure  # grab parent figure

    # ── curves ──────────────────────────────────────────────────────
    for label, path in csv_files:
        y_true, y_proba = _load(path)

        if kind == "roc":
            fpr, tpr, _ = roc_curve(y_true, y_proba)
            score = auc(fpr, tpr)
            ax.plot(fpr, tpr, lw=1.6, label=f"{label} (AUC {score:.3f})")
        else:
            prec, rec, _ = precision_recall_curve(y_true, y_proba)
            score = average_precision_score(y_true, y_proba)
            ax.plot(rec, prec, lw=1.6, label=f"{label} (AUPRC {score:.3f})")

    # ── cosmetics ───────────────────────────────────────────────────
    if kind == "roc":
        ax.plot([0, 1], [0, 1], "--", color="grey", lw=1)
        ax.set(xlabel="False-positive rate", ylabel="True-positive rate")
    else:
        # baseline = prevalence
        baseline = y_true.mean()
        ax.hlines(baseline, 0, 1, ls="--", color="grey", lw=1)
        ax.set(xlabel="Recall", ylabel="Precision")

    ax.set_title(
        title or ("ROC curve" if kind == "roc" else "Precision-Recall curve")
    )
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.grid(True, **grid_kwargs)
    ax.legend(**legend_kwargs)

    if save and created_fig:           
        os.makedirs(out_dir, exist_ok=True)

        def _slug(s: str) -> str:
            return (
                str(s)
                .strip()
                .replace(" ", "_")
                .replace("/", "_")
                .replace("\\", "_")
            )

        if filename is None:
            base = title or kind.upper()
            filename = f"{_slug(base)}_{kind}.pdf"

        pdf_path = os.path.join(out_dir, filename)
        fig.savefig(pdf_path, format="pdf", bbox_inches="tight")
        print(f"✅  Figure saved → {pdf_path}")

    if created_fig:
        fig.tight_layout()
        plt.show()
        return fig, ax

    return ax




# ── customise these paths ───────────────────────────────────────────────
csv_protein = "/path/to/RABIT-trained/model/nested_oof_predictions.csv"
csv_clinic  = "/path/to/clinical/model/nested_oof_predictions.csv"

pairs = [("RABIT-trained Model", csv_protein),
         ("Age + Sex Model", csv_clinic)]

# ── draw the plots ──────────────────────────────────────────────────────
plot_curves(pairs, kind="roc", title="AUROC comparison on RABIT cohort")

plot_curves(pairs, kind="pr",  title="AUPRC comparison on RABIT cohort", save=False)


In [None]:
# ── customise these paths ───────────────────────────────────────────────
csv_rabit = "/path/to/RABIT-trained/model/ensemble_predictions.csv"
csv_clinic  = "/path/to/clinical/model/ensemble_predictions.csv"
csv_external  = "/path/to/external-trained/model/nested_oof_predictions.csv"

pairs = [("RABIT-trained Model", csv_rabit),
         ("Age + Sex Model", csv_clinic),
         ("External-trained Model", csv_external)]

# ── draw the plots ──────────────────────────────────────────────────────
plot_curves(pairs, kind="roc", title="AUROC comparison on External cohort")
plot_curves(pairs, kind="pr",  title="AUPRC comparison on External cohort", save=False)


# Logistic Regression Analysis

In [None]:
# convert coefficients into odds ratios
coef_df = pd.read_csv("/path/to/RABIT-trained/model/nested_coefficients.csv")

avg_coef_df = (
    coef_df
    .groupby("feature", as_index=False)["coef"]
    .mean()
    .rename(columns={"coef": "avg_coef"})
)

avg_coef_df["abs_coef"]   = avg_coef_df["avg_coef"].abs()
avg_coef_df["odds_ratio"] = np.exp(avg_coef_df["avg_coef"])

coef_df_sorted = (
    avg_coef_df
    .sort_values(by="odds_ratio", ascending=False)
    .reset_index(drop=True)
)
coef_df_sorted


In [None]:
# for visualization
def plot_odds_ratios(
    df: pd.DataFrame,
    *,
    top_n: int = 10,         
    print_n: int = 10,       
    save: bool = True,
    out_dir: str | os.PathLike = (
        "/path/to/figures"
    ),
    filename: str | None = None,
):

    required = {"feature", "odds_ratio"}
    if not required.issubset(df.columns):
        raise KeyError(f"DataFrame must include columns {required}")

    df_sorted = df.sort_values("odds_ratio")
    if 0 < top_n * 2 <= len(df_sorted):
        df_plot = pd.concat([df_sorted.head(top_n), df_sorted.tail(top_n)])
    else:
        df_plot = df_sorted
    df_plot = df_plot.sort_values("odds_ratio").reset_index(drop=True)

    # ── plotting ─────────────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(0.55 * len(df_plot) + 2, 6))
    x = np.arange(len(df_plot))
    colors = ["red" if or_ < 1 else "steelblue" for or_ in df_plot["odds_ratio"]]
    heights = df_plot["odds_ratio"] - 1

    ax.bar(
        x,
        heights,
        bottom=1,
        color=colors,
        edgecolor="black",
        width=0.8,
    )

    ax.axhline(1, color="gray", linestyle="--", linewidth=1)
    ax.set_xticks(x)
    ax.set_xticklabels(df_plot["feature"], rotation=90, ha="center", fontsize=8)
    ax.set_xlabel("Proteins (ordered by odds ratio)")

    ymin = min(0.8, df_plot["odds_ratio"].min() * 0.95)
    ymax = max(1.2, df_plot["odds_ratio"].max() * 1.05)
    ax.set_ylim(ymin, ymax)
    ax.set_ylabel("Odds ratio")
    ax.set_title(f"{len(df_plot)} extreme proteins (±{top_n}) by odds ratio")

    fig.tight_layout()

    if save:
        os.makedirs(out_dir, exist_ok=True)
        if filename is None:
            filename = f"odds_ratio_extremes_top{top_n}.pdf"
        pdf_path = os.path.join(out_dir, filename)
        fig.savefig(pdf_path, format="pdf", bbox_inches="tight")
        print(f"✅  Figure saved → {pdf_path}")

    plt.show()

    # ── print extremes ───────────────────────────────────────────────
    smallest_print = df_sorted.head(print_n)
    largest_print = df_sorted.tail(print_n)

    print(f"\n{print_n} smallest odds ratios:")
    print(", ".join(smallest_print["feature"].tolist()))

    print(f"\n{print_n} largest odds ratios:")
    print(", ".join(largest_print["feature"].tolist()))

    return fig, ax


plot_odds_ratios(coef_df_sorted, top_n=10, print_n=10)


In [None]:
# convert coefficients into odds ratios
coef_df_external = pd.read_csv("/path/to/external-trained/model/nested_coefficients.csv")

avg_coef_df_external = (
    coef_df_external
    .groupby("feature", as_index=False)["coef"]
    .mean()
    .rename(columns={"coef": "avg_coef"})
)

avg_coef_df_external["abs_coef"]   = avg_coef_df_external["avg_coef"].abs()
avg_coef_df_external["odds_ratio"] = np.exp(avg_coef_df_external["avg_coef"])

coef_df_sorted_external = (
    avg_coef_df_external
    .sort_values(by="odds_ratio", ascending=False)
    .reset_index(drop=True)
)

coef_df_sorted_external.to_csv('./external_prot_features.csv', index=False)
plot_odds_ratios(coef_df_sorted_external, top_n=10, print_n=10)


In [None]:
bmi_df = pd.read_csv('./ehr_measurements_tables/bmi_3038553.csv')
bmi_df

In [None]:
antitnf_patients = pd.read_csv('./labels/rabit_antitnf_labels.csv')
antitnf_patients


In [None]:
def build_patient_measurement_dict(labeldf: pd.DataFrame,
                                   measdf: pd.DataFrame) -> dict:
    """
    For each patient in `labeldf`, grab their prediction_time and all measurements
    from `measdf`, sorted by measurement_DATE, collecting each column as a list.
    At the end, prints:
      - count of patients with ≥1 measurement
      - count of patients with ≥5 measurements
      - summary statistics of how many measurements each patient has.

    Parameters
    ----------
    labeldf : pd.DataFrame
        Must contain columns ['patient_id', 'prediction_time'].
    measdf : pd.DataFrame
        Must contain columns [
            'person_id',
            'measurement_concept_id',
            'measurement_DATE',
            'value_as_number',
            'value_as_concept_id',
            'measurement_source_value',
            'measurement_source_concept_id',
            'unit_source_value'
        ].

    Returns
    -------
    dict
        Keys are patient_id (from labeldf). Values are dicts with:
            - 'prediction_time': single value
            - one entry per measurement column, where the value is a list
              of all that patient’s entries (in date order).
    """
    # Subset the columns we care about
    label_sub = labeldf[['patient_id', 'prediction_time']].drop_duplicates()
    meas_sub = measdf[[
        'person_id',
        'measurement_concept_id',
        'measurement_DATE',
        'value_as_number',
        'value_as_concept_id',
        'measurement_source_value',
        'measurement_source_concept_id',
        'unit_source_value'
    ]]

    result = {}
    # Build the nested dict
    for _, lab_row in label_sub.iterrows():
        pid = lab_row['patient_id']
        pred_time = lab_row['prediction_time']

        patient_meas = (
            meas_sub[meas_sub['person_id'] == pid]
            .sort_values(by='measurement_DATE')
        )

        entry = {'prediction_time': pred_time}
        for col in [
            'measurement_concept_id',
            'measurement_DATE',
            'value_as_number',
            'value_as_concept_id',
            'measurement_source_value',
            'measurement_source_concept_id',
            'unit_source_value'
        ]:
            entry[col] = patient_meas[col].tolist()

        result[pid] = entry

    # Compute per-patient counts
    counts = pd.Series(
        [len(d['measurement_concept_id']) for d in result.values()],
        name='num_measurements'
    )

    # Print the requested statistics
    num_at_least_1 = (counts >= 1).sum()
    num_at_least_5 = (counts >= 5).sum()

    print(f"Patients with ≥1 measurement: {num_at_least_1}")
    print(f"Patients with ≥5 measurements: {num_at_least_5}")
    print("\nSummary of measurement counts per patient:")
    print(counts.describe())

    return result



patient_dict_all = build_patient_measurement_dict(antitnf_patients, bmi_df) # for bmi

patient_dict_all


In [None]:
## Subset result for rheumatoid arthritis patients
patient_dict = {k: patient_dict_all[k]          # keep the original value
            for k in ra_patients_with_antitnf                  # iterate over the keys you care about
            if k in patient_dict_all}      # guard against missing keys
patient_dict

In [None]:
print(len(patient_dict.keys()))

In [None]:
def plot_patient_time_series(patient_dict):
    """
    Draws a line for each patient showing value_as_number over time,
    and prints how many patients were plotted.

    Parameters
    ----------
    patient_dict : dict
        Nested dict from build_patient_measurement_dict(), e.g.:
        {
          patient_id1: {
            'measurement_DATE': [date1, date2, …],
            'value_as_number': [val1, val2, …],
            …
          },
          patient_id2: { … },
          …
        }
    """
    plt.figure(figsize=(12, 6))
    plotted_count = 0

    for pid, data in patient_dict.items():
        dates = pd.to_datetime(data['measurement_DATE'])
        values = data['value_as_number']
        if len(dates) > 0:
            plt.plot(dates, values, label=str(pid), alpha=0.6)
            plotted_count += 1

    # Print total number of patients plotted
    print(f"Total patients plotted: {plotted_count}")

    plt.xlabel("Measurement Date")
    plt.ylabel("Value as Number")
    plt.title("Patient Measurements Over Time")
    plt.gcf().autofmt_xdate()  # rotate date labels

    # only show legend if few patients
    if plotted_count <= 20:
        plt.legend(title="Patient ID", bbox_to_anchor=(1.02, 1), loc="upper left")
    else:
        # too many lines to label individually: annotate count on the plot
        plt.annotate(f"{plotted_count} patients", 
                     xy=(0.99, 0.01),
                     xycoords='axes fraction',
                     ha='right', va='bottom')

    plt.tight_layout()
    plt.show()


plot_patient_time_series(patient_dict)

In [None]:

def count_extreme_values(patient_dict, low=10, high=70):
    """
    Count how many individual measurements and how many patients have
    value_as_number < low or > high.

    Parameters
    ----------
    patient_dict : dict
        Nested dict as returned by build_patient_measurement_dict(), e.g.:
        {
          patient_id1: {
            'value_as_number': [ … ],
            …
          },
          …
        }
    low : float
        Lower threshold.
    high : float
        Upper threshold.

    Returns
    -------
    tuple (int, int)
        (total_extreme_entries, num_patients_with_extremes)
    """
    total_extreme_entries = 0
    num_patients_with_extremes = 0

    for pid, data in patient_dict.items():
        # get the list of values, drop any None/NaN
        vals = pd.to_numeric(data.get('value_as_number', []), errors='coerce')
        # boolean mask of extremes
        mask = (vals < low) | (vals > high)
        count = mask.sum()
        if count > 0:
            num_patients_with_extremes += 1
            total_extreme_entries += int(count)

    return total_extreme_entries, num_patients_with_extremes


ext_entries, ext_patients = count_extreme_values(patient_dict, low=30, high=400)
print(f"Entries with value_as_number < 10 or > 70: {ext_entries}")
print(f"Patients with ≥1 such entry: {ext_patients}")


In [None]:
def filter_patient_measurements(patient_dict, low=10, high=70):
    """
    Remove all entries where value_as_number < low or value_as_number > high
    from each patient’s record in the nested dict. All other lists are
    filtered in parallel so entries stay aligned.

    Parameters
    ----------
    patient_dict : dict
        Nested dict as returned by build_patient_measurement_dict(), e.g.:
        {
          patient_id1: {
            'prediction_time': ...,
            'measurement_concept_id': [...],
            'measurement_DATE': [...],
            'value_as_number': [...],
            ... 
          },
          ...
        }
    low : float
        Lower bound (inclusive).
    high : float
        Upper bound (inclusive).

    Returns
    -------
    dict
        A new dict with the same structure, but with out‐of‐range entries removed.
    """
    filtered = {}
    for pid, data in patient_dict.items():
        vals = data.get('value_as_number', [])
        # determine which indices to keep
        keep_idxs = []
        for i, v in enumerate(vals):
            try:
                if low <= v <= high:
                    keep_idxs.append(i)
            except Exception:
                # skip non-numeric or missing values
                continue

        # rebuild each list in data by selecting only kept indices
        new_data = {'prediction_time': data['prediction_time']}
        for key, lst in data.items():
            if key == 'prediction_time':
                continue
            new_data[key] = [lst[i] for i in keep_idxs]

        filtered[pid] = new_data

    return filtered



In [None]:
# For BMI
filtered_patient_dict = filter_patient_measurements(patient_dict, low=10, high=70)
ext_entries, ext_patients = count_extreme_values(filtered_patient_dict, low=10, high=70)
print(f"Entries with value_as_number < 10 or > 70: {ext_entries}")
print(f"Patients with ≥1 such entry: {ext_patients}")

In [None]:
plot_patient_time_series(filtered_patient_dict)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.dates as mdates

def plot_time_series_by_response(patient_dict: dict,
                                 responsedf: pd.DataFrame,
                                 normalize: bool = False):
    """
    Plot each patient’s value_as_number over time, coloring responders blue
    and non-responders red, with a best-fit trend line per group.

    Parameters
    ----------
    patient_dict : dict
        Nested dict from build_patient_measurement_dict(), keyed by patient_id.
    responsedf : pd.DataFrame
        Must contain ['person_id', 'responder'] where responder is True/False.
    normalize : bool, default False
        If True, divide each patient’s values by their first value before plotting.
    """
    # map patient → True/False
    resp_map = dict(zip(responsedf['person_id'], responsedf['responder']))
    responders    = [pid for pid in patient_dict if resp_map.get(pid) is True]
    nonresponders = [pid for pid in patient_dict if resp_map.get(pid) is False]

    plt.figure(figsize=(12, 6))

    # plot individual lines
    for pid in responders:
        dates = pd.to_datetime(patient_dict[pid]['measurement_DATE'])
        raw_vals = np.array(patient_dict[pid]['value_as_number'], dtype=float)
        vals = raw_vals / raw_vals[0] if normalize and raw_vals.size and raw_vals[0] else raw_vals
        if vals.size:
            plt.plot(dates, vals, color='blue', alpha=0.3)
    for pid in nonresponders:
        dates = pd.to_datetime(patient_dict[pid]['measurement_DATE'])
        raw_vals = np.array(patient_dict[pid]['value_as_number'], dtype=float)
        vals = raw_vals / raw_vals[0] if normalize and raw_vals.size and raw_vals[0] else raw_vals
        if vals.size:
            plt.plot(dates, vals, color='red', alpha=0.3)

    # helper to fit & plot trend line
    def _fit_and_plot(pids, color):
        x_nums, y_vals = [], []
        for pid in pids:
            dates = pd.to_datetime(patient_dict[pid]['measurement_DATE'])
            raw_vals = pd.to_numeric(patient_dict[pid]['value_as_number'], errors='coerce')
            # raw_vals is already a NumPy array, so no .to_numpy()
            if normalize and raw_vals.size and not np.isnan(raw_vals[0]) and raw_vals[0] != 0:
                vals = raw_vals / raw_vals[0]
            else:
                vals = raw_vals
            mask = ~np.isnan(vals)
            if mask.any():
                x_nums.append(mdates.date2num(dates[mask]))
                y_vals.append(vals[mask])

        if not x_nums:
            return
        x_all = np.concatenate(x_nums)
        y_all = np.concatenate(y_vals)
        slope, intercept = np.polyfit(x_all, y_all, 1)
        x_fit = np.linspace(x_all.min(), x_all.max(), 100)
        y_fit = slope * x_fit + intercept
        dates_fit = mdates.num2date(x_fit)
        plt.plot(dates_fit, y_fit, color=color, linewidth=2, label=f"{color.title()} trend")

    # plot trend lines
    _fit_and_plot(responders, 'blue')
    _fit_and_plot(nonresponders, 'red')

    # finalize
    plt.xlabel("Measurement Date")
    ylabel = "Normalized Value" if normalize else "Value as Number"
    plt.ylabel(ylabel)
    title = "Patient Measurements Over Time by Responder Status"
    if normalize:
        title += " (normalized)"
    plt.title(title)
    plt.gcf().autofmt_xdate()

    # dummy lines for the legend
    plt.plot([], [], color='blue',  alpha=0.3, label="Responder")
    plt.plot([], [], color='red',   alpha=0.3, label="Non-responder")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left")

    plt.tight_layout()
    plt.show()


plot_time_series_by_response(filtered_patient_dict, responder_df, normalize=True)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def plot_violin_by_response(patient_dict: dict,
                            responsedf: pd.DataFrame,
                            normalize: bool = False):
    """
    Create a violin plot for the first and last measurements of each patient,
    split by responder status, with paired dots and connecting lines.

    Parameters
    ----------
    patient_dict : dict
        Nested dict of patient measurements as from build_patient_measurement_dict().
    responsedf : pd.DataFrame
        DataFrame with columns ['person_id', 'responder'] (True/False).
    normalize : bool, default False
        If True, divide each patient’s values by their first measurement before plotting.
    """
    # Build lookup of responder status
    resp_map = dict(zip(responsedf['person_id'], responsedf['responder']))

    # Containers for first/last values
    resp_first, resp_last = [], []
    nonr_first, nonr_last = [], []

    for pid, data in patient_dict.items():
        if pid not in resp_map:
            continue
        raw_vals = np.array(data['value_as_number'], dtype=float)
        if raw_vals.size == 0:
            continue
        # normalize if requested
        if normalize and raw_vals[0] and not np.isnan(raw_vals[0]):
            vals = raw_vals / raw_vals[0]
        else:
            vals = raw_vals
        first, last = vals[0], vals[-1]
        if resp_map[pid]:
            resp_first.append(first)
            resp_last.append(last)
        else:
            nonr_first.append(first)
            nonr_last.append(last)

    # Prepare for plotting
    groups = ['Resp First', 'Resp Last', 'NonR First', 'NonR Last']
    data = [resp_first, resp_last, nonr_first, nonr_last]
    positions = [0, 1, 3, 4]

    fig, ax = plt.subplots(figsize=(10, 6))
    violins = ax.violinplot(data, positions=positions, widths=0.8, showmedians=True)

    # Color violins
    colors = ['lightblue', 'lightblue', 'salmon', 'salmon']
    for body, col in zip(violins['bodies'], colors):
        body.set_facecolor(col)
        body.set_edgecolor('black')
        body.set_alpha(0.5)

    # Paired lines & dots for responders
    for first, last in zip(resp_first, resp_last):
        ax.plot([0, 1], [first, last], color='blue', alpha=0.3)
    ax.scatter([0]*len(resp_first), resp_first, color='blue', edgecolor='black')
    ax.scatter([1]*len(resp_last), resp_last, color='blue', edgecolor='black')

    # Paired lines & dots for non-responders
    for first, last in zip(nonr_first, nonr_last):
        ax.plot([3, 4], [first, last], color='red', alpha=0.3)
    ax.scatter([3]*len(nonr_first), nonr_first, color='red', edgecolor='black')
    ax.scatter([4]*len(nonr_last), nonr_last, color='red', edgecolor='black')

    # Labels & title
    ax.set_xticks(positions)
    ax.set_xticklabels(groups)
    ylabel = 'Normalized Value' if normalize else 'Value as Number'
    ax.set_ylabel(ylabel)
    subtitle = " (normalized)" if normalize else ""
    ax.set_title(f'First vs Last Measurement by Responder Status{subtitle}')
    plt.tight_layout()
    plt.show()


plot_violin_by_response(filtered_patient_dict, responder_df, normalize=True)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def plot_single_patient_time_series(patient_dict: dict,
                                    responsedf: pd.DataFrame,
                                    patient_id,
                                    normalize: bool = False):
    """
    Plot a single patient’s value_as_number over time.

    Parameters
    ----------
    patient_dict : dict
        Nested dict from build_patient_measurement_dict(), keyed by patient_id.
    responsedf : pd.DataFrame
        Must contain ['person_id', 'responder'] where responder is True/False.
    patient_id : hashable
        The identifier of the patient you want to plot.
    normalize : bool, default False
        If True, divide this patient’s values by their first measurement.
    """
    # look up responder status
    resp_map = dict(zip(responsedf['person_id'], responsedf['responder']))
    is_resp = resp_map.get(patient_id, False)
    color = 'blue' if is_resp else 'red'
    label = f"{patient_id} ({'responder' if is_resp else 'non-responder'})"

    # extract dates & values
    dates = pd.to_datetime(patient_dict[patient_id]['measurement_DATE'])
    raw_vals = np.array(patient_dict[patient_id]['value_as_number'], dtype=float)
    if normalize and raw_vals.size and not np.isnan(raw_vals[0]) and raw_vals[0] != 0:
        vals = raw_vals / raw_vals[0]
    else:
        vals = raw_vals

    # plot
    plt.figure(figsize=(8, 4))
    plt.plot(dates, vals, color=color, alpha=0.7, linewidth=2, label=label)

    # labels & title
    plt.xlabel("Measurement Date")
    ylabel = "Normalized Value" if normalize else "Value as Number"
    plt.ylabel(ylabel)
    plt.title(f"Patient {patient_id} Time Series")
    plt.gcf().autofmt_xdate()
    plt.legend(loc="best")
    plt.tight_layout()
    plt.show()


# single patient, raw values
plot_single_patient_time_series(patient_dict, responder_df, patient_id=30912777)

# single patient, normalized to first measurement
plot_single_patient_time_series(patient_dict, responder_df, patient_id=30912777, normalize=True)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.dates as mdates

def plot_time_series_by_response_with_drug_window(
        patient_dict: dict,
        responsedf: pd.DataFrame,
        drugdict: dict,
        normalize: bool = False):
    """
    Plot each patient’s value_as_number over time, but only between their
    earliest and latest drug administration timestamps. If normalize=True,
    divide each patient’s measurements by their first measurement that falls
    within their drug-administration window.

    Parameters
    ----------
    patient_dict : dict
        Nested dict from build_patient_measurement_dict(), keyed by patient_id.
    responsedf : pd.DataFrame
        Must contain ['person_id', 'responder'] where responder is True/False.
    drugdict : dict
        {
          patient_id1: {
            'all_visits': [...],
            'drug_administrations': [ts1, ts2, …]
          },
          ...
        }
    normalize : bool, default False
        If True, normalize each patient’s values to the first value within
        the drug-administration window.
    """
    # Build responder lookup
    resp_map = dict(zip(responsedf['person_id'], responsedf['responder']))
    responders    = [pid for pid in patient_dict if resp_map.get(pid) is True]
    nonresponders = [pid for pid in patient_dict if resp_map.get(pid) is False]

    plt.figure(figsize=(12, 6))

    def _get_drug_window(pid):
        info = drugdict.get(pid, {})
        times = info.get('drug_administrations', [])
        ts = pd.to_datetime(pd.Series(times), errors='coerce').dropna()
        if ts.empty:
            return None, None
        return ts.min(), ts.max()

    # Plot each patient clipped to their drug window
    for pid_list, color in [(responders, 'blue'), (nonresponders, 'red')]:
        for pid in pid_list:
            start, end = _get_drug_window(pid)
            if start is None:
                continue

            # measurement dates & raw values
            dates = pd.to_datetime(patient_dict[pid]['measurement_DATE'])
            raw   = np.array(patient_dict[pid]['value_as_number'], dtype=float)

            # mask to window and non-nans
            mask_time = (dates >= start) & (dates <= end)
            valid_mask = mask_time & (~np.isnan(raw))
            if not valid_mask.any():
                continue

            # determine normalization factor
            if normalize:
                first_idx = np.where(valid_mask)[0][0]
                norm_factor = raw[first_idx]
                if norm_factor and not np.isnan(norm_factor):
                    vals = raw / norm_factor
                else:
                    vals = raw.copy()
            else:
                vals = raw

            # final mask for plotting
            plot_mask = valid_mask
            plt.plot(dates[plot_mask], vals[plot_mask], color=color, alpha=0.3)

    # Helper to fit & plot group trend lines within drug windows
    def _fit_and_plot(pids, color):
        xnums, yvals = [], []
        for pid in pids:
            start, end = _get_drug_window(pid)
            if start is None:
                continue

            dates = pd.to_datetime(patient_dict[pid]['measurement_DATE'])
            raw   = pd.to_numeric(patient_dict[pid]['value_as_number'], errors='coerce')

            mask_time = (dates >= start) & (dates <= end)
            valid_mask = mask_time & (~pd.isna(raw))
            if not valid_mask.any():
                continue

            if normalize:
                first_idx = np.where(valid_mask)[0][0]
                norm_factor = raw[first_idx]
                if norm_factor and not np.isnan(norm_factor):
                    vals = raw / norm_factor
                else:
                    vals = raw.copy()
            else:
                vals = raw

            xnums.append(mdates.date2num(dates[valid_mask]))
            yvals.append(vals[valid_mask])

        if not xnums:
            return
        x_all = np.concatenate(xnums)
        y_all = np.concatenate(yvals)
        slope, intercept = np.polyfit(x_all, y_all, 1)
        x_fit = np.linspace(x_all.min(), x_all.max(), 100)
        y_fit = slope * x_fit + intercept
        plt.plot(mdates.num2date(x_fit), y_fit, color=color, linewidth=2, label=f"{color.title()} trend")

    _fit_and_plot(responders, 'blue')
    _fit_and_plot(nonresponders, 'red')

    # Final formatting
    plt.xlabel("Measurement Date")
    ylabel = "Normalized Value" if normalize else "Value as Number"
    plt.ylabel(ylabel)
    title = "Measurements Over Time by Responder Status"
    if normalize:
        title += " (normalized to first value in drug window)"
    plt.title(title)
    plt.gcf().autofmt_xdate()

    # Legend placeholders
    plt.plot([], [], color='blue',  alpha=0.3, label="Responder")
    plt.plot([], [], color='red',   alpha=0.3, label="Non-responder")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left")

    plt.tight_layout()
    plt.show()


plot_time_series_by_response_with_drug_window(
    filtered_patient_dict,
    responder_df,
    summary_dict,
    normalize=True
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.dates as mdates

def plot_single_patient_with_drug_markers(patient_dict: dict,
                                          responsedf: pd.DataFrame,
                                          drugdict: dict,
                                          patient_id,
                                          normalize: bool = False):
    """
    Plot a single patient’s measurements over their drug-administration window,
    marking each drug administration with a green dot.

    Parameters
    ----------
    patient_dict : dict
        From build_patient_measurement_dict(); keyed by patient_id.
    responsedf : pd.DataFrame
        Must have ['person_id', 'responder'] where responder is True/False.
    drugdict : dict
        {
          patient_id: {
            'all_visits': [...],
            'drug_administrations': [timestamp1, timestamp2, …]
          },
          …
        }
    patient_id : hashable
        The ID of the patient to plot.
    normalize : bool, default False
        If True, divide values by the first measurement within the drug window.
    """
    # 1. Lookup responder status & choose color
    resp_map = dict(zip(responsedf['person_id'], responsedf['responder']))
    is_resp = resp_map.get(patient_id, False)
    color   = 'blue' if is_resp else 'red'
    label   = f"{patient_id} ({'responder' if is_resp else 'non-responder'})"

    # 2. Determine drug window
    info = drugdict.get(patient_id, {})
    admin_times = pd.to_datetime(
        pd.Series(info.get('drug_administrations', [])),
        errors='coerce'
    ).dropna()
    if admin_times.empty:
        print(f"No drug administrations for patient {patient_id}.")
        return
    window_start = admin_times.min()
    window_end   = admin_times.max()

    # 3. Grab & filter measurements
    dates = pd.to_datetime(patient_dict[patient_id]['measurement_DATE'])
    raw   = np.array(patient_dict[patient_id]['value_as_number'], dtype=float)
    mask_time  = (dates >= window_start) & (dates <= window_end)
    mask_valid = mask_time & (~np.isnan(raw))
    if not mask_valid.any():
        print(f"No measurements for patient {patient_id} in drug window.")
        return
    meas_dates = dates[mask_valid]
    meas_vals  = raw[mask_valid]

    # 4. Normalize if requested (to first in window)
    if normalize:
        norm_factor = meas_vals[0]
        if norm_factor and not np.isnan(norm_factor):
            meas_vals = meas_vals / norm_factor

    # 5. Plot time series line
    plt.figure(figsize=(8, 4))
    plt.plot(meas_dates, meas_vals,
             color=color, alpha=0.7, linewidth=2,
             label=label)

    # 6. Mark each drug administration with a green dot on the line
    admin_in_window = admin_times[(admin_times >= window_start) & (admin_times <= window_end)]
    if not admin_in_window.empty:
        x_nums  = mdates.date2num(meas_dates)
        y_vals  = meas_vals
        admin_nums = mdates.date2num(admin_in_window)
        # interpolate y at each admin time
        admin_y = np.interp(admin_nums, x_nums, y_vals)
        plt.scatter(admin_in_window, admin_y,
                    color='green', marker='o',
                    label='Drug administration', zorder=5)

    # 7. Final touches
    plt.xlabel("Measurement Date")
    ylabel = "Normalized Value" if normalize else "Value as Number"
    plt.ylabel(ylabel)
    title = f"Patient {patient_id} Time Series"
    if normalize:
        title += " (normalized to first in drug window)"
    plt.title(title)
    plt.gcf().autofmt_xdate()
    plt.legend()
    plt.tight_layout()
    plt.show()


plot_single_patient_with_drug_markers(
    filtered_patient_dict,
    responder_df,
    summary_dict,
    patient_id=30116620,
    normalize=False
)


In [None]:
import gseapy as gp
print(gseapy.__version__)

In [None]:
names = gp.get_library_name()
names.index('GO_Molecular_Function_2018')
names[77:88]

In [None]:
## download library or read a .gmt file
go_mf = gp.get_library(name='GO_Molecular_Function_2025', organism='Human')
go_mf

In [None]:
# Load the GO Molecular Function 2023 library
go_mf = gp.get_library(name='GO_Molecular_Function_2025', organism='Human')

# Convert the library into a more usable format
pathway_proteins = []

for pathway, proteins in go_mf.items():
    pathway_proteins.append({"Pathway": pathway, "Proteins": proteins})

# Optionally, convert the result to a pandas DataFrame for easier viewing and manipulation
import pandas as pd

pathway_df = pd.DataFrame(pathway_proteins)

# Save to a CSV if needed
pathway_df.to_csv("go_mf_pathways_and_proteins.csv", index=False)

# Display the first few rows
pathway_df

In [None]:
actual_prot = pd.read_csv('../all_proteomics_lc_cleaned.csv')
# Extract column names that end with "_protein"
protein_columns = [col for col in actual_prot.columns if col.endswith('_protein')]

# Capitalize everything and remove the "_protein" suffix
background_proteins = [col.replace('_protein', '').upper() for col in protein_columns]

# Resulting list of cleaned column names
background_proteins


In [None]:
import math
def get_top_bottom_pct(df,
                       feature_col: str = 'feature',
                       odds_col: str = 'odds_ratio',
                       pct: float = 0.10,
                       suffix: str = '_prediction'):
    """
    Returns two lists:
      - top_pct: the features in the top pct fraction by odds_ratio
      - bot_pct: the features in the bottom pct fraction by odds_ratio

    Exactly ceil(len(df) * pct) items in each list.
    Strips the given suffix from each feature name.
    """
    # how many to take
    n = math.ceil(len(df) * pct)

    # sort descending for top
    top_df = df.sort_values(by=odds_col, ascending=False).head(n)
    top_pct = (
        top_df[feature_col]
        .str.replace(f'{suffix}$', '', regex=True)
        .tolist()
    )

    # sort ascending for bottom
    bot_df = df.sort_values(by=odds_col, ascending=True).head(n)
    bot_pct = (
        bot_df[feature_col]
        .str.replace(f'{suffix}$', '', regex=True)
        .tolist()
    )

    return top_pct, bot_pct


top10pct, bot10pct = get_top_bottom_pct(coef_df_sorted)
print(len(top10pct), len(bot10pct))  # each should be math.ceil(2923*0.1)=293
print("Top 10% features:", top10pct)
print("Bottom 10% features:", bot10pct)



In [None]:
# backgound only reconigized a gene list input.
top10_enr_bg = gp.enrichr(gene_list=top10pct,
                 gene_sets=['GO_Molecular_Function_2025'],
                 # organism='human', # organism argment is ignored because user input a background
                 background=background_proteins,
                 outdir=None, # don't write to disk
                )

top10_enr_bg.results

In [None]:
top10_enr_bg.results.head(20)

In [None]:
# backgound only reconigized a gene list input.
bot10_enr_bg = gp.enrichr(gene_list=bot10pct,
                 gene_sets=['GO_Molecular_Function_2025'],
                 # organism='human', # organism argment is ignored because user input a background
                 background=background_proteins,
                 outdir=None, # don't write to disk
                )

bot10_enr_bg.results

In [None]:
import gseapy as gp
from gseapy.parser import get_library

# 2. Load the Enrichr library you care about
go_lib = get_library(name='GO_Molecular_Function_2025', organism='Human')

# 3. Build the universe of all genes in that library
enrichr_genes = set().union(*(set(genes) for genes in go_lib.values()))

# 4. Find which of your background genes are missing
missing = [g for g in background_proteins if g not in enrichr_genes]

print(f"{len(missing)} genes of your {len(background_proteins)}-gene background were NOT found in GO_Molecular_Function_2025")
print("Missing genes:", missing)


In [None]:
import gseapy as gp
from gseapy.parser import get_library

def load_enrichr_library(lib_name: str, organism: str = 'Human') -> set:
    """
    Download an Enrichr library and return the set of all member genes/IDs.
    
    Parameters
    ----------
    lib_name : str
        Name of the Enrichr gene set (e.g. 'GO_Molecular_Function_2025').
    organism : str
        Organism for the library ('Human' or 'Mouse').
    
    Returns
    -------
    Set[str]
        All gene symbols or Entrez IDs in the library.
    """
    gs_dict = get_library(name=lib_name, organism=organism)
    # union all term lists into one flat set
    return set().union(*(set(genes) for genes in gs_dict.values()))

def is_in_enrichr(gene: str, lib_genes: set) -> bool:
    """
    Case‐insensitively check if `gene` is in the enrichr library set.
    """
    return gene.upper() in {g.upper() for g in lib_genes}

# --- Example usage ---

# 1. Load your library of interest
library_genes = load_enrichr_library('GO_Molecular_Function_2025')

# 2. Check one or more genes
for test_gene in ['LEU1', 'T1']:
    found = is_in_enrichr(test_gene, library_genes)
    print(f"{test_gene}: {'FOUND' if found else 'NOT found'} in GO_Molecular_Function_2025")


In [None]:


# remove the suffix "_prediction" if present
coef_df_sorted['feature'] = coef_df_sorted['feature'].str.replace(r'_prediction$', '', regex=True)

coef_df_sorted


In [None]:
coef_df_sorted.to_csv('./test.csv', index=False)