# UK Biobank - Survival Analysis - KDE Plots

***

In [None]:
# Import of packages required for data processing
import pandas as pd
import numpy as np
import scipy.stats as stats
import seaborn as sns
from sklearn.metrics import mean_absolute_error
from sklearn.linear_model import QuantileRegressor
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm


In [2]:
# set style for the plots
SMALL_SIZE = 15
MEDIUM_SIZE = 18
BIGGER_SIZE = 19

plt.rc('font', size=SMALL_SIZE)          # default text size
plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of x tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of y tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # fontsize legend
plt.rc('figure', titlesize=BIGGER_SIZE)

***
## **Data preparation**

In [None]:
# Define paths and list of organs
organs = ["brain", "heart", "liver", "left_kidney", "right_kidney", "pancreas", "spleen", "left_fundus", "right_fundus"]

organ_dict = {'brain': 'brain', 'heart': 'heart', 'liver': 'liver', 'left_kidney': 'kidney',
              'right_kidney': 'kidney', 'pancreas': 'pancreas', 'spleen': 'spleen', 
              'left_fundus': 'fundus', 'right_fundus': 'fundus'}

# File path of csv with age predictions (including cols 'key', 'age', 'pred_<organ>')
ukb_age_predictions =  "/mnt/qdata/share/raecker1/ukbdata_70k/results/ukb_age_mainly_healthy_trainset_70k_masked_best_model/ukb_all_results.csv"

ukb_age_fundus = '/mnt/qdata/share/rakuest1/data/UKB/interim/ukb_all.csv'
disease_status_all = "ukb_all_organs_disease_status.csv"
disease_status = "/home/raecker1/internship/ukb_data_processing/ukb677731_disease_status_<disease>.csv"

# Key files
healthy_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/ukb_keys_mainly_healthy_<organ>.csv"
healthy_test_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/test_<organ>_mainly_healthy.dat"
test_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/full_test_<organ>_mainly_healthy.dat"

'# File path of csv with age predictions (including cols \'key\', \'age\', \'pred_<organ>\')\nukb_age_predictions =  "/mnt/qdata/share/raecker1/ukbdata_70k/results/ukb_age_mainly_healthy_trainset_70k_masked_best_model/ukb_all_results.csv"\n\nukb_age_fundus = \'/mnt/qdata/share/rakuest1/data/UKB/interim/ukb_all.csv\'\ndisease_status = "ukb_all_organs_disease_status.csv"\n\n# Key files\nhealthy_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/ukb_keys_mainly_healthy_<organ>.csv"\nhealthy_test_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/test_<organ>_mainly_healthy.dat"\ntest_ids = "/mnt/qdata/share/raecker1/ukbdata_70k/interim/keys/full_test_<organ>_mainly_healthy.dat"\n'

In [None]:
# Loading data into a DataFrame
df_ukb_age_predictions = pd.read_csv(ukb_age_predictions)
df_fundus_age = pd.read_csv(ukb_age_fundus, usecols=['key', 'age'])
df_fundus_age = df_fundus_age.rename(columns={'age': 'fundus_age'})
df_ukb_age_predictions = pd.merge(df_ukb_age_predictions, df_fundus_age, on='key', how='left')
df_age_pred_clean = df_ukb_age_predictions.drop_duplicates(subset=["key"])
df_disease_status = pd.read_csv(disease_status_all)

In [6]:
# Get ids for testing
def get_id_dict(organs, organ_dict, id_file):
    """
    Function to create a dictionary mapping organ names to their corresponding IDs.
    
    Parameters:
    organs (list): List of organ names.
    id_file (str): Path to the CSV file containing organ names and their IDs.
    
    Returns:
    dict: Dictionary with organ names as keys and their IDs as values.
    """
    
    id_dict = {}
    for organ in organs:
        id_dict[organ] = pd.read_csv(id_file.replace('<organ>', organ_dict[organ]), header=None).squeeze().tolist()
        if 'fundus' in organ:
            id_dict[organ] = [int(x.split("/")[1]) for x in id_dict[organ] if x.split("/")[0] in organ]
    return id_dict

dict_healthy_ids = get_id_dict(organs, organ_dict, healthy_ids)
dict_healthy_test_ids = get_id_dict(organs, organ_dict, healthy_test_ids)
dict_test_ids = get_id_dict(organs, organ_dict, test_ids)


Bias Correction

In [None]:
# Perform bias correction
# Function for median regression bias correction
def median_regression_bias_correction(df, organ, disease_status_df):
    
    df_age = df.copy()
    if 'sex' in df.columns:
        df_age = df.drop(columns=['sex'])
    
    df_disease_status = disease_status_df.drop(columns=['age'])
    df_merged = pd.merge(df_age, df_disease_status, left_on='key', right_on='eid', how='left')
    df_merged.index = df_age.index  # Keep original index for consistency

    if 'fundus' in organ:
        age = df_merged['fundus_age']
    else:
        age = df_merged['age']

    # independent variables
    X = pd.DataFrame({
        'age': age,  # Use age from predictions df
        'sex': df_merged['sex'],
        'current_pathology': df_merged[f'{organ}_disease_before_current_age'].fillna(False),
        'future_pathology': df_merged[f'{organ}_disease_after_current_age'].fillna(False)
    }, index=df_merged.index)
    

    y = df_merged[f'pred_{organ}'] - age  # age gap
    
    mask = ~(X.isna().any(axis=1) | y.isna())
    X_clean = X[mask]
    y_clean = y[mask]
    
    # Fit median regression
    median_reg = QuantileRegressor(quantile=0.5, alpha=0.01, solver='highs')
    median_reg.fit(X_clean, y_clean)

    print(f"\nMedian regression coefficients for {organ}:")
    print(f"Age: {median_reg.coef_[0]:.4f}")
    print(f"Sex: {median_reg.coef_[1]:.4f}")
    print(f"Current pathology: {median_reg.coef_[2]:.4f}")
    print(f"Future pathology: {median_reg.coef_[3]:.4f}")
    print(f"Intercept: {median_reg.intercept_:.4f}")
    
    # Calculate corrected ages
    median_age = X_clean['age'].median()    # reference

    X_reference = X.copy()

    # Predict what the age would be at median age for each person
    X_reference_clean = X_reference[mask]
    ca = X_reference_clean['age']
    predicted_median_age = median_reg.predict(X_reference_clean)

    # get coefficients
    age_coef = median_reg.coef_[0]

    # Calculate correction
    correction = y_clean + ca - age_coef * (ca - median_age)
    corrected_ages = y.copy()
    corrected_ages[mask] = correction
    
    return corrected_ages


# Function for age-wise bias correction
def age_wise_bias_correction(df, healthy_keys, var_x, var_y, group):
    '''Function, which perfoms the bias correction for the predicted ages.'''

    # Group by "age" and calculate the median of "pred" for each age value
    df_healthy = df[df["key"].isin(healthy_keys)]
    medians = df_healthy.groupby(var_x)[var_y].median().reset_index()

    # Merge the median values back into the original DataFrame
    df["copy_index"] = df.index
    df_merged2 = pd.merge(df, medians, on=var_x, how="right", suffixes=("", "_median"))
    df_merged2 = df_merged2.set_index("copy_index").sort_index()

    # Subtract the median of "pred" for each age group from the respective "pred" values
    if group == "dev":
        return df_merged2[var_y] - df_merged2[f"{var_y}_median"]
    else:
        return df_merged2[var_y] - df_merged2[f"{var_y}_median"] + df_merged2[var_x]


# Function for linear bias correction
def linear_bias_correction(df, healthy_keys, var_x, var_y, group):
    """Function which performs linear bias correction."""

    df_healthy = df[df["key"].isin(healthy_keys)]
    
    # Perform linear regression to get the slope and intercept of the regression line
    mask = ~np.isnan(df_healthy[var_x]) & ~np.isnan(df_healthy[var_y])
    slope, intercept, _, _, _ = stats.linregress(df_healthy[var_x][mask], df_healthy[var_y][mask])
    print(f"Slope: {slope}, Intercept: {intercept}")
    if group == "dev":
        return df[var_y] - (slope * df[var_x] + intercept)
    else:
        return df[var_y] - (slope * df[var_x] + intercept) + df[var_x]
    

for organ in organs:
    df_age_pred_clean[f"pred_{organ}_corr"] = median_regression_bias_correction(df_age_pred_clean, organ, df_disease_status)

    # Exclude strong outliers, only very few cases
    df_age_pred_clean[f"pred_{organ}_corr"] = np.where(
        (df_age_pred_clean[f"pred_{organ}_corr"] - df_age_pred_clean["age"] >= 25) |
        (df_age_pred_clean[f"pred_{organ}_corr"] - df_age_pred_clean["age"] <= -25),
        np.nan,
        df_age_pred_clean[f"pred_{organ}_corr"]
    )



Calculate predicted age gaps

In [None]:
# Calculating the predicted age gaps and showing the aging groups
for organ in organs:
    if 'fundus' in organ:
        age = df_age_pred_clean["fundus_age"]
    else:
        age = df_age_pred_clean["age"]
    df_age_pred_clean[f"{organ}_pag"] = df_age_pred_clean[f"pred_{organ}_corr"] - age
    print(len(df_age_pred_clean))
    
df_age_pred_clean.head(5)

---

## **KDE-plots**
#### **KDE plots for healthy test data**

In [None]:
def plot_kde_healthy(df, organs):
    fig, axes = plt.subplots(3, 3, figsize=(20, 15))
    for i, organ in enumerate(organs):
        ax = axes.flat[i]

        df_organ = df.dropna(subset=[f'{organ}_pag', 'fundus_age', 'age'])

        status = (df_organ['key'].isin(dict_healthy_test_ids[organ]))
        cmap = 'Blues'

        age = df_organ['age'][status]                      # Chronological age
        dev = df_organ[f'{organ}_pag'][status]             # Predicted age gap for the organ
        pred = df_organ[f'pred_{organ}_corr'][status]

        md        = np.mean(dev)                   # Mean of the difference
        sd        = np.std(dev, axis=0)            # Standard deviation of the difference
        CI_low    = md - 1.96*sd
        CI_high   = md + 1.96*sd

        r, p = stats.pearsonr(pred, age)
        print(f'{organ} - Pearson r: {r}, p-value: {p}')
        MAE = mean_absolute_error(pred, age)
        SD = np.std(np.abs(dev))

        sns.kdeplot(x=age, 
                    y=dev, 
                    cmap=cmap, 
                    fill=True, 
                    ax=ax)

        ax.axhline(md,           color='black', linestyle='-')
        ax.axhline(md + 1.96*sd, color='gray', linestyle='--')
        ax.axhline(md - 1.96*sd, color='gray', linestyle='--')

        ax.text(51, 13, f'MAE = {round(MAE, 2)} ± {round(SD, 2)}\nr = {round(r, 2)}*', 
                 ma='right',
                 size= MEDIUM_SIZE)
        
        ax.set_xlabel('chronological age [y]')
        organ_name = str(organ.replace('_', ' \enspace '))
        ax.set_ylabel(r"predicted age gap  $\bf{%s}$ [y]" % organ_name)

        ax.set_ylim(-25, 25)
        ax.set_xlim(40, 90)

        xOutPlot = 90

        ax.text(xOutPlot, 
                md - 1.96*sd, 
                r'-1.96SD:' + "\n" + "%.2f" % CI_low, 
                ha = "right",
                va = "center")
        ax.text(xOutPlot, 
                md + 1.96*sd, 
                r'+1.96SD:' + "\n" + "%.2f" % CI_high, 
                ha = "right",
                va = "center")
        ax.text(xOutPlot, 
                md, 
                r'Mean:' + "\n" + "%.2f" % md, 
                ha = "right",
                va = "center")

    fig.tight_layout()
    #plt.savefig(f'kde_ca_dev_intervals_healthy.png', bbox_inches="tight", dpi=300)
    #plt.close()
    plt.show()


plot_kde_healthy(df_age_pred_clean, organs)

### **KDE-plots comparing healthy and diseased subgroups**

In [None]:
def get_disease_ids(disease):
    # Dictionary to hold diseased IDs for each organ
    dict_diseased_prescan_ids = {}
    dict_diseased_postscan_ids = {}

    df_diseased = pd.read_csv(disease_status.replace("<disease>", disease), usecols=["eid", "has_diseases", "disease_before_current_age", "disease_after_current_age"])
    df_diseased = pd.merge(df_diseased, df_age_pred_clean, left_on="eid", right_on="key", how="inner")

    for organ in organs:
        df_diseased = df_diseased[df_diseased[f'pred_{organ}'].notna()]
        dict_diseased_prescan_ids[organ] = df_diseased[df_diseased['disease_before_current_age'] == True]['eid'].tolist()
        dict_diseased_postscan_ids[organ] = df_diseased[df_diseased['disease_after_current_age'] == True]['eid'].tolist()

    return dict_diseased_prescan_ids, dict_diseased_postscan_ids


def transparent_cmap(base_cmap="Greens"):
    """Return a colormap with alpha increasing with density."""
    base = cm.get_cmap(base_cmap)
    colors = base(np.linspace(0, 1, 256))
    colors[:, -1] = np.linspace(0, 1, 256)  # alpha channel = 0→transparent, 1→opaque
    return ListedColormap(colors)


def plot_kde_comparison(df, disease, group, organs):
    fig, axes = plt.subplots(3, 3, figsize=(20, 15))
    dict_diseased_prescan_ids, dict_diseased_postscan_ids = get_disease_ids(disease)
    for i, organ in enumerate(organs):
        ax = axes.flat[i]

        df_organ = df.dropna(subset=[f'{organ}_pag', 'fundus_age', 'age'])
        #status = df_organ['key'].apply(lambda x: 'healthy' if x in healthy_test_ids else 'diseased')
        if group == 'diseased_prescan':
            status = (df_organ['key'].isin(dict_diseased_prescan_ids[organ]))  # diseased
            cmap = 'Oranges'
        elif group == 'diseased_postscan':
            status = (df_organ['key'].isin(dict_diseased_postscan_ids[organ]))  # diseased
            cmap = 'Oranges'
        elif group == 'diseased':
            status = (~df_organ['key'].isin(dict_healthy_ids[organ]))  # diseased
            cmap = 'Oranges'
        else:
            raise ValueError(f"Unknown group: {group}")
        
        healthy_compare = (df_organ['key'].isin(dict_healthy_test_ids[organ])) # healthy

        age = df_organ['age']                     # Chronological age
        dev = df_organ[f'{organ}_pag']           # Predicted age gap for the organ

        md_healthy = np.median(dev[healthy_compare])
        md_diseased = np.median(dev[status])

        sns.kdeplot(x=age[healthy_compare], 
                    y=dev[healthy_compare],
                    cmap=transparent_cmap("Blues"), 
                    fill=True, 
                    ax=ax)

        sns.kdeplot(x=age[status], 
                    y=dev[status],
                    cmap=transparent_cmap(cmap), 
                    fill=True, 
                    ax=ax)

        ax.axhline(md_healthy, color='black', linestyle='--')
        ax.axhline(md_diseased, color='black', linestyle='-')

        ax.set_xlabel('chronological age [y]')
        organ_name = str(organ.replace('_', ' \enspace '))
        ax.set_ylabel(r"predicted age gap  $\bf{%s}$ [y]" % organ_name)

        ax.set_ylim(-15, 15)
        ax.set_xlim(40, 90)

        xOutPlot = 90

        ax.text(xOutPlot, md_healthy, 
            f'\nHealthy: {md_healthy:.2f}', 
            ha = "right",
            va = "center")
        
        ax.text(xOutPlot, md_diseased, 
            f'Diseased: {md_diseased:.2f}\n', 
            ha = "right",
            va = "center")

    fig.tight_layout()
    print(f'saving KDE_{group}_{disease}.png')
    #plt.savefig(f'KDE_{group}_{disease}_median.png', bbox_inches="tight", dpi=300)
    #plt.close()
    plt.show()



In [None]:
disease_list = ['G30', 'E11', 'I10', 'I21', 'I25', 'N17', 'N18']

for disease in disease_list:
    plot_kde_comparison(df_age_pred_clean, disease, 'diseased_postscan', organs)
    if disease != 'G30':
        plot_kde_comparison(df_age_pred_clean, disease, 'diseased_prescan', organs)

***
## **Pairplot**

In [None]:
df_dev = df_age_pred_clean[[f"{organ}_pag" for organ in organs]]
print(len(df_dev))
df_dev = df_dev.dropna()
print(len(df_dev))

corr = df_dev.corr('pearson')

# Calculate p-values
pvals = pd.DataFrame(np.zeros((len(organs), len(organs))), columns=organs, index=organs)
for i in range(len(organs)):
    for j in range(len(organs)):
        if i != j:
            _, pvals.iloc[i, j] = stats.pearsonr(df_dev[f"{organs[i]}_pag"], df_dev[f"{organs[j]}_pag"])

threshold = 1.39e-3
mask_pvals = pvals > threshold

matrix = np.triu(corr, k=0)

corr_matrix = matrix.copy()
corr_matrix[mask_pvals] = np.nan 
plt.figure(figsize=(20, 20))

ax = sns.heatmap(corr, mask=corr_matrix, vmin=0, vmax=1, center=0.5, cmap=sns.color_palette("Spectral_r", as_cmap=True), square=True)
ax.set_xticklabels([label.get_text().replace('_', ' ').replace(' pag', '') for label in ax.get_xticklabels()], rotation=45, horizontalalignment='right', fontsize=20)
ax.set_yticklabels([label.get_text().replace('_', ' ').replace(' pag', '') for label in ax.get_yticklabels()], rotation=0, horizontalalignment='right', fontsize=20)

ax.invert_xaxis()
# Update the heatmap labels with correlation coefficients and p-values
for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
        if i >= j:
            continue
        pval = pvals.iloc[i, j]
#            if pval > threshold:
#                continue
        corr = matrix[i, j]
        
        ax.text(i+0.5, j+0.4, round(corr, 2), ha='center', va='center', weight='bold', fontsize=16, color='black')
        ax.text(i+0.5, j+0.7, format(pval, '.2e'), ha='center', va='center', fontsize=12, color='black')
plt.show()
ax.figure.savefig(f"Heatmap.png", bbox_inches="tight", dpi=300)

***
### **Statistical Tests**

In [None]:
def test_statistical_differences(df, organs, disease_list, test):
    for disease in disease_list:
        dict_diseased_prescan_ids, dict_diseased_postscan_ids = get_disease_ids(disease)

        for organ in organs:
            df_organ = df.dropna(subset=[f'{organ}_pag'])
            groups = ['diseased_prescan', 'healthy_test']
            status = []
            for group in groups:
                if group == 'healthy':
                    status.append((df_organ['key'].isin(dict_healthy_ids[organ]))) # healthy
                elif group == 'healthy_test':
                    status.append((df_organ['key'].isin(dict_healthy_test_ids[organ]))) # healthy
                elif group == 'diseased_prescan':
                    status.append((df_organ['key'].isin(dict_diseased_prescan_ids[organ])))  # diseased
                elif group == 'diseased_postscan':
                    status.append((df_organ['key'].isin(dict_diseased_postscan_ids[organ])))  # diseased
                elif group == 'diseased':
                    status.append((~df_organ['key'].isin(dict_healthy_ids[organ])))  # diseased
                else:
                    raise ValueError(f"Unknown group: {group}")
            group1 = df_organ[f'{organ}_pag'][status[0]]
            group2 = df_organ[f'{organ}_pag'][status[1]]

            stat, p_levene = stats.levene(group1, group2)
            equal_var = p_levene > 0.05  # True = equal variances

            # Step 2: Perform independent t-test
            t_stat, p_val_t = stats.ttest_ind(group1, group2, equal_var=equal_var)

            u_stat, p_val_u = stats.mannwhitneyu(group1, group2, alternative='two-sided')

            print(f"Organ: {organ}")
            print(f"Group 1: {groups[0]}, Group 2: {groups[1]}")
            print(f"Levene's test p = {p_levene:.4f} Equal variances: {equal_var}")
            if test == 'u':
                print(f"Mann-Whitney U test: U = {u_stat}, p = {p_val_u}")
            elif test == 't':
                print(f"T-test: t = {t_stat}, p = {p_val_t}")
            else:
                raise ValueError(f"Unknown test: {test}")
            
            
disease_list = ['G30', 'E11', 'I10', 'I21', 'I25', 'N17', 'N18']
test = 'u'
test_statistical_differences(df_age_pred_clean, organs, disease_list, test)
