In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
from statistics import mean, median
import os
import seaborn as sns
import matplotlib.dates as mdates
import matplotlib.ticker as mtick
import scipy.stats as stats
from scipy.special import gammaln
from sklearn.preprocessing import KBinsDiscretizer
import gc

from lifelines import CoxPHFitter, AalenJohansenFitter
from lifelines.plotting import add_at_risk_counts

##### Load validation data and set target groups

In [None]:
dem_types = pd.read_csv('', names=['item', 'dtype'], skiprows=1)
dtype_dict = {}
for idx, row in dem_types.iterrows():
    dtype_dict[row['item']] = row['dtype']
base_path = ''
val_data = pd.read_csv(os.path.join(base_path, ''), low_memory=True, dtype=dtype_dict)
val_data.columns = [col.replace(',', '_') if ',' in col else col for col in val_data.columns]
all_data = pd.read_csv('', sep='\t')
print(all_data.columns)
val_data = pd.merge(val_data, all_data[['ppid', 'DateOfDeath']], how='left', on='ppid')

In [None]:
all_data.columns.tolist()

In [None]:
plt.rcParams.update({'font.size':12, 'font.weight':'normal', 'font.family':'serif'})

In [None]:
def discretize(y, nb=5):
    discretizer = KBinsDiscretizer(n_bins=nb, encode='ordinal', strategy='quantile')
    return discretizer.fit_transform(y.reshape(-1, 1)).flatten()
bin_labels={0:'Very Low', 1:'Low', 2:'Medium', 3:'Medium-High', 4:'High'}
val_data['total_count_all_gr'] = discretize(np.array(val_data['total_count_all_tf'].values), nb=5).astype(int)
val_data['total_count_all_gr'] = val_data['total_count_all_gr'].map(bin_labels)
### Set age groups
val_data['age_gr'] = np.where((val_data['AgeAtAdmission']>=50)&(val_data['AgeAtAdmission']<60), '50-59', '90+')
val_data['age_gr'] = np.where((val_data['AgeAtAdmission']>=60)&(val_data['AgeAtAdmission']<70), '60-69', val_data['age_gr'])
val_data['age_gr'] = np.where((val_data['AgeAtAdmission']>=70)&(val_data['AgeAtAdmission']<80), '70-79', val_data['age_gr'])
val_data['age_gr'] = np.where((val_data['AgeAtAdmission']>=80)&(val_data['AgeAtAdmission']<90), '80-89', val_data['age_gr'])
### Set SIMD groups
val_data['simd_gr'] = np.where((val_data['simd_dec']>=1)&(val_data['simd_dec']<3), '1 - most deprived', '5 - least deprived')
val_data['simd_gr'] = np.where((val_data['simd_dec']>=3)&(val_data['simd_dec']<9), '2-4', val_data['simd_gr'])

In [None]:
val_data['total_count_all_gr'].value_counts()

In [None]:
val_data['age_gr'].value_counts()

In [None]:
val_data['simd_gr'].value_counts()

In [None]:
#val_data.ED_adate_dt.min(), val_data.ED_adate_dt.max()
val_data['DateOfDeath'] = pd.to_datetime(np.where(val_data['DateOfDeath'].isnull(), 
                                 pd.Timestamp('2024-03-01 00:00:00'), val_data['DateOfDeath']))
val_data['ED_adate_dt'] = pd.to_datetime(val_data['ED_adate_dt'])
val_data['HOSP_ddt'] = pd.to_datetime(val_data['HOSP_ddt'])
val_data['LOS_d'] = (val_data['HOSP_ddt'] - val_data['ED_adate_dt']).dt.days
val_data['time_until_death'] = np.ceil((val_data['DateOfDeath'] - val_data['ED_adate_dt']).dt.days).astype(int)
val_data['s_days'] = np.ceil(val_data['LOS_d']).astype(int)
val_data['s_days'] = np.where(val_data['gt_m']==1, val_data['time_until_death'], val_data['s_days'])
#val_filt = val_data[val_data['s_days']<=100]

In [None]:
val_data['s_days'].describe()

In [None]:
len(val_data[val_data['s_days']>100]), len(val_data)

In [None]:
### Set discharge status
trak_inp_d = pd.read_csv('', sep='\t', low_memory=False)
val_data = pd.merge(val_data, trak_inp_d[['ppid', 'EpisodeNumber', 'DischargeToCode']],
                      how='left', on=['ppid', 'EpisodeNumber'])

In [None]:
val_data.DischargeToCode.value_counts()

In [None]:
### Set non-home discharge status
val_data['gt_dd'] = np.where(val_data['gt_m'] == 1, 0, val_data['gt_dd'])
val_data['status'] = np.where(val_data['gt_m']==1, 2, 1)
val_data['status'] = np.where(val_data['gt_dd']==1, 0, val_data['status'])
#val_data['status'] = np.where((~val_data['DischargeToCode'].isin(['H', 'HHO', 'HA', 'HHS', 'ESDS', 'HWR']))&
                              #(val_data['gt_m']!=1), 0, val_data['status'])

print(val_data.status.value_counts())

In [None]:
val_data.columns.tolist()

#### CPH regression against in-hospital death

In [None]:
val_cov = val_data[['age_gr', 'Sex_F', 'simd_gr', 'total_longterm_conditions', 'total_drug_categories']]

In [None]:
sns.set_style('darkgrid')
plt.rcParams.update({'font.size':12, 'font.weight':'normal', 'font.family':'serif'})

In [None]:
val_data[val_data['s_days']<=200]['total_count_all_gr'].value_counts()

#### Overall hazard ratios

In [None]:
#### Fit overall hazard function
#### Plot hazard ratios with 95% CI
#### Plot cumulative incidence by health contact level

In [None]:
n_at_risk_all = val_data.groupby('total_count_all_gr').apply(lambda x: pd.cut(x['s_days'], bins=np.linspace(0, 100, num=21)).value_counts().sort_index(
        level = ['Very Low', 'Low', 'Medium', 'Medium-High', 'High'], 
    )).reindex(['Very Low', 'Low', 'Medium', 'Medium-High', 'High']).T

n_at_risk_all['High'] = n_at_risk_all['High'].sort_values(ascending=False).values.tolist()
n_at_risk_all['Medium-High'] = n_at_risk_all['Medium-High'].sort_values(ascending=False).values.tolist()
n_at_risk_all['Medium'] = n_at_risk_all['Medium'].sort_values(ascending=False).values.tolist()
n_at_risk_all['Low'] = n_at_risk_all['Low'].sort_values(ascending=False).values.tolist()
n_at_risk_all['Very Low'] = n_at_risk_all['Very Low'].sort_values(ascending=False).values.tolist()

In [None]:
n_at_risk_h = val_data[val_data['total_count_all_gr']=='High'].groupby('ppid').apply(lambda x: pd.cut(x['s_days'], bins=np.linspace(0, 100, num=10))).value_counts()

In [None]:
n_at_risk_all

In [None]:
def get_overall_cr_summary(val_data, event_col='status', eoi=2, duration_col='s_days', 
                          group_labels=['Very Low', 'Low', 'Medium', 'Medium-High', 'High'],
                          colors=['#918e26', '#a1dab4','#41b6c4', '#2c7fb8', '#253494'], s_max=300, tps=20):
    val_data['Sex_c'] = np.where(val_data['Sex_F']==1, 'F', 'M')
    val_data['tltc_gr'] = np.where(val_data['total_longterm_conditions']>3, 'High-count MM', 'No MM')
    val_data['tltc_gr'] = pd.Categorical(np.where((val_data['total_longterm_conditions']>1)&(val_data['total_longterm_conditions']<=3), 
                                   'Simple MM', val_data['tltc_gr']), categories=['No MM', 'Simple MM', 'High-count MM'])
    val_data['tdr_gr'] = np.where(val_data['total_drug_categories']>=7, '7+', '0')
    val_data['tdr_gr'] = np.where((val_data['total_drug_categories']>3)&(val_data['total_drug_categories']<7), '4-6', 
                                  val_data['tdr_gr'])
    val_data['tdr_gr'] = np.where((val_data['total_drug_categories']>0)&(val_data['total_drug_categories']<=3), '1-3', 
                                  val_data['tdr_gr'])
    ajf_dict = {}
    for group in group_labels:
        mask = (val_data['total_count_all_gr']==group)
        ajf = AalenJohansenFitter()
        ajf.fit(val_data.loc[mask, duration_col],
                val_data.loc[mask, event_col],
                event_of_interest=eoi)
        ajf_dict[group] = ajf
        
    fig, axes = plt.subplots(2, 3, figsize=(12,10))
    axes = axes.flatten()
    for i, (group, ajf) in enumerate(ajf_dict.items()):
        if i < 5:
            ax = axes[i]
            ajf.plot(ax=ax, color=colors[i])
            ax.set_title(f'CIF - {group} intensity')
            ax.set_xlabel('Days from ED attendance')
            ax.set_ylabel('Cumulative incidence')
            ax.get_legend().remove()

    fig.delaxes(axes[5])
    plt.tight_layout()
    plt.show()

    figc, ax = plt.subplots(figsize=(10,6))
    for i, (group, ajf) in enumerate(ajf_dict.items()):
        ajf.plot(ax=ax, color=colors[i], label=f'{group} intensity', alpha=0.7)
    ### N at risk table
    times = np.linspace(0, s_max, num=tps)
    n_at_risk_all = val_data.groupby('total_count_all_gr').apply(lambda x: pd.cut(x[duration_col], bins=times).value_counts().sort_index(
        level = group_labels
    )).reindex(group_labels)
    #n_at_risk_all = val_data.groupby('total_count_all_gr').apply(lambda x: pd.cut(x[duration_col], bins=times).value_counts())
    lb1 = 'Day 0'
    day_0_col = val_data['total_count_all_gr'].value_counts().sort_index(level=group_labels).reindex(group_labels)
    n_at_risk_all['Day 0'] = day_0_col
    n_at_risk_all.insert(0, 'Day 0', n_at_risk_all.pop('Day 0'))
    print(n_at_risk_all)
    print('Day 0 population')
    print(day_0_col)
    ### Fix index order issue
    n_at_risk_all = n_at_risk_all.T
    n_at_risk_all['High'] = n_at_risk_all['High'].sort_values(ascending=False).values.tolist()
    n_at_risk_all['Medium-High'] = n_at_risk_all['Medium-High'].sort_values(ascending=False).values.tolist()
    n_at_risk_all['Medium'] = n_at_risk_all['Medium'].sort_values(ascending=False).values.tolist()
    n_at_risk_all['Low'] = n_at_risk_all['Low'].sort_values(ascending=False).values.tolist()
    n_at_risk_all['Very Low'] = n_at_risk_all['Very Low'].sort_values(ascending=False).values.tolist()
    n_at_risk_all = n_at_risk_all.T
    
    n_at_risk_table = ax.table(cellText=n_at_risk_all.values, rowLabels=n_at_risk_all.index,
                               colLabels=[f'Day {round(times[i])}' for i in range(len(times))], loc='bottom', bbox=[0, -0.35, 1, 0.2])
    n_at_risk_table.auto_set_font_size(False)
    n_at_risk_table.set_fontsize(9)
                                                                 
    plt.title('Relationships between incidence of in-hospital death and contact intensity level.')
    plt.xlabel('Days from ED attendance')
    plt.ylabel('Cumulative incidence')
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.show()
    
    return ajf_dict

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=100])

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=100], eoi=1, event_col='gt_m')

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=100], tps=60)

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=100], eoi=1, event_col='gt_m')

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=300])

In [None]:
ajf_models = get_overall_cr_summary(val_data[val_data['s_days']<=300], eoi=1, event_col='gt_m')

In [None]:
val_data.s_days.describe()

#### Subgroup analysis

In [None]:
def get_gr_cph(val_data, event_col='gt_m', duration_col='s_days',
              group_labels=['Very Low', 'Low', 'Medium', 'Medium-High', 'High'],
              plot_labels=['50-59', '60-69', '70-79', '80-89', '90+'],
              covariate='age_gr', y_up=3.7, x_up=225, colors=['#918e26', '#a1dab4', 
                                                              '#41b6c4', '#2c7fb8', '#253494'],
              leg_title='Age group'):
    
    val_data['Sex_c'] = np.where(val_data['Sex_F']==1, 'F', 'M')
    val_data['tltc_gr'] = np.where(val_data['total_longterm_conditions']>3, 'High-count MM', 'No MM')
    val_data['tltc_gr'] = np.where((val_data['total_longterm_conditions']>1)&(val_data['total_longterm_conditions']<=3), 
                                   'Simple MM', val_data['tltc_gr'])
    val_data['tdr_gr'] = np.where(val_data['total_drug_categories']>=7, '7+', '0')
    val_data['tdr_gr'] = np.where((val_data['total_drug_categories']>3)&(val_data['total_drug_categories']<7), '4-6', 
                                  val_data['tdr_gr'])
    val_data['tdr_gr'] = np.where((val_data['total_drug_categories']>0)&(val_data['total_drug_categories']<=3), '1-3', 
                                  val_data['tdr_gr'])
    #print(val_data[['age_gr', 'Sex_c', 'simd_gr', 'tltc_gr', 'tdr_gr']].isnull().sum())
    cph_list = []
    for group in group_labels:
        print(f'Fitting CPH model for {group} intensity.')
        group_data = val_data[val_data['total_count_all_gr']==group]
        #print(val_data['total_count_all_gr'].unique())
        ### Fit CPH model
        cph = CoxPHFitter()
        cph.fit(group_data, duration_col=duration_col, event_col=event_col, strata=['total_count_all_gr'], 
           formula='age_gr + C(Sex_c, Treatment(\'M\')) + simd_gr + C(tltc_gr, Treatment(\'No MM\')) + tdr_gr')
        print('-------')
        print(cph.print_summary())
        print('-------')
        print(f"\nSignificance test for {group} contact intensity:")
        print(cph.log_likelihood_ratio_test())
        cph_list.append(cph)

    ### CI function
    #fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(figsize=(12,6), nrows=2, ncols=3)
    for group, cph in zip(group_labels, cph_list):
        plt.figure(figsize=(4,4))
        cph.plot_partial_effects_on_outcome(covariates=covariate, values=plot_labels,
                                       color=colors, 
                                            y='cumulative_hazard', figsize=(4,4), alpha=0.5)
        #add_at_risk_counts(cph)
        plt.title(f"{group} contact intensity")
        plt.legend(title=leg_title, labels=plot_labels + ['Base survival'], prop={'size': 11}, loc='upper left')
        plt.xlabel('Days from ED attendance')
        plt.ylabel('Cumulative hazard')
        plt.ylim([0, y_up])
        plt.xlim([0, x_up])
        
    plt.tight_layout()
    plt.show()

In [None]:
get_gr_cph(val_filt)

In [None]:
get_gr_cph(val_filt, plot_labels=['1 - most deprived', '2-4', '5 - least deprived'],
              covariate='simd_gr', y_up=2.2,
          colors=['#918e26', '#41b6c4', '#253494'], leg_title='SIMD quintiles')

In [None]:
get_gr_cph(val_filt, plot_labels=['No MM', 'Simple MM', 'High-count MM'],
              covariate='tltc_gr', y_up=2.5, colors=['#918e26', '#41b6c4', '#253494'],
          leg_title='Multimorbidity')

In [None]:
get_gr_cph(val_filt, plot_labels=['0', '1-3', '4-6', '7+'],
              covariate='tdr_gr', y_up=2.5, colors=['#918e26', '#41b6c4', '#2c7fb8', '#253494'],
          leg_title='# Concurrent prescriptions')