In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from datetime import timedelta
from sklearn.metrics import roc_auc_score, roc_curve, auc
from scipy import stats
import seaborn as sns

In [None]:
df = pd.read_csv('your/prediction/result/with/subgroup/columns')

In [None]:
## prepare columns for each subgroup for [0,1]

column_list = ['Age_65_older', 'Male', 'HTN', 'DM', 'CAD', 'HF','Af', 'all']

In [2]:
def calculate_auc_95ci(y_true, y_pred):
    """
    Calculate the AUC and its 95% confidence interval.
    """
    auc = roc_auc_score(y_true, y_pred)
    n1 = sum(y_true)
    n2 = len(y_true) - n1
    q1 = auc / (2 - auc)
    q2 = 2*auc**2 / (1 + auc)
    se_auc = np.sqrt((auc*(1 - auc) + (n1 - 1)*(q1 - auc**2) + (n2 - 1)*(q2 - auc**2)) / (n1*n2))
    ci = stats.norm.interval(0.95, loc=auc, scale=se_auc)
    return auc, ci

In [None]:
means = []
conf_intervals = []

for column in column_list:
    
    if column == 'all':
        auc, ci = calculate_auc_95ci(df['K_6.5'], df['mean_prediction'])
        means.append(auc)
        conf_intervals.append(ci)
        
    else:
        ## calculate the auc[95%CI] for rows in positive subgroup
        df_filter = df[df[column] == 1]
        auc, ci = calculate_auc_95ci(df_filter['K_6.5'], df_filter['mean_prediction'])
        means.append(auc)
        conf_intervals.append(ci)
        
        ## calculate the auc[95%CI] for rows in negative subgroup
        df_filter = df[df[column] == 0]
        auc, ci = calculate_auc_95ci(df_filter['K_6.5'], df_filter['mean_prediction'])
        means.append(auc)
        conf_intervals.append(ci)

In [None]:
categories = [
    'Age ≥65 years', 'Age <65 years', 'Male sex', 'Female sex', 'Hypertension', 'Without hypertension', 
    'Diabetes', 'Without diabetes', 'CAD', 'Without CAD', 'Heart failure', 
    'Without Heart failure', 'Af', 'Without Af', 'Overall']


# Plot
fig, ax = plt.subplots(figsize=(10, 10))

# Adding horizontal lines for each category
for i, (mean, conf_int) in enumerate(zip(means, conf_intervals)):
    ax.plot(conf_int, [i, i], color='black')
    ax.plot(mean, i, 's', color='black')
    
# Add a dotted vertical line at the mean of the overall category
overall_mean = means[-1]  # Assuming the last mean is the overall
ax.axvline(x=overall_mean, color='grey', linestyle='--', linewidth=2)

# Add a line for the overall estimate
ax.plot(conf_intervals[-1], [len(conf_intervals) - 1, len(conf_intervals) - 1], color='black', linewidth=2)
ax.plot(overall_mean, len(means) - 1, 'D', color='black')  # 'D' for diamond shape


# Inverting y-axis so that the overall is at the bottom
ax.invert_yaxis()

# Setting the y-ticks to be the categories
ax.set_yticks(range(len(categories)))
ax.set_yticklabels(categories)
ax.tick_params(axis='y', labelsize=14)
ax.tick_params(axis='x', labelsize=14)

# Adding labels and title (if needed)
ax.set_title('AUC (95% CI)', fontsize=16)

# Set the range for the x-axis
ax.set_xlim(0.2, 1.0)

# Show the plot
plt.show()