In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats, special
import matplotlib.ticker as mtick
from matplotlib.dates import DateFormatter
from datetime import timedelta
from datetime import datetime
from tqdm import tqdm

from boruta import BorutaPy
import miceforest as mf
import missingno as msno
from statannotations.Annotator import Annotator
import warnings
warnings.filterwarnings(action='ignore')

#### Generate NHSL-specific cohort for care contacts prediction

In [None]:
inp_data = pd.read_csv('')

In [None]:
ind_contacts_full_gr = pd.read_csv('')

In [None]:
inp_data.shape

In [None]:
ind_contacts_full_gr.cintervention.value_counts().head(20)

In [None]:
inp_data[['ppid', 'HOSP_adt']].tail()

In [None]:
ind_contacts_full_gr.ppid.unique()

In [None]:
#### Setup event log
ind_contacts_full_gr = ind_contacts_full_gr.sort_values(['ppid', 'cdt'])
ind_contacts_full_gr = pd.merge(ind_contacts_full_gr, inp_data[['ppid', 'HOSP_adt', 'HOSP_ddt']], how='left')
ind_contacts_full_gr = ind_contacts_full_gr[(ind_contacts_full_gr.cdt > ind_contacts_full_gr.HOSP_adt)&(ind_contacts_full_gr.cdt < ind_contacts_full_gr.HOSP_ddt)]
ind_contacts_full_gr.head(10)

In [None]:
ind_contacts_full_gr.ppid.nunique()

In [None]:
##### Set additional lookup dates
inp_data['HOSP_FCC_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].isin(['NURSE', 'PT', 'OT', 'SLT'])].groupby('ppid')['cdt'].first().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_LCC_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].isin(['NURSE', 'PT', 'OT', 'SLT'])].groupby('ppid')['cdt'].last().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_FAS_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('TRAK-')].groupby('ppid')['cdt'].first().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_LAS_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('TRAK-')].groupby('ppid')['cdt'].last().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_FSM_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('SPECMV-')].groupby('ppid')['cdt'].first().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_LSM_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('SPECMV-')].groupby('ppid')['cdt'].last().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_FWM_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('WARDMV-')].groupby('ppid')['cdt'].first().reset_index(), how='left', on='ppid')['cdt']
inp_data['HOSP_LWM_dt'] = pd.merge(inp_data, ind_contacts_full_gr[ind_contacts_full_gr['cintervention'].str.startswith('WARDMV-')].groupby('ppid')['cdt'].last().reset_index(), how='left', on='ppid')['cdt']

In [None]:
inp_data[['ppid', 'ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']].isna().sum()

In [None]:
##### Preprocess and validate timestamps
inp_data['ED_adate_dt'] = np.where(inp_data['ED_adate_dt']>inp_data['HOSP_adt'], np.nan, inp_data['ED_adate_dt'])
inp_data['ED_ddate_dt'] = np.where(inp_data['ED_ddate_dt']>inp_data['HOSP_adt'], inp_data['HOSP_adt'], inp_data['ED_ddate_dt'])
inp_data['triage_dt'] = np.where((inp_data['triage_dt']>inp_data['HOSP_adt'])|(inp_data['triage_dt']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['triage_dt'])
inp_data['breq_dt'] = np.where((inp_data['breq_dt']>inp_data['HOSP_adt'])|(inp_data['breq_dt']<inp_data['triage_dt']),
                                 np.nan, inp_data['breq_dt'])
inp_data['CC_adate'] = np.where((inp_data['CC_adate']>inp_data['HOSP_ddt'])|(inp_data['CC_adate']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['CC_adate'])
inp_data['CC_ddate'] = np.where((inp_data['CC_ddate']>inp_data['HOSP_ddt'])|(inp_data['CC_ddate']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['CC_ddate'])
inp_data['HOSP_FCC_dt'] = np.where((inp_data['HOSP_FCC_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_FCC_dt']<inp_data['HOSP_adt']),
                                 np.nan, inp_data['HOSP_FCC_dt'])
inp_data['HOSP_LCC_dt'] = np.where((inp_data['HOSP_LCC_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_LCC_dt']<inp_data['HOSP_adt']),
                                 np.nan, inp_data['HOSP_LCC_dt'])
inp_data['HOSP_FAS_dt'] = np.where((inp_data['HOSP_FAS_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_FAS_dt']<inp_data['HOSP_adt']),
                                 np.nan, inp_data['HOSP_FAS_dt'])
inp_data['HOSP_LAS_dt'] = np.where((inp_data['HOSP_LAS_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_LAS_dt']<inp_data['HOSP_adt']),
                                 np.nan, inp_data['HOSP_LAS_dt'])
inp_data['HOSP_FSM_dt'] = np.where((inp_data['HOSP_FSM_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_FSM_dt']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['HOSP_FSM_dt'])
inp_data['HOSP_LSM_dt'] = np.where((inp_data['HOSP_LSM_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_LSM_dt']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['HOSP_LSM_dt'])
inp_data['HOSP_FWM_dt'] = np.where((inp_data['HOSP_FWM_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_FWM_dt']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['HOSP_FWM_dt'])
inp_data['HOSP_LWM_dt'] = np.where((inp_data['HOSP_LWM_dt']>inp_data['HOSP_ddt'])|(inp_data['HOSP_LWM_dt']<inp_data['ED_adate_dt']),
                                 np.nan, inp_data['HOSP_LWM_dt'])

In [None]:
inp_data[['ppid', 'ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']].isna().sum()

#### Deal with missingness in timestamps

In [None]:
inp_data[inp_data.ED_adate_dt.isna()][['ED_adate_dt', 'triage_dt', 'breq_dt', 'ED_ddate_dt',
                                       'HOSP_adt', 'HOSP_FCC_dt', 'HOSP_FAS_dt', 'HOSP_FSM_dt', 'HOSP_FWM_dt', 'HOSP_ddt']].head(10)

In [None]:
#### Setup predictor vars
inp_data_imp = inp_data.copy()
pred_vars = ['ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']
inp_data_imp = pd.concat([inp_data_imp['ppid'], inp_data_imp[pred_vars]], axis=1)
for var in pred_vars:
    inp_data_imp[var] = pd.to_datetime(inp_data_imp[var]).apply(lambda x: x.timestamp() if not pd.isnull(x) else np.nan)

In [None]:
inp_data[inp_data['ED_adate_dt']>inp_data['HOSP_adt']][['HOSP_adt', 'ED_adate_dt']].head(20)

In [None]:
inp_data_imp['ED_adate_diff'] = inp_data_imp['HOSP_adt'] - inp_data_imp['ED_adate_dt']
inp_data_imp['ED_ddate_diff'] = inp_data_imp['HOSP_adt'] - inp_data_imp['ED_ddate_dt']
inp_data_imp['ED_triage_diff'] = inp_data_imp['HOSP_adt'] - inp_data_imp['triage_dt']
inp_data_imp['ED_breq_diff'] = inp_data_imp['breq_dt'] - inp_data_imp['triage_dt']
inp_data_imp['HOSP_FCC_diff'] = inp_data_imp['HOSP_FCC_dt'] - inp_data_imp['HOSP_adt']
inp_data_imp['HOSP_LCC_diff'] = inp_data_imp['HOSP_ddt'] - inp_data_imp['HOSP_LCC_dt']
inp_data_imp['HOSP_FWM_diff'] = inp_data_imp['HOSP_FWM_dt'] - inp_data_imp['ED_adate_dt']
inp_data_imp['HOSP_LWM_diff'] = inp_data_imp['HOSP_ddt'] - inp_data_imp['HOSP_LWM_dt']

In [None]:
inp_data_imp['ED_triage_diff'].median()

In [None]:
#### Set median time difference as imputation target
inp_data_imp['ED_adate_dt'] = inp_data_imp['ED_adate_dt'].fillna(inp_data_imp['HOSP_adt'] - inp_data_imp['ED_adate_diff'].median())
inp_data_imp['ED_ddate_dt'] = inp_data_imp['ED_ddate_dt'].fillna(inp_data_imp['HOSP_adt'] - inp_data_imp['ED_ddate_diff'].median())
inp_data_imp['triage_dt'] = inp_data_imp['triage_dt'].fillna(inp_data_imp['HOSP_adt'] - inp_data_imp['ED_triage_diff'].median())
inp_data_imp['breq_dt'] = inp_data_imp['breq_dt'].fillna(inp_data_imp['triage_dt'] + inp_data_imp['ED_breq_diff'].median())
inp_data_imp['HOSP_FCC_dt'] = inp_data_imp['HOSP_FCC_dt'].fillna(inp_data_imp['HOSP_adt'] + inp_data_imp['HOSP_FCC_diff'].median())
inp_data_imp['HOSP_LCC_dt'] = inp_data_imp['HOSP_LCC_dt'].fillna(inp_data_imp['HOSP_ddt'] - inp_data_imp['HOSP_LCC_diff'].median())
inp_data_imp['HOSP_FWM_dt'] = inp_data_imp['HOSP_FWM_dt'].fillna(inp_data_imp['ED_adate_dt'] + inp_data_imp['HOSP_FWM_diff'].median())
inp_data_imp['HOSP_LWM_dt'] = inp_data_imp['HOSP_LWM_dt'].fillna(inp_data_imp['HOSP_ddt'] - inp_data_imp['HOSP_LWM_diff'].median())

In [None]:
for var in pred_vars:
    inp_data_imp[var] = pd.to_datetime(inp_data_imp[var], unit='s', errors='coerce')

In [None]:
inp_data_imp[['ED_adate_dt', 'triage_dt', 'breq_dt', 'ED_ddate_dt',
                                       'HOSP_adt', 'HOSP_FCC_dt', 'HOSP_FAS_dt', 'HOSP_FSM_dt', 'HOSP_FWM_dt', 'HOSP_ddt']].head(10)

In [None]:
inp_data_tg = inp_data_imp[['ppid', 'ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']]
inp_data_f = pd.merge(inp_data.drop(['ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt'], axis=1),
                      inp_data_tg, how='left')

In [None]:
inp_data[['ppid', 'ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']].isna().sum()

In [None]:
#### Check assumptions
inp_data_f['ED_adate_dt'] = np.where(inp_data_f['ED_adate_dt']>inp_data_f['HOSP_adt'], np.datetime64('NaT'), inp_data_f['ED_adate_dt'])
inp_data_f['ED_ddate_dt'] = np.where(inp_data_f['ED_ddate_dt']>inp_data_f['HOSP_adt'], inp_data_f['HOSP_adt'], inp_data_f['ED_ddate_dt'])
inp_data_f['triage_dt'] = np.where((inp_data_f['triage_dt']>inp_data_f['HOSP_adt'])|(inp_data_f['triage_dt']<inp_data_f['ED_adate_dt']),
                                 inp_data_f['ED_adate_dt'], inp_data_f['triage_dt'])
inp_data_f['breq_dt'] = np.where((inp_data_f['breq_dt']>inp_data_f['HOSP_adt'])|(inp_data_f['breq_dt']<inp_data_f['triage_dt']),
                                 inp_data_f['triage_dt'], inp_data_f['breq_dt'])
inp_data_f['CC_adate'] = np.where((inp_data_f['CC_adate']>inp_data_f['HOSP_ddt'])|(inp_data_f['CC_adate']<inp_data_f['ED_adate_dt']),
                                 np.datetime64('NaT'), inp_data_f['CC_adate'])
inp_data_f['CC_ddate'] = np.where((inp_data_f['CC_ddate']>inp_data_f['HOSP_ddt'])|(inp_data_f['CC_ddate']<inp_data_f['ED_adate_dt']),
                                 np.datetime64('NaT'), inp_data_f['CC_ddate'])
inp_data_f['HOSP_FCC_dt'] = np.where((inp_data_f['HOSP_FCC_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_FCC_dt']<inp_data_f['HOSP_adt']),
                                 inp_data_f['HOSP_adt'], inp_data_f['HOSP_FCC_dt'])
inp_data_f['HOSP_LCC_dt'] = np.where((inp_data_f['HOSP_LCC_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_LCC_dt']<inp_data_f['HOSP_adt']),
                                 inp_data_f['HOSP_FCC_dt'], inp_data_f['HOSP_LCC_dt'])
inp_data_f['HOSP_FAS_dt'] = np.where((inp_data_f['HOSP_FAS_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_FAS_dt']<inp_data_f['HOSP_adt']),
                                 np.datetime64('NaT'), inp_data_f['HOSP_FAS_dt'])
inp_data_f['HOSP_LAS_dt'] = np.where((inp_data_f['HOSP_LAS_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_LAS_dt']<inp_data_f['HOSP_adt']),
                                 np.datetime64('NaT'), inp_data_f['HOSP_LAS_dt'])
inp_data_f['HOSP_FSM_dt'] = np.where((inp_data_f['HOSP_FSM_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_FSM_dt']<inp_data_f['ED_adate_dt']),
                                 np.datetime64('NaT'), inp_data_f['HOSP_FSM_dt'])
inp_data_f['HOSP_LSM_dt'] = np.where((inp_data_f['HOSP_LSM_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_LSM_dt']<inp_data_f['ED_adate_dt']),
                                 np.datetime64('NaT'), inp_data_f['HOSP_LSM_dt'])
inp_data_f['HOSP_FWM_dt'] = np.where((inp_data_f['HOSP_FWM_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_FWM_dt']<inp_data_f['ED_adate_dt']),
                                 inp_data_f['HOSP_adt'], inp_data_f['HOSP_FWM_dt'])
inp_data_f['HOSP_LWM_dt'] = np.where((inp_data_f['HOSP_LWM_dt']>inp_data_f['HOSP_ddt'])|(inp_data_f['HOSP_LWM_dt']<inp_data_f['ED_adate_dt']),
                                 inp_data_f['HOSP_FWM_dt'], inp_data_f['HOSP_LWM_dt'])

In [None]:
inp_data_f[['ppid', 'ED_adate_dt', 'ED_ddate_dt', 'HOSP_adt', 'HOSP_ddt', 'breq_dt', 'triage_dt',
         'CC_adate', 'CC_ddate', 'DateOfDeath', 'HOSP_FCC_dt', 'HOSP_LCC_dt',
          'HOSP_FAS_dt', 'HOSP_LAS_dt', 'HOSP_FSM_dt', 'HOSP_LSM_dt', 'HOSP_FWM_dt', 'HOSP_LWM_dt']].isna().sum()

##### Set outcomes

In [None]:
### Get unscheduled re-admission flags
trak_inp_a = pd.read_csv('', sep='\t', low_memory=False)
trak_inp_d = pd.read_csv('', sep='\t', low_memory=False)
trak_ed = pd.read_csv('', sep='\t', low_memory=False, encoding='iso-8859-1')
trak_mv = pd.read_csv('', sep='\t', low_memory=False)

In [None]:
trak_mv['cdt'] = trak_mv['StartDate'] + ' ' + trak_mv['StartTime']
trak_mv['cdt'] = pd.to_datetime(trak_mv['cdt'])

In [None]:
trak_mv = trak_mv[trak_mv['NatSpecialty'] == 'Medicine of the Elderly']
trak_mv['gt_eld'] = 1

In [None]:
trak_mv.cdt.isnull().sum()

In [None]:
inp_data_f = pd.read_csv('')
inp_data_f['HOSP_adt'] = pd.to_datetime(inp_data_f['HOSP_adt'])

In [None]:
inp_data_f = pd.merge(inp_data_f, trak_mv[['ppid', 'EpisodeNumber', 'cdt', 'gt_eld']],
                      how='left', on=['ppid', 'EpisodeNumber']).sort_values(['ppid', 'HOSP_adt']).drop_duplicates(['ppid'], keep='first')
inp_data_f['gt_eld'] = inp_data_f['gt_eld'].fillna(0).astype(np.int8)
inp_data_f['eld_dist'] = (inp_data_f['cdt'] - inp_data_f['HOSP_adt']) / pd.Timedelta(hours=1)
inp_data_f['gt_eld_d1'] = np.where((inp_data_f['eld_dist'] < 24)&(inp_data_f['eld_dist'] > 0), -1, inp_data_f['gt_eld']).astype(np.int8)
inp_data_f['gt_eld_d2'] = np.where((inp_data_f['eld_dist'] < 48)&(inp_data_f['eld_dist'] > 0), -1, inp_data_f['gt_eld']).astype(np.int8)
inp_data_f['gt_eld_d3'] = np.where((inp_data_f['eld_dist'] < 72)&(inp_data_f['eld_dist'] > 0), -1, inp_data_f['gt_eld']).astype(np.int8)

In [None]:
inp_data_f.shape

In [None]:
### Discharge disposition
home_disch_codes = ['H', 'HWR',
                    'HA', 'CHAHP01', 'CHAHP03', 'ESDS']

In [None]:
trak_inp_d.columns

In [None]:
inp_data_f.columns

In [None]:
### Binary outcomes
inp_data_f['gt_cc'] = np.where((inp_data_f['LOS_CC'].isna())|(inp_data_f['LOS_CC']==-1), 0, 1)
inp_data_f['gt_es_hosp'] = np.where(inp_data_f['LOS_hosp'] > inp_data_f.LOS_hosp.quantile(0.8), 1, 0)
inp_data_f['gt_m'] = np.where((inp_data_f['DateOfDeath'] >= inp_data_f['HOSP_adt'])&(inp_data_f['DateOfDeath'] <= inp_data_f['HOSP_ddt']), 1, 0)
#inp_data_f = pd.merge(inp_data_f, trak_inp_d[['ppid', 'EpisodeNumber', 'DischargeToCode']],
                     #how='left', on=['ppid', 'EpisodeNumber'])

In [None]:
inp_data_f['gt_dd'] = np.where(inp_data_f['DischargeToCode'].isin(home_disch_codes), 0, 1)

In [None]:
inp_data_f['received_rehab'] = np.where(inp_data_f['total_count_rehab']>0, 1, 0)

In [None]:
inp_data_f.gt_dd.value_counts(normalize=True)

In [None]:
inp_data_f.gt_m.value_counts(normalize=True)

In [None]:
inp_data_f.gt_es_hosp.value_counts(normalize=True)

In [None]:
inp_data_f.gt_cc.value_counts(normalize=True)

In [None]:
inp_data_f.gt_eld.value_counts(normalize=True)

In [None]:
inp_data_f.gt_eld_d1.value_counts(normalize=True)

In [None]:
inp_data_f.gt_eld_d2.value_counts(normalize=True)

In [None]:
inp_data_f.gt_eld_d3.value_counts(normalize=True)

In [None]:
inp_data_f.received_rehab.value_counts(normalize=True)

In [None]:
inp_data_f[['EpisodeNumber', 'HOSP_adt', 'HOSP_ddt', 'cdt', 'gt_eld', 'gt_eld_d1', 'eld_dist']][(inp_data_f.eld_dist < 24)&(inp_data_f.eld_dist>0)].head(20)

In [None]:
#### Annual distribution
### Set age groups
inp_data_f['age_gr'] = np.where((inp_data_f['age_at_admission']>=50)&(inp_data_f['age_at_admission']<60), '50-59', '90+')
inp_data_f['age_gr'] = np.where((inp_data_f['age_at_admission']>=60)&(inp_data_f['age_at_admission']<70), '60-69', inp_data_f['age_gr'])
inp_data_f['age_gr'] = np.where((inp_data_f['age_at_admission']>=70)&(inp_data_f['age_at_admission']<80), '70-79', inp_data_f['age_gr'])
inp_data_f['age_gr'] = np.where((inp_data_f['age_at_admission']>=80)&(inp_data_f['age_at_admission']<90), '80-89', inp_data_f['age_gr'])

In [None]:
### Load full cohort
inp_data_full = pd.read_csv('')
inp_data_full['gt_cc'] = np.where((inp_data_full['LOS_CC'].isna())|(inp_data_full['LOS_CC']==-1), 0, 1)
inp_data_full['gt_es_hosp'] = np.where(inp_data_full['LOS_hosp'] > inp_data_full.LOS_hosp.quantile(0.8), 1, 0)
inp_data_full['gt_m'] = np.where((inp_data_full['DateOfDeath'] >= inp_data_full['HOSP_adt'])&(inp_data_full['DateOfDeath'] <= inp_data_full['HOSP_ddt']), 1, 0)
inp_data_full = pd.merge(inp_data_full, trak_inp_d[['ppid', 'EpisodeNumber', 'DischargeToCode']],
                      how='left', on=['ppid', 'EpisodeNumber'])
inp_data_full['gt_dd'] = np.where(inp_data_full['DischargeToCode'].isin(home_disch_codes), 0, 1)
inp_data_full = pd.merge(inp_data_full, trak_mv[['ppid', 'EpisodeNumber', 'cdt', 'gt_eld']],
                      how='left', on=['ppid', 'EpisodeNumber']).sort_values(['ppid', 'HOSP_adt']).drop_duplicates(['ppid'], keep='first')
inp_data_full['gt_eld'] = inp_data_f['gt_eld'].fillna(0).astype(np.int8)
inp_data_full['received_rehab'] = np.where(inp_data_full['total_count_rehab']>0, 1, 0)

In [None]:
plt.rcParams.update({'font.size':12, 'font.weight':'normal', 'font.family':'serif'})

In [None]:
#### Get annual in-hospital death rates
inp_data_full['adm_year'] = pd.to_datetime(inp_data_full['HOSP_adt']).dt.year
events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['gt_m'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
print(events_long)
#### Plot
fig, ax = plt.subplots(3, 2, figsize=(9, 9), sharey=True)
plt.suptitle('Annual event rate of adverse outcomes across the older population.')
fig.supxlabel('Year of emergency admission')
ax[0][0] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='In-hospital death.',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[0][0])
ax[0][0].legend(title='Outcome', labels=['Survived', 'Died'], loc='lower right')
ax[0][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[0][0].set_xlabel('')

events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['gt_cc'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
ax[0][1] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='Admitted to ICU/HDU',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[0][1])
ax[0][1].legend(title='Outcome', labels=['No', 'Yes'], loc='lower right')
ax[0][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[0][1].tick_params(left=False)
ax[0][1].set_xlabel('')

events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['gt_es_hosp'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
ax[1][0] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='Prolonged hospital stay',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[1][0])
ax[1][0].legend(title='Outcome', labels=['<14 days', '>=14 days'], loc='lower right')
ax[1][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][0].set_xlabel('')

events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['gt_dd'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
ax[1][1] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='Discharge disposition',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[1][1])
ax[1][1].legend(title='Outcome', labels=['Home discharge', 'Non-home discharge'], loc='lower right')
ax[1][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][1].tick_params(left=False)
ax[1][1].set_xlabel('')

events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['gt_eld'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
ax[2][0] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='Geriatric Medicine services',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[2][0])
ax[2][0].legend(title='Outcome', labels=['No', 'Yes'], loc='lower right')
ax[2][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[2][0].tick_params(left=False)
ax[2][0].set_xlabel('')

events_long = pd.melt(inp_data_full, id_vars=['adm_year'], value_vars=['received_rehab'], value_name='annual')
events_long = events_long.groupby(['adm_year', 'annual']).size().reset_index(name='Count')
events_y = events_long.groupby('adm_year')['Count'].apply(lambda x: x.sum()).reset_index().rename(columns={'Count':'Total'})
events_long = events_long.merge(events_y, how='left', on='adm_year')
events_long['Percentage'] = round(events_long['Count'] / events_long['Total'], 4)
ax[2][1] = pd.pivot_table(events_long[['adm_year', 'annual', 'Percentage']], columns=['annual'], index=['adm_year'],
                    sort=True).plot(title='Received rehabilitation',
                                     kind='bar',
                                     figsize=(8,8),
                                     stacked=True,
                                    color=['turquoise', 'darkred', 'darkgreen', 'khaki', 'aquamarine', 'gray'], ax=ax[2][1])
ax[2][1].legend(title='Outcome', labels=['No', 'Yes'], loc='lower right')
ax[2][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[2][1].tick_params(left=False)
ax[2][1].set_xlabel('')

ind = 0
for i,j in ax:
    plt.sca(i)
    plt.xticks(rotation=45, ha='center')
    #plt.xlabel('Year of admission')
    plt.ylabel('% of admissions')
    plt.sca(j)
    plt.xticks(rotation=45, ha='center')
    #plt.xlabel('Year of admission')
    plt.ylabel('% of admissions')
plt.tight_layout()
plt.show()

In [None]:
#### Get annual in-hospital death rates
inp_data_fm = inp_data_f[(inp_data_f.gt_m==1)&(inp_data_f.simd_quint>-1)]
events_long = pd.melt(inp_data_fm, id_vars=['simd_quint'], value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
#ax[0][0].yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _:'{:.0%}'.format(y)))
#### Plot
fig, ax = plt.subplots(3, 2, figsize=(12,8), sharey=True)
plt.suptitle('Sociodemographic characteristics in patients with adverse outcome.')
fig.supxlabel('SIMD (1 - most deprived, 5 - least deprived)')
ax[0][0] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='In-hospital death',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[0][0], legend=False,
                                                     color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[0][0].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[0][0].set_xlabel('')
ax[0][0].set_ylabel('')
ax[0][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

events_long = pd.melt(inp_data_f[(inp_data_f['gt_cc']==1)&(inp_data_f.simd_quint>-1)], id_vars=['simd_quint'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[0][1] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='ICU/HDU admission',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[0][1],
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
ax[0][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         bbox_to_anchor=(1,1))
ax[0][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[0][1].tick_params(left=False)
ax[0][1].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_es_hosp']==1)&(inp_data_f.simd_quint>-1)], id_vars=['simd_quint'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[1][0] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='Extended hospital stay',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[1][0], legend=False,
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][0].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[1][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][0].tick_params(left=False)
ax[1][0].set_xlabel('')


events_long = pd.melt(inp_data_f[(inp_data_f['gt_dd']==1)&(inp_data_f.simd_quint>-1)], id_vars=['simd_quint'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[1][1] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='Non-home discharge',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[1][1], legend=False,
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[1][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][1].tick_params(left=False)
ax[1][1].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_eld']==1)&(inp_data_f.simd_quint>-1)], id_vars=['simd_quint'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[2][0] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='Geriatric services',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[2][0], legend=False,
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[2][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[2][0].tick_params(left=False)
ax[2][0].set_xlabel('')
ax[2][0].set_ylabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['received_rehab']==1)&(inp_data_f.simd_quint>-1)], id_vars=['simd_quint'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['simd_quint', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('simd_quint')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='simd_quint')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[2][1] = pd.pivot_table(events_long[['simd_quint', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['simd_quint'], sort=True).plot(title='Received rehabilitation',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[2][1], legend=False,
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[2][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[2][1].tick_params(left=False)
ax[2][1].set_xlabel('')
ax[2][1].set_ylabel('')

ind = 0
for i,j in ax:
    plt.sca(i)
    plt.xticks(rotation=0, ha='center')
    #plt.xlabel('SIMD')
    #plt.ylabel('% of patients with outcome')
    plt.sca(j)
    plt.xticks(rotation=0, ha='center')
    #plt.xlabel('SIMD')
    #plt.ylabel('% of patients with outcome')
plt.tight_layout()
plt.show()

In [None]:
#### Get annual in-hospital death rates
inp_data_fm = inp_data_f[(inp_data_f.gt_m==1)]
events_long = pd.melt(inp_data_fm, id_vars=['Sex'], value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['Sex', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('Sex')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='Sex')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
#ax[0][0].yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _:'{:.0%}'.format(y)))
#### Plot
fig, ax = plt.subplots(3, 2, figsize=(15,8), sharey=True)
plt.suptitle('Demographic characteristics in patients with adverse outcomes.')
fig.supxlabel('Sex')
ax[0][0] = pd.pivot_table(events_long[['Sex', 'age_simd', 'Percentage']], columns=['age_simd'],
                    index=['Sex'], sort=True).plot(title='In-hospital death',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[0][0], legend=False,
                                                   color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[0][0].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[0][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[0][0].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_cc']==1)], id_vars=['Sex'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['Sex', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('Sex')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='Sex')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[0][1] = pd.pivot_table(events_long[['Sex', 'age_simd', 'Percentage']], columns=['age_simd'],
                          index=['Sex'], sort=True).plot(title='ICU/HDU admission',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[0][1],
                                                         color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
ax[0][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         bbox_to_anchor=(1,1))
ax[0][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[0][1].tick_params(left=False)
ax[0][1].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_es_hosp']==1)], id_vars=['Sex'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['Sex', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('Sex')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='Sex')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[1][0] = pd.pivot_table(events_long[['Sex', 'age_simd', 'Percentage']], columns=['age_simd'], index=['Sex'],sort=True).plot(title='Extended hospital stay',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[1][0], legend=False,color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][0].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[1][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][0].tick_params(left=False)
ax[1][0].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_dd']==1)&(inp_data_f.simd_quint>-1)], id_vars=['Sex'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['Sex', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('Sex')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='Sex')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[1][1] = pd.pivot_table(events_long[['Sex', 'age_simd', 'Percentage']], columns=['age_simd'],index=['Sex'],sort=True).plot(title='Non-home discharge',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[1][1], legend=False,color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][1].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[1][1].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[1][1].tick_params(left=False)
ax[1][1].set_xlabel('')

events_long = pd.melt(inp_data_f[(inp_data_f['gt_eld']==1)], id_vars=['Sex'],
                      value_vars=['age_gr'], value_name='age_simd')
events_long = events_long.groupby(['Sex', 'age_simd']).size().reset_index(name='# Diagnoses')
events_y_counts = events_long.groupby('Sex')['# Diagnoses'].apply(lambda x: x.sum()).reset_index().rename(columns={'# Diagnoses':'Overall patients'})
events_long = events_long.merge(events_y_counts, how='left', on='Sex')
events_long['Percentage'] = round(events_long['# Diagnoses'] / events_long['Overall patients'], 4)
ax[2][0] = pd.pivot_table(events_long[['Sex', 'age_simd', 'Percentage']], columns=['age_simd'], index=['Sex'],sort=True).plot(
    title='Admission to MoE',
                                                     kind='bar',
                                                     figsize=(7,8),
                                                     stacked=True, ax=ax[2][0], legend=False,color=['#fef0d9', '#fdcc8a', '#fc8d59', '#e34a33', '#b30000'])
#ax[1][0].legend(title='Age groups', labels=['50-59', '60-69', '70-79', '80-89', '90+'],
         #loc='lower right')
ax[2][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
ax[2][0].tick_params(left=False)
ax[2][0].set_xlabel('')
ax[2][1].axis('off')

ind = 0
for i,j in ax:
    plt.sca(i)
    plt.xticks(rotation=0, ha='center')
    #plt.xlabel('Sex')
    #plt.ylabel('% of patients with outcome')
    plt.sca(j)
    plt.xticks(rotation=0, ha='center')
    #plt.xlabel('Sex')
    #plt.ylabel('% of patients with outcome')
plt.tight_layout()
plt.show()

#### Define continuous intensity outcomes

In [None]:
#### Continuous outcomes
inp_data_f['total_n_disciplines'] = inp_data_f[['n_PT', 'n_OT', 'n_SLT', 'n_GMD', 'n_NURSE']].apply(lambda row: (row>0).sum(), axis=1)
inp_data_f['total_n_disciplines'] = np.where(inp_data_f['total_n_disciplines']==0, 1, inp_data_f['total_n_disciplines'])
print(inp_data_f['total_count_all'].describe())
print(inp_data_f['total_count_ooh_all'].describe())
print(inp_data_f['total_mins_all'].describe())
print(inp_data_f['total_n_disciplines'].describe())

In [None]:
plt.rcParams.update({'font.size':12, 'font.weight':'normal', 'font.family':'serif'})

In [None]:
#### Apply Box-cox transformation
#tf_data, ld, (ci_u, ci_l) = stats.boxcox(inp_data_f['total_count_all'], alpha=0.5)
tf_data = stats.boxcox(inp_data_f['total_count_all'], lmbda=0.0)
#tf_data = np.log1p(inp_data_f['total_count_all'])
#print(f"Lambda value used for transformation: {ld}[{ci_l}, {ci_u}]")
#print(f"Lambda value used for transformation: {ld}")
fig, ax = plt.subplots(1, 2, figsize=(12,4))
sns.distplot(inp_data_f['total_count_all'], hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='Original', color='crimson', ax=ax[0])
ax[0].set_title("Original health contacts data")
ax[0].set_xlabel('# Contacts')
sns.distplot(tf_data, hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='BCT', color='darkblue', ax=ax[1])
ax[1].set_title("Transformed health contacts data")
ax[1].set_xlabel('# Contacts')
plt.show()

### Normality test
stat, p = stats.shapiro(inp_data_f['total_count_all'])
print(f"Shapiro-Wilk normality test for original data: stat={stat}, p={p}")
stat, p = stats.shapiro(tf_data)
print(f"Shapiro-Wilk normality test for transformed data: stat={stat}, p={p}")
inp_data_f['total_count_all_tf'] = tf_data

In [None]:
#### Apply Box-cox transformation
#tf_data, ld = stats.boxcox(inp_data_f['total_mins_all'] + 1)
tf_data = stats.boxcox(inp_data_f['total_mins_all'] + 1, lmbda=0.0)
#tf_data = np.log1p(inp_data_f['total_count_all'])
#print(f"Lambda value used for transformation: {ld}")
fig, ax = plt.subplots(1, 2, figsize=(12,4))
sns.distplot(inp_data_f['total_mins_all'], hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='Original', color='crimson', ax=ax[0])
ax[0].set_title("Original health contacts data")
ax[0].set_xlabel('Cumulative contact minutes')
sns.distplot(tf_data, hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='BCT', color='darkblue', ax=ax[1])
ax[1].set_title("Transformed health contacts data")
ax[1].set_xlabel('Cumulative contact minutes')
plt.show()

### Normality test
stat, p = stats.shapiro(inp_data_f['total_mins_all'])
print(f"Shapiro-Wilk normality test for original data: stat={stat}, p={p}")
stat, p = stats.shapiro(tf_data)
print(f"Shapiro-Wilk normality test for transformed data: stat={stat}, p={p}")
inp_data_f['total_mins_all_tf'] = tf_data

In [None]:
#### Apply Box-cox transformation
#tf_data, ld = stats.boxcox(inp_data_f['total_mins_all'] + 1)
tf_data = stats.boxcox(inp_data_f['total_count_ooh_all'] + 1, lmbda=0.0)
#tf_data = np.log1p(inp_data_f['total_count_all'])
#print(f"Lambda value used for transformation: {ld}")
fig, ax = plt.subplots(1, 2, figsize=(12,4))
sns.distplot(inp_data_f['total_count_ooh_all'], hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='Original', color='crimson', ax=ax[0])
ax[0].set_title("Original health contacts data")
ax[0].set_xlabel('# Out-of-hours contacts')
sns.distplot(tf_data, hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='BCT', color='darkblue', ax=ax[1])
ax[1].set_title("Transformed health contacts data")
ax[1].set_xlabel('# Out-of-hours contacts')
plt.show()

### Normality test
stat, p = stats.shapiro(inp_data_f['total_count_ooh_all'])
print(f"Shapiro-Wilk normality test for original data: stat={stat}, p={p}")
stat, p = stats.shapiro(tf_data)
print(f"Shapiro-Wilk normality test for transformed data: stat={stat}, p={p}")
inp_data_f['total_count_ooh_all_tf'] = tf_data

In [None]:
#### Apply Box-cox transformation
#tf_data, ld = stats.boxcox(inp_data_f['total_mins_all'] + 1)
tf_data, ld = stats.boxcox(inp_data_f['total_n_disciplines'])
#tf_data = np.log1p(inp_data_f['total_count_all'])
print(f"Lambda value used for transformation: {ld}")
fig, ax = plt.subplots(1, 2, figsize=(12,4))
sns.distplot(inp_data_f['total_n_disciplines'], hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='Original', color='crimson', ax=ax[0])
ax[0].set_title("Original health contacts data")
ax[0].set_xlabel('# Distinct disciplines')
sns.distplot(tf_data, hist=False, kde=True,
             kde_kws={'fill':True, 'linewidth':2}, label='BCT', color='darkblue', ax=ax[1])
ax[1].set_title("Transformed health contacts data")
ax[1].set_xlabel('# Distinct disciplines')
plt.show()

### Normality test
stat, p = stats.shapiro(inp_data_f['total_n_disciplines'])
print(f"Shapiro-Wilk normality test for original data: stat={stat}, p={p}")
stat, p = stats.shapiro(tf_data)
print(f"Shapiro-Wilk normality test for transformed data: stat={stat}, p={p}")
#inp_data_f['total_n_disciplines'] = tf_data

In [None]:
#### Inverse tf test
og_data_rec = special.inv_boxcox(tf_data, 0.0)
print(np.allclose(og_data_rec, inp_data_f['total_count_all']))

In [None]:
#### Plot
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6), sharey=True)
plt.suptitle('Distribution of health contacts in those with and without in-hospital death.')
ax1 = sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_m']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=15,
            ax=ax1, label='Died', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_m']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=15,
            ax=ax1, label='Survived', alpha=0.4, color='turquoise')
ax1.legend(title='Outcome', labels=['Died', 'Survived'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax1.set_title('Nursing')
ax2 = sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_m']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=15,
            ax=ax2, label='Died', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_m']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=15,
            ax=ax2, label='Survived', alpha=0.4, color='turquoise')
ax2.legend(title='Outcome', labels=['Died', 'Survived'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax2.set_title('Physiotherapy')
ax3 = sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_m']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=15,
            ax=ax3, label='Died', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_m']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=15,
            ax=ax3, label='Survived', alpha=0.4, color='turquoise')
ax3.legend(title='Outcome', labels=['Died', 'Survived'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax3.set_title('Occupational Therapy')
ax4 = sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_m']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=15,
            ax=ax4, label='Died', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_m']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=15,
            ax=ax4, label='Survived', alpha=0.4, color='turquoise')
ax4.legend(title='Outcome', labels=['Died', 'Survived'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax4.set_title('Speech and Language Therapy')
#ax4.legend(title='Outcome', labels=['Survived', 'Died'], loc='best')
ax4.yaxis.set_major_formatter("{x:1.0f}%")

ind = 0
for a in [ax1, ax2, ax3, ax4]:
    plt.sca(a)
    #plt.xticks([], [])
    if ind not in [0, 2]:
        plt.tick_params(left=False)
    if ind in [0,1]:
        plt.xlabel('')
    else:
        plt.xlabel('Log-transformed total health contacts')
    plt.ylabel('Percentage')
    ind+=1

plt.tight_layout()
plt.show()

In [None]:
#### Plot
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6), sharey=True)
plt.suptitle('Distribution of health contacts in those with and without extended stay.')
ax1 = sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_es_hosp']==1)],
                   x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax1, label='$\geq$14 days', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_es_hosp']==0)],
             x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax1, label='$\leq$14 days', alpha=0.4, color='turquoise')
ax1.legend(title='Outcome', labels=['$\geq$14 days', '$\leq$14 days'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax1.set_title('Nursing')
ax2 = sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_es_hosp']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax2, label='$\geq$14 days', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_es_hosp']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax2, label='$\leq$14 days', alpha=0.4, color='turquoise')
ax2.legend(title='Outcome', labels=['$\geq$14 days', '$\leq$14 days'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax2.set_title('Physiotherapy')
ax3 = sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_es_hosp']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax3, label='$\geq$14 days', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_es_hosp']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax3, label='$\leq$14 days', alpha=0.4, color='turquoise')
ax3.legend(title='Outcome', labels=['$\geq$14 days', '$\leq$14 days'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax3.set_title('Occupational Therapy')
ax4 = sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_es_hosp']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax4, label='$\geq$14 days', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_es_hosp']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax4, label='$\leq$14 days', alpha=0.4, color='turquoise')
ax4.legend(title='Outcome', labels=['$\geq$14 days', '$\leq$14 days'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax4.set_title('Speech and Language Therapy')
#ax4.legend(title='Outcome', labels=['Survived', 'Died'], loc='best')
ax4.yaxis.set_major_formatter("{x:1.0f}%")

ind = 0
for a in [ax1, ax2, ax3, ax4]:
    plt.sca(a)
    #plt.xticks([], [])
    if ind not in [0, 2]:
        plt.tick_params(left=False)
    if ind in [0,1]:
        plt.xlabel('')
    else:
        plt.xlabel('Log-transformed total health contacts')
    plt.ylabel('Percentage')
    ind+=1

plt.tight_layout()
plt.show()

In [None]:
#### Plot
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6), sharey=True)
plt.suptitle('Distribution of health contacts in those with and without home discharge.')
ax1 = sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_dd']==1)],
                   x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax1, label='Non-home', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_NURSE']>0)&(inp_data_f['gt_dd']==0)],
             x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax1, label='Home', alpha=0.4, color='turquoise')
ax1.legend(title='Outcome', labels=['Non-home', 'Home'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax1.set_title('Nursing')
ax2 = sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_dd']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax2, label='Non-home', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_PT']>0)&(inp_data_f['gt_dd']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax2, label='Home', alpha=0.4, color='turquoise')
ax2.legend(title='Outcome', labels=['Non-home', 'Home'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax2.set_title('Physiotherapy')
ax3 = sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_dd']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax3, label='Non-home', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_OT']>0)&(inp_data_f['gt_dd']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax3, label='Home', alpha=0.4, color='turquoise')
ax3.legend(title='Outcome', labels=['Non-home', 'Home'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax3.set_title('Occupational Therapy')
ax4 = sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_dd']==1)], x='total_count_all',
                element='step', stat='percent', log_scale=True, bins=10,
            ax=ax4, label='Non-home', alpha=0.4, color='crimson')
sns.histplot(data=inp_data_f[(inp_data_f['n_SLT']>0)&(inp_data_f['gt_dd']==0)], x='total_count_all',
              element='step', stat='percent', log_scale=True, bins=10,
            ax=ax4, label='Home', alpha=0.4, color='turquoise')
ax4.legend(title='Outcome', labels=['Non-home', 'Home'], loc='best')
#ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _ :'{:.0%)'.format(y)))
ax4.set_title('Speech and Language Therapy')
#ax4.legend(title='Outcome', labels=['Survived', 'Died'], loc='best')
ax4.yaxis.set_major_formatter("{x:1.0f}%")

ind = 0
for a in [ax1, ax2, ax3, ax4]:
    plt.sca(a)
    #plt.xticks([], [])
    if ind not in [0, 2]:
        plt.tick_params(left=False)
    if ind in [0,1]:
        plt.xlabel('')
    else:
        plt.xlabel('Log-transformed total health contacts')
    plt.ylabel('Percentage')
    ind+=1

plt.tight_layout()
plt.show()

In [None]:
inp_data_f.total_count_all.describe()

#### Plot the relationships between age/sex/SIMD and health contacts

In [None]:
r, p = stats.pearsonr(inp_data_f['total_count_all'], inp_data_f['age_at_admission'])
plt.figure(figsize=(6,4))
#sns.scatterplot(x=inp_data_f['age_at_admission'], y=inp_data_f['total_count_all'], color='darkblue')
sns.regplot(x=inp_data_f['age_at_admission'], y=inp_data_f['total_count_all'], ci=95, scatter_kws=dict(color='darkblue'),
            line_kws=dict(color='crimson'), order=1)
plt.title('Linear relationship between health contacts and age at admission.')
plt.xlabel('Age')
plt.ylabel('Total health contacts')
plt.annotate(f'r = {r:.2f}, p<0.001', xy={0.05, 0.9}, xycoords='axes fraction')
plt.show()

In [None]:
inp_data_f['simd_quint']

In [None]:
inp_data_f['age_gr'] = pd.Categorical(inp_data_f['age_gr'])
#inp_data_f['Sex'] = pd.Categorical(inp_data_f['Sex'])
#inp_data_f['simd_quint'] = pd.Categorical(inp_data_f['simd_quint'])
inp_data_f['simd_quint'] = inp_data['simd_quint'].astype(np.int8)
inp_data_simd = inp_data_f[inp_data_f.simd_quint!=-1]
inp_data_simd['simd_quint'] = pd.Categorical(inp_data_simd['simd_quint'])
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 4), sharey=True)
#fig.suptitle('Distribution of health contacts across age, sex and SIMD.')
ax1 = sns.boxplot(data=inp_data_f, x='age_gr', y='total_count_all_tf', ax=ax1, palette='Blues')
ax1.set_ylabel('Log-transformed total health contacts')
ax1.set_xlabel('Age group')
ax2 = sns.boxplot(data=inp_data_f, x='Sex', y='total_count_all_tf', ax=ax2)
ax2.set_xlabel('Sex')
ax3 = sns.boxplot(data=inp_data_simd, x='simd_quint', y='total_count_all_tf', ax=ax3, palette='PuOr')
ax3.set_xlabel('SIMD (Quintiles)')
gr_pairs = [('50-59', '60-69'), ('50-59', '70-79'), ('50-59', '80-89'), ('50-59', '90+')]
gr_pairs2 = [('M', 'F')]
gr_pairs3 = [(1, 2), (1, 3), (1, 4), (1, 5)]
for a in [ax1, ax2, ax3]:
    plt.sca(a)
    #plt.xticks([], [])
    plt.xticks(rotation=45, ha='center')
    if a!=ax1:
        plt.tick_params(left=False)
    ind+=1
annot = Annotator(ax1, data=inp_data_f, x='age_gr', y='total_count_all_tf', order=inp_data_f.age_gr.cat.categories,
                              pairs=gr_pairs)
annot.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot.apply_and_annotate()

annot2 = Annotator(ax2, data=inp_data_f, x='Sex', y='total_count_all_tf', order=['M', 'F'],
                              pairs=gr_pairs2)
annot2.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot2._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot2.apply_and_annotate()

annot3 = Annotator(ax3, data=inp_data_simd, x='simd_quint', y='total_count_all_tf', order=[1, 2, 3, 4, 5],
                              pairs=gr_pairs3)
annot3.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot3._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot3.apply_and_annotate()
#plt.tight_layout()
plt.show()

In [None]:
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(7, 8), sharey=True)
inp_data_vis = inp_data_f.copy()
inp_data_vis['gt_m'] = np.where(inp_data_vis['gt_m']==0, 'N', 'Y')
inp_data_vis['gt_cc'] = np.where(inp_data_vis['gt_cc']==0, 'N', 'Y')
inp_data_vis['gt_dd'] = pd.Categorical(np.where(inp_data_vis['gt_dd']==0, 'Y', 'N'), categories=['Y', 'N'])
inp_data_vis['gt_es_hosp'] = np.where(inp_data_vis['gt_es_hosp']==0, '<14 days', '>=14 days')
inp_data_vis['gt_eld'] = np.where(inp_data_vis['gt_eld']==0, 'N', 'Y')
inp_data_vis['received_rehab'] = np.where(inp_data_vis['received_rehab']==0, 'N', 'Y')
fig.suptitle('Distribution of health contacts by outcome group.')
fig.supylabel('Log-transformed total health contacts')
ax1 = sns.boxplot(data=inp_data_vis, x='gt_m', y='total_count_all_tf', ax=ax1, palette='deep')
#sns.swarmplot(data=inp_data_vis, x='gt_m', y='total_count_all_tf', ax=ax1, palette='deep',zorder=.5)
ax1.set_ylabel('')
ax1.set_xlabel('In-hospital death')
ax2 = sns.boxplot(data=inp_data_vis, x='gt_cc', y='total_count_all_tf', ax=ax2, palette='deep')
ax2.set_xlabel('ICU/HDU admission')
ax3 = sns.boxplot(data=inp_data_vis, x='gt_es_hosp', y='total_count_all_tf', ax=ax3, palette='deep')
ax3.set_ylabel('')
ax3.set_xlabel('Extended hospital stay')
ax4 = sns.boxplot(data=inp_data_vis, x='gt_dd', y='total_count_all_tf', ax=ax4, palette='deep')
ax4.set_xlabel('Home discharge')
ax5 = sns.boxplot(data=inp_data_vis, x='gt_eld', y='total_count_all_tf', ax=ax5, palette='deep')
ax5.set_xlabel('Geriatric Medicine services')
ax5.set_ylabel('')
ax6 = sns.boxplot(data=inp_data_vis, x='received_rehab', y='total_count_all_tf', ax=ax6, palette='deep')
ax6.set_xlabel('Received rehabilitation')
ax6.set_ylabel('')

for a in [ax1, ax2, ax3, ax4, ax5, ax6]:
    plt.sca(a)
    #plt.xticks([], [])
    if a not in [ax1, ax3]:
        plt.tick_params(left=False)
    ind+=1

gr_pairs = [('N', 'Y')]
gr_pairs2 = [('<14 days', '>=14 days')]

annot = Annotator(ax1, data=inp_data_vis, x='gt_m', y='total_count_all_tf', order=['N', 'Y'],
                              pairs=gr_pairs)
annot.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot.apply_and_annotate()

annot2 = Annotator(ax2, data=inp_data_vis, x='gt_cc', y='total_count_all_tf', order=['N', 'Y'],
                              pairs=gr_pairs)
annot2.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot2._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot2.apply_and_annotate()

annot3 = Annotator(ax3, data=inp_data_vis, x='gt_es_hosp', y='total_count_all_tf', order=['<14 days', '>=14 days'],
                              pairs=gr_pairs2)
annot3.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot3._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot3.apply_and_annotate()

annot4 = Annotator(ax4, data=inp_data_vis, x='gt_dd', y='total_count_all_tf', order=['N', 'Y'],
                              pairs=gr_pairs)
annot4.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot4._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot4.apply_and_annotate()

annot5 = Annotator(ax5, data=inp_data_vis, x='gt_eld', y='total_count_all_tf', order=['N', 'Y'],
                              pairs=gr_pairs)
annot5.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot5._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot5.apply_and_annotate()

annot6 = Annotator(ax6, data=inp_data_vis, x='received_rehab', y='total_count_all_tf', order=['N', 'Y'],
                              pairs=gr_pairs)
annot6.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=2,
               comparisons_correction='bonferroni')
annot6._pvalue_format.pvalue_thresholds = [[0.001, '***'], [0.01, '**'], [0.1, '*'], [1, 'ns']]
annot6.apply_and_annotate()

plt.tight_layout()
plt.show()

In [None]:
#### Export
inp_data_f.to_csv('', index=False)