In [None]:
# Jupyter notebook setup
%load_ext autoreload
%autoreload 2

# Deep learning imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Data science imports
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

# System imports
import os
import warnings
import sys 
sys.path.append('../src/')

# Custom imports
from training import train
from gpu_utils import restrict_GPU_pytorch

# Configure settings
pd.options.mode.chained_assignment = None
warnings.filterwarnings('ignore', category=FutureWarning)
restrict_GPU_pytorch('0')
np.random.seed(0)

# Visualization imports
import matplotlib.pyplot as plt

### 1. Load extracted EDW data.

In [None]:
anticoag_treatment = pd.read_csv('../extracted_EDW_data/trt_antic_ecglst.csv')
anticoag_treatment['anticoag_treatment'] = ~anticoag_treatment['OrderDTS'].isna()
anticoag_treatment = anticoag_treatment[['UniqueID', 'anticoag_treatment']]

# Filter for earlest hospitalization that occurs after ecg_date_lst
hospitalization = pd.read_csv('../extracted_EDW_data/hospital_vists.csv')
hospitalization['HospitalAdmitDTS'] = pd.to_datetime(hospitalization['HospitalAdmitDTS'])
hospitalization['ecg_date_lst'] = pd.to_datetime(hospitalization['ecg_date_lst'])
hospitalization = hospitalization[hospitalization['HospitalAdmitDTS'] > hospitalization['ecg_date_lst']]
hospitalization = hospitalization.sort_values('HospitalAdmitDTS')
hospitalization = hospitalization.groupby('UniqueID').first().reset_index()

specialist_visit = pd.read_csv('../extracted_EDW_data/specvis_lst.csv')
specialist_visit = specialist_visit[['UniqueID', 'spec_vis']]

insurance = pd.read_csv('../extracted_EDW_data/ins_ecglst.csv')
insurance = insurance[['UniqueID', 'instype_final']]

instype_map = {'Medicaid': 'Medicaid', 'Unknown/Missing': 'Unknown/Missing', 'Commercial': 'Commercial', 'Dual': 'Medicare', 'Medicare':'Medicare','Other': 'Other' }
insurance['instype_final'] = insurance['instype_final'].map(instype_map)
est_care = pd.read_csv('../extracted_EDW_data/estcare_ecglst.csv')
est_care = est_care[['UniqueID', 'PCPvisits_bin', 'CARvisits_bin', 'OTHvisits_bin']]

rate = pd.read_csv('../extracted_EDW_data/rate_ecglst.csv')
rhythm = pd.read_csv('../extracted_EDW_data/rhythm_ecglst.csv')
stroke = pd.read_csv('../extracted_EDW_data/stroke_ecglst.csv')
stroke['ecg_date_lst'] = pd.to_datetime(stroke['ecg_date_lst'])
stroke['DGNS_DT'] = pd.to_datetime(stroke['DGNS_DT'])

stroke['stroke_within_year'] = (
    (stroke['stroke'] == 1) & 
    ((stroke['DGNS_DT'] - stroke['ecg_date_lst']).dt.days <= 365)
).astype(int)

rate = rate[['UniqueID', 'trt_rate']]
rhythm = rhythm[['UniqueID', 'trt_rhythm']]
stroke = stroke[['UniqueID', 'stroke', 'stroke_within_year']]

exists_in_new_system = pd.read_csv('../extracted_EDW_data/Missing Patients EDW RZ.csv')
unique_ids_in_new_system = exists_in_new_system[exists_in_new_system['patid_found'] == 1]['UniqueID'].unique()


rate['trt_rate'].mean(), rhythm['trt_rhythm'].mean(), stroke['stroke'].mean(), stroke['stroke_within_year'].mean()


### 2. Load AF dataset and combine with EDW data.

In [None]:
from paths import map_params_to_filename
from ecg_datasets import ECGDemographicsDataset, ECGDataset
from ecg_preprocessing_fs import create_name_dob_hash
import pandas as pd 

outcome= 'afib'
merge_with_EDW_vars = True
preprocessing_params = {'max_pred_gap': 90, 
                        'selection_criteria': 'va', 
                        'include_single_ecgs': True, 
                        'mini': False}
unique_id_col = 'UniqueID'


ecg_data = pd.read_csv('./processed_data/processed_afib_' + map_params_to_filename(preprocessing_params) + '.csv')   
ecg_data['PatientFirstName'].fillna('nan', inplace=True)
ecg_data[unique_id_col] = create_name_dob_hash(ecg_data, 'PatientFirstName', 'PatientLastName', 'DateofBirth')

print("# of (Patients, ECGs) before merging with map to UniqueID: ", ecg_data['UniqueID'].nunique(), len(ecg_data))

preprocessing_params['one_ecg_per_patient'] = 'last' # can be 'false', 'last', 'first', 'last_white', 'last_two_ecgs'
preprocessing_params['loss'] = 'CE' # can be CE or Focal
preprocessing_params['mini'] = False # Refers to subsampling the train set.

# Collated based on files that throw errors 
files_to_skip = ['/data/workspace/ekg_bwr_trunc_norm/003161595_08-05-2019_15-20-53_SCD10410491PA05082019152053.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/000122100_09-21-2019_01-02-14_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006363943_05-24-2018_18-00-22_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/010464725_05-24-2019_15-17-48_SCD12365371PA24052019151748.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/000951233_08-12-2019_15-28-53_SKJ14029684PA12082019152853.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/003145657_08-14-2019_11-01-13_SKJ13388441SA14082019110113.npy', 
                    '/data/workspace/ekg_bwr_trunc_norm/035737121_08-13-2017_18-51-36_SKJ13408672SA13082017185136.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/003156598_11-22-2019_16-12-28_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006619841_05-07-2019_08-28-52_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006619841_05-11-2019_03-20-46_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006619841_05-06-2019_23-35-11_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/001410670_04-19-2016_14-07-26_SCD06223397PA19042016140726.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/002105344_08-10-2018_16-05-48_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/002105344_08-10-2018_16-05-48_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/002063599_01-26-2017_10-56-00_SCD07047035PA26012017105600.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/001410670_04-19-2016_14-07-26_SCD06223397PA19042016140726.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006619841_05-06-2019_23-35-11_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/006619841_05-07-2019_01-21-07_none.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/002120566_12-19-2018_08-39-41_SKJ14080390PA19122018083941.npy',
                    '/data/workspace/ekg_bwr_trunc_norm/002021971_05-14-2016_09-49-49_none.npy'
                    ]

ecg_data = ecg_data[~ecg_data['path_to_bwr_trunc_norm_data'].isin(files_to_skip)]
ecg_data['ecg_date'] = pd.to_datetime(ecg_data[['year', 'month', 'day']])
ecg_data['DateofBirth'] = pd.to_datetime(ecg_data['DateofBirth'])
ecg_data['PatientAge'] = ecg_data['ecg_date'] - ecg_data['DateofBirth']
ecg_data['PatientAge_years'] = ecg_data['PatientAge'].dt.days / 365.2425
ecg_data['PatientAge_years_01'] = ecg_data['PatientAge_years'] / 100
ecg_data.drop(columns=['Unnamed: 0.1', 'Unnamed: 0', 'PatientID', 'year', 'month', 'day', 'muse_mrn'], inplace=True)


print("# of (Patients, ECGs) before merging with map to UniqueID: ", ecg_data['UniqueID'].nunique(), len(ecg_data))
muse_edw_map = pd.read_csv('../outputs/intermediate/muse_edw_map.csv', dtype='str')
ecg_data = ecg_data[ecg_data['UniqueID'].isin(muse_edw_map['UniqueID'])]
ecg_data = pd.merge(ecg_data, muse_edw_map[['UniqueID', 'PatientID']], on='UniqueID')
print("# of (Patients, ECGs) after merging with map to UniqueID: ", ecg_data['UniqueID'].nunique(), len(ecg_data))


# Demographics based on Brianna's pull
demographics_file = pd.read_csv('../extracted_EDW_data/demographics_with_diagnosis_info.csv')
demographics_file['earliest_diagnosis'] = pd.to_datetime(demographics_file['earliest_diagnosis'])

test_set_size = .4
random_state = 0
preprocessing_params['test_set_size'] = test_set_size
# preprocessing_params['random_state'] = random_state

pids = sorted(list(set(ecg_data[unique_id_col])))
train_ids, test_ids = train_test_split(pids, test_size=preprocessing_params['test_set_size'], random_state=random_state)
val_ids, test_ids = train_test_split(test_ids, test_size=.5, random_state=random_state)

if merge_with_EDW_vars:
    # Merge with demographics
    ecg_data = ecg_data[ecg_data['UniqueID'].isin(demographics_file['UniqueID'])]
    print("# of (Patients, ECGs) after demographics merge: ", ecg_data['UniqueID'].nunique(), len(ecg_data))
    ecg_data = pd.merge(ecg_data, demographics_file, on='UniqueID')


    print("# of Patients in sample matched to some diagnosis: ", ecg_data['diagnosis_in_charts'].sum())
    # Filter for rows where diagnosis occurs AFTER ECG or there is no diagnosis in the charts
    ecg_data = ecg_data[(~ecg_data['diagnosis_in_charts']) | (0 < (ecg_data['earliest_diagnosis']  - ecg_data['ecg_date']).dt.days)]
    print("# of (Patients, ECGs) after filtering out established AFib diagnoses: ", ecg_data['UniqueID'].nunique(), len(ecg_data))
    ecg_data['time_to_diagnosis'] = (ecg_data['earliest_diagnosis'] - ecg_data['ecg_date']).dt.days

    # Add binary indicators for demographics
    for race_val in ['White', 'Hispanic or Latino', 'Black or African American', 'Asian', 'Other',
                     'Declined or Unavailable', 'Native American or Pacific Islander']:
        ecg_data['binary_' + race_val] = ecg_data['PatientRaceFinal'] == race_val
    ecg_data['binary_Male'] = ecg_data['SexDSC'] == 'Male'
    ecg_data.drop(columns=['binary_Race_CAUCASIAN', 'binary_Race_HISPANIC', 'binary_Race_BLACK', 'binary_Race_HISPANIC'], inplace=True)

    # Add indicators for downstream outcomes
    ecg_data = pd.merge(ecg_data, anticoag_treatment, on='UniqueID') # 13% of patients go on to have anticoag treatment 
    ecg_data['mortality'] = ~ecg_data['DeathDTS'].isna() # 12% of patients die
    ecg_data['hospitalization'] = ecg_data['UniqueID'].isin(hospitalization['UniqueID']) 
    ecg_data = pd.merge(ecg_data, specialist_visit, on='UniqueID')
    ecg_data = pd.merge(ecg_data, rate, on='UniqueID')
    ecg_data = pd.merge(ecg_data, rhythm, on='UniqueID')
    ecg_data = pd.merge(ecg_data, stroke, on='UniqueID')
    ecg_data = pd.merge(ecg_data, est_care, on='UniqueID')
    ecg_data = pd.merge(ecg_data, insurance, on='UniqueID')
print("\n# of Patients in  sample: ", ecg_data['UniqueID'].nunique())
print("# of ECGs in sample: ", len(ecg_data))

ecg_data['DeathDTS'] = pd.to_datetime(ecg_data['DeathDTS']) 
ecg_data['date'] = pd.to_datetime(ecg_data['date']) 
ecg_data['mortality_within_one_year'] =  (ecg_data['DeathDTS'] - ecg_data['date']).dt.days < 365

# Remove UniqueIDs who are associated with; both positive & negative class;
# it's because sometime a patient's name includes their middle name,
# which can result in two UniqueIDs for the same patient.
uniqueids_positive_and_negative = ecg_data.groupby('UniqueID')['label'].nunique()
repeated_unique_ids_across_class = uniqueids_positive_and_negative[uniqueids_positive_and_negative > 1].index.values
ecg_data = ecg_data[~ecg_data['UniqueID'].isin(repeated_unique_ids_across_class)]

if preprocessing_params['one_ecg_per_patient'] != 'false':
    # Sort by ECG date (ascending=False)
    # Select one per UniqueID
    ecg_data = ecg_data.sort_values('ecg_date')
    if preprocessing_params['one_ecg_per_patient'].startswith('first'):
        ecg_data = ecg_data.groupby('UniqueID').first().reset_index()
    else:
        ecg_data = ecg_data.groupby('UniqueID').last().reset_index()
print(ecg_data['label'].mean(), preprocessing_params)
print("\n# of Patients in  sample: ", ecg_data['UniqueID'].nunique())
print("# of ECGs in sample: ", len(ecg_data))



In [650]:
# For sensitivity analysis, is not used in main analysis.
if preprocessing_params['one_ecg_per_patient'] == 'last_two_ecgs':
    negatives = ecg_data[ecg_data['label'] == 0]
    muse_cache_df = pd.read_csv('../outputs_intermediate/muse_cache_files/patient_mrn_to_file.csv', dtype='str')

    all_negatives_ecgs = muse_cache_df[muse_cache_df['UniqueID'].isin(negatives['UniqueID'])]
    all_negatives_ecgs['date'] = pd.to_datetime(all_negatives_ecgs[['year', 'month', 'day']])
    all_negatives_ecgs = all_negatives_ecgs.sort_values(by = 'date')
    all_negatives_ecgs = all_negatives_ecgs.groupby('UniqueID').head(2).reset_index()

    # Filter out  patients with only 1 ECG
    unique_id_to_counts = all_negatives_ecgs['UniqueID'].value_counts() 
    unique_id_to_counts =  unique_id_to_counts[unique_id_to_counts > 1]
    print("# of Unique IDs in negatives: ", all_negatives_ecgs['UniqueID'].nunique())
    all_negatives_ecgs = all_negatives_ecgs[all_negatives_ecgs['UniqueID'].isin(unique_id_to_counts.index)]
    print("# of Unique IDs in negatives after filtering out patients with 1 ECG: ",all_negatives_ecgs['UniqueID'].nunique())

    # Filter out patients where second ECG happens more than 90 days after first
    all_negatives_ecgs = all_negatives_ecgs.sort_values(by=['UniqueID', 'date'])

    # Compute the difference in days between the two rows for each UniqueID
    date_diff = all_negatives_ecgs.groupby('UniqueID')['date'].diff().abs()

    # Keep only UniqueIDs where the date difference is ≤ 90 days
    valid_ids = date_diff.groupby(all_negatives_ecgs['UniqueID']).max().between(pd.Timedelta(days=1), pd.Timedelta(days=89))

    # Filter the dataframe to retain only valid UniqueIDs
    all_negatives_ecgs = all_negatives_ecgs[all_negatives_ecgs['UniqueID'].isin(valid_ids[valid_ids].index)]
    print("# of Unique IDs in negatives after filtering for patients with 2 ECGs within 90 days ",all_negatives_ecgs['UniqueID'].nunique())

    all_negatives_ecgs = all_negatives_ecgs.sort_values(by='date')
    negative_pids_with_two_ecgs =  all_negatives_ecgs['UniqueID'].values
    positive_pids = ecg_data[ecg_data['label'] == 1]['UniqueID'].values
    pids_to_keep_under_two_ecg_constraint = np.concatenate([negative_pids_with_two_ecgs, positive_pids])


### 2. Split dataset into train, calibration, and study sample.

In [None]:
BATCH_SIZE = 8
N_WORKERS = 6
NUM_EPOCHS = 10
LR = 0.0001

unique_id_col = 'UniqueID'
additional_feat_names = ['binary_Black or African American', 'binary_Hispanic or Latino',
                         'binary_Declined or Unavailable','binary_Asian',  'binary_Other', 
                         'binary_Native American or Pacific Islander',
                         'binary_Male', 
                         'PatientAge_years_01']

split_dfs = []
split_paths = []
split_y = []
additional_feats = []
for i, pid_set in enumerate([train_ids, val_ids, test_ids]):
    split_df = ecg_data[ecg_data[unique_id_col].isin(pid_set)]
    split_df.reset_index(drop=True, inplace=True)

    # Relevant to sensitivity analyses; not used in main analysis.
    if i in [0,1] and preprocessing_params['one_ecg_per_patient'] == 'last_white':
        split_df = split_df[split_df['PatientRaceFinal'] == 'White']
    
    # Relevant to sensitivity analyses; not used in main analysis.
    if i in [0,1] and preprocessing_params['one_ecg_per_patient'] == 'last_two_ecgs':
        split_df = split_df[split_df['UniqueID'].isin(pids_to_keep_under_two_ecg_constraint)]

    if i > 0:
        # Ensure we evaluate each patient with only one ECG
        split_df = split_df.sample(frac=1, random_state=0)
        split_df = split_df.groupby(unique_id_col).first().reset_index() # Replace with a random ECG
        print(split_df[unique_id_col].nunique())
        split_df = split_df[split_df[unique_id_col].isin(unique_ids_in_new_system)]
        print(split_df[unique_id_col].nunique())

    split_paths.append(list(split_df['path_to_bwr_trunc_norm_data']))
    split_y.append(np.array(list(split_df['label'])))
    additional_feat_values = split_df[additional_feat_names].fillna(0).values
    split_dfs.append(split_df)
    additional_feats.append(additional_feat_values.astype(int))

train_ecg_paths, val_ecg_paths, test_ecg_paths = split_paths
train_additional_feats, val_additional_feats, test_additional_feats = additional_feats
train_y, val_y, test_y = split_y

if len(additional_feat_names) == 0:
    train_dataset = ECGDataset(train_ecg_paths, train_y)
    val_dataset = ECGDataset(val_ecg_paths, val_y)
    test_dataset = ECGDataset(test_ecg_paths, test_y)
else:
    train_dataset = ECGDemographicsDataset(train_ecg_paths,  train_additional_feats, train_y)
    val_dataset = ECGDemographicsDataset(val_ecg_paths, val_additional_feats, val_y)
    test_dataset = ECGDemographicsDataset(test_ecg_paths, test_additional_feats, test_y)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=False, shuffle=True, num_workers=N_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=False, shuffle=False, num_workers=N_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, pin_memory=False, shuffle=False, num_workers=N_WORKERS)
print(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset))

expt_config={'arch': 'Net1D','additional_features': False}
if len(additional_feat_names) > 0:
    expt_config['additional_features'] = True

print("AF rates across splits: ", train_y.mean(), val_y.mean(), test_y.mean())


### 3. Train model

In [None]:
from models import get_model
from torch import optim, nn
from kornia.losses import BinaryFocalLossWithLogits

model = get_model(additional_inputs=0)
if expt_config['additional_features']:
    model = get_model( additional_inputs=len(additional_feat_names)) # 2.8M params

# Obtain from: https://github.com/ecg-net/PreOpNet.
checkpoint = torch.load('../PreOpNet/PreOpNet MACE.pt')
model.load_state_dict(checkpoint, strict=False)
model.cuda()

# Option to use a focal loss instead of cross-entropy. We found it did not 
# improve calibration of the network trained on the ECG + demographics. 
criterion = nn.BCEWithLogitsLoss()
if preprocessing_params['loss'] == 'Focal':
    kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
    criterion =  BinaryFocalLossWithLogits(**kwargs)

optimizer_ft = optim.Adam(model.parameters(), lr=LR) # Same as VA paper

save_file_name = '../outputs/saved_models/model_tmp_' + outcome 
if expt_config['additional_features']:
    save_file_name = './outputs/saved_models/model_tmp_' + outcome + '_with_other_features'

preprocessing_keys = sorted(preprocessing_params.keys())
for key in preprocessing_keys:
    save_file_name = save_file_name +'_' + key + '=' + str(preprocessing_params[key])

print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
print(expt_config)
print(save_file_name)

model, history = train(
    model,
    criterion,
    optimizer_ft,
    train_loader,
    val_loader,
    save_file_name,
    max_epochs_stop=10,
    n_epochs=NUM_EPOCHS,
    print_every=1,
    augmentations_on=True,
    expt_config=expt_config)
torch.cuda.empty_cache()

### 4. Load model

In [None]:
from models import get_model
import statsmodels.api as sm
from utils import calibration_plot, diagnosis_curve_plot

if expt_config['additional_features']:
    model = get_model(additional_inputs=len(additional_feat_names)) # 2.8M params
    save_file_name = './outputs/saved_models/model_tmp_' + outcome + '_with_other_features'
else:
    model = get_model()
    save_file_name = './outputs/saved_models/model_tmp_' + outcome



preprocessing_keys = sorted(preprocessing_params.keys())
for key in preprocessing_keys:
    save_file_name = save_file_name +'_' + key + '=' + str(preprocessing_params[key])


print(save_file_name)
checkpoint = torch.load(save_file_name)
model.load_state_dict(checkpoint)
model.cuda()

### 5. Obtain model predictions on calibration and study sample

In [None]:
def eval(model, dataloader):
    model.eval()
    device = torch.device("cuda:{}".format(0))
    all_pred_prob = []
    all_pred = []
    all_labels = []
    for data, target in tqdm(dataloader):
        if len(data) == 2:
            data = (data[0].to(device=device, dtype=torch.float), data[1].to(device=device, dtype=torch.float))
            cleaned_data = (torch.nan_to_num(data[0], nan=0), torch.nan_to_num(data[1], nan=0))
        else:
            data = data.to(device=device, dtype=torch.float)
            cleaned_data = torch.nan_to_num(data, nan=0) 

        cleaned_targets = torch.nan_to_num(target, nan=0)
        output = model(cleaned_data)
        pred_prob = torch.sigmoid(output)
        pred = torch.round(torch.sigmoid(output))
        all_pred_prob.append(pred_prob.detach())
        all_pred.append(pred.detach())
        all_labels.append(cleaned_targets.detach())

    # Calculate AUC
    all_labels = torch.cat(all_labels).cpu().numpy()
    all_pred = torch.cat(all_pred).cpu().numpy()
    all_pred_prob = torch.cat(all_pred_prob).cpu().numpy()
    acc = np.mean(all_labels == all_pred)
    
    sensitivity = np.mean(all_pred[all_labels == 1])
    specificity = 1 - np.mean(all_pred[all_labels == 0])
    auc = roc_auc_score(all_labels, all_pred_prob)
    auprc = average_precision_score(all_labels, all_pred_prob)
    return acc, auc, auprc, sensitivity, specificity, all_labels, all_pred_prob, all_pred

acc, auc, auprc, _, _, val_labels, val_pred_probs, val_preds = eval(model, val_loader)
print('Accuracy: ', acc, '\tAUPRC: ', auprc, '\tAUC: ', auc)
acc, auc, auprc, sensitivity, specificity, test_labels, test_pred_probs, test_preds = eval(model, test_loader)
acc, auc, auprc, sensitivity, specificity


### 6. Calibrate model using the calibration sample. 

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegression, LinearRegression

scaler = StandardScaler()

age_thresh = 17
val_split_df = split_dfs[1]
test_split_df = split_dfs[2]
train_split_df = split_dfs[0]

val_split_df['risk_score'] = val_pred_probs
test_split_df['risk_score'] = test_pred_probs

# For sensitivity analysis; not used in main analysis.
if preprocessing_params['one_ecg_per_patient'] != 'last_two_ecgs':
    combined_df = pd.concat([val_split_df, test_split_df], ignore_index=True)
    shuffled_df = shuffle(combined_df, random_state=42)

    # Split back into validation and test sets
    split_ratio = 0.625
    split_point = int(len(shuffled_df) * split_ratio)

    # Create new validation and test splits
    val_split_df = shuffled_df.iloc[:split_point].reset_index(drop=True)
    test_split_df = shuffled_df.iloc[split_point:].reset_index(drop=True)

if preprocessing_params['one_ecg_per_patient'] == 'last_white':
    val_split_df = val_split_df[val_split_df['PatientRaceFinal'] == 'White']


val_split_df['instype_final'] = val_split_df['instype_final'].map(instype_map)
val_split_df = val_split_df[val_split_df['PatientAge_years'] > age_thresh]
val_split_df = val_split_df[val_split_df['instype_final'].isin(['Medicare', 'Medicaid', 'Commercial'])]
val_split_df = val_split_df[val_split_df['PatientRaceFinal'].isin(['White', 'Black or African American', 'Hispanic or Latino', 'Asian', 'Other'])]
val_split_df = val_split_df[val_split_df['PrimLangDSC'].isin(['English', 'Spanish', 'Other'])]


test_split_df = test_split_df[test_split_df['PatientAge_years'] > age_thresh]
test_split_df = test_split_df[test_split_df['instype_final'].isin(['Medicare', 'Medicaid', 'Commercial'])]
test_split_df = test_split_df[test_split_df['PatientRaceFinal'].isin(['White', 'Black or African American', 'Hispanic or Latino', 'Asian', 'Other'])]
test_split_df = test_split_df[test_split_df['PrimLangDSC'].isin(['English', 'Spanish', 'Other'])]
test_split_df['instype_final'] = test_split_df['instype_final'].map(instype_map)

def add_necessary_binary_features(df):
    df['binary_Female'] = df['SexDSC'] == 'Female'
    df['binary_Non-English'] = df['PrimLangDSC'] != 'English'
    df['binary_Medicare'] = df['instype_final'] == 'Medicare'
    df['binary_Medicaid'] = df['instype_final'] == 'Medicaid'
    df['binary_NoPCPEncounter'] = df['PCPvisits_bin'] == 0
    df['binary_NoCAREncounter'] = df['CARvisits_bin'] == 0
    df['binary_NoOTHEncounter'] = df['OTHvisits_bin'] == 0
    return  df
    
val_split_df = add_necessary_binary_features(val_split_df)
test_split_df = add_necessary_binary_features(test_split_df)
train_split_df = add_necessary_binary_features(train_split_df)

test_split_df['PrimLangDSC'] = test_split_df['PrimLangDSC'].map(lambda x: x if x == "English" else "Non-English")


lr = LogisticRegression(penalty=None)
lpm = LinearRegression()
feat_names = [
            'binary_Black or African American', 'binary_Hispanic or Latino', 'binary_Other', 'binary_Asian', 
            'binary_Medicaid', 'binary_Medicare', 'binary_Non-English',
            'binary_NoPCPEncounter', 'binary_NoCAREncounter', 'binary_NoOTHEncounter',
            'binary_Female', 'PatientAge_years_01', 'risk_score' ]

val_X_Y = val_split_df[feat_names + ['label']]
val_X = val_X_Y[feat_names].values.astype(float)
val_y = val_X_Y['label'].values.astype(float)

test_X_Y = test_split_df[feat_names + ['label']]
test_X = test_X_Y[feat_names].values.astype(float)
test_y = test_X_Y['label'].values.astype(float)

val_X = scaler.fit_transform(val_X)

test_X_Y = test_split_df[feat_names + ['label']].dropna()
test_X = test_X_Y[feat_names].values.astype(float)
test_y = test_X_Y['label'].values.astype(float)

test_X = scaler.transform(test_X)

lr.fit(val_X, val_y)
test_lr_pred_probs = lr.predict_proba(test_X)[:,1]
val_split_df['lr_adjusted_risk_score'] = lr.predict_proba(val_X)[:,1]
test_split_df['lr_adjusted_risk_score'] = test_lr_pred_probs

lpm.fit(val_X, val_y)
val_split_df['lpm_adjusted_risk_score'] = np.clip(lpm.predict(val_X), 0, 1)
test_split_df['lpm_adjusted_risk_score'] = np.clip(lpm.predict(test_X), 0, 1)


print(roc_auc_score(val_split_df['label'], val_split_df['lr_adjusted_risk_score']), roc_auc_score(test_split_df['label'], test_split_df['lr_adjusted_risk_score']))
print(roc_auc_score(val_split_df['label'], val_split_df['lpm_adjusted_risk_score']),roc_auc_score(test_split_df['label'], test_split_df['lpm_adjusted_risk_score']))

# Verify that the risk score is calibrated with respect to the label using an LPM fit to the study sample.
demographic_variables =  [
                        'binary_Black or African American', 'binary_Hispanic or Latino',
                         'binary_Asian',  'binary_Other', 
                         'binary_Medicare', 'binary_Medicaid', 
                         'binary_Non-English',
                         'binary_Female', 
                         'PatientAge_years_01',
                        'binary_NoPCPEncounter', 'binary_NoCAREncounter', 'binary_NoOTHEncounter'

                         ]

outcome = 'label'
for risk_score_col_name in ['lpm_adjusted_risk_score']:
    # Print out the statistics
    df = test_split_df.copy()
    df = df[demographic_variables + [risk_score_col_name, outcome]].dropna()

    X = df.drop(outcome, axis=1)  # Independent variables
    X = X[[risk_score_col_name] + demographic_variables ].astype(float)
    X = sm.add_constant(X)

    y = df[outcome].astype(int)  # Dependent variable

    model = sm.OLS(y, X)

    results = model.fit()
    print(results.summary())

### 7. Measure performance of risk score.

#### 7.1 AUC plots (Figure 2)

In [None]:
import roc_utils as ru
from utils import colors, prettify_group_name
import matplotlib.gridspec as gridspec

# To perform a ROC analysis using bootstrapping
n_samples = 10

axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
               ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
                ('PrimLangDSC', ('English', 'Non-English')),]

n_cols = 3 
n_rows = 2
risk_score_col_name = 'lpm_adjusted_risk_score'

# Create a figure
fig = plt.figure(figsize=(15, 8), dpi=300)  # Adjust the figsize as needed

# Define a gridspec layout for consistent subplot widths
gs = gridspec.GridSpec(2, 3, height_ratios=[2, 1], hspace=0.2)

# Top row of plots
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])

# Bottom row of plots
ax4 = fig.add_subplot(gs[1, 0])
ax5 = fig.add_subplot(gs[1, 1])
ax6 = fig.add_subplot(gs[1, 2])

axs = [[ax1, ax2, ax3], [ax4, ax5, ax6]]
for i, (col_name, feature_names) in enumerate(axis_params):
    ax = axs[0][i]
    plt.sca(ax)

    for feature_name in feature_names:

        plot_df = test_split_df[test_split_df[col_name] == feature_name]
        x = plot_df[risk_score_col_name]
        y = plot_df['label']
        roc = ru.compute_roc(X=x, y=y, pos_label=1)

        ru.plot_roc(roc,color=colors[feature_name], label=feature_name)

    
    plt.title(prettify_group_name[col_name])
    if i!=0:
        plt.ylabel("")

### Plot event rate in different groups
axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
        ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
        ('PrimLangDSC', ('English', 'Non-English')),
        ]

for k, (group_name, group_vals) in enumerate(axis_params):
    ax = axs[1][k]
    plt.sca(ax)

    plot_estimates = test_split_df.groupby(group_name)[[risk_score_col_name, 'label']].mean().reset_index()
    plot_estimates = plot_estimates[plot_estimates[group_name].isin(group_vals)]
    plot_estimates = pd.melt(plot_estimates, id_vars=[group_name], value_vars=[risk_score_col_name, 'label'], var_name='Type', value_name='Estimate')
    plot_estimates['Type'] = plot_estimates['Type'].map({risk_score_col_name: 'Predicted', 'label': 'Empirical'})
    if group_name == 'PatientRaceFinal':
        plot_estimates['PatientRaceFinal'] = pd.Categorical(plot_estimates[group_name], categories=['White', 'Black or African American', 'Hispanic or Latino', 'Other', 'Asian'], ordered=True)
        plot_estimates.sort_values('PatientRaceFinal', inplace=True)
    elif group_name == 'instype_final':
        plot_estimates['instype_final'] = pd.Categorical(plot_estimates[group_name], categories=['Commercial', 'Medicare', 'Medicaid'], ordered=True)
        plot_estimates.sort_values('instype_final', inplace=True)

    plot_estimates['Estimate'] = plot_estimates['Estimate']*100

    for i, type in enumerate(['Predicted', 'Empirical']):
    
        if type == 'Empirical':
            colname = 'label'
            color = 'white'
            edgecolor = 'black'
        else:
            colname = risk_score_col_name
            color = 'black'
            edgecolor = 'black'
        values = plot_estimates[plot_estimates['Type'] == type]['Estimate'].values
        group_vals = plot_estimates[plot_estimates['Type'] == type][group_name].values
        yticks = list(range(len(values)))

        for j, (value, ytick, group_val) in enumerate(zip(values, yticks, group_vals)):
            group_color = colors[group_val]
            type_name = type
            if j != 0:
                type_name = ""
            if type == 'Empirical':
                ax.scatter(value, ytick, s=100, facecolor=color,  edgecolor=group_color, label=type_name)
            else:
                ax.scatter(value, ytick, s=100, facecolor=group_color,  edgecolor=group_color, label=type_name, alpha=0.7)

    ax.set_yticklabels([])
    # ax.set_title(prettify_group_name[group_name])
    ax.set_ylabel('')
    ax.set_xlabel('AF ECG Prevalence (%)')

    ax.set_ylim(np.min(yticks)-.5, np.max(yticks)+.5)
    ax.set_xlim(0, 4.5)
    ax.legend(loc='lower right', title='Event Rate')
    if k != 2:
        ax.legend().remove()
    ax.invert_yaxis()
    ax.set_xlim(left=0)
    # ax.set_xlim(right=0.05)
    ax.grid(color='lightgray', linestyle='--', linewidth=0.5)

plt.tight_layout()

n_bins = 5
test_split_df['quintile'] = pd.qcut(test_split_df[risk_score_col_name], q=n_bins, labels=False, duplicates='drop')

# Get top and bottom quintiles
top_quintile = test_split_df[test_split_df['quintile'] == test_split_df['quintile'].max()]
bottom_quintile = test_split_df[test_split_df['quintile'] == test_split_df['quintile'].min()]

# Compute rates
top_rate = top_quintile['label'].mean()
bottom_rate = bottom_quintile['label'].mean()

# Compute ratio
ratio = top_rate / bottom_rate if bottom_rate > 0 else float('inf')

print(f"Rate Ratio bt highest and lowest bins, ({n_bins} Bins): {ratio}")

#### 7.2 Calibration plots (Figure S3)

In [None]:
from utils import colors
bins = np.linspace(0, 0.03, num=10)  # 0 to 0.5 with 10 equal-width bins

# Create a placeholder for plotting data

# Group by race
group_name = 'PatientRaceFinal'
group_vals = ['White', 'Black or African American', 'Hispanic or Latino', 'Asian', 'Other']

# group_name = 'instype_final'
# group_vals = ['Medicare', 'Medicaid', 'Commercial']

# group_name = 'PrimLangDSC'
# group_vals = ['English', 'Spanish', 'Other']

axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
               ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
                ('PrimLangDSC', ('English', 'Non-English')),]

n_cols = 3 
n_rows = 1
fig, axs = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(10, 4))

risk_score_col_name = 'lpm_adjusted_risk_score'
for i, (group_name, group_vals) in enumerate(axis_params):

    plt.sca(axs[i])
    plot_data = []

    for group_val in group_vals:
        group_df = test_split_df[test_split_df[group_name] == group_val]
        # Bin the data
        group_df['bin'] = pd.cut(group_df[risk_score_col_name], bins, include_lowest=True)
        grouped = group_df.groupby('bin')
        
        # Calculate x and y values for the plot
        x_values = grouped[risk_score_col_name].mean()
        # y_values = grouped['label'].mean() - grouped[risk_score_col_name].mean()
        y_values = grouped['label'].mean()
        sizes = np.sqrt(grouped.size())  # Number of points in each bin
        valid_bins = sizes[sizes > 20].index
        x_values = x_values[valid_bins]
        y_values = y_values[valid_bins]
        sizes = sizes[valid_bins]
        
        # Add to plot data
        plot_data.append((x_values, y_values, sizes, group_val))


    for x, y, sizes, group_val in plot_data:
        plt.scatter(
            x, 
            y, 
            s=sizes,  # Marker size proportional to the group size
            label=group_val, 
            color=colors.get(group_val, 'gray'), 
            alpha=0.7
        )

        # Add plot labels and legend
    # plt.axhline(0, color='gray', linestyle='--', linewidth=1)
    lower = -.002
    upper = 0.03
    plt.plot([lower, upper], [lower, upper], linestyle='--', linewidth=1, color='gray', alpha=0.7, zorder=-100)
    plt.xlabel('Mean Predicted Risk Score')
    plt.ylabel('Prediction Error\n(Empirical Rate - Predicted Rate)')
    if i > 0:
        plt.ylabel('')
    prettify_group_name = {'PatientRaceFinal': 'Race', 'instype_final': 'Insurance', 'PrimLangDSC': 'Primary Language'}
    plt.title("Calibration by\n{}".format(prettify_group_name[group_name]))
    plt.legend( loc='lower right', fontsize=8)
    # plt.ylim(-.10, .10)
    plt.xlim(lower, upper)
    plt.ylim(lower, upper)
    plt.grid(color='lightgray', linestyle='--')
    plt.tight_layout()

#### 7.3 Calibration statistical tests (Section 4.2)

In [None]:
from pycaleva import CalibrationEvaluator

for i, (col_name, feature_names) in enumerate(axis_params):

    for feature_name in feature_names:
        print()
        plot_df = test_split_df[test_split_df[col_name] == feature_name]
        x = plot_df['lpm_adjusted_risk_score']
        y = plot_df['label']

        ce = CalibrationEvaluator(y, x, outsample=True, n_groups='auto')
        print(prettify_group_name[col_name], ':',  feature_name, '\n', ce.z_test())

### 8. Run main analysis


#### 8.1 Main analysis set-up

In [688]:
from utils import dem_feat_names, ecg_feat_names
from utils import plot_CIs_covariates
from matplotlib.ticker import FuncFormatter

table_name_target_col_map = {'diag': 'Diagnosis',
                             'diag_1yr': 'Diagnosis, 1yr',
                            'ecg': 'AF ECG',
                            'rvr': 'HR > 160',
                            'stroke': 'Stroke', 
                            'stroke_within_year': 'Stroke'}

preprocessing_params['rar_covariates'] = 'demographics_alone' # 'demographics_alone, demographics_joint
preprocessing_params['rar_model'] = 'OLS' # OLS or Logit
preprocessing_params['risk_score_col_name'] = 'lpm_adjusted_risk_score'

def convert_results_to_CI_df(results):
    coefficients = results.params
    errors = results.bse

    # Create the confidence intervals
    lower_bound = coefficients - 1.96 * errors
    upper_bound = coefficients + 1.96 * errors

    # Sort variables based on coefficients for cleaner visualization (optional)
    variables = coefficients.index
    sorted_idx = np.argsort(coefficients)
    variables = variables[sorted_idx]
    coefficients = coefficients[sorted_idx]
    lower_bound = lower_bound[sorted_idx]
    upper_bound = upper_bound[sorted_idx]

    coefficients = pd.DataFrame(coefficients).rename(columns={0: 'Estimate'})
    lower_bound = pd.DataFrame(lower_bound).rename(columns={0: 'Lower bound'})
    upper_bound = pd.DataFrame(upper_bound).rename(columns={0: 'Upper bound'})
    CI_df = pd.concat([coefficients, lower_bound, upper_bound], axis=1)
    return CI_df

## Add in HR > 160
max_hr_window = 365
## NOTE: Make sure to run Compute HR > 160 cell at the bottom of the notebook otherwise this will fail.
unique_id_to_max_hr = pd.read_csv('unique_id_to_max_hr_' + str(max_hr_window)+ '_days')
test_split_df['max_hr_within_' + str(max_hr_window) + '_days'] = test_split_df['UniqueID'].map(dict(zip(unique_id_to_max_hr['UniqueID'], unique_id_to_max_hr['AtrialRate'])))
test_split_df['max_hr_within_365_days'].fillna(0, inplace=True)
test_split_df['diagnosis_within_3yr'] = (test_split_df['diagnosis_in_charts'] & (test_split_df['time_to_diagnosis'] <= 1095)).astype(int)
test_split_df['diagnosis_within_2yr'] = (test_split_df['diagnosis_in_charts'] & (test_split_df['time_to_diagnosis'] <= 730)).astype(int)
hr_threshold = 160
hr_col = 'hr_over_' + str(hr_threshold)
test_split_df[hr_col] = (test_split_df['max_hr_within_365_days'] > hr_threshold).astype(int)

## Rename for ease of reading
test_split_df['binary_NoPCPEncounter'] = test_split_df['PCPvisits_bin'] == 0
test_split_df['binary_NoCAREncounter'] = test_split_df['CARvisits_bin'] == 0
test_split_df['binary_NoOTHEncounter'] = test_split_df['OTHvisits_bin'] == 0
test_split_df['binary_prim_language'] = test_split_df['PrimLangDSC'].map(lambda x: x if x == 'English' else 'Non-English')

risk_score_col_name = preprocessing_params['risk_score_col_name']
covariate_names =  [ 'binary_Female', 
                     'PatientAge_years_01',
                     'binary_Black or African American', 
                     'binary_Hispanic or Latino',
                     'binary_Asian', 'binary_Other',
                     'binary_Non-English',
                     'binary_Medicare', 'binary_Medicaid',
                     'binary_NoPCPEncounter',
                     'binary_NoCAREncounter',
                     'binary_NoOTHEncounter',
                     risk_score_col_name ]


# # # Kitchen sink
# covariate_names = dem_feat_names + ecg_feat_names + ['binary_NoPCPEncounter',  'binary_NoCAREncounter',  'binary_NoOTHEncounter']

# # Demographics alone
# # covariate_names = dem_feat_names 
# covariate_names = ['binary_Black or African American', 
#                          'binary_Hispanic or Latino',
#                          'binary_Asian', 'binary_Other']

rar_df = test_split_df
rar_model = preprocessing_params['rar_model']

target_cols = ['diagnosis_in_charts', 'label', hr_col, 'stroke_within_year', 'diagnosis_within_3yr', 'diagnosis_within_2yr' ]

#### 8.2 Outcome rates conditional on risk for AF diagnosis and ECG with HR > 160 (Figure 3)

In [None]:
from utils import colors, prettify_col_name
outcome_cols = ['diagnosis_in_charts', 'label', hr_col, 'stroke_within_year']
outcome_cols = ['diagnosis_in_charts', hr_col]
axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
               ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
                ('binary_prim_language', ('English', 'Non-English')),]

n_rows = len(outcome_cols)
n_cols = len(axis_params)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(9, 6), sharey=True, sharex=True, dpi=150)
test_split_df['lpm_adjusted_risk_score_rank'] = test_split_df['lpm_adjusted_risk_score'].rank(method='first')
test_split_df['quintile'] = pd.qcut(test_split_df['lpm_adjusted_risk_score_rank'], q=5, labels=False)
test_split_df['quintile'] = test_split_df['quintile'] + 1


for j, (group_col, group_vals) in enumerate(axis_params):
    for i, outcome_col in enumerate(outcome_cols):
        ax = axs[i, j]
        plt.sca(ax)
        
        plot_df = test_split_df[test_split_df[group_col].isin(group_vals)]
        plot_df['outcome_pct'] = plot_df[outcome_col]*100
        sns.lineplot(
            data=plot_df, 
            x="quintile", 
            y='outcome_pct', 
            hue=group_col, 
            hue_order=group_vals,
            palette=[colors[group_val] for group_val in group_vals], 
            ax=ax,
            marker="o",
            errorbar="ci",
            err_style="bars"

        )
        ax.set_title("")
        if i == 0:
            ax.set_title(prettify_col_name(group_col))
        ax.set_xlabel("Risk Quintile")
        ax.set_ylabel("HR > 160\nOutcome (%)")
        if i == 0:
            ax.set_ylabel("Diagnosis\nOutcome (%)")
        plt.legend(title=prettify_col_name(group_col), fontsize=8)
        if i == 0:
            plt.legend().remove()
        plt.grid(color='gray', alpha=0.4, linestyle='--')

    plt.tight_layout()


#### 8.3 Outcome rates conditional on risk for AF diagnosis, ECG with HR > 180, AF ECG within 90 days, Stroke (Figure S4)

In [None]:
from utils import colors, prettify_col_name

outcome_cols = ['diagnosis_in_charts', hr_col, 'label', 'stroke_within_year']
axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
               ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
                ('binary_prim_language', ('English', 'Non-English')),]

fig, axs = plt.subplots(4, 3, figsize=(10, 10), sharey=True, sharex=True, dpi=150)
test_split_df['lpm_adjusted_risk_score_rank'] = test_split_df['lpm_adjusted_risk_score'].rank(method='first')
test_split_df['quintile'] = pd.qcut(test_split_df['lpm_adjusted_risk_score_rank'], q=5, labels=False)
test_split_df['quintile'] = test_split_df['quintile'] + 1

rates_to_report = {}

for j, (group_col, group_vals) in enumerate(axis_params):
    for i, outcome_col in enumerate(outcome_cols):
        if outcome_col not in rates_to_report:
            rates_to_report[outcome_col] = {}
        ax = axs[i, j]
        plt.sca(ax)
        
        plot_df = test_split_df[test_split_df[group_col].isin(group_vals)]
        plot_df[outcome_col] = plot_df[outcome_col]*100
        sns.lineplot(
            data=test_split_df[test_split_df[group_col].isin(group_vals)], 
            x="quintile", 
            y=outcome_col, 
            hue=group_col, 
            hue_order=group_vals,
            palette=[colors[group_val] for group_val in group_vals], 
            ax=ax,
            marker="o",
            errorbar="ci",
            err_style="bars"

        )
        ax.set_title("")
        if j == 1:
            ax.set_title(prettify_col_name(outcome_col))
        ax.set_xlabel("Risk Quintile")
        ax.set_ylabel("Outcome (%)")
        plt.legend(title=prettify_col_name(group_col), fontsize=8)
        if i != 1:
            plt.legend().remove()
        plt.grid(color='gray', alpha=0.4, linestyle='--')

        if outcome_col in ['diagnosis_within_year', 'diagnosis_in_charts', 'label', hr_col, 'stroke_within_year']:
            ref_group = test_split_df[test_split_df[group_col] == group_vals[0]]
            ref_group_rate = ref_group[ref_group['quintile'] == 5][outcome_col].mean()
            # print(f"{group_vals[0]} diagnosis rate in highest quintile: ", ref_group_rate)
            rates_to_report[outcome_col][group_vals[0]] = {}
            rates_to_report[outcome_col][group_vals[0]]['rate'] = np.round(ref_group_rate*100, 1)
            rates_to_report[outcome_col][group_vals[0]]['n'] = len(ref_group)

            for group_val in group_vals[1:]:
                rates_to_report[outcome_col][group_val] = {}
                group = test_split_df[test_split_df[group_col] == group_val]
                group = group[group['quintile'] == 5]
                # print(f"{group_val} {outcome_col} rate in highest quintile: ", group[outcome_col].mean())
                rates_to_report[outcome_col][group_val]['rate'] = np.round(group[outcome_col].mean()*100, 1)
                rates_to_report[outcome_col][group_val]['n'] = len(group)
            # print()

    plt.tight_layout()


In [None]:
from scipy.stats import norm

# Define confidence interval function for binary proportions
def proportion_confint(successes, n, alpha=0.05):
    p = successes / n
    z = norm.ppf(1 - alpha / 2)
    se = np.sqrt(p * (1 - p) / n)
    return p, p - z * se, p + z * se

# Subset the data
rate_comparison_df = test_split_df[
    test_split_df['PatientRaceFinal'].isin(['White', 'Black or African American'])
]
rate_comparison_df = rate_comparison_df[rate_comparison_df['quintile'] == 5]

rates_with_cis_to_report = {}

# Groupby aggregation with counts and sums
summary = rate_comparison_df.groupby('PatientRaceFinal').agg(
    n=('label', 'count'),
    label_sum=('label', 'sum'),
    diagnosis_in_charts_sum=('diagnosis_in_charts', 'sum'),
    hr_over_160_sum=(hr_col, 'sum')
).reset_index()

# Apply CI calculation
for col in ['label', 'diagnosis_in_charts', hr_col]:
    summary[[f'{col}_mean', f'{col}_ci_lower', f'{col}_ci_upper']] = summary.apply(
        lambda row: proportion_confint(row[f'{col}_sum'], row['n']), axis=1, result_type='expand'
    )

# Optional: format for display
display_cols = ['PatientRaceFinal']
for col in ['label', 'diagnosis_in_charts', hr_col]:
    summary[f'{col}_with_ci'] = summary.apply(
        lambda row: f"{row[f'{col}_mean']:.3f} ({row[f'{col}_ci_lower']:.3f}, {row[f'{col}_ci_upper']:.3f})",
        axis=1
    )
    display_cols.append(f'{col}_with_ci')

# Final output table
summary[display_cols]

#### 8.4 Output variables referenced in paper

In [None]:
def parse_str_with_mean_ci(mean_ci_str):
    mean, lci, uci = mean_ci_str.split(' ')
    return np.round(float(mean)*100,1), np.round(float(lci[1:-1])*100, 1), np.round(float(uci[:-1])*100, 1)
meanAFWhite, lciAFWhite, uciAFWhite= parse_str_with_mean_ci(summary[summary['PatientRaceFinal']  == 'White']['label_with_ci'].iloc[0])
meanAFBlack, lciAFBlack, uciAFBlack= parse_str_with_mean_ci(summary[summary['PatientRaceFinal']  == 'Black or African American']['label_with_ci'].iloc[0])
meanRVRWhite, lciRVRWhite, uciRVWhite= parse_str_with_mean_ci(summary[summary['PatientRaceFinal']  == 'White']['hr_over_160_with_ci'].iloc[0])
meanRVRBlack, lciRVRBlack, uciRVRBlack= parse_str_with_mean_ci(summary[summary['PatientRaceFinal']  == 'Black or African American']['hr_over_160_with_ci'].iloc[0])

print(f"\\newcommand{{\\rateDiagWhiteHighest}}{ {rates_to_report['diagnosis_in_charts']['White']['rate']}}")
print(f"\\newcommand{{\\rateDiagBlackHighest}}{ {rates_to_report['diagnosis_in_charts']['Black or African American']['rate']}}")
print(f"\\newcommand{{\\rateDiagHispHighest}}{ {rates_to_report['diagnosis_in_charts']['Hispanic or Latino']['rate']}}")
print(f"\\newcommand{{\\rateDiagAsianHighest}}{ {rates_to_report['diagnosis_in_charts']['Asian']['rate']}}")

print(f"\\newcommand{{\\rateDiagCommercialHighest}}{ {rates_to_report['diagnosis_in_charts']['Commercial']['rate']}}")
print(f"\\newcommand{{\\rateDiagMedicaidHighest}}{ {rates_to_report['diagnosis_in_charts']['Medicaid']['rate']}}")

print(f"\\newcommand{{\\rateDiagEngHighest}}{ {rates_to_report['diagnosis_in_charts']['English']['rate']}}")
print(f"\\newcommand{{\\rateDiagNonEngHighest}}{ {rates_to_report['diagnosis_in_charts']['Non-English']['rate']}}")

print(f"\\newcommand{{\\rateMeanAFWhiteHighest}}{ {meanAFWhite}}")
print(f"\\newcommand{{\\rateLCIAFWhiteHighest}}{ {lciAFWhite}}")
print(f"\\newcommand{{\\rateUCIAFWhiteHighest}}{ {uciAFWhite}}")

print(f"\\newcommand{{\\rateMeanAFBlackHighest}}{ {meanAFBlack}}")
print(f"\\newcommand{{\\rateLCIAFBlackHighest}}{ {lciAFBlack}}")
print(f"\\newcommand{{\\rateUCIAFBlackHighest}}{ {uciAFBlack}}")

print(f"\\newcommand{{\\rateMeanRVRWhiteHighest}}{ {meanRVRWhite}}")
print(f"\\newcommand{{\\rateLCIRVRWhiteHighest}}{ {lciRVRWhite}}")
print(f"\\newcommand{{\\rateUCIRVRWhiteHighest}}{ {uciRVWhite}}")

print(f"\\newcommand{{\\rateMeanRVRBlackHighest}}{ {meanRVRBlack}}")
print(f"\\newcommand{{\\rateLCIRVRBlackHighest}}{ {lciRVRBlack}}")
print(f"\\newcommand{{\\rateUCIRVRBlackHighest}}{ {uciRVRBlack}}")

highest_quintile_AF_ECG_rate = np.round(test_split_df[test_split_df['quintile'] == 5]['label'].mean()*100, 1)
lowest_quintile_AF_ECG_rate = np.round(test_split_df[test_split_df['quintile'] == 1]['label'].mean()*100, 1)
highest_quintile_AF_ECG_rate, lowest_quintile_AF_ECG_rate

print(f"\\newcommand{{\\highestQuintileRate}}{ {highest_quintile_AF_ECG_rate}}")
print(f"\\newcommand{{\\lowestQuintileRate}}{ {lowest_quintile_AF_ECG_rate}}")

In [None]:
# Run p-value tests to compare diagnosis rate in highest 
from statsmodels.stats.proportion import proportions_ztest

col_name = 'diagnosis_in_charts'
# col_name = 'diagnosis_within_year'
# Example inputs (replace with your actual numbers)
nobs_a = rates_to_report[col_name]['White']['n']  
count_a = int(rates_to_report[col_name]['White']['rate']*nobs_a/100)   # number of diagnoses in group B

for group in ['Black or African American', 'Hispanic or Latino', 'Asian']:

    

    nobs_b = rates_to_report[col_name][group]['n']   # number of diagnoses in group B
    count_b = int(rates_to_report[col_name][group]['rate']*nobs_b/100)    # group size for B

    # Run two-proportion z-test
    counts = [count_a, count_b]
    nobs = [nobs_a, nobs_b]
    print(counts, nobs, [count_a/nobs_a, count_b/nobs_b])
    stat, pval = proportions_ztest(count=counts, nobs=nobs, alternative='two-sided')

    print(f"{group} P-value: {pval:.4f}")

nobs_a = rates_to_report[col_name]['Commercial']['n']  
count_a = int(rates_to_report[col_name]['Commercial']['rate']*nobs_a/100)   # number of diagnoses in group B

for group in ['Medicare', 'Medicaid']:

    

    nobs_b = rates_to_report[col_name][group]['n']   # number of diagnoses in group B
    count_b = int(rates_to_report[col_name][group]['rate']*nobs_b/100)    # group size for B

    # Run two-proportion z-test
    counts = [count_a, count_b]
    nobs = [nobs_a, nobs_b]
    print(counts, nobs, [count_a/nobs_a, count_b/nobs_b])
    stat, pval = proportions_ztest(count=counts, nobs=nobs, alternative='two-sided')

    print(f"{group} P-value: {pval:.4f}")

nobs_a = rates_to_report[col_name]['English']['n']  
count_a = int(rates_to_report[col_name]['English']['rate']*nobs_a/100)   # number of diagnoses in group B

for group in ['Non-English']:

    

    nobs_b = rates_to_report[col_name][group]['n']   # number of diagnoses in group B
    count_b = int(rates_to_report[col_name][group]['rate']*nobs_b/100)    # group size for B

    # Run two-proportion z-test
    counts = [count_a, count_b]
    nobs = [nobs_a, nobs_b]
    print(counts, nobs, [count_a/nobs_a, count_b/nobs_b])
    stat, pval = proportions_ztest(count=counts, nobs=nobs, alternative='two-sided')

    print(f"{group} P-value: {pval:.4f}")

#### 8.5 Compute RAR coefficients for each outcome

In [694]:
from utils import deterministic_dict_hash
controls = ['binary_Female', 
            'PatientAge_years_01',
            'binary_NoPCPEncounter',
            'binary_NoCAREncounter',
            'binary_NoOTHEncounter',
            risk_score_col_name
            ]

if preprocessing_params['rar_covariates'] == 'demographics_joint':
    demographic_cols = [['binary_Black or African American', 'binary_Hispanic or Latino', 'binary_Asian', 'binary_Other', 
                        'binary_Medicare', 'binary_Medicaid', 'binary_Non-English']]
else:
    demographic_cols = [['binary_Black or African American', 'binary_Hispanic or Latino', 'binary_Asian', 'binary_Other'],
                        ['binary_Medicare', 'binary_Medicaid'],
                        ['binary_Non-English']]
event_rates = test_split_df[test_split_df['PatientRaceFinal'] == 'White'][target_cols].mean()
event_rate_map = dict(zip(event_rates.index, event_rates.values))
normalize_by_event_rate = False

all_CI_df = []
for target_col in target_cols:
    for demographic_col_names  in demographic_cols:
        df = rar_df.copy()
        df = df[controls + demographic_col_names + [target_col]].dropna()

        # Separate independent variables (X) and dependent variable (y)
        X = df.drop(target_col, axis=1)  # Independent variables
        X = X[controls +  demographic_col_names].astype(float)

        y = df[target_col].astype(int)  # Dependent variable

        # Add a constant to the model (intercept)
        X = sm.add_constant(X)

        # Build the model
        if rar_model == 'Logit':
            model = sm.Logit(y, X)
        elif rar_model == 'OLS':
            model = sm.OLS(y, X)


        # Fit the model
        results = model.fit(cov='HC3')
        CI_df = convert_results_to_CI_df(results)
        CI_df = CI_df.drop('const')
        if risk_score_col_name in CI_df.columns:
            CI_df = CI_df.drop(risk_score_col_name)

        for col_name in controls:
            if col_name in CI_df.columns:
                CI_df = CI_df.drop(col_name)

        for col_name in ['Estimate', 'Lower bound', 'Upper bound']:
            if rar_model == 'OLS':
                CI_df[col_name] = CI_df[col_name] * 100
        if normalize_by_event_rate:
            for col_name in ['Estimate', 'Lower bound', 'Upper bound']:
                CI_df[col_name] = CI_df[col_name] / event_rate_map[target_col]
            

        CI_df['target_col'] = target_col
        all_CI_df.append(CI_df)

all_CI_df = pd.concat(all_CI_df)
all_CI_df['target_col'] = all_CI_df['target_col'].map({'diagnosis_in_charts': 'diag', 'label': 'ecg', 'hr_over_160': 'rvr', 'stroke': 'stroke', 'diagnosis_within_year': 'diag_1yr',
                                                       'spec_vis': 'spec', 'trt_rate': 'trt', 'trt_rhythm': 'trt', 'stroke_within_year': 'stroke_within_year',
                                                       'anticoag_treatment': 'anticoag_treatment',
                                                        'diagnosis_within_3yr': 'diag_3yr', 'diagnosis_within_2yr': 'diag_2yr' })   

all_CI_df.to_csv('./outputs_est_coeffs/' + deterministic_dict_hash(preprocessing_params))                

#### 8.6 Plot RAR coeffs (Figure 4, Figure S5)

In [None]:
covariate_names = ['binary_Black or African American',
                'binary_Hispanic or Latino', 'binary_Asian', 'binary_Medicaid', 'binary_Medicare', 'binary_Non-English']

panel_names_order = ['diag', 'rvr'] # Figure 4
panel_names_order = ['diag', 'rvr', 'ecg', 'stroke_within_year'] # Figure S5

n_rows = 1
fig, axs = plt.subplots(n_rows, len(panel_names_order), figsize=(15, 4 ), sharex=True, sharey=True, dpi=100)


for i, panel_name in enumerate(panel_names_order):
    ax = axs[i]
    plt.sca(ax)
    plot_CI_df = all_CI_df[all_CI_df['target_col'] == panel_name].copy()
    plot_CI_df = plot_CI_df[plot_CI_df.index.isin(covariate_names)]
    plot_CI_df['cleaned_group_name'] = plot_CI_df.index.map(lambda x: x.split('_')[1])
    plot_CI_df.set_index('cleaned_group_name', inplace=True)
    dependent_vars = [panel_name]
    ax = plot_CIs_covariates(plot_CI_df[['Estimate', 'Lower bound', 'Upper bound']], covariate_names=[x.split('_')[1] for x in covariate_names],
                             show=False, ax=ax, ylabel_size=15, xlabel_size=15, horizontal_lines=False)
    
 
    plt.title(table_name_target_col_map[panel_name],  fontsize=15)
    plt.grid(color='lightgray', alpha=.5)

    plt.xlabel("Effect Size", fontsize=15)
    if rar_model == 'OLS':
        ax.xaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.0f}%"))
    
    plt.xticks(fontsize=12)

    plt.tight_layout()
plt.tight_layout()

In [None]:
ecg_data['PrimLangDSC'] = ecg_data['PrimLangDSC'].apply(lambda x: 'English' if x == 'English' else 'Non-English')

demographic_to_n = {}
for demographic_group in ['PatientRaceFinal', 'instype_final', 'PrimLangDSC']:
    gg = ecg_data[demographic_group].value_counts()
    for group_value, count in gg.items():
        print(group_value)
        demographic_to_n[group_value] = count


demographic_to_n_diag = {}
for demographic_group in ['PatientRaceFinal', 'instype_final', 'PrimLangDSC']:
    gg = ecg_data.groupby(demographic_group)['diagnosis_in_charts'].sum()
    for group_value, count in gg.items():
        demographic_to_n_diag[group_value] = count
        
rar_df['binary_White'] = (rar_df['PatientRaceFinal'] == 'White').astype(int)
rar_df['binary_Commercial'] = (rar_df['instype_final'] == 'Commercial').astype(int)
rar_df['binary_English'] = (rar_df['PrimLangDSC'] == 'English').astype(int)


#### 8.7 Output prevalence analysis (Table 2)

In [None]:
demographic_cols = [(['binary_Black or African American', 'binary_Hispanic or Latino', 'binary_Asian'], 'White', 'Race'),
                    (['binary_Medicare', 'binary_Medicaid'], 'Commercial', 'Insurance'),
                    (['binary_Non-English'], 'English', 'Primary Language'),
                         ]

all_demographics = ['binary_Black or African American',
                'binary_Hispanic or Latino', 'binary_Asian', 'binary_Medicaid', 'binary_Medicare', 'binary_Non-English']

demographic_cols = [(['binary_White', 'binary_Hispanic or Latino', 'binary_Asian'], 'Black or African American', 'Race'),
                    (['binary_Medicare', 'binary_Medicaid'], 'Commercial', 'Insurance'),
                    (['binary_Non-English'], 'English', 'Primary Language'),
                         ]

all_demographics = ['binary_White',
                'binary_Hispanic or Latino', 'binary_Asian', 'binary_Medicaid', 'binary_Medicare', 'binary_Non-English']
prevalence_df = []
n_decimals = 1
for demographic_col_names, default_group_name, category_name in demographic_cols:
    for group_name in demographic_col_names:

        # When other_demographics is empty, results correspond to the following model:
        # diag ~ race + risk_score + controls 
        # When other_demographics is not empty, we fit the following model:
        # diag ~ race + insurance + primary language + risk_score + controls
        # controls contains age, sex, access to care indicators, and the risk score
        other_demographics = [x for x in all_demographics if x not in demographic_col_names]
        # other_demographics = []

        target_col = 'diagnosis_in_charts'
        df = rar_df.copy()
        df = df[controls + demographic_col_names + other_demographics + [target_col]].dropna()

        X = df[controls + demographic_col_names + other_demographics].astype(float)
        y = df[target_col].astype(int)  

        # Add a constant to the model (intercept)
        X = sm.add_constant(X)

        # Fit model
        model = sm.OLS(y, X).fit()

        # test_split_df is equivalent to study split we refer to in paper.
        orig_rate = test_split_df[test_split_df[group_name] == 1]['diagnosis_in_charts'].mean()
        group_patients = X[X[group_name] == 1].copy()

        # demographic_col_names refers to columns pertaining to that specific demographic;
        # for example, when computing race prevalences, the insurance and primary language features within a group
        # are untouched
        for col in demographic_col_names:
            group_patients[col] = 0

        # Take average prediction for diagnosis across all patients within group,
        # This measures, for example, the rate of diagnosis if all black patients were white using the estimated coefficients.
        group_pred = model.predict(group_patients)
        corrected_rate = group_pred.mean()
        diff = corrected_rate - orig_rate

        orig_rate = f'{np.round(orig_rate*100, n_decimals)}'
        est_rate = f'{np.round(corrected_rate*100, n_decimals)}' 
        n_hidden_diags = int(diff*demographic_to_n[group_name.split('_')[-1]])
        n_diag = demographic_to_n_diag[group_name.split('_')[-1]]
        print(group_name, n_hidden_diags, n_diag)
        change_rate = np.round(100*n_hidden_diags/n_diag, n_decimals)
        prevalence_df.append({'Category': category_name, 'Group': group_name.split('_')[1], 'Obs. Prevalence': orig_rate + f' ({n_diag})', 'Est. Prevalence': est_rate + f' ({n_diag + n_hidden_diags})', 
                               'Change in Diagnosis': str(change_rate) + f' ({n_hidden_diags})'})
    orig_rate = test_split_df[test_split_df[demographic_col_names].sum(axis=1) == 0]['diagnosis_in_charts'].mean()
    n_diag = demographic_to_n_diag[default_group_name]
    default_patients = X[X[demographic_col_names].sum(axis=1) == 0]
    group_pred = model.predict(default_patients)
    corrected_rate = group_pred.mean()
    diff = corrected_rate - orig_rate
    n_hidden_diags = int(diff*demographic_to_n[default_group_name])
    orig_rate = f'{np.round(orig_rate*100, n_decimals)}'
    est_rate = f'{np.round(corrected_rate*100, n_decimals)}' 
    change_rate = np.round(100*n_hidden_diags/n_diag, n_decimals)
    prevalence_df.append({'Category': category_name, 'Group': default_group_name, 'Obs. Prevalence': orig_rate + f' ({n_diag})', 
                          'Est. Prevalence': est_rate + f' ({n_diag + n_hidden_diags})', 
                               'Change in Diagnosis': str(change_rate) + f' ({n_hidden_diags})'})
    print()
prevalence_df = pd.DataFrame(prevalence_df)

print(prevalence_df.to_latex(index=False, column_format="lcc", caption="Corrected Prevalence Rates by Group", label="tab:prevalences"))


In [None]:
# Mapping for display name cleanup
value_rename_map = {
    'BlackorAfricanAmerican': 'Black/African American',
    'HispanicorLatino': 'Hispanic/Latino'
}

# Start LaTeX table (requires \usepackage{makecell} in LaTeX preamble)
latex_rows = []
latex_rows.append(r"\begin{tabular}{llccc}")
latex_rows.append(r"\toprule")
latex_rows.append(r" & & \makecell{\textbf{Obs. Prevalence}} & \makecell{\textbf{Est. Prevalence}\\\textbf{$\Delta$ in Diagnosis Rate} \\")
latex_rows.append(r" & & \makecell{\textbf{% (N)}} & \makecell{\textbf{% (N)}\\\textbf{% (N)} \\")
latex_rows.append(r"\hline")

prev_category = None
for _, row in prevalence_df.iterrows():
    category = row['Category']
    val = row['Group']
    category = category.strip()
    val = val.strip().replace(' ', '')

    # Clean display values
    val = value_rename_map.get(val, val)

    # Only show category label the first time it appears
    group_col = f"\\textbf{{{category}}}" if category != prev_category else ""
    if category != prev_category:
        latex_rows.append("\midrule")
    prev_category = category

    # Build LaTeX row
    latex_row = f"{group_col} & {val}  & {row['Obs. Prevalence']} & {row['Est. Prevalence']} & {row['Change in Diagnosis']} \\\\"
    latex_rows.append(latex_row)
latex_rows.append(r"\bottomrule")
latex_rows.append(r"\end{tabular}")

# Combine and print
latex_code = "\n".join(latex_rows)
print(latex_code)


### 9. Run supplementary analyses

#### 9.1 Compare to baselines (Figure S1)

In [None]:
from utils import dem_feat_names, ecg_feat_names

risk_score_col_name = 'lpm_adjusted_risk_score'
comparison_test_split_df = test_split_df.copy()
dem_ecg_feat_names = dem_feat_names + ecg_feat_names
comparison_test_split_df = comparison_test_split_df[dem_ecg_feat_names + ['label'] + [risk_score_col_name]]
comparison_test_split_df = comparison_test_split_df.dropna()

for i, (feat_set_name, feat_names) in enumerate([('dems', dem_feat_names),
                    ('ecg', ecg_feat_names),
                    ('dem_ecg_feats', dem_ecg_feat_names)]):

    train_split_df = split_dfs[0]
    train_X_Y = train_split_df[feat_names + ['label']].dropna()
    train_X = train_X_Y[feat_names].values
    train_Y = train_X_Y['label'].values

    test_X_Y = comparison_test_split_df[feat_names + ['label']].dropna()
    test_X = test_X_Y[feat_names].values
    test_Y = test_X_Y['label'].values

    train_X = scaler.fit_transform(train_X)
    test_X = scaler.transform(test_X)


    lr = LogisticRegression()
    lr.fit(train_X, train_Y)

    lr_pred_probs = lr.predict_proba(test_X)[:,1]
    comparison_test_split_df['lr_risk_score_' + feat_set_name] = lr_pred_probs
    print(feat_set_name, roc_auc_score(test_Y, lr_pred_probs))


# Plot AUC curves

fig = plt.figure(figsize=(6, 6), dpi=150)

prettify_risk_score_col_name = {
    'lpm_adjusted_risk_score': 'Our Model',
    'lr_risk_score_dem_ecg_feats': 'LR + Demographics, ECG',
    'lr_risk_score_ecg': 'LR + ECG',
    'lr_risk_score_dems': 'LR + Demographics'
}
for i, risk_score_to_compare in enumerate(['lpm_adjusted_risk_score', 'lr_risk_score_dem_ecg_feats', 
                                            'lr_risk_score_ecg', 'lr_risk_score_dems']):

    x = comparison_test_split_df[risk_score_to_compare]
    y = comparison_test_split_df['label']
    roc = ru.compute_roc(X=x, y=y, pos_label=1)

    ru.plot_roc(roc,color=colors[risk_score_to_compare], 
                label=prettify_risk_score_col_name[risk_score_to_compare])
    
    plt.title('Comparison to Alternative Models')
    if i!=0:
        plt.ylabel("")
        
    # Plot AUC curves
fig = plt.figure(figsize=(6, 6), dpi=150)

prettify_risk_score_col_name = {
    'lpm_adjusted_risk_score': 'Our Model',
    'lr_risk_score_dem_ecg_feats': 'LR + Demographics, ECG',
    'lr_risk_score_ecg': 'LR + ECG',
    'lr_risk_score_dems': 'LR + Demographics'
}
for i, risk_score_to_compare in enumerate(['lpm_adjusted_risk_score', 'lr_risk_score_dem_ecg_feats', 
                                            'lr_risk_score_ecg', 'lr_risk_score_dems']):

    x = comparison_test_split_df[risk_score_to_compare]
    y = comparison_test_split_df['label']
    roc = ru.compute_roc(X=x, y=y, pos_label=1)

    ru.plot_roc(roc,color=colors[risk_score_to_compare], 
                label=prettify_risk_score_col_name[risk_score_to_compare])
    
    plt.title('Comparison to Alternative Models')
    if i!=0:
        plt.ylabel("")

#### 9.2 Plot distributions of estimated risk by demographic group (Figure S2)

In [None]:
test_split_df['PrimLangDSC'] = test_split_df['PrimLangDSC'].apply(lambda x: 'English' if x == 'English' else 'Non-English')

axis_params = [('PatientRaceFinal', ('White', 'Black or African American', 'Hispanic or Latino', 'Asian')),
               ('instype_final', ('Commercial', 'Medicare', 'Medicaid')),
                ('PrimLangDSC', ('English', 'Non-English')),
                ]


# Risk score column name
risk_score_col_name = 'lpm_adjusted_risk_score'  # Update with the actual column name

# Set up the figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

for ax, (col_name, values) in zip(axes, axis_params):
    # Filter the DataFrame to include only rows with the specified values
    plt.sca(ax)
    filtered_df = test_split_df[test_split_df[col_name].isin(values)]
    filtered_df['log_' + risk_score_col_name] = np.log(filtered_df[risk_score_col_name])
    # filtered_df = filtered_df[filtered_df['diagnosis_in_charts'] == 1]
    # Plot the distribution of risk scores within each group
    sns.kdeplot(
        data=filtered_df,
        x='log_' + risk_score_col_name,
        hue=col_name,
        ax=ax,
        # stat = 'density',
        # bins=50,
        common_norm=False, 
        # cumulative=False,
        legend=True,
        bw_adjust=1,
    )
    ax.get_legend().set_title(prettify_col_name(col_name))
    # Set axis labels and title
    ax.set_title(f'Distribution of Risk by {prettify_col_name(col_name)}')
    ax.set_xlabel("log(Predicted Risk of AF within 90 days)")
    ax.set_ylabel("Density")
    ax.grid(color='lightgray', linestyle='--', linewidth=0.5, zorder=-100)


# Adjust layout and display the plot
plt.tight_layout()
plt.show()

#### 9.3 Sensitivity analyses (Figure S6)

In [None]:
prefix = './outputs_est_coeffs/'

diag_config = {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last_two_ecgs', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}
# diag_config = {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last', 'loss': 'CE', 'test_set_size': 0.4,
#                 'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}
# # Filter to the alternate diagnosis outcomes (2 year and 3 year)


df = pd.read_csv(prefix + '/' + deterministic_dict_hash(diag_config))
df = df[df['target_col'] == 'diag']
df[df['Unnamed: 0'].isin(all_demographics)]

In [594]:
# Sensitivity analysis

# Read in coefficients from two different dataset
all_configs = [('Patients with confirmed negative', {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last_two_ecgs', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}), # Restrict dataset to patients with two ECGs. done.
                ('White patients', {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last_white', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}), # Restrict train + calibration dataset to white patients.
                ('Joint regression', {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_joint'}), # Regress on all demographics at once, not individually
				('Main', {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}) # Main text results
]

coeffs_across_sensitivity_analyses = []
prefix = './outputs/est_coeffs/'
for config_name, config in all_configs:
	df = pd.read_csv(prefix + '/' + deterministic_dict_hash(config))
	df = df[df['target_col'] == 'diag']
	df['config_name'] = config_name
	coeffs_across_sensitivity_analyses.append(df)
coeffs_across_sensitivity_analyses = pd.concat(coeffs_across_sensitivity_analyses)

diag_config = {'max_pred_gap': 90, 'selection_criteria': 'va', 'include_single_ecgs': True, 'mini': False, 'one_ecg_per_patient': 'last', 'loss': 'CE', 'test_set_size': 0.4,
                'rar_model': 'OLS', 'risk_score_col_name': 'lpm_adjusted_risk_score','rar_covariates': 'demographics_alone'}
df = pd.read_csv(prefix + '/' + deterministic_dict_hash(diag_config))
df = df[df['target_col'].isin(['diag_2yr', 'diag_3yr'])]

df.rename(columns={'target_col': 'config_name'}, inplace=True)
df['config_name'] = df['config_name'].map(lambda x: 'Diagnosis within 2 years' if x == 'diag_2yr' else 'Diagnosis within 3 years')

coeffs_across_sensitivity_analyses = pd.concat([coeffs_across_sensitivity_analyses, df], axis=0)
coeffs_across_sensitivity_analyses.rename(columns={'Unnamed: 0': 'Group'}, inplace=True)

In [None]:

covariate_names = ['binary_Black or African American',
                'binary_Hispanic or Latino', 'binary_Asian', 'binary_Medicaid', 'binary_Medicare', 'binary_Non-English']

config_names = ['Main', 'Patients with confirmed negative', 'White patients', 'Joint regression', 'Diagnosis within 2 years', 'Diagnosis within 3 years']
panel_names_order = ['diag',  'ecg', 'rvr']
n_rows = 2
n_cols = int(len(covariate_names)/n_rows)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(15, 5), sharex=True, sharey=True, dpi=200)


for i, covariate_name in enumerate(covariate_names):
    ax = axs[int(i/n_cols), i%n_cols]
    plt.sca(ax)
    plot_CI_df = coeffs_across_sensitivity_analyses[coeffs_across_sensitivity_analyses['Group'] == covariate_name]
    plot_CI_df['cleaned_group_name'] = plot_CI_df['Group'].map(lambda x: x.split('_')[1])
    plot_CI_df.set_index('config_name', inplace=True)
    dependent_vars = [panel_name]
    ax = plot_CIs_covariates(plot_CI_df[['Estimate', 'Lower bound', 'Upper bound']], covariate_names=config_names,
                             show=False, ax=ax, ylabel_size=15, xlabel_size=15, horizontal_lines=False, color_CIs_by_significance=False)
    
    # _ = ax.set_xticks(ax.get_xticks()[::-1], [parameter.split('_')[-1] for parameter in list(covariate_names)], rotation=90)
    _ = ax.set_xlim(-4.5, 3.5)
    
    plt.title(covariate_name.split('_')[1],  fontsize=15)
    plt.grid(color='lightgray', alpha=.5)

    plt.xlabel("")
    if int(i/n_cols) == 1:
        plt.xlabel("Effect Size", fontsize=15)
    if rar_model == 'OLS':
        ax.xaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.0f}%"))
    
    plt.xticks(fontsize=12)

    plt.tight_layout()
plt.tight_layout()

### 10. Write out scores

In [60]:
cols_to_include =  ['risk_score',
                    'lr_adjusted_risk_score',
                    'lpm_adjusted_risk_score',
                    'Hosp',
                    'LocationName',
                    'binary_Male',
                    'PrimLangDSC',
                    'PatientAge_years',
                    'PatientRaceFinal',
                    'instype_final',
                    'diagnosis_in_charts',
                    'diagnosis_in_encounters',
                    'diagnosis_in_problem_list',
                    'anticoag_treatment',
                    'spec_vis',
                    'trt_rate',
                    'trt_rhythm',
                    'hr_over_160',
                    'hospitalization',
                    'mortality',
                    'stroke',
                    'PCPvisits_bin',
                    'CARvisits_bin',
                    'OTHvisits_bin',
                    'afib_pos_ecg']
test_split_df['afib_pos_ecg'] = test_split_df['label']
test_split_df['Hosp'] = test_split_df['ecg_location']
df_to_write_out = test_split_df[cols_to_include]
df_to_write_out.to_csv('afib_risk_scores_outcomes_v7.csv')

### 11. Compute HR column

In [None]:
muse_cache_df = pd.concat([pd.read_csv('../outputs_intermediate/muse_cache_files/patient_mrn_to_file.csv', dtype='str'), 
                           pd.read_csv('../outputs_intermediate/muse_cache_files/patient_mrn_to_file_2020_2021_2022.csv', dtype='str')])
muse_cache_df['date'] = pd.to_datetime(muse_cache_df[['year', 'month', 'day']])
test_split_df['date'] = pd.to_datetime(test_split_df['date'])
# For each UniqueID, determine the maximum heart-rate of an ECG that occured within some period of time

from ecg_feature_names import  xml_file_to_extracted_features

window = 365 # 1 year

train_split_df = split_dfs[0]
all_data = pd.concat([train_split_df, val_split_df, test_split_df])
heldout_uniqueids = all_data['UniqueID'].unique()
relevant_test_split = all_data[['UniqueID', 'date']]
relevant_muse_cache = muse_cache_df[muse_cache_df['UniqueID'].isin(heldout_uniqueids)]

# merge test_split_df and muse_cache_df on UniqueID; keep path and date 
# process each row, read in atrial rate and ventricular rate
unique_id_to_max_hr = pd.merge(relevant_test_split, relevant_muse_cache, on='UniqueID', suffixes=('', '_other'))

# Filter for ECGs after the initial date for each UniqueID, but within the time window
unique_id_to_max_hr = unique_id_to_max_hr[unique_id_to_max_hr['date_other'] > unique_id_to_max_hr['date']]
unique_id_to_max_hr = unique_id_to_max_hr[unique_id_to_max_hr['date_other'] <= unique_id_to_max_hr['date'] + pd.Timedelta(days=window)]
len(unique_id_to_max_hr), unique_id_to_max_hr['UniqueID'].nunique(), test_split_df['UniqueID'].nunique()

# Read in the heart rate for each ECG
import ipyparallel as ipp
# 21 minutes
n_engines = 8
cluster = ipp.Cluster(n=n_engines)
cluster.start_cluster_sync()
rc = cluster.connect_client_sync()
rc.wait_for_engines(n_engines)
dview = rc[:]

all_paths = list(unique_id_to_max_hr['path'])
xml_outputs = dview.map_sync(xml_file_to_extracted_features, all_paths)
feature_dicts = [o for o in xml_outputs]
feature_df = pd.DataFrame(feature_dicts)
feature_df['path'] = all_paths

unique_id_to_max_hr = pd.merge(unique_id_to_max_hr, feature_df, on='path')


unique_id_to_max_hr.to_csv('../outputs_intermediate/unique_id_to_all_ecg_info_' + str(window)+ '_days.csv')

unique_id_to_max_hr = pd.read_csv('../outputs_intermediate/unique_id_to_all_ecg_info_' + str(window)+ '_days.csv')
unique_id_to_max_hr['UniqueID'].nunique(), len(unique_id_to_max_hr)
unique_id_to_max_hr = unique_id_to_max_hr[['UniqueID', 'VentricularRate', 'AtrialRate']]
unique_id_to_max_hr = unique_id_to_max_hr.groupby('UniqueID').max().reset_index()
unique_id_to_max_hr.to_csv('../outputs_intermediate/unique_id_to_max_hr_' + str(window)+ '_days.csv')