In [None]:
import numpy as np
import pandas as pd
import os
import math
import warnings
from datetime import datetime
from typing import Dict, Sequence, Optional, Tuple, Iterable
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.exceptions import ConvergenceError
from lifelines.statistics import logrank_test
from sklearn.linear_model import Lasso
from sklearn.model_selection import (
    GridSearchCV,
    cross_val_score,
    learning_curve,
)
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr, linregress, zscore
from tqdm import tqdm
import joblib
from matplotlib.lines import Line2D


# Train aging clock using measured proteins

In [None]:
# Read in files
measuredprot_train = pd.read_csv('/path/to/measured_proteins/training_split')
measuredprot_test = pd.read_csv('/path/to/measured_proteins/test_split')

rabitprot_train = pd.read_csv('/path/to/RABIT_proteins/training_split')
rabitprot_train.rename(columns={'patient_ids': 'eid'}, inplace=True)
rabitprot_train.columns = [col.replace('_prediction', '_protein') if '_prediction' in col else col for col in rabitprot_train.columns]

rabitprot_test = pd.read_csv('/path/to/RABIT_proteins/test_split')
rabitprot_test.rename(columns={'patient_ids': 'eid'}, inplace=True)
rabitprot_test.columns = [col.replace('_prediction', '_protein') if '_prediction' in col else col for col in rabitprot_test.columns]

ehr_train = pd.read_csv('/path/to/EHR ehr representation/training_split')
ehr_train.drop(columns=["labeling_time"], inplace=True)
ehr_train.rename(columns={"patient_ids": "eid"}, inplace=True)

ehr_test = pd.read_csv('/path/to/EHR ehr representation/test_split')
ehr_test.drop(columns=["labeling_time"], inplace=True)
ehr_test.rename(columns={"patient_ids": "eid"}, inplace=True)

In [None]:
# Metadata dataframe with columns "patient_id" (patient identifier), "prediction_time" (date of RABIT protein generation),
## "birth_datetime" (patient birthday), and "age_at_prediction" (patient age at date of RABIT protein generation)
age_df = pd.read_csv('./info_files/all_ukbb_patients_ages_at_proteomics_collection.csv')
age_df

### organ specific protein dataframe downloaded from https://www.biorxiv.org/content/10.1101/2024.06.07.597771v1

In [None]:
organ_prot = pd.read_csv('/path/to/organ_specific_proteins/organ_specific_proteins_for_aging.csv')
organ_prot

In [None]:
# Conver to dictionary
def dataframe_to_dict_unique(df):
    result = {}
    for col in df.columns:
        # Get non-NaN values from the column as a list.
        values = df[col].dropna().tolist()
        # Remove duplicates while preserving the order.
        unique_values = list(dict.fromkeys(values))
        result[col] = unique_values
    return result

organ_prot_cleaned_dict = dataframe_to_dict_unique(organ_prot)
organ_prot_cleaned_dict


In [None]:
# Subset proteins by organspecific protein list
def subset_protdf_by_organkey(result_dict, protdf, organkey):
    # Retrieve the protein list for the specified organkey.
    protein_list = result_dict.get(organkey, [])
    
    # Count how many entries are multiprote strings (i.e. contain a dot).
    multiprote_count = sum(1 for x in protein_list if "." in x)
    
    # Split any multiprote strings into individual protein names.
    individual_proteins = []
    for prot in protein_list:
        if "." in prot:
            parts = prot.split(".")
            individual_proteins.extend(parts)
        else:
            individual_proteins.append(prot)
    
    total_individuals = len(individual_proteins)

    col_mapping = {}
    for col in protdf.columns:
        if col in ['eid', 'age_at_prediction']:
            continue
        if col.endswith('_protein'):
            cleaned = col[:-8]  # Remove '_protein'
        else:
            cleaned = col
        # Use lowercase for matching
        col_mapping[cleaned.lower()] = col

    # Search for matches and record any protein names that aren't found.
    matched_columns = []
    found_count = 0
    not_found = []
    
    for prot in individual_proteins:
        prot_lower = prot.lower()
        if prot_lower in col_mapping:
            matched_columns.append(col_mapping[prot_lower])
            found_count += 1
        else:
            not_found.append(prot)
    
    # Remove duplicate matched columns (preserving order) and always include the mandatory columns.
    final_cols = ['eid', 'age_at_prediction'] + list(dict.fromkeys(matched_columns))
    
    # Subset the DataFrame.
    subset_df = protdf[final_cols].copy()
    
    # Print the summary information.
    print(f"Out of {total_individuals} individual protein names from '{organkey}', {found_count} were found in protdf.")
    print(f"There were {multiprote_count} multiprote strings in '{organkey}' to begin with.")
    
    # Remove duplicates from not_found (preserving order) before printing.
    not_found_unique = list(dict.fromkeys(not_found))
    if not_found_unique:
        print("The following protein names were not found in protdf:")
        for name in not_found_unique:
            print(name)
    else:
        print("All protein names were found in protdf.")
    
    return subset_df


In [None]:
measuredprot_train_training = pd.merge(
    measuredprot_train,
    age_df[['patient_id', 'age_at_prediction']],  # Use a list for column selection
    left_on='eid',
    right_on='patient_id',
    how='left'
)

measuredprot_train_training.drop(columns='patient_id', inplace=True)


measuredprot_test_test = pd.merge(
    measuredprot_test,
    age_df[['patient_id', 'age_at_prediction']],  # Use a list for column selection
    left_on='eid',
    right_on='patient_id',
    how='left'
)
measuredprot_test_test.drop(columns='patient_id', inplace=True)


In [None]:
def plot_learning_curve_and_figures(currentorgan, Xtrain, ytrain, Xtest, ytest, ypred, organdictionary, best_alpha, modelpath):            
    # Directory to save plots
    plot_dir = f"{modelpath}/{currentorgan}_model"

    os.makedirs(plot_dir, exist_ok=True)
    
    # Generate Learning Curve
    print("Generating Learning Curve...")
    train_sizes, train_scores, val_scores = learning_curve(
        Lasso(alpha=best_alpha, random_state=42, max_iter=5000),
        Xtrain,
        ytrain,
        cv=5,
        scoring='neg_mean_squared_error',
        train_sizes=np.linspace(0.1, 1.0, 10)
    )
    
    # Calculate mean and standard deviation of scores
    train_mean = -train_scores.mean(axis=1)  # Negate to get positive MSE
    train_std = train_scores.std(axis=1)
    val_mean = -val_scores.mean(axis=1)
    val_std = val_scores.std(axis=1)
    
    # Plot learning curve and show in Jupyter Notebook
    plt.figure(figsize=(10, 6))
    plt.plot(train_sizes, train_mean, label='Training Error', marker='o')
    plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.2)
    plt.plot(train_sizes, val_mean, label='Validation Error', marker='o')
    plt.fill_between(train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.2)
    plt.xlabel('Training Set Size')
    plt.ylabel('Mean Squared Error')
    plt.title('Learning Curve for Lasso Regression')
    plt.legend()
    plt.grid()
    
    # Save the learning curve as an SVG file
    learning_curve_svg_path = os.path.join(plot_dir, 'learning_curve.svg')
    plt.savefig(learning_curve_svg_path, format='svg')
    print(f"{currentorgan} Learning curve saved to {learning_curve_svg_path}")
    
    # Show plot in the notebook
    plt.show()
    plt.close()
    
    # 1. Calculate Pearson Correlation
    pearson_corr, p_value = pearsonr(ytest, ypred)
    print(f"Pearson Correlation: {pearson_corr:.4f} (p-value: {p_value:.4e})")

    
    # 2. Generate Predicted vs Actual Values Plot
    print("Generating Predicted vs Actual Values Plot...")
    plt.figure(figsize=(10, 6))
    plt.scatter(ytest, ypred, alpha=0.6, label='Predictions')
    plt.plot([ytest.min(), ytest.max()], [ytest.min(), ytest.max()], 'r--', label='Perfect Prediction')
    plt.xlabel('Actual Age')
    plt.ylabel('Predicted Age')
    plt.title(f'{currentorgan} Predicted vs Actual Values')
    
    # Display the Pearson correlation value on the graph
    plt.text(0.05, 0.95, f'Pearson r = {pearson_corr:.4f}', transform=plt.gca().transAxes,
             fontsize=14, verticalalignment='top')
    
    plt.legend()
    plt.grid()
    
    # Save the plot as an SVG file
    predicted_vs_actual_svg_path = os.path.join(plot_dir, 'predicted_vs_actual.svg')
    plt.savefig(predicted_vs_actual_svg_path, format='svg')
    print(f"{currentorgan} Predicted vs Actual plot saved to {predicted_vs_actual_svg_path}")
    
    # Show plot in the notebook
    plt.show()
    plt.close()

def train_organspecific_model(traindf, testdf, organdictionary, modelpath):
    organlist = organdictionary.keys()
    for organ in organlist:
        # Subset dataframes
        measuredprot_train_training_organ = subset_protdf_by_organkey(organdictionary, traindf, organ)
        measuredprot_test_test_organ = subset_protdf_by_organkey(organdictionary, testdf, organ)

        # Clean data and standardize
        X_train = measuredprot_train_training_organ.drop(columns=['eid', 'age_at_prediction'])
        y_train = measuredprot_train_training_organ['age_at_prediction']
        print(f"{organ}: {len(X_train.columns)} features — first five:",
        X_train.columns[:5].tolist())

        X_test = measuredprot_test_test_organ.drop(columns=['eid', 'age_at_prediction'])
        y_test = measuredprot_test_test_organ['age_at_prediction']

        imputer = SimpleImputer(strategy='mean')  # Replace missing values with the mean
        X_train = pd.DataFrame(imputer.fit_transform(X_train), columns=X_train.columns)
        X_test = pd.DataFrame(imputer.transform(X_test), columns=X_test.columns)
        
        scaler = StandardScaler()
        X_train = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
        X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns)
        
        # Hyperparameter tuning
        alphas = np.logspace(-2, 1, 20)  # Test 20 values instead of 50
        best_alpha = None
        best_score = -float('inf')
        
        print("Starting Lasso hyperparameter tuning...")
        for alpha in tqdm(alphas, desc="Tuning alpha"):
            model = Lasso(alpha=alpha, random_state=42, max_iter=5000)  # Increased max_iter
            scores = cross_val_score(model, X_train, y_train, cv=3, scoring='r2') # Can potentially change this scoring function
            mean_score = scores.mean()
        
            if mean_score > best_score:
                best_score = mean_score
                best_alpha = alpha
        
        print(f"Best alpha: {best_alpha}")
        print(f"Best cross-validated R^2: {best_score}")
        
        # Train final model (best alpha)
        lasso = Lasso(alpha=best_alpha, random_state=42, max_iter=5000)
        lasso.fit(X_train, y_train)
        
        # Test on test set
        y_pred = lasso.predict(X_test)
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        
        print(f"Mean Squared Error (Test Set): {mse}")
        print(f"R^2 Score (Test Set): {r2}")
        
        # Save model and components
        model_dir = f"{modelpath}/{organ}_model"
        os.makedirs(model_dir, exist_ok=True)
        model_filename = os.path.join(model_dir, 'lasso_age_predictor.pkl')
        joblib.dump(lasso, model_filename)
        print(f"Model saved to {model_filename}")
        
        # Save the imputer
        imputer_filename = os.path.join(model_dir, 'imputer.pkl')
        joblib.dump(imputer, imputer_filename)
        print(f"Imputer saved to {imputer_filename}")
        
        # Save the scaler
        scaler_filename = os.path.join(model_dir, 'scaler.pkl')
        joblib.dump(scaler, scaler_filename)
        print(f"Scaler saved to {scaler_filename}")
        
        # Feature importance analysis
        feature_importance = pd.DataFrame({
            'Feature': X_train.columns,
            'Coefficient': lasso.coef_
        })
        
        important_features = feature_importance[feature_importance['Coefficient'] != 0]
        print("Important Features with Non-Zero Coefficients:")
        print(important_features)
        important_features_filename = os.path.join(model_dir, 'important_features.csv')
        important_features.to_csv(important_features_filename, index=False)
        print(f"Important features saved to {important_features_filename}")

        print(f"Creating learning curves and figures for {organ}")
        plot_learning_curve_and_figures(organ, X_train, y_train, X_test, y_test, y_pred, organdictionary, best_alpha, modelpath)



In [None]:
# Train model
train_organspecific_model(measuredprot_train_training, measuredprot_test_test, organ_prot_cleaned_dict, 
                          modelpath='aging_clock_model_measuredprot_organ_specific')

# EHR aging clock training

In [None]:
ehr_train_training = pd.merge(
    ehr_train,
    age_df[['patient_id', 'age_at_prediction']],  # Use a list for column selection
    left_on='eid',
    right_on='patient_id',
    how='left'
)

ehr_train_training.drop(columns='patient_id', inplace=True)


ehr_test_test = pd.merge(
    ehr_test,
    age_df[['patient_id', 'age_at_prediction']],  # Use a list for column selection
    left_on='eid',
    right_on='patient_id',
    how='left'
)
ehr_test_test.drop(columns='patient_id', inplace=True)


In [None]:
# Clean data and standardize data
X_train = ehr_train_training.drop(columns=['eid', 'age_at_prediction'])
y_train = ehr_train_training['age_at_prediction']

X_test = ehr_test_test.drop(columns=['eid', 'age_at_prediction'])
y_test = ehr_test_test['age_at_prediction']

imputer = SimpleImputer(strategy='mean')
X_train = pd.DataFrame(imputer.fit_transform(X_train), columns=X_train.columns)
X_test = pd.DataFrame(imputer.transform(X_test), columns=X_test.columns)

scaler = StandardScaler()
X_train = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns)

# Hyperparameter tuning for LASSO (conducted in original study)
alphas = np.logspace(-2, 1, 20)  # Test 20 values instead of 50
best_alpha = None
best_score = -float('inf')

print("Starting Lasso hyperparameter tuning...")
for alpha in tqdm(alphas, desc="Tuning alpha"):
    model = Lasso(alpha=alpha, random_state=42, max_iter=5000)  # Increased max_iter
    scores = cross_val_score(model, X_train, y_train, cv=3, scoring='r2') # Can potentially change this scoring function
    mean_score = scores.mean()

    if mean_score > best_score:
        best_score = mean_score
        best_alpha = alpha

print(f"Best alpha: {best_alpha}")
print(f"Best cross-validated R^2: {best_score}")

# Train final model (best alpha)
lasso = Lasso(alpha=best_alpha, random_state=42, max_iter=5000)
lasso.fit(X_train, y_train)

# Test on hold-out test set
y_pred = lasso.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Mean Squared Error (Test Set): {mse}")
print(f"R^2 Score (Test Set): {r2}")

# Save model components
model_dir = 'aging_clock_model_ehr_hamilton'
os.makedirs(model_dir, exist_ok=True)
model_filename = os.path.join(model_dir, 'lasso_age_predictor.pkl')
joblib.dump(lasso, model_filename)
print(f"Model saved to {model_filename}")

imputer_filename = os.path.join(model_dir, 'imputer.pkl')
joblib.dump(imputer, imputer_filename)
print(f"Imputer saved to {imputer_filename}")

scaler_filename = os.path.join(model_dir, 'scaler.pkl')
joblib.dump(scaler, scaler_filename)
print(f"Scaler saved to {scaler_filename}")

# Identify feature importances
feature_importance = pd.DataFrame({
    'Feature': X_train.columns,
    'Coefficient': lasso.coef_
})

important_features = feature_importance[feature_importance['Coefficient'] != 0]
print("Important Features with Non-Zero Coefficients:")
print(important_features)

# Save important features
important_features_filename = os.path.join(model_dir, 'important_features.csv')
important_features.to_csv(important_features_filename, index=False)
print(f"Important features saved to {important_features_filename}")


In [None]:
def run_age_clock_model(input_data, model_dir, suffix):
    # Load the saved model, imputer, and scaler
    lasso = joblib.load(os.path.join(model_dir, 'lasso_age_predictor.pkl'))
    imputer = joblib.load(os.path.join(model_dir, 'imputer.pkl'))
    scaler = joblib.load(os.path.join(model_dir, 'scaler.pkl'))
    
    # Separate features and target from new data
    X_new = input_data.drop(columns=['eid', 'age_at_prediction'])
    y_new = input_data['age_at_prediction']
    
    # Handle missing values in new data
    X_new = pd.DataFrame(imputer.transform(X_new), columns=X_new.columns)
    
    # Standardize features in new data
    X_new = pd.DataFrame(scaler.transform(X_new), columns=X_new.columns)
    
    # Make predictions on the new data
    y_pred_new = lasso.predict(X_new)
    
    # Create a DataFrame with the results
    result_df = pd.DataFrame({
        'eid': input_data['eid'],
        'actual_age': y_new,  # actual age
        f'predicted_age_{suffix}': y_pred_new  # predicted age with suffix
    })
    
    # Evaluate the model on the new dataset
    mse_new = mean_squared_error(y_new, y_pred_new)
    r2_new = r2_score(y_new, y_pred_new)
    
    print(f"Mean Squared Error (New Data): {mse_new}")
    print(f"R^2 Score (New Data): {r2_new}")
    
    return result_df



ehr_pred = run_age_clock_model(ehr_test_test, model_dir='aging_clock_model_ehr', suffix='ehr')


# Calculate protein age gaps for RABIT and measured proteins using trained protein clock

In [None]:
def run_age_clock_model(input_data, parent_model_dir, suffix, orgdict):
    results = {}
    for subdir in os.listdir(parent_model_dir):
        full_model_dir = os.path.join(parent_model_dir, subdir)
        if not os.path.isdir(full_model_dir):
            continue
        
        # Derive the organ name (assuming subdirectory names like 'heart_model')
        organ_name = subdir.replace('_model', '')
        print(f"\n=== Processing {organ_name} ===")
        input_data_subset = subset_protdf_by_organkey(orgdict, input_data, organ_name)
        if input_data_subset.empty:
            print(f"No data for organ {organ_name}. Skipping...")
            continue
        
        try:
            # Load the saved model, imputer, and scaler for the current organ
            model_path = os.path.join(full_model_dir, 'lasso_age_predictor.pkl')
            imputer_path = os.path.join(full_model_dir, 'imputer.pkl')
            scaler_path = os.path.join(full_model_dir, 'scaler.pkl')
            
            lasso = joblib.load(model_path)
            imputer = joblib.load(imputer_path)
            scaler = joblib.load(scaler_path)
            print(f"Loaded artifacts for {organ_name} from {full_model_dir}")
        except Exception as e:
            print(f"Error loading model artifacts for {organ_name}: {e}")
            continue
        try:
            X_new = input_data_subset.drop(columns=['eid', 'age_at_prediction'])
            y_new = input_data_subset['age_at_prediction']
            
            # Preprocess new data: impute then scale
            X_new_imputed = pd.DataFrame(imputer.transform(X_new), columns=X_new.columns)
            X_new_scaled = pd.DataFrame(scaler.transform(X_new_imputed), columns=X_new_imputed.columns)
        except Exception as e:
            print(f"Error processing input data for {organ_name}: {e}")
            continue
        
        try:
            # Make predictions
            y_pred_new = lasso.predict(X_new_scaled)
            mse_new = mean_squared_error(y_new, y_pred_new)
            r2_new = r2_score(y_new, y_pred_new)
            print(f"Results for {organ_name}: MSE = {mse_new:.4f}, R² = {r2_new:.4f}")
        except Exception as e:
            print(f"Error during prediction for {organ_name}: {e}")
            continue
        
        result_df = pd.DataFrame({
            'eid': input_data_subset['eid'],
            'actual_age': y_new,
            f'predicted_age_{suffix}': y_pred_new
        })
        
        results[organ_name] = result_df
    
    return results

rabit_results = run_age_clock_model(
    input_data=testsplit_rabit_test,
    parent_model_dir='aging_clock_model_measuredprot_organ_specific',
    suffix='rabit',
    orgdict=organ_prot_cleaned_dict
)

measuredprot_results = run_age_clock_model(
    input_data=testsplit_measuredprot_test,
    parent_model_dir='aging_clock_model_measuredprot_organ_specific',
    suffix='measuredprot',
    orgdict=organ_prot_cleaned_dict
)


In [None]:
# Organize all results (ehr, RABIT, Measured) into one dictionary for downstream processing

tissue_keys = ['Adipose', 'Artery', 'Brain', 'Heart', 'Immune', 'Intestine',
               'Kidney', 'Liver', 'Lung', 'Muscle', 'Pancreas', 'Organismal', 'Conventional']

# Create a dictionary of dataframes by copying the original df for each tissue
ehr_results = {tissue: ehr_pred.copy() for tissue in tissue_keys}

# Initialize a dictionary to store the merged dataframes for each organ
age_gap_dfs = {}

# Loop through each organ available in measuredprot_results
for organ, measuredprot_pred in measuredprot_results.items():
    if organ in rabit_results:
        rabit_pred_clean = rabit_results[organ].drop(columns=['actual_age'])
        ehr_results_clean = ehr_results[organ].drop(columns=['actual_age'])
        merged_df = pd.merge(measuredprot_pred, rabit_pred_clean, on='eid', how='inner')
        merged_df2 = pd.merge(merged_df, ehr_results_clean, on='eid', how='inner')

        # Store the merged dataframe in the dictionary, keyed by the organ name
        age_gap_dfs[organ] = merged_df2

age_gap_dfs



In [None]:
# Calculate age gaps
def calculate_age_gaps(agegapdf):

    # Make a copy of the DataFrame to avoid modifying the original data
    df = agegapdf.copy()

    # Ensure that the necessary columns are present
    required_columns = [
        "eid",
        "actual_age",
        "predicted_age_measuredprot",
        "predicted_age_rabit",
        "predicted_age_ehr"
    ]
    for col in required_columns:
        if col not in df.columns:
            raise ValueError(f"Missing required column: '{col}'")

    # Define a helper function to compute smoothed predicted age using LOESS
    def compute_smoothed_pred_age(actual_age, pred_age, frac=2/3):

        # Drop any rows with missing values in actual_age or pred_age
        valid_idx = actual_age.notna() & pred_age.notna()
        actual_age_valid = actual_age[valid_idx]
        pred_age_valid = pred_age[valid_idx]

        if len(actual_age_valid) == 0:
            raise ValueError("No valid data points available for LOESS fitting.")

        # Sort the data by actual_age
        sorted_indices = actual_age_valid.argsort()
        sorted_age = actual_age_valid.iloc[sorted_indices]
        sorted_pred = pred_age_valid.iloc[sorted_indices]

        # Fit LOESS using statsmodels' lowess function
        lowess_results = sm.nonparametric.lowess(
            endog=sorted_pred,
            exog=sorted_age,
            frac=frac,
            return_sorted=True
        )

        # Extract the smoothed predicted ages and the corresponding ages
        smoothed_age = lowess_results[:, 0]
        smoothed_pred = lowess_results[:, 1]

        # Interpolate the smoothed predicted ages for the original actual_age
        smoothed_pred_age = np.interp(
            actual_age,
            smoothed_age,
            smoothed_pred,
            left=np.nan,  
            right=np.nan  
        )

        return smoothed_pred_age

    # Calculate smoothed predicted ages and age gaps for real proteins
    df['smoothed_pred_measuredprot'] = compute_smoothed_pred_age(
        actual_age=df['actual_age'],
        pred_age=df['predicted_age_measuredprot'],
        frac=2/3
    )
    df['age_gap_measuredprot'] = df['predicted_age_measuredprot'] - df['smoothed_pred_measuredprot']

    # Calculate smoothed predicted ages and age gaps for synthetic proteins
    df['smoothed_pred_rabit'] = compute_smoothed_pred_age(
        actual_age=df['actual_age'],
        pred_age=df['predicted_age_rabit'],
        frac=2/3
    )
    df['age_gap_rabit'] = df['predicted_age_rabit'] - df['smoothed_pred_rabit']

    # Calculate smoothed predicted ages and age gaps for ehr proteins
    df['smoothed_pred_ehr'] = compute_smoothed_pred_age(
        actual_age=df['actual_age'],
        pred_age=df['predicted_age_ehr'],
        frac=2/3
    )
    df['age_gap_ehr'] = df['predicted_age_ehr'] - df['smoothed_pred_ehr']

    # Calculate Z-scores for each age gap
    df['age_gap_measuredprot_zscore'] = (df['age_gap_measuredprot'] - df['age_gap_measuredprot'].mean()) / df['age_gap_measuredprot'].std(ddof=0)
    df['age_gap_rabit_zscore'] = (df['age_gap_rabit'] - df['age_gap_rabit'].mean()) / df['age_gap_rabit'].std(ddof=0)
    df['age_gap_ehr_zscore'] = (df['age_gap_ehr'] - df['age_gap_ehr'].mean()) / df['age_gap_ehr'].std(ddof=0)

    df.drop(['smoothed_pred_measuredprot', 'smoothed_pred_rabit', 'smoothed_pred_ehr'], axis=1, inplace=True)

    return df

def calculate_age_gaps_for_dict(dict_of_dfs):
    processed_dict = {}
    for tissue, df in dict_of_dfs.items():
        # Process each dataframe (using a copy to avoid modifying the original)
        processed_dict[tissue] = calculate_age_gaps(df.copy())
    return processed_dict

# Calculate gaps (LOESS regression, same as original study)
age_gap_simple_dfs = calculate_age_gaps_for_dict(age_gap_dfs)
age_gap_simple_dfs


# Calculating disease hazard ratios for each organ

In [None]:
persondf = pd.read_csv('/path/to/person omop table')

# Create sex dataframe
subset_df = persondf[['person_id', 'gender_concept_id']]

renamed_df = subset_df.rename(columns={
    'person_id': 'eid',
    'gender_concept_id': 'sex'
})

gender_mapping = {
    8507: 'male',
    8532: 'female'
}
renamed_df['sex'] = renamed_df['sex'].map(gender_mapping)
renamed_df

In [None]:
# Initialize a dictionary to store the merged covariates for each organ
covariates_dfs = {}

# Loop over each organ-specific DataFrame in age_gap_simple_dfs
for organ, age_gap_df in age_gap_simple_dfs.items():
    # Merge the renamed_df with the organ-specific age gap DataFrame on 'eid'
    covariates = pd.merge(
        renamed_df,
        age_gap_df,
        on='eid',
        how='inner'
    )
    covariates_dfs[organ] = covariates
covariates_dfs


In [None]:
# Read disease labels. also need age_df from above
master_label = pd.read_csv('/path/to/disease diagnosis labels/from original paper (see biorxiv link above)')


In [None]:
# Functions for cox analysis
import pandas as pd
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError

def perform_multiple_cox_models_individual_censor_dates(
    covdf_dict: dict,   # Dictionary of organ-specific covariate DataFrames, keyed by organ name
    target_df: pd.DataFrame,
    agedf: pd.DataFrame,
    disease_columns: list,
    covariate_columns: list = ['sex', 'actual_age'],
    categorical_covariates: list = ['sex'],
    censor_date: str = None,
    censor_date_df: pd.DataFrame = None,   # New: DataFrame with individual censor dates.
    default_censor_date: str = None,         # Fallback censor date for samples missing one.
    date_format: str = '%Y-%m-%d'
):
    results_dict = {}
    merged_dfs_dict = {}
    
    target_df = target_df.rename(columns={'person_id': 'eid'})
    agedf = agedf.rename(columns={'patient_id': 'eid'})
    target_df[disease_columns] = target_df[disease_columns].replace(0, pd.NaT)
    for disease in disease_columns:
        target_df[disease] = pd.to_datetime(target_df[disease], format=date_format, errors='coerce')
    
    # Convert prediction_time to datetime
    agedf['prediction_time'] = pd.to_datetime(agedf['prediction_time'], format=date_format, errors='coerce')
    
    # Individual censor dates from UK Biobank (varies by patient geographic location)
    if censor_date_df is not None:
        censor_date_df = censor_date_df.rename(columns={'patient_id': 'eid'})
        censor_date_df['censor_date'] = pd.to_datetime(censor_date_df['censor_date'], format=date_format, errors='coerce')

    # Loop over each organ
    for organ, covdf in covdf_dict.items():
        print(f"\nProcessing organ: {organ}")
        merged_df = pd.merge(covdf, agedf, on='eid', how='left')
        merged_df = pd.merge(merged_df, target_df, on='eid', how='left')
        if censor_date_df is not None:
            merged_df = pd.merge(merged_df, censor_date_df[['eid', 'censor_date']], on='eid', how='left')
            # If a default censor date is provided, fill missing values
            if default_censor_date:
                default_censor_datetime = pd.to_datetime(default_censor_date, format=date_format, errors='coerce')
                merged_df['censor_date'] = merged_df['censor_date'].fillna(default_censor_datetime)
            else:
                # If some rows are missing censor_date, raise an error.
                if merged_df['censor_date'].isnull().any():
                    raise ValueError(f"Missing censor dates.")
        else:
            if censor_date:
                common_censor_datetime = pd.to_datetime(censor_date, format=date_format, errors='coerce')
            else:
                # Use the maximum diagnosis date across all diseases as censor_date
                max_dates = target_df[disease_columns].max().max()
                common_censor_datetime = max_dates + pd.Timedelta(days=1)  
            merged_df['censor_date'] = common_censor_datetime
        
        merged_dfs_dict[organ] = merged_df
    
        # unique_censor_dates = merged_df['censor_date'].unique()
        # print(f"Censor date(s) for {organ}: {unique_censor_dates}")
        
        organ_results = {}
        
        # looping through disease per organ
        for disease in disease_columns:
            print(f"\nAnalyzing disease: {disease} for organ: {organ}")
            disease_df = merged_df[['eid', 'prediction_time', disease, 'censor_date'] + covariate_columns].copy()
            disease_df['diagnosis_date'] = pd.to_datetime(disease_df[disease], format=date_format, errors='coerce')
            disease_df['event_occurred'] = (~disease_df['diagnosis_date'].isna()).astype(int)
            # Exclude prevalent cases: those with diagnosis_date <= prediction_time
            disease_df['prevalent'] = disease_df['diagnosis_date'] <= disease_df['prediction_time']
            disease_df = disease_df[~disease_df['prevalent']].copy()
            
            disease_df['time_to_event'] = (disease_df['diagnosis_date'] - disease_df['prediction_time']).dt.days
            disease_df.loc[disease_df['event_occurred'] == 0, 'time_to_event'] = (
                disease_df.loc[disease_df['event_occurred'] == 0, 'censor_date'] -
                disease_df.loc[disease_df['event_occurred'] == 0, 'prediction_time']
            ).dt.days
            disease_df['event_occurred'] = disease_df['diagnosis_date'].notna().astype(int)
            
            # columsn for modeling
            model_df = disease_df[['time_to_event', 'event_occurred'] + covariate_columns].copy()
            if categorical_covariates:
                model_df = pd.get_dummies(model_df, columns=categorical_covariates, drop_first=True)
            model_df = model_df.dropna()
            
            # Initialize and fit the Cox proportional hazards model
            cph = CoxPHFitter()
            formula = " + ".join([col for col in model_df.columns if col not in ['time_to_event', 'event_occurred']])
            try:
                cph.fit(model_df, duration_col='time_to_event', event_col='event_occurred', formula=formula)
                print(cph.summary)
                organ_results[disease] = {
                    'model': cph,
                    'summary': cph.summary
                }
            except ConvergenceError:
                print(f"Model for {disease} in organ {organ} did not converge.")
                organ_results[disease] = {
                    'model': None,
                    'summary': "Model did not converge."
                }
            except Exception as e:
                print(f"An error occurred while fitting the model for {disease} in organ {organ}: {e}")
                organ_results[disease] = {
                    'model': None,
                    'summary': f"Error: {e}"
                }
        
        results_dict[organ] = organ_results

    return results_dict, merged_dfs_dict





def aggregate_cox_results_nested(results_dict, covariate):
    aggregated_results = {}
    for organ, disease_results in results_dict.items():
        aggregated_data = {
            'Disease': [],
            'HR': [],
            'CI_lower': [],
            'CI_upper': [],
            'p': [],
            'Events': []
        }
        
        for disease, content in disease_results.items():
            if isinstance(content, dict):
                summary = content.get('summary')
                model = content.get('model')
                
                if isinstance(summary, pd.DataFrame):
                    if covariate in summary.index:
                        hr = summary.loc[covariate, 'exp(coef)']
                        ci_lower = summary.loc[covariate, 'exp(coef) lower 95%']
                        ci_upper = summary.loc[covariate, 'exp(coef) upper 95%']
                        p_val = summary.loc[covariate, 'p']
                        
                        # Extract number of events from the model (if available)
                        if model is not None and hasattr(model, 'event_observed'):
                            events = model.event_observed.sum()
                        else:
                            events = 'N/A'
                        
                        aggregated_data['Disease'].append(disease)
                        aggregated_data['HR'].append(hr)
                        aggregated_data['CI_lower'].append(ci_lower)
                        aggregated_data['CI_upper'].append(ci_upper)
                        aggregated_data['p'].append(p_val)
                        aggregated_data['Events'].append(events)
                    else:
                        print(f"Covariate '{covariate}' not found in summary for disease '{disease}' in organ '{organ}'. Skipping.")
                else:
                    print(f"Summary for disease '{disease}' in organ '{organ}' is not a DataFrame. Skipping.")
            else:
                print(f"Disease '{disease}' in organ '{organ}' has content of type {type(content)}. Skipping.")
        
        summary_df = pd.DataFrame(aggregated_data)
        
        if not summary_df.empty:
            # Perform FDR correction using Benjamini-Hochberg on the p-values for this organ
            reject, fdr_corrected, _, _ = multipletests(summary_df['p'], method='fdr_bh')
            summary_df['FDR'] = fdr_corrected
        else:
            summary_df['FDR'] = []
        
        aggregated_results[organ] = summary_df
        
    return aggregated_results


def plot_forest_plot_with_annotations_nested(aggregated_results, covariate='age_gap_measuredprot', title_prefix='Forest Plot for'):
    for organ, summary_df in aggregated_results.items():
        summary_df = summary_df.sort_values(by='HR', ascending=True)
        y_positions = np.arange(len(summary_df))
        
        # Determine marker styles based on FDR
        # Solid blue dot if FDR < 0.05, else unfilled blue circle
        facecolors = ['blue' if fdr < 0.05 else 'none' for fdr in summary_df['FDR']]
        edgecolors = ['blue' for _ in summary_df['FDR']]
        
        fig, ax1 = plt.subplots(figsize=(14, 0.6 * len(summary_df) + 3))
        
        # Plot error bars (95CI)
        ax1.errorbar(
            summary_df['HR'],
            y_positions,
            xerr=[summary_df['HR'] - summary_df['CI_lower'], summary_df['CI_upper'] - summary_df['HR']],
            fmt='none',
            ecolor='black',
            elinewidth=1,
            capsize=3
        )
        
        # Plot HR points without overriding facecolors
        scatter = ax1.scatter(
            summary_df['HR'],
            y_positions,
            edgecolors=edgecolors,
            facecolors=facecolors,
            marker='o',
            s=100,
            linewidth=1
        )

        ax1.axvline(x=1, color='red', linestyle='--')
        ax1.set_yticks(y_positions)
        ax1.set_yticklabels(summary_df['Disease'])
        ax1.set_xlabel('Hazard Ratio (HR)')
        ax1.set_title(f'{title_prefix} {organ.capitalize()}')
        x_max = summary_df['CI_upper'].max()
        x_min = summary_df['CI_lower'].min()
        x_buffer = (x_max - x_min) * 0.2
        ax1.set_xlim(x_min - x_buffer, x_max + x_buffer)
        ax1.grid(False)
        ax2 = ax1.twinx()
        ax2.set_ylim(ax1.get_ylim())
        ax2.set_yticks(y_positions)
        ax2.set_yticklabels(summary_df['Events'])
        ax2.set_ylabel('Number of Events')
        ax2.spines['right'].set_visible(False)
        ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        ax2.yaxis.set_label_position("right")
        ax2.yaxis.tick_right()
        ax2.grid(False)
        
        solid_patch = mpatches.Patch(facecolor='blue', edgecolor='blue', label='FDR < 0.05')
        hollow_patch = mpatches.Patch(facecolor='none', edgecolor='blue', label='FDR ≥ 0.05')
        plt.legend(handles=[solid_patch, hollow_patch], loc='upper left')
        
        plt.tight_layout()
        plt.show()



In [None]:
# Read censor date dataframe, varies by nationality
censor_date_df = pd.read_csv('/path/to/censor/df')
censor_date_df

In [None]:
# covariates to include
covariate_columns = ['age_gap_measuredprot_zscore', 'sex', 'actual_age']
categorical_covariates = ['sex']

# default censor (latest followup date across omop)
default_censor_date = '2022-12-31'


disease_columns = [col for col in master_label.columns if col != 'person_id']

results_measuredprot, merged_dfs = perform_multiple_cox_models_individual_censor_dates(
    covdf_dict=covariates_dfs,
    target_df=master_label,
    agedf=age_df,
    disease_columns=disease_columns,
    covariate_columns=covariate_columns,
    categorical_covariates=categorical_covariates,
    censor_date_df=censor_date_df,  
    default_censor_date=default_censor_date,
    date_format='%Y-%m-%d'
)



In [None]:
# covariates to include
covariate_columns = ['age_gap_rabit_zscore', 'sex', 'actual_age']
categorical_covariates = ['sex']

# default censor (latest followup date across omop)
default_censor_date = '2022-12-31'


disease_columns = [col for col in master_label.columns if col != 'person_id']

results_rabit, merged_dfs = perform_multiple_cox_models_individual_censor_dates(
    covdf_dict=covariates_dfs,  
    target_df=master_label,
    agedf=age_df,
    disease_columns=disease_columns,
    covariate_columns=covariate_columns,
    categorical_covariates=categorical_covariates,
    censor_date_df=censor_date_df,    
    default_censor_date=default_censor_date,
    date_format='%Y-%m-%d'
)

In [None]:
# covariates to include
covariate_columns = ['age_gap_ehr_zscore', 'sex', 'actual_age']
categorical_covariates = ['sex']

# default censor (latest followup date across omop)
default_censor_date = '2022-12-31'


disease_columns = [col for col in master_label.columns if col != 'person_id']


results_ehr, merged_dfs = perform_multiple_cox_models_individual_censor_dates(
    covdf_dict=covariates_dfs,      # Dictionary of organ-specific covariate DataFrames
    target_df=master_label,
    agedf=age_df,
    disease_columns=disease_columns,
    covariate_columns=covariate_columns,
    categorical_covariates=categorical_covariates,
    censor_date_df=None,
    default_censor_date=default_censor_date,
    date_format='%Y-%m-%d'
)



# Visualize forest plots

In [None]:
def plot_forest_plots_grid(
    aggregated_results_list: Sequence[Dict[str, pd.DataFrame]],
    covariates: Sequence[str],
    *,
    title_prefix: str = '',
    plot_ground_truth: bool = False,
    ground_truth_path: str | None = None,
    order_by: str = 'ground_truth',
    exclude_covariates: Optional[Iterable[str]] = None,
    ncols: int = 2,

    # Label maps
    label_replacements: dict | None = None,
    covariate_label_map: dict | None = None,

    # ---- COLOR / STYLE CONTROL ----
    cov_colors: Dict[str, str] | Sequence[str] | None = None,  # map or list aligned with covariates
    gt_color: str = "black",
    dotted_line_color: str = "#5d6778",
    sig_edgecolor: str = "black",
    nonsig_face: str = "none",
    ci_linewidth: float = 1.0,
    ci_capsize: float = 3.0,
    point_size: float = 100,

    # Text sizes
    title_fontsize: float = 16,
    axis_label_fontsize: float = 14,
    tick_label_fontsize: float = 12,
    events_label_fontsize: float = 10,
    legend_fontsize: float = 12,

    # Right-hand events axis tweaks
    events_tick_rotation: int = 90,
    events_tick_pad: int = 6,

    # Spacing between subplots
    subplot_wspace: float = 0.45,
    subplot_hspace: float = 0.4,

    # Legend figure options
    make_legend_figure: bool = True,
    legend_title: str = "Legend",
    legend_box_face: str = "white",
    legend_box_edge: str = "#000000",
    legend_save_path: Optional[str] = None,
    legend_figsize: Tuple[float, float] = (4, 0),  # (width, height); height auto if 0
):

    # select covariates
    if exclude_covariates:
        excl = set([exclude_covariates] if isinstance(exclude_covariates, str) else exclude_covariates)
        keep = [(cv, ag) for cv, ag in zip(covariates, aggregated_results_list) if cv not in excl]
        if not keep:
            raise ValueError("All covariates excluded.")
        covariates, aggregated_results_list = zip(*keep)

    # load ground truth
    if plot_ground_truth:
        gt_df = pd.read_csv(ground_truth_path)

    label_replacements = label_replacements or {}
    covariate_label_map = covariate_label_map or {}

    # colors
    default_palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#9467bd',
                       '#8c564b', '#17becf', '#e377c2', '#7f7f7f']
    if cov_colors is None:
        colors = default_palette[:len(covariates)]
    elif isinstance(cov_colors, dict):
        colors = [cov_colors.get(cv, default_palette[i % len(default_palette)])
                  for i, cv in enumerate(covariates)]
    else:
        if len(cov_colors) < len(covariates):
            raise ValueError("need more colors.")
        colors = list(cov_colors[:len(covariates)])

    cov_handles = [
        Line2D([], [], marker='o', color='w',
               markerfacecolor=colors[i], markeredgecolor=colors[i],
               markersize=8,
               label=covariate_label_map.get(covariates[i], covariates[i]))
        for i in range(len(covariates))
    ]
    sig_handles = [
        Line2D([], [], marker='o', color=sig_edgecolor, markerfacecolor=sig_edgecolor,
               markersize=8, label='FDR < 0.05'),
        Line2D([], [], marker='o', color=sig_edgecolor, markerfacecolor='none',
               markersize=8, label='FDR ≥ 0.05'),
    ]
    if plot_ground_truth:
        cov_handles.append(
            Line2D([], [], marker='D', color=gt_color, markersize=8, label='Ground Truth')
        )

    organs = list(aggregated_results_list[0].keys())
    nplots = len(organs)
    nrows = math.ceil(nplots / ncols)
    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(7 * ncols, 8 * len(organs) / ncols + 3 * nrows),
        squeeze=False
    )
    axes_flat = axes.flatten()
    fig.subplots_adjust(wspace=subplot_wspace, hspace=subplot_hspace)

    for idx, organ in enumerate(organs):
        ax = axes_flat[idx]

        # determine order
        if order_by == 'ground_truth':
            gt_sub = gt_df[gt_df['organ'] == organ]
            df0 = aggregated_results_list[0][organ].copy()
            df0['clean'] = df0['Disease'].str.replace('_earliest', '', regex=False)
            merged = df0.merge(gt_sub, left_on='clean', right_on='trait', how='left')
            merged['exp(coef)'].fillna(-np.inf, inplace=True)
            order = merged.sort_values('exp(coef)', ascending=False)['Disease']
        else:
            idx_cov = covariates.index(order_by)
            order = aggregated_results_list[idx_cov][organ] \
                        .sort_values('HR', ascending=False)['Disease']
        disease_order = order.tolist()
        base_y = np.arange(len(disease_order))

        plot_dfs = [
            agr[organ].set_index('Disease').loc[disease_order].reset_index()
            for agr in aggregated_results_list
        ]

        # x limits
        x_min = min(df['CI_lower'].min() for df in plot_dfs)
        x_max = max(df['CI_upper'].max() for df in plot_dfs)
        pad = (x_max - x_min) * 0.2

        # plot each covariate series
        offset = 0.3
        for i, df in enumerate(plot_dfs):
            dy = (i - (len(plot_dfs) - 1) / 2) * offset
            yy = base_y + dy
            col = colors[i]

            ax.errorbar(
                df['HR'], yy,
                xerr=[df['HR'] - df['CI_lower'], df['CI_upper'] - df['HR']],
                fmt='none', ecolor=col, elinewidth=ci_linewidth, capsize=ci_capsize
            )
            faces = [col if p < 0.05 else nonsig_face for p in df['FDR']]
            ax.scatter(
                df['HR'], yy, s=point_size, marker='o',
                facecolor=faces, edgecolor=col,
                linewidth=1, zorder=3
            )

        # overlay ground truth
        if plot_ground_truth:
            gt_map = gt_df[gt_df['organ'] == organ].set_index('trait')
            for j, dis in enumerate(disease_order):
                clean = dis.replace('_earliest', '')
                if clean in gt_map.index:
                    r = gt_map.loc[clean]
                    ax.errorbar(
                        r['exp(coef)'], base_y[j],
                        xerr=[[r['exp(coef)'] - r['exp(coef) lower 95%']],
                              [r['exp(coef) upper 95%'] - r['exp(coef)']]],
                        fmt='none', ecolor=gt_color, elinewidth=ci_linewidth, capsize=ci_capsize
                    )
                    ax.scatter(
                        r['exp(coef)'], base_y[j], s=point_size,
                        marker='D', color=gt_color, zorder=4
                    )

        # format axes
        ax.axvline(1, color=dotted_line_color, linestyle='--')  # dotted ref line
        labels = [label_replacements.get(d, d) for d in disease_order]
        ax.set_yticks(base_y)
        ax.set_yticklabels(labels, fontsize=tick_label_fontsize)
        ax.invert_yaxis()
        ax.set_xlabel('Hazard Ratio (HR)', fontsize=axis_label_fontsize)
        title = f"{title_prefix} {organ.capitalize()}" if title_prefix else organ.capitalize()
        ax.set_title(title, fontsize=title_fontsize)
        ax.set_xlim(x_min - pad, x_max + pad)
        ax.tick_params(axis='x', labelsize=tick_label_fontsize)

        # remove grid lines
        ax.grid(False)

        # right-hand events axis
        events = (
            aggregated_results_list[0][organ]
            .set_index('Disease')
            .loc[disease_order]['Events']
            .tolist()
        )
        ax2 = ax.twinx()
        ax2.patch.set_visible(False)
        ax2.set_zorder(ax.get_zorder() - 1)
        ax2.set_ylim(ax.get_ylim())
        ax2.set_yticks(base_y)
        ax2.set_yticklabels(events,
                            fontsize=events_label_fontsize,
                            rotation=events_tick_rotation,
                            va="center",
                            ha="left")
        ax2.tick_params(axis='y', pad=events_tick_pad)
        # ax2.set_ylabel('Number of Events', fontsize=axis_label_fontsize)
        ax2.spines['right'].set_visible(False)
        ax2.tick_params(axis='x', bottom=False, labelbottom=False)
        ax2.grid(False)

    # delete unused panels
    for j in range(nplots, nrows * ncols):
        fig.delaxes(axes_flat[j])

    # separate legend figure
    legend_fig = None
    if make_legend_figure:
        handles = cov_handles + sig_handles
        if legend_figsize[1] == 0:
            est_h = max(2, 0.55 * len(handles) + 0.5)
            legend_figsize = (legend_figsize[0], est_h)

        legend_fig, legend_ax = plt.subplots(figsize=legend_figsize)
        legend_ax.axis("off")
        leg = legend_ax.legend(
            handles=handles,
            loc='center',
            ncol=1,
            frameon=True,
            fontsize=legend_fontsize,
            title=legend_title,
            title_fontsize=axis_label_fontsize
        )
        frame = leg.get_frame()
        frame.set_facecolor(legend_box_face)
        frame.set_edgecolor(legend_box_edge)

        legend_fig.tight_layout()
        if legend_save_path:
            import os
            os.makedirs(os.path.dirname(legend_save_path) or ".", exist_ok=True)
            legend_fig.savefig(legend_save_path, dpi=300, bbox_inches="tight")

    fig.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

    return fig, axes, legend_fig



aggregated_results_measuredprot = aggregate_cox_results_nested(results_measuredprot, covariate='age_gap_measuredprot_zscore')
aggregated_results_rabit = aggregate_cox_results_nested(results_rabit, covariate='age_gap_rabit_zscore')
aggregated_results_ehr = aggregate_cox_results_nested(results_ehr, covariate='age_gap_ehr_zscore')


label_replacements = {
    'Parkinson_disease_and_parkinsonism_earliest': 'PD',
    'Alzheimer_disease_earliest':  'AD',
    'Chronic_liver_disease_earliest': 'CLD',
    'Ischemic_heart_disease_earliest': 'IHD',
    'Osteoporosis_earliest': 'OSP',
    'Emphysema_COPD_earliest': 'COPD',
    'Type2_diabetes_earliest': 'T2DM',
    'Chronic_kidney_disease_earliest': 'CKD',
    'Leukemia': 'Leukemia',
    'Non-Hodgkin_lymphoma': 'NHL',
    'Cerebrovascular_disease_earliest': 'CBVD',
    'Osteoarthritis_earliest': 'OA',
    'Rheumatoid_arthritis_earliest': 'RA',
    'All_cause_dementia_earliest': 'Dementia',
    'Heart_failure_earliest': 'HF',
    'Atrial_fibrillation_or_flutter_earliest': 'AFib',
    'Vascular_dementia_earliest': 'VD'
}


covariate_label_map = {
    'age_gap_measuredprot_zscore': 'Measured Protein Age Gap',
    'age_gap_rabit_zscore':  'RABIT Protein Age Gap',
    'age_gap_ehr_zscore':    'EHR Age Gap',
}

plot_forest_plots_grid(
    aggregated_results_list=[aggregated_results_measuredprot,
                             aggregated_results_rabit,
                             aggregated_results_ehr],
    covariates=['age_gap_measuredprot_zscore',
                'age_gap_rabit_zscore',
                'age_gap_ehr_zscore'],
    plot_ground_truth=True,
    ground_truth_path='/ground/truth/dataframe/from/original/paper',
    order_by='age_gap_rabit_zscore',
    ncols=4,
    label_replacements=label_replacements,
    covariate_label_map=covariate_label_map,
    cov_colors={
        'age_gap_measuredprot_zscore': '#d89a97',
        'age_gap_rabit_zscore':  '#94bed8',
        'age_gap_ehr_zscore':    '#ead490',
    },
    gt_color='#5d6778',
    title_fontsize=18,
    axis_label_fontsize=16,
    tick_label_fontsize=14,
    events_label_fontsize=12,
    legend_fontsize=14,
)


In [None]:
def dict_to_hr_dataframe(nested_results_dict, disease_col='Disease'):
    rows = []
    for organ, df in nested_results_dict.items():
        required = [disease_col, 'HR', 'FDR']
        temp = df[required].copy()
        temp['Organ'] = organ
        rows.append(temp)
    if rows:
        return pd.concat(rows, ignore_index=True)

    return pd.DataFrame(columns=['Organ', disease_col, 'HR', 'FDR'])

def plot_hr_scatter_nested_per_organ(df1, df2, disease_col='Disease', hr_col1='HR', hr_col2='HR', 
                                     title_prefix="Hazard Ratios Comparison for", 
                                     df1label='rabit', df2label='measuredprot'):
    # Determine the set of common organs
    organs1 = set(df1['Organ'].unique())
    organs2 = set(df2['Organ'].unique())
    common_organs = organs1.intersection(organs2)
    
    if not common_organs:
        print("No common organs.")
        return

    for organ in sorted(common_organs):
        # Filter the DataFrames for the current organ
        df1_sub = df1[df1['Organ'] == organ]
        df2_sub = df2[df2['Organ'] == organ]
        merge_cols = [disease_col]
        merged_df = pd.merge(df1_sub, df2_sub, on=merge_cols, suffixes=('_1', '_2'))      
        # Calculate correlations
        try:
            pearson_corr, pearson_p = pearsonr(merged_df[f"{hr_col1}_1"], merged_df[f"{hr_col2}_2"])
            spearman_corr, spearman_p = spearmanr(merged_df[f"{hr_col1}_1"], merged_df[f"{hr_col2}_2"])
        except Exception as e:
            print(f"Error calculating correlations for organ {organ}: {e}")
            continue
        
        # Compute best-fit linear regression
        try:
            reg_results = linregress(merged_df[f"{hr_col1}_1"], merged_df[f"{hr_col2}_2"])
            slope, intercept, r_value, p_value, std_err = reg_results
            r2 = r_value**2
        except Exception as e:
            print(f"Error in linear regression for organ {organ}: {e}")
            continue

        # Create scatter plot for this organ
        plt.figure(figsize=(12, 10))
        
        # Plot each point individually according to its FDR values.
        for idx, row in merged_df.iterrows():
            x_val = row[f"{hr_col1}_1"]
            y_val = row[f"{hr_col2}_2"]
            fdr1 = row['FDR_1']
            fdr2 = row['FDR_2']
            
            # Decide marker appearance based on FDR thresholds
            if fdr1 < 0.05 and fdr2 >= 0.05:
                facecolor = 'red'
                edgecolor = None
            elif fdr2 < 0.05 and fdr1 >= 0.05:
                facecolor = 'blue'
                edgecolor = None
            elif fdr1 < 0.05 and fdr2 < 0.05:
                facecolor = 'purple'
                edgecolor = None
            else:
                facecolor = 'white'
                edgecolor = 'black'
            plt.scatter(x_val, y_val, color=facecolor, edgecolor=edgecolor, s=50, zorder=3, marker='o')
        
        # Plot best-fit line
        x_min = merged_df[f"{hr_col1}_1"].min()
        x_max = merged_df[f"{hr_col1}_1"].max()
        x_vals = [x_min, x_max]
        y_vals = [slope * x + intercept for x in x_vals]
        plt.plot(x_vals, y_vals, 'g-', zorder=2)  # (No legend label here)
        
        # Annotate each point with the disease name
        texts = []
        for index, row in merged_df.iterrows():
            label = row[disease_col]
            txt = plt.text(row[f"{hr_col1}_1"], row[f"{hr_col2}_2"], label, 
                           fontsize=9, ha='right', va='bottom', zorder=4)
            texts.append(txt)
        
        # Adjust text to reduce overlap
        adjust_text(texts, only_move={'points': 'y', 'text': 'y'},
                    arrowprops=dict(arrowstyle="->", color='gray', lw=0.5))
        
        # Set labels and title
        plt.xlabel(f"Hazard Ratio from {df1label} ({hr_col1})")
        plt.ylabel(f"Hazard Ratio from {df2label} ({hr_col2})")
        plt.title(f"{title_prefix} {organ.capitalize()}")
        
        # Annotate correlation statistics in the upper left
        legend_text = (
            f"Pearson r = {pearson_corr:.3f} (p = {pearson_p:.3g})\n"
            f"Spearman r = {spearman_corr:.3f} (p = {spearman_p:.3g})"
        )
        plt.text(0.05, 0.95, legend_text, transform=plt.gca().transAxes, fontsize=10,
                 verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.5))
        
        # Note: The legend is intentionally omitted from each scatter plot.
        
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    # After all scatter plots have been displayed, create a separate legend figure.
    figLegend = plt.figure(figsize=(4, 4))
    axLegend = figLegend.add_subplot(111)
    axLegend.axis('off')
        custom_handles = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markeredgecolor='red', markersize=8,
               label=f"Significant in {df1label} only"),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markeredgecolor='blue', markersize=8,
               label=f"Significant in {df2label} only"),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='purple', markeredgecolor='purple', markersize=8,
               label="Significant in both"),
        Line2D([0], [0], marker='o', color='black', markerfacecolor='white', markersize=8,
               label="Not significant in either"),
        Line2D([0], [0], color='g', lw=2, label="Best-fit line")
    ]
    axLegend.legend(handles=custom_handles, loc='center', ncol=1)
    plt.tight_layout()
    plt.show()


df_rabit = dict_to_hr_dataframe(aggregated_results_rabit, disease_col='Disease')
df_measuredprot = dict_to_hr_dataframe(aggregated_results_measuredprot, disease_col='Disease')
df_ehr = dict_to_hr_dataframe(aggregated_results_ehr, disease_col='Disease')

# Create separate scatter plots for each organ comparing the HRs from rabit vs. measuredprot.
plot_hr_scatter_nested_per_organ(df_rabit, df_measuredprot, disease_col='Disease', hr_col1='HR', hr_col2='HR',
                                 title_prefix="Hazard Ratios Comparison for", 
                                 df1label='RABIT Proteomics', df2label='Measured Proteomics')

In [None]:
def _benjamini_hochberg(pvals):
    import numpy as np
    p = np.asarray(pvals, dtype=float)
    m = np.sum(~np.isnan(p))
    if m == 0:
        return [float("nan")] * len(p)
    idx = np.argsort(p)
    ranks = np.arange(1, len(p) + 1)
    non_nan = ~np.isnan(p[idx])
    sp = p[idx][non_nan]
    adj = sp * m / ranks[non_nan]
    adj = np.minimum.accumulate(adj[::-1])[::-1]
    full = np.full_like(p, np.nan)
    full[idx[non_nan]] = np.clip(adj, 0, 1)
    return full.tolist()

def plot_hr_scatter_series(
    df1, df2,
    disease_col: str = "Disease",
    hr_col1: str = "HR",
    hr_col2: str = "HR",
    df1label: str = "rabit",
    df2label: str = "measuredprot",
    label_replacements: dict | None = None,
    title_fontsize: float = 16,
    axis_label_fontsize: float = 14,
    tick_label_fontsize: float = 12,
    dot_label_fontsize: float = 10,
    corr_label_fontsize: float = 12,
    legend_fontsize: float = 12,
    invert: bool = False,

    color_df1_only: str = "#d62728",   # significant only in df1
    color_df2_only: str = "#1f77b4",   # significant only in df2
    color_both:     str = "#9467bd",   # significant in both
    color_nonsig_face: str = "white",  # non-sig fill
    color_nonsig_edge: str = "black",  # non-sig edge
    line_color: str = "#2ca02c",       # best-fit line
):
    label_replacements = label_replacements or {}

    # common organs
    organs = sorted(set(df1["Organ"]) & set(df2["Organ"]))
    if not organs:
        raise ValueError("No common organs to plot.")

    # pearson correlations
    panels = []
    for organ in organs:
        m1 = df1[df1["Organ"] == organ]
        m2 = df2[df2["Organ"] == organ]
        merged = pd.merge(m1, m2, on=[disease_col], suffixes=("_1", "_2"))
        if merged.empty:
            continue
        r, p_raw = pearsonr(merged[f"{hr_col1}_1"], merged[f"{hr_col2}_2"])
        panels.append({"organ": organ, "df": merged, "r": r, "p_raw": p_raw})

    # adjust p values with benjamini hochberg
    p_bh_list = _benjamini_hochberg([p["p_raw"] for p in panels])
    for panel, bh in zip(panels, p_bh_list):
        panel["p_bh"] = bh

    for panel in panels:
        organ  = panel["organ"]
        merged = panel["df"]

        fig, ax = plt.subplots(figsize=(6, 6))

        x_vals, y_vals, labels = [], [], []
        for _, row in merged.iterrows():
            xi = row[f"{hr_col2}_2"] if invert else row[f"{hr_col1}_1"]
            yi = row[f"{hr_col1}_1"] if invert else row[f"{hr_col2}_2"]
            f1, f2 = row["FDR_1"], row["FDR_2"]

            if f1 < 0.05 <= f2:
                fc, ec = color_df1_only, None
            elif f2 < 0.05 <= f1:
                fc, ec = color_df2_only, None
            elif f1 < 0.05 and f2 < 0.05:
                fc, ec = color_both, None
            else:
                fc, ec = color_nonsig_face, color_nonsig_edge

            ax.scatter(xi, yi, facecolor=fc, edgecolor=ec, s=60, zorder=3)
            x_vals.append(xi); y_vals.append(yi)
            labels.append(label_replacements.get(row[disease_col], row[disease_col]))

        # best-fit line
        slope, intercept, *_ = linregress(x_vals, y_vals)
        x0, x1 = min(x_vals), max(x_vals)
        ax.plot([x0, x1], [slope * x0 + intercept, slope * x1 + intercept],
                color=line_color, zorder=2)

        # labels
        texts = []
        for xv, yv, lab in zip(x_vals, y_vals, labels):
            texts.append(ax.text(xv, yv, lab,
                                 fontsize=dot_label_fontsize,
                                 ha="center", va="center", zorder=4))
        adjust_text(
            texts, x=x_vals, y=y_vals, ax=ax,
            arrowprops=dict(arrowstyle="->", color="gray", lw=0.5),
            expand_points=(1.2, 1.2), expand_text=(1.2, 1.2),
            force_points=(0.3, 0.3), force_text=(0.3, 0.3),
            avoid_self=True
        )

        # titles / axis labels
        ax.set_title(organ.capitalize(), fontsize=title_fontsize, pad=10)
        if invert:
            ax.set_xlabel(f"{df2label} {hr_col2}", fontsize=axis_label_fontsize, labelpad=8)
            ax.set_ylabel(f"{df1label} {hr_col1}", fontsize=axis_label_fontsize, labelpad=8)
        else:
            ax.set_xlabel(f"{df1label} {hr_col1}", fontsize=axis_label_fontsize, labelpad=8)
            ax.set_ylabel(f"{df2label} {hr_col2}", fontsize=axis_label_fontsize, labelpad=8)

        ax.tick_params(labelsize=tick_label_fontsize)
        ax.grid(False)

        ax.text(
            0.05, 0.95,
            f"r = {panel['r']:.3f}\nBH-p = {panel['p_bh']:.2e}",
            transform=ax.transAxes,
            fontsize=corr_label_fontsize,
            va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
        )

        plt.tight_layout()
        plt.show()

    # global legend
    handles = [
        Line2D([0], [0], marker="o", color="w", markerfacecolor=color_df1_only,
               markersize=8, label=f"Sig in {df1label} only"),
        Line2D([0], [0], marker="o", color="w", markerfacecolor=color_df2_only,
               markersize=8, label=f"Sig in {df2label} only"),
        Line2D([0], [0], marker="o", color="w", markerfacecolor=color_both,
               markersize=8, label="Sig in both"),
        Line2D([0], [0], marker="o", color=color_nonsig_edge, markerfacecolor=color_nonsig_face,
               markersize=8, label="Not sig"),
        Line2D([0], [0], color=line_color, lw=2, label="Best-fit line")
    ]

    fig_leg = plt.figure(figsize=(4, len(handles) * 0.7))
    ax_leg = fig_leg.add_subplot(111)
    ax_leg.axis("off")
    ax_leg.legend(handles=handles, loc="center", ncol=1, frameon=False,
                  fontsize=legend_fontsize)
    plt.tight_layout()
    plt.show()




df_rabit = dict_to_hr_dataframe(aggregated_results_rabit, disease_col='Disease')
df_measuredprot = dict_to_hr_dataframe(aggregated_results_measuredprot, disease_col='Disease')
df_ehr = dict_to_hr_dataframe(aggregated_results_ehr, disease_col='Disease')

label_replacements = {
    'Parkinson_disease_and_parkinsonism_earliest': 'PD',
    'Alzheimer_disease_earliest':  'AD',
    'Chronic_liver_disease_earliest': 'CLD',
    'Ischemic_heart_disease_earliest': 'IHD',
    'Osteoporosis_earliest': 'OSP',
    'Emphysema_COPD_earliest': 'COPD',
    'Type2_diabetes_earliest': 'T2DM',
    'Chronic_kidney_disease_earliest': 'CKD',
    'Leukemia': 'Leukemia',
    'Non-Hodgkin_lymphoma': 'NHL',
    'Cerebrovascular_disease_earliest': 'CBVD',
    'Osteoarthritis_earliest': 'OA',
    'Rheumatoid_arthritis_earliest': 'RA',
    'All_cause_dementia_earliest': 'Dementia',
    'Heart_failure_earliest': 'HF',
    'Atrial_fibrillation_or_flutter_earliest': 'AFib',
    'Vascular_dementia_earliest': 'VD'
}




plot_hr_scatter_series(
    df_rabit, df_measuredprot,
    label_replacements=label_replacements,
    df1label='RABIT Proteomics',
    df2label='Measured Proteomics',
    invert=True,
    color_df1_only="#94bed8",
    color_df2_only="#d89a97",
    color_both="#a073a0",
    color_nonsig_face="#FFFFFF",
    color_nonsig_edge="#333333",
    line_color="#5D6778"
)


In [None]:
def plot_organ_correlation_scatter(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    *,
    disease_col: str = "Disease",
    hr_col: str = "HR",
    df1label: str = "rabit",
    df2label: str = "measuredprot",
    label_fontsize: int = 16,
    save: bool = True,
    out_dir: str | os.PathLike = (
        "/save/dir/
    ),
    filename: str | None = None,
):
    organs1 = set(df1["Organ"])
    organs2 = set(df2["Organ"])
    common_organs = organs1 & organs2
    if not common_organs:
        print("No common organs found.")
        return

    rows = []
    for organ in sorted(common_organs):
        sub1 = df1[df1["Organ"] == organ]
        sub2 = df2[df2["Organ"] == organ]
        merged = pd.merge(sub1, sub2, on=disease_col, suffixes=("_1", "_2"))
        if merged.empty:
            continue
        try:
            r_p, p_p = pearsonr(merged[f"{hr_col}_1"], merged[f"{hr_col}_2"])
            r_s, _   = spearmanr(merged[f"{hr_col}_1"], merged[f"{hr_col}_2"])
        except Exception:
            continue
        rows.append(
            {
                "Organ": organ,
                "Pearson_r": r_p,
                "Pearson_p": p_p,
                "Spearman_r": r_s,
            }
        )

    if not rows:
        print("No correlations computed.")
        return

    res_df = pd.DataFrame(rows)
    # BH correction
    _, adj_p, _, _ = multipletests(res_df["Pearson_p"], method="fdr_bh")
    res_df["Adj_P"] = adj_p
    res_df["neg_log10_adj_p"] = -np.log10(res_df["Adj_P"])
    res_df["color"] = res_df.apply(
        lambda r: "red"
        if (r["Pearson_r"] >= 0.4 and r["Adj_P"] < 0.05)
        else "grey",
        axis=1,
    )

    fig, ax = plt.subplots(figsize=(8, 6))
    for _, row in res_df.iterrows():
        ax.scatter(
            row["Pearson_r"],
            row["neg_log10_adj_p"],
            s=100,
            color=row["color"],
            edgecolor="k",
            zorder=3,
        )

    ax.axvline(0.4, color="blue", linestyle=":", lw=1)
    ax.axhline(-np.log10(0.05), color="blue", linestyle=":", lw=1)

    ax.set_xlabel("Pearson Correlation")
    ax.set_ylabel("-log10 BH-adjusted Pearson p-value")
    ax.set_title(f"Organ System Correlations: {df1label} vs {df2label}")
    ax.grid(True)

    # label
    texts = [
        ax.text(
            row["Pearson_r"],
            row["neg_log10_adj_p"],
            row["Organ"],
            fontsize=label_fontsize,
            ha="center",
            va="center",
            zorder=4,
        )
        for _, row in res_df.iterrows()
    ]
    adjust_text(
        texts,
        expand_points=(2.5, 2.5),
        expand_text=(2.5, 2.5),
        force_text=(0.8, 0.8),
        force_points=(0.6, 0.6),
        arrowprops=dict(arrowstyle="->", color="gray", lw=0.5),
        lim=200,
    )

    fig.tight_layout()

    if save:
        os.makedirs(out_dir, exist_ok=True)
        if filename is None:
            filename = (
                f"organ_corr_{df1label}_vs_{df2label}.pdf".replace(" ", "_")
            )
        pdf_path = os.path.join(out_dir, filename)
        fig.savefig(pdf_path, format="pdf", bbox_inches="tight")
        print(f"saved figure {pdf_path}")

    plt.show()
    return fig, ax


plot_organ_correlation_scatter(df_rabit, df_measuredprot, disease_col='Disease', hr_col='HR',
                               df1label='rabit', df2label='measuredprot', 
                               out_dir="/path/to/save",
                               filename='fig4b.pdf'
)


In [None]:
# Convert ground truth
gtruth = pd.read_csv('/path/to/ground/truth')
gtruth_transformed = gtruth.rename(columns={
    'organ': 'Organ',
    'exp(coef)': 'HR',
    'q': 'FDR'
})

# Create the 'Disease' column by appending '_earliest' to the values in 'trait'
gtruth_transformed['Disease'] = gtruth_transformed['trait'] + '_earliest'

# Keep only the necessary columns
gtruth_transformed = gtruth_transformed[['Organ', 'Disease', 'HR', 'FDR']]

plot_organ_correlation_scatter(gtruth_transformed, df_measuredprot, disease_col='Disease', hr_col='HR',
                               df1label='Ground Truth', df2label='measuredprot',
                               out_dir="/save/dir",
                               filename='fig4b.pdf')

plot_organ_correlation_scatter(gtruth_transformed, df_rabit, disease_col='Disease', hr_col='HR',
                               df1label='Ground Truth', df2label='rabit',
                               out_dir="/save/dir",
                               filename='fig4b.pdf')

# Young immune agers mortality

In [None]:
deathdf = pd.read_csv('/path/to/omop/death/table/death.csv')
deathdf

In [None]:
immune_age_gap_df = age_gap_simple_dfs['Immune']
immune_age_gap_df

In [None]:
death_dates   = (deathdf.loc[:, ['person_id', 'death_date']]
                 .drop_duplicates(subset='person_id', keep='first'))
prediction_ts = (age_df.loc[:, ['patient_id', 'prediction_time']]
                 .drop_duplicates(subset='patient_id', keep='first'))
merged = (immune_age_gap_df
          .merge(death_dates,
                 left_on='eid', right_on='person_id',
                 how='left')
          .drop(columns='person_id'))  
merged = (merged
          .merge(prediction_ts,
                 left_on='eid', right_on='patient_id',
                 how='left')
          .drop(columns='patient_id'))

merged

In [None]:
def plot_km_by_quartile(
    df: pd.DataFrame,
    value_col: str,
    *,
    start_col: str = "prediction_time",
    death_col: str = "death_date",
    id_col: str = "eid",
    censor_df: pd.DataFrame | None = None,
    censor_id_col: str = "patient_id",
    censor_date_col: str = "censor_date",
    global_censor_date: str | pd.Timestamp | None = None,
    title_prefix: str | None = None,
    show_counts: bool = True,
    show_ci: bool = True,
    figsize: tuple[int, int] = (8, 6),
    # ── NEW SAVING CONTROLS ────────────────────────────────────────────
    save: bool = True,
    out_dir: str | os.PathLike = (
        "/save/dir"
    ),
    filename_prefix: str | None = None,   # e.g. "km_age_gap"
):

    cols = [id_col, value_col, start_col, death_col]
    data = df.loc[:, cols].copy()

    data[start_col] = pd.to_datetime(data[start_col], errors="coerce")
    data[death_col] = pd.to_datetime(data[death_col], errors="coerce")
    data = data.dropna(subset=[value_col, start_col])

    if censor_df is not None:
        censor_df = (
            censor_df[[censor_id_col, censor_date_col]]
            .drop_duplicates(subset=censor_id_col, keep="first")
        )
        censor_df[censor_date_col] = pd.to_datetime(
            censor_df[censor_date_col], errors="coerce"
        )
        data = data.merge(
            censor_df,
            how="left",
            left_on=id_col,
            right_on=censor_id_col,
            validate="m:1",
        )
    else:
        data[censor_date_col] = pd.NaT

    data["quartile"] = pd.qcut(
        data[value_col],
        q=4,
        labels=["Q1 (lowest)", "Q2", "Q3", "Q4 (highest)"],
    )

    # censor using same individual censor dates from cox regression
    if global_censor_date is None:
        global_censor_date = data[death_col].max()
        if pd.isna(global_censor_date):
            global_censor_date = pd.Timestamp.today().normalize()
    global_censor_date = pd.to_datetime(global_censor_date)

    data["death_or_censor"] = data[death_col]
    mask_no_death = data["death_or_censor"].isna()
    data.loc[mask_no_death, "death_or_censor"] = data.loc[
        mask_no_death, censor_date_col
    ]
    mask_still_na = data["death_or_censor"].isna()
    data.loc[mask_still_na, "death_or_censor"] = global_censor_date

    data["event_observed"] = data[death_col].notna()
    data["duration"] = (
        data["death_or_censor"] - data[start_col]
    ).dt.days.clip(lower=0)

    kmf = KaplanMeierFitter()
    fig, ax = plt.subplots(figsize=figsize)

    legend_entries = []  # (label, colour)
    warnings.filterwarnings("ignore", category=FutureWarning)

    for q, grp in data.groupby("quartile", sort=False, observed=False):
        label = f"{q} (n={len(grp)})" if show_counts else str(q)
        kmf.fit(grp["duration"], grp["event_observed"], label=label)
        kmf.plot(ci_show=show_ci, legend=False, ax=ax)

        colour = ax.get_lines()[-1].get_color()
        legend_entries.append((label, colour))

    warnings.filterwarnings("default", category=FutureWarning)

    # statistical test for q1 and q4
    q1_data = data[data["quartile"] == "Q1 (lowest)"]
    q4_data = data[data["quartile"] == "Q4 (highest)"]

    if not q1_data.empty and not q4_data.empty:
        lr_res = logrank_test(
            q1_data["duration"],
            q4_data["duration"],
            event_observed_A=q1_data["event_observed"],
            event_observed_B=q4_data["event_observed"],
        )
        p_val = lr_res.p_value
        print(f"Log-rank test (Q1 vs Q4): p = {p_val:.3g}")
        ax.text(
            0.03,
            0.05,
            f"Q1 vs Q4 log-rank p = {p_val:.3g}",
            transform=ax.transAxes,
            fontsize=12,
            va="bottom",
            ha="left",
            bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"),
        )
    else:
        p_val = np.nan
        print("Q1 or Q4 group empty – p-value not computed.")

    # plot
    title_bits = [title_prefix] if title_prefix else []
    title_bits.append(f"Survival by {value_col} quartiles")
    ax.set_title(" – ".join(title_bits))
    ax.set_xlabel(f"Days from {start_col}")
    ax.set_ylabel("Survival probability")

    ax.tick_params(
        axis="both",
        which="major",
        length=6,
        width=1.2,
        direction="out",
        bottom=True,
        top=False,
        left=True,
        right=False,
    )

    fig.tight_layout()
    leg_fig, leg_ax = plt.subplots(figsize=(4, 1))
    leg_ax.axis("off")
    handles = [Line2D([0], [0], color=c, lw=2) for _, c in legend_entries]
    labels = [lbl for lbl, _ in legend_entries]
    leg_ax.legend(handles, labels, loc="center left", frameon=False)
    leg_fig.tight_layout()
    if save:
        os.makedirs(out_dir, exist_ok=True)

        if filename_prefix is None:
            filename_prefix = f"km_{value_col}"

        km_path = os.path.join(out_dir, f"{filename_prefix}_km.pdf")
        leg_path = os.path.join(out_dir, f"{filename_prefix}_legend.pdf")

        fig.savefig(km_path, format="pdf", bbox_inches="tight")
        leg_fig.savefig(leg_path, format="pdf", bbox_inches="tight")
        print(f"KM plot saved   → {km_path}")
        print(f"Legend saved    → {leg_path}")

    plt.show(fig)
    plt.show(leg_fig)

    return data, p_val



plot_km_by_quartile(
    merged,
    value_col="age_gap_rabit_zscore",
    censor_df=censor_date_df,       
    id_col="eid",                 
    censor_id_col="patient_id",  
    censor_date_col="censor_date",
    title_prefix="Immune cohort",
    out_dir="/save/dir",   
    filename_prefix="rabit_km"
)
