In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
from statistics import mean, median
import os
import seaborn as sns
import matplotlib.dates as mdates
import matplotlib.ticker as mtick
import scipy.stats as stats
from scipy.special import gammaln
import gc

### ML packages
from sklearn.metrics import f1_score, mean_squared_error, mean_absolute_error, make_scorer, accuracy_score, balanced_accuracy_score, cohen_kappa_score, confusion_matrix, classification_report
from confidenceinterval import roc_auc_score, ppv_score, npv_score, tnr_score, tpr_score
import confidenceinterval as cfi
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.isotonic import IsotonicRegression
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold
from sklearn.preprocessing import KBinsDiscretizer
from xgboost import XGBClassifier, XGBRegressor
import xgboost as xgb
from xgboost import cv
import shap
import ml_insights as mli

##### Load care intensity training data

In [None]:
lkup_fields = ['ppid',
 'EpisodeNumber',
 'AdmissionDate',
 'ED_adate_dt',
 'IndexAttDate',
 'HOSP_adt',
 'DischargeDate',
 'DateOfDeath',
 'HOSP_ddt',
 'breq_dt',
 'HOSP_FCC_dt',
 'HOSP_FAS_dt',
 'gt_m',
 'gt_cc',
 'gt_es_hosp',
 'gt_dd',
 'gt_eld',
 'gt_eld_d1',
 'gt_eld_d2',
 'gt_eld_d3',
 'gt_rehab',
 'total_count_all',
 'total_count_rehab',
 'total_count_all_tf',
 'total_n_disciplines',
 'total_count_ooh_all',
 'total_n_disciplines_gr',
 'age_gr',
 'total_count_cts_gr']

In [None]:
feature_names = {
    'trQ_waterlow_score': 'Waterlow score',
    'AgeAtAdmission': 'Age',
    'arrival_mode_B': 'Arrival - NHSL Bus',
    'arrival_mode_E': 'Arrival - Emergency Ambulance', 
    'arrival_mode_O': 'Arrival - Other', 
    'arrival_mode_PU': 'Arrival - Public Transport',
    'arrival_mode_U': 'Arrival - GP Ambulance', 
    'arrival_mode_Unk': 'Arrival - Unknown', 
    'arrival_mode_W': 'Arrival - Walked',
    'simd_dec': 'SIMD (most to least deprived)',
    'arrival_mode_PR': 'Arrival - Private Transport',
    'trQ_bwm_urinary_catheterisation': 'Urinary Catheterisation', 
    'trQ_bwm_urinary_incontinence': 'Urinary Incontinence', 
    'trQ_bwm_dysuria': 'Dysuria', 
    'trQ_bwm_>6times_per_day': 'Bowel Movement >6 times per day', 
    'trQ_bwm_nocturia_>2_per_night': 'Nocturia >2 per night', 
    'trQ_bwm_faeces_incontinence': 'Faeces Incontinence', 
    'trQ_bwm_constipation': 'Constipation', 
    'trQ_bwm_diarrhoea': 'Diarrhoea', 
    'trQ_bwm_blood_in_stools': 'Blood in stools', 
    'trQ_bwm_medication': 'Bowel movement medication',
    'trQ_falls_within_6_months': 'Fall within last 6 months', 
    'trQ_falls_clinical_risk': 'At clinical risk of falls', 
    'trQ_nutr_food_allergies': 'Food allergies', 
    'trQ_nutr_swallowing_difficulty': 'Swallowing difficulty', 
    'trQ_mrsa_infection_prevention': 'Infection prevention measures', 
    'trQ_mrsa_transfer_with_norovirus': 'MRSA Norovirus', 
    'trQ_mrsa_resp_or_fever': 'MRSA with Respiratory issues or Fever', 
    'trQ_mrsa_rash_fever_or_flu': 'MRSA with Rash, Fever or Flu', 
    'trQ_mrsa_infectious_diseases_contact': 'MRSA contact with infection diseases', 
    'trQ_rub_nursing_falls_risk_assessment': 'Nursing Falls risk assessment', 
    'trQ_rub_at_risk_of_bed_fall': 'At risk of bed fall', 
    'trQ_MUST_score': 'MUST Score', 
    'trQ_mobility_walking_ASSISTANCE': 'Walking assistance', 
    'trQ_mobility_walking_BED_REST': 'Walking (Bed rest)', 
    'trQ_mobility_walking_INDEPENDENT': 'Walking dependence', 
    'trQ_mobility_toileting_ASSISTANCE': 'Toileting assistance', 
    'trQ_mobility_toileting_BED_REST': 'Toileting (Bed rest)', 
    'trQ_mobility_toileting_INDEPENDENT': 'Toileting dependence', 
    'trQ_mobility_bathing_ASSISTANCE': 'Bathing assistance', 
    'trQ_mobility_bathing_BED_REST': 'Bathing (Bed rest)', 
    'trQ_mobility_bathing_INDEPENDENT': 'Bathing dependence', 
    'trQ_mobility_bed_rolling_ASSISTANCE': 'Rolling in bed assistance', 
    'trQ_mobility_bed_rolling_INDEPENDENT': 'Rolling in bed dependence', 
    'trQ_mobility_bed_moveup_ASSISTANCE': 'Moving up bed assistance',
    'trQ_mobility_bed_moveup_INDEPENDENT': 'Moving up bed dependence', 
    'trQ_mobility_bed_out_ASSISTANCE': 'Moving out of bed assistance', 
    'trQ_mobility_bed_out_BED_REST': 'Moving out of bed (Bed rest)',
    'trQ_mobility_bed_out_INDEPENDENT': 'Moving out of bed dependence', 
    'trQ_mobility_bed_in_ASSISTANCE': 'Moving in bed assistance', 
    'trQ_mobility_bed_in_BED_REST': 'Moving in bed (Bed rest)', 
    'trQ_mobility_bed_in_INDEPENDENT': 'Moving in bed dependence', 
    'trQ_mobility_sss_ASSISTANCE': 'Sit-stand-sit assistance', 
    'trQ_mobility_sss_BED_REST': 'Sit-stand-sit (Bed rest)',
    'trQ_mobility_sss_INDEPENDENT': 'Sit-stand-sit dependence',
    'trQ_mobility_lateral_ASSISTANCE': 'Lateral movement assistance', 
    'trQ_mobility_lateral_BED_REST': 'Lateral movement (Bed rest)', 
    'trQ_mobility_lateral_INDEPENDENT': 'Lateral movement dependence', 
    'trQ_mobility_floorup_ASSISTANCE': 'Floor-up movement assistance',
    'trQ_mobility_floorup_BED_REST': 'Floor-up movement (Bed rest)', 
    'trQ_mobility_floorup_INDEPENDENT': 'Floor-up movement dependence',
    'num_inp_attendances_lyr': 'Scheduled inpatient attendances last year', 
    'total_longterm_conditions': '# unique long-term conditions', 
    'num_outp_att_CB': 'Outpatient visits (Urology)', 
    'lactate_v': 'Lactate (mmol/L) - last value',
    'lactate_rm': 'Lactate (mmol/L) - moving average',
    'dsl_outp_att': 'Last outpatient attendance (days)', 
    'haemoglobin_nl': 'Haemoglobin - low', 
    'dsl_physltc_pulmonary_fibrosis': 'Pulmonary fibrosis (days)', 
    'hba1c_(ifcc)_rs': 'HbA1c (IFCC, mmol/mol) - moving std', 
    'urea_v': 'Urea (mmol/L) - last value', 
    'dsl_antipsychotics': 'Antipsychotics (days since last)', 
    'red_cell_count_nl': 'Red Cell Count - low', 
    'bilirubin_nh': 'Bilirubin - high', 
    'num_outp_att_AR': 'Outpatient visits (Rheumatology)', 
    'num_outp_att_F2': 'Outpatient visits (Gynaecology)', 
    'hba1c_(ifcc)_v': 'HbA1c (IFCC, mmol/mol) - last value', 
    'c-reactive_prot_nh': 'CRP - high', 
    'n_presc_anticoagulant_protamine_drugs': 'Anticoagulants and protaime (# prescribed)', 
    'num_outp_att_G1': 'Outpatient visits (General Psychiatry)', 
    'ggt_v': 'GGT (U/L) - last value', 
    'num_outp_att_C11': 'Outpatient visits (General Surgery)', 
    'dsl_antidementia_drugs': 'Antidementia drugs (days since last)', 
    'dsl_anti_hypertension_hf_drugs': 'Antihypertensive drugs (days since last)', 
    'num_outp_att_G4': 'Outpatient visits (Psychiatry Of Old Age)', 
    'dsl_antidepressant_drugs': 'Antidepressant drugs (days since last)', 
    'num_outp_att_A1': 'Outpatient visits (General Medicine)', 
    'phys_men_multimorbidity': 'Physical-mental multimorbidity', 
    'num_outp_att_C3': 'Outpatient visits (Anaesthetics)', 
    'hba1c_(ifcc)_nh': 'HbA1c (IFCC) - high', 
    'n_presc_nausea_vertigo_drugs': 'Nausea and vertigo drugs (# prescribed)', 
    'albumin_nl': 'Albumin - low', 
    'total_menlongterm_conditions': '# Unique mental chronic conditions', 
    'num_outp_att_AB': 'Outpatient visits (Geriatric Medicine)', 
    'num_outp_att_R5': 'Outpatient visits (Physiotherapy)', 
    'dsl_physltc_prog_neur_disease': 'Progressive neurological disease (days)', 
    'ast_rm': 'AST (U/L) - moving average', 
    'num_inp_attendances': '# Scheduled inpatient attendances', 
    'ferritin_nl': 'Ferritin - low', 
    'dsl_physltc_arthritis_arthropathy': 'Arthritis or other arthropathy (days)', 
    'num_outp_att_C7': 'Outpatient visits (Opthalmology)', 
    'dsl_physltc_heart_failure': 'Heart Failure (days)', 
    'num_outp_att_AG': 'Outpatient visits (Renal Medicine)', 
    'num_outp_att_A9': 'Outpatient visits (Gastroenterology)', 
    'total_drug_categories': '# Unique prescribed drug categories', 
    'num_outp_att_A82': 'Outpatient visits (Diabetes)', 
    'urea_rs': 'Urea (mmol/L) - moving std', 
    'num_outp_att_C8': 'Outpatient visits (Trauma and Orthopaedic Surgery)', 
    'ferritin_nh': 'Ferritin - high', 
    'bilirubin_nl': 'Bilirubin - low', 
    'num_outp_att_A2': 'Outpatient visits (Cardiology)', 
    'num_outp_att_C5': 'Outpatient visits (ENT)',
    'n_presc_antidepressant_drugs': 'Antidepressant drugs (# prescribed)',
    'urea_nh': 'Urea - high', 
    'num_outp_att_A81': 'Outpatient visits (Endocrine)', 
    'dsl_physltc_liver_disease': 'Liver disease (days)', 
    'dsl_antiplatelet_drugs': 'Antiplatelet drugs (days since last)', 
    'num_inp_att_AG': 'Inpatient visits (Renal Medicine)', 
    'white_cell_count_nh': 'White Cell Count - high', 
    'ggt_rs': 'GGT (U/L) - moving std', 
    'mean_cell_volume_nh': 'MCV - high', 
    'dsl_physltc_chronic_renal_disease': 'Chronic renal disease (days)', 
    'bilirubin_rs': 'Bilirubin (umol/L) - moving std', 
    'c-reactive_prot_rs': 'CRP (mg/L) - moving std', 
    'dsl_physltc_stroke': 'Stroke - (days)', 
    'dsl_physltc_atrial_fibrillation': 'Atrial fibrillation (days)', 
    'dsl_physltc_copd': 'COPD (days)', 
    'dsl_physltc_per_vascular_disease': 'Peripheral Vascular Disease (days)', 
    'n_presc_parkinsonism_drugs': 'Parkinsonism drugs (# prescribed)', 
    'num_outp_att_dna_AG': 'Outpatient failed visits (Renal Medicine)', 
    'c-reactive_prot_rm': 'CRP (mg/L) - moving average', 
    'n_presc_nitrates_ccb_drugs': 'Nitrates and CCBs (# prescribed)', 
    'n_presc_beta_blockers': 'Beta blockers (# prescribed)', 
    'ferritin_rs': 'Ferritin (ug/L) - moving std', 
    'num_inp_att_C8': 'Inpatient visits (Trauma and Orthopaedic Surgery)', 
    'n_presc_antiplatelet_drugs': 'Antiplatelet drugs (# prescribed)', 
    'monocyte_count_nh': 'Monocyte count - high', 
    'num_outp_att_C9': 'Outpatient visits (Plastic Surgery)', 
    'lactate_rs': 'Lactate (mmol/L) - moving std', 
    'neutrophil_count_nh': 'Neutrophil count - high', 
    'n_presc_diuretics': 'Diuretics (# prescribed)', 
    'num_outp_attendances': '# Outpatient visits', 
    'dsl_bone_metabolism_affecting_drugs': 'Bone/metabolism-affecting drugs (days since last)', 
    'ggt_rm': 'GGT (U/L) - moving average', 
    'hba1c_(ifcc)_rm': 'HbA1c (IFCC) - moving average', 
    'dsl_menltc_depression': 'Depression (days)', 
    'dsl_physltc_hypertension': 'Hypertension (days)', 
    'sodium_nl': 'Sodium - low', 
    'ggt_nh': 'GGT - high', 
    'n_presc_lipid_regulators': 'Lipid regulators (# prescribed)', 
    'n_presc_antipsychotics': 'Antipsychotics (# prescribed)', 
    'dsl_physltc_inf_bowel_disease': 'Inflammatory Bowel Disease (days)', 
    'calcium_nl': 'Calcium - low', 
    'dsl_physltc_ischaemic_heart_disease': 'Ischaemic Heart Disease (days)', 
    'n_presc_anti_hypertension_hf_drugs': 'Antihypertensive drugs (# prescribed)', 
    'n_presc_antidementia_drugs': 'Anti-dementia drugs (# prescribed)', 
    'lymphocyte_count_nl': 'Lymphocyte count - low', 
    'egfr_(/1.73m2)_nl': 'eGFR (/1.73m2) - low',
    'egfr_(/1.73m2)_v': 'eGFR (/1.73m2) - last value',
    'egfr_(/1.73m2)_rs': 'eGFR (/1.73m2) - moving std',
    'platelet_count_nl': 'Platelet Count - low', 
    'dsl_diuretics': 'Diuretics (days since last)', 
    'n_presc_bone_metabolism_affecting_drugs': 'Bone/metabolism-affecting drugs (# prescribed)', 
    'ferritin_v': 'Ferritin (ug/L) - last value', 
    'num_outp_att_AQ': 'Outpatient visits (Respiratory Medicine)', 
    'esr_v': 'ESR (mm/hr) - last value', 
    'dsl_physltc_epilepsy': 'Epilepsy (days)', 
    'num_outp_att_AD': 'Outpatient visits (Medical Oncology)', 
    'dsl_physltc_diabetes': 'Diabetes (days)', 
    'dsl_menltc_alcohol_substance_misuse': 'Alcohol/substance misuse (days)', 
    'dsl_physltc_obesity': 'Obesity (days)', 
    'num_outp_att_AP': 'Outpatient visits (Rehabilitation Medicine)', 
    'num_outp_att_H2': 'Outpatient visits (Clinical Oncology)', 
    'c-reactive_prot_v': 'CRP (mg/L) - last value', 
    'ck_v': 'CK (IU/L) - last value', 
    'dsl_physltc_asthma': 'Asthma (days)', 
    'hs_troponin_i_v': 'HS Troponin I (ng/L) - last value', 
    'basophil_count_rm': 'Basophil Count - moving average', 
    'num_outp_att_dna_A9': 'Outpatient failed visits (Gastroenterology)', 
    'num_outp_att_AH': 'Outpatient visits (Neurology)',
    'alt_nl': 'ALT - low', 
    'dsl_menltc_chronic_psychiatry_disorder': 'Chronic Psychiatric Disorder (days)', 
    'hs_troponin_i_rm': 'HS Troponin I (ng/L) - moving average', 
    'dsl_physltc_osteoporosis': 'Osteoporosis (days)', 
    'basophil_count_v': 'Basophil Count - last value', 
    'albumin_rm': 'Albumin (g/L) - moving average',
    'creatinine_nl': 'Creatinine - low', 
    'ferritin_rm': 'Ferritin (ug/L) - moving average', 
    'dsl_physltc_hip_fracture': 'Hip fracture (days)', 
    'num_outp_att_dna_C7': 'Outpatient failed visits (Opthalmology)', 
    'creatinine_nh': 'Creatinine - high', 
    'alk.phos_nh': 'Alkaline Phosphatase - high', 
    'eosinophil_count_nl': 'Eosinophil Count - low', 
    'dsl_physltc_historical_or_active_cancer': 'Historical or Active Cancer (days)', 
    'bilirubin_rm': 'Bilirubin (umol/L) - moving average', 
    'Sex_F': 'Sex (Female)',
    'esr_rm': 'ESR (mm/hr) - moving average',
    'urea_v': 'Urea (mmol/L) - last value',
    'urea_rm': 'Urea (mmol/L) - moving average',
    'ck_rm': 'CK (IU/L) - moving average',
    'triage_code': 'ED triage code',
    'tco2_v': 'tCO2 (mmol/L) - last value',
    'arrival_mode_R': 'Arrival - Routine Ambulance',
    'dsl_inp_att': 'Last scheduled inpatient attendance',
    'trQ_4at': '4AT Score',
    'basophil_count_rm': 'Basophil Count - moving average',
    'albumin_v': 'Albumin (g/L) - last value',
    'c-reactive_prot_rs': 'CRP (mg/L) - moving std',
    'hs_troponin_t_v': 'HS Troponin T (ng/L) - last value',
    'albumin_rs': 'Albumin (g/L) - moving std',
    'trQ_mobility_bathing_INDEPENDENT': 'Mobility (Bathing independence)',
    'tco2_rs': 'tCO2 (mmol/L) - moving std',
    'bilirubin_v': 'Bilirubin (umol/L) - last value',
    'haematocrit_nl': 'Haematocrit - low'}

In [None]:
#### Load features while specifying data types for memory efficiency
dem_types = pd.read_csv('', names=['item', 'dtype'], skiprows=1)
dtype_dict = {}
for idx, row in dem_types.iterrows():
    dtype_dict[row['item']] = row['dtype']

base_path = ''
model_path = ''
train_data = pd.read_csv(os.path.join(base_path, ''), low_memory=True)
val_data = pd.read_csv(os.path.join(base_path, ''), low_memory=True, dtype=dtype_dict)

train_data.columns = [col.replace('<', '_below_') if '<' in col else col for col in train_data.columns]
train_data.columns = [col.replace(',', '_') if ',' in col else col for col in train_data.columns]
val_data.columns = [col.replace('<', '_below_') if '<' in col else col for col in val_data.columns]
val_data.columns = [col.replace(',', '_') if ',' in col else col for col in val_data.columns]

### Shuffle data when using time-series split
train_data = train_data.sample(frac=1, random_state=42).reset_index(drop=True)
val_data = val_data.sample(frac=1, random_state=42).reset_index(drop=True)
### Lookup fields
train_lkup_cts = train_data[lkup_fields]
val_lkup_cts = val_data[lkup_fields]
### GT fields
train_y_cts = train_data['total_count_all_tf']
val_y_cts = val_data['total_count_all_tf']
### XGBoost features
train_x_cts = train_data.drop(train_lkup_cts.columns.tolist(), axis=1)
val_x_cts = val_data.drop(val_lkup_cts.columns.tolist(), axis=1)
print('Training features')
print(train_x_cts.columns.tolist())
print(train_x_cts.shape, val_x_cts.shape, train_y_cts.shape, val_y_cts.shape)
### Create XGBoost objects
train_dm_cts = xgb.DMatrix(train_x_cts, label=train_y_cts)
val_dm_cts = xgb.DMatrix(val_x_cts, label=val_y_cts)

##### Setup hyperparameters

In [None]:
early_stopping = 50
rounds = 20000
#pos_weight = round(len(train_y_cts[train_y_cts==0]) / len(train_y_cts[train_y_cts==1]), 3)
#print('Weight scale parameter for imbalanced data:', pos_weight)
#print(round(len(train_y_cts[train_y_cts==1])/ len(train_y_cts), 2))

def negative_binomial_loss(y_true, y_pred, alpha):
    return -np.sum(gammaln(y_true, + 1/alpha) - gammaln(y_true + 1) - gammaln(1/alpha) + y_true*np.log(alpha*y_pred/(1+alpha*y_pred)) + (1/alpha)*np.log(1/(1+alpha*y_pred)))

def negative_binomial_eval(y_pred, dtrain, alpha=2.0):
    y_true = dtrain.get_label()
    grad = alpha * y_pred / (1 + alpha * y_pred) - alpha * y_true / (1 + alpha * y_pred)
    hess = alpha * (1 + alpha * y_true) / (1 + alpha * y_pred)**2
    return grad, hess
    
params_def = {
    'max_depth': 3,
    'objective': 'reg:pseudohubererror',
    'nthread': 25,
    #'eval_metric': 'mape',
    'eta': 0.01,
    #'colsample_bytree': .5,
    #'alpha': 1
    #'lambda': 2
    ### For imbalanced data
    #'scale_pos_weight': pos_weight
    #'subsample': .6
}

In [None]:
plt.rcParams.update({'font.size':12, 'font.weight':'normal', 'font.family':'serif'})

##### Train/eval pipeline

In [None]:
def train_optimize_model(train_dm, val_dm, save_path, target, version,
                         param_grid=params_def,
                         es=early_stopping, rounds=rounds):
    print('Training baseline model for target: ', target)
    evals_result = {}
    model = xgb.train(param_grid, train_dm, num_boost_round=rounds, early_stopping_rounds=es,
                      evals=[(train_dm, 'train'), (val_dm, 'validation')],
                      evals_result=evals_result)
    print('Best Score: {:.3f} with {} rounds'.format(model.best_score, model.best_iteration+1))
    print('Refitting model to best iteration...')
    best_iter = model.best_iteration
    best_model = xgb.train(param_grid, train_dm, num_boost_round=best_iter, early_stopping_rounds=es,
                      evals=[(train_dm, 'train'), (val_dm, 'validation')], verbose_eval=False)
    print('Training complete. Storing baseline to disk.')
    best_model.save_model(save_path + '_' + target + '_' + version + '.model')
    return best_model, evals_result

def plot_learning_curve(model, evals_result, metric='mphe'):
    if evals_result==None:
        return
    epochs = len(evals_result['train'][metric])
    x_axis = range(0, epochs)
    plt.figure(figsize=(6,6))
    lw = 1.5
    plt.plot(x_axis, evals_result['train'][metric], color='darkorange', lw=lw, label='Training loss')
    plt.plot(x_axis, evals_result['validation'][metric], color='navy', lw=lw, label='Validation loss')
    #plt.xlim([-0.05, 1.05])
    #plt.ylim([-0.05, print('Weight scale parameter for imbalanced data:', pos_weight)
#1.05])
    plt.xlabel('# Epochs')
    plt.ylabel('Huber loss')
    plt.title('XGBoost Regression learning curves')

def bootstrap_metric(labels_true, labels_pred, metric_func, n_iter=1000):
    n = len(labels_true)
    res = np.zeros(n_iter)
    for i in range(n_iter):
        ind = np.random.randint(0, n, n)
        sample_true = labels_true[ind]
        sample_pred = labels_pred[ind]
        res[i] = metric_func(sample_true, sample_pred)
    return res

def compute_ci(bootstrap_res, ci=0.95):
    lp = (1 - ci) / 2
    up = 1 - lp
    return np.round(np.percentile(bootstrap_res, [lp*100, up*100]), 3)

def rmse(labels_true, labels_pred):
    return np.sqrt(mean_squared_error(labels_true, labels_pred))

def mae(labels_true, labels_pred):
    return mean_absolute_error(labels_true, labels_pred)

def mape(labels_true, labels_pred):
    return np.mean(2 * np.abs(labels_true - labels_pred) / (np.abs(labels_true) + np.abs(labels_pred))) * 100

def mape_c(labels_true, labels_pred):
    mask = labels_true != 0
    return np.mean(np.abs((labels_true[mask] - labels_pred[mask]) / labels_true[mask])) * 100

def f1_cs(labels_true, labels_pred):
    return f1_score(labels_true, labels_pred, average='macro')

def kappa_cs(labels_true, labels_pred):
    return cohen_kappa_score(labels_true, labels_pred, weights='quadratic')

def evaluate_model(val_features, labels_val, model, evals_result, 
                  task='Total health contacts', tgt='Total health contacts', tp='ED attendance'):
    print('Evaluating model for target: ' + task)
    res_dict = {}
    res_dict['timepoint'] = tp
    res_dict['target'] = tgt
    if 'age_gr' in val_features:
        val_features = val_features.drop('age_gr', axis=1)
    plot_learning_curve(model, evals_result)
    labels_pred_val = model.predict(xgb.DMatrix(val_features))
    #### Get performance measures with 95% CI
    rmse_ci = compute_ci(bootstrap_metric(labels_val, labels_pred_val, rmse))
    mae_ci = compute_ci(bootstrap_metric(labels_val, labels_pred_val, mae))
    mape_ci = compute_ci(bootstrap_metric(labels_val, labels_pred_val, mape_c))
    rmse_val = round(rmse(labels_val, labels_pred_val), 3)
    mae_val = round(mae(labels_val, labels_pred_val), 3)
    mape_val = round(mape_c(labels_val, labels_pred_val), 3)
    print(f'RMSE: {rmse_val}, 95% CI: {rmse_ci}')
    print(f'MAE: {mae_val}, 95% CI: {mae_ci}')
    print(f'MAPE: {mape_val}, 95% CI: {mape_ci}')
    res_dict['RMSE'] = rmse_val
    res_dict['RMSE-upper'] = rmse_ci[0]
    res_dict['RMSE-lower'] = rmse_ci[1]
    res_dict['MAE'] = mae_val
    res_dict['MAE-upper'] = mae_ci[0]
    res_dict['MAE-lower'] = mae_ci[1]
    res_dict['MAPE'] = mape_val
    res_dict['MAPE-upper'] = mape_ci[0]
    res_dict['MAPE-lower'] = mape_ci[1]
    print('Evaluation complete.')
    return res_dict

def get_shap_feature_importance(val_x, model, out_path, subset='Clinically-supervised model',
                                task='Any dementia diagnosis', 
                                fn='adem_diag', n_feat=20, plot_type='bar', feature_names=None):

    print('Getting global feature importances for task:', task)
    shap.initjs()
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(val_x)
    #print(shap_values)
    shap.summary_plot(shap_values, val_x, show=False, plot_size=(8, 7), max_display=n_feat,
                     title='', feature_names=feature_names)
    plt.title('')
    #f = plt.gcf()
    #f.savefig(out_path + '/' + fn + '_shap_global_fi.png', bbox_inches='tight', dpi=200)
    plt.show()

In [None]:
xgb_cts, evals = train_optimize_model(train_dm_cts, val_dm_cts, model_path, 'health_cts', 'v1.1.ADM')

In [None]:
xgb_cts = xgb.Booster()
xgb_cts.load_model('')

In [None]:
res_dict = evaluate_model(val_x_cts, val_y_cts, xgb_cts, evals, task='Total health contacts',
                          tgt='Total health contacts')

In [None]:
get_shap_feature_importance(val_x_cts.rename(columns=feature_names), xgb_cts, out_path=None, subset='All features model',
                                task='Total health contacts', 
                                fn='Total health contacts', n_feat=20, plot_type='bar')

#### Evaluate performance on discrete bins

In [None]:
def discretize(y, nb=5):
    discretizer = KBinsDiscretizer(n_bins=nb, encode='ordinal', strategy='quantile')
    return discretizer.fit_transform(y.reshape(-1, 1)).flatten()

def evaluate_disc_performance_multi(model, val_data, true_labels, n_bins=5,
                                   bin_labels=['Very Low', 'Low', 'Medium', 'Medium-High', 'High']):
    if 'age_gr' in val_data.columns:
        val_data = val_data.drop('age_gr', axis=1)
    labels_pred = model.predict(xgb.DMatrix(val_data))
    labels_pred_disc = discretize(labels_pred, nb=n_bins)
    labels_true_disc = discretize(np.array(true_labels.values), nb=n_bins)
    #print(labels_true_disc)
    ### Setup confusion matrix
    cm = confusion_matrix(labels_true_disc, labels_pred_disc, normalize='true')
    cm_df = round(pd.DataFrame(cm, index=bin_labels, columns=bin_labels) * 100, 1)
    #cm_df = cm_df.astype(int)
    print('Overall confusion matrix')
    print(cm_df)
    plt.figure(figsize=(4,4))
    sns.heatmap(cm_df, annot=True, fmt='.1f', cmap='Blues')
    plt.title('Overall confusion matrix after intensity discretisation.')
    plt.ylabel('True category')
    plt.xlabel('Predicted category')
    plt.xticks(rotation=60)
    plt.show()
    ### Setup classification report
    clr = classification_report(labels_true_disc, labels_pred_disc, target_names=bin_labels)
    print('Classification report')
    print(clr)
    ### Overall metrics
    acc_ci = compute_ci(bootstrap_metric(labels_true_disc, labels_pred_disc, accuracy_score))
    bacc_ci = compute_ci(bootstrap_metric(labels_true_disc, labels_pred_disc, balanced_accuracy_score))
    kappa_ci = compute_ci(bootstrap_metric(labels_true_disc, labels_pred_disc, kappa_cs))
    f1_ci = compute_ci(bootstrap_metric(labels_true_disc, labels_pred_disc, f1_cs))
    acc = accuracy_score(labels_true_disc, labels_pred_disc)
    bacc = balanced_accuracy_score(labels_true_disc, labels_pred_disc)
    kappa = kappa_cs(labels_true_disc, labels_pred_disc)
    f1 = f1_score(labels_true_disc, labels_pred_disc, average='macro')
    res_dict = dict({
        'Accuracy': round(acc, 3),
        'Acc_upper': round(acc_ci[0], 3),
        'Acc_lower': round(acc_ci[1], 3),
        'Balanced Accuracy': round(bacc, 3),
        'Bacc_upper': round(bacc_ci[0], 3),
        'Bacc_lower': round(bacc_ci[1], 3),
        'Cohen\'s Kappa Score': round(kappa, 3),
        'CKS_upper': round(kappa_ci[0], 3),
        'CKS_lower': round(kappa_ci[1], 3),
        'F1-Score': round(f1, 3),
        'F1_upper': round(f1_ci[0], 3),
        'F1_lower': round(f1_ci[1], 3)
        })
    print(res_dict)
    return res_dict

In [None]:
disc_dict = evaluate_disc_performance_multi(xgb_cts, val_x_cts, val_y_cts)

In [None]:
def evaluate_disc_performance_across(model, val_data, true_labels, n_bins=5,
                                    bin_labels=['Very Low', 'Low', 'Medium', 'Medium-High', 'High']):
    if 'age_gr' in val_data.columns:
        val_data = val_data.drop('age_gr', axis=1)
    labels_pred = model.predict(xgb.DMatrix(val_data))
    labels_pred_disc = discretize(labels_pred, nb=n_bins)
    labels_true_disc = discretize(np.array(true_labels.values), nb=n_bins)
    #print(pd.Series(pred_bins).describe())
    #print(pd.Series(true_bins).describe())
    results = pd.DataFrame()
    for i in range(n_bins):
        print(f'Processing {bin_labels[i]} category..')
        pred_bin = (labels_pred_disc==i).astype(int)
        true_bin = (labels_true_disc==i).astype(int)
        #print(len(pred_bin), len(true_bin))
        #tn, fp, fn, tp = confusion_matrix(true_bin, pred_bin).ravel()
        #print(classification_report(true_bin, pred_bin, target_names=['0', '1']))
        #### Get PPV, NPV
        #ppv = tp / (tp + fp) if (tp+fp) > 0 else 0
        #npv = tn / (tn + fn) if (tn+fn) > 0 else 0
        #acc = (tp + tn) / (tp + fp + tn + fn)
        #f1 = f1_score(true_bin, pred_bin)
        ### Overall metrics
        acc_ci = compute_ci(bootstrap_metric(true_bin, pred_bin, accuracy_score))
        bacc_ci = compute_ci(bootstrap_metric(true_bin, pred_bin, balanced_accuracy_score))
        kappa_ci = compute_ci(bootstrap_metric(true_bin, pred_bin, kappa_cs))
        f1_ci = compute_ci(bootstrap_metric(true_bin, pred_bin, f1_cs))
        acc = accuracy_score(true_bin, pred_bin)
        bacc = balanced_accuracy_score(true_bin, pred_bin)
        kappa = kappa_cs(true_bin, pred_bin)
        f1 = f1_score(true_bin, pred_bin, average='macro')
        res_dict = pd.DataFrame({
            'Class': bin_labels[i],
            'Accuracy': round(acc, 3),
            'Acc_upper': round(acc_ci[0], 3),
            'Acc_lower': round(acc_ci[1], 3),
            'Balanced Accuracy': round(bacc, 3),
            'Bacc_upper': round(bacc_ci[0], 3),
            'Bacc_lower': round(bacc_ci[1], 3),
            'Cohen\'s Kappa Score': round(kappa, 3),
            'CKS_upper': round(kappa_ci[0], 3),
            'CKS_lower': round(kappa_ci[1], 3),
            'F1-Score': round(f1, 3),
            'F1_upper': round(f1_ci[0], 3),
            'F1_lower': round(f1_ci[1], 3)
            }, index=[i])
        results = pd.concat([results, res_dict], axis=0)
    return pd.DataFrame(results)

In [None]:
multiv_dict = evaluate_disc_performance_across(xgb_cts, val_x_cts, val_y_cts)

In [None]:
multiv_dict['timepoint'] = 'Hospital admission'
multiv_dict.head()

In [None]:
perf_dict = pd.DataFrame(res_dict | disc_dict, index=[0])

In [None]:
perf_dict

#### Evaluate stratified performance

In [None]:
def evaluate_age_groups(val_cts, model, val_lkup_cts,
                  task='Total health contacts', tgt='Total health contacts', tp='Hospital admission',
                       gr_labels=['50-59', '60-69', '70-79', '80-89', '90+']):
    print('Evaluating model by age group for target: ' + task)
    #if 'age_gr' in val_cts:
        #val_cts = val_cts.drop('age_gr', axis=1)
        
    val_y = val_cts['total_count_all_tf']
    val_x_a0 = val_cts[(val_cts['AgeAtAdmission'] >= 50) & (val_cts['AgeAtAdmission'] < 60)]
    val_y_a0 = val_x_a0['total_count_all_tf']
    val_x_a1 = val_cts[(val_cts['AgeAtAdmission'] >= 60) & (val_cts['AgeAtAdmission'] < 70)]
    val_y_a1 = val_x_a1['total_count_all_tf']
    val_x_a2 = val_cts[(val_cts['AgeAtAdmission'] >= 70) & (val_cts['AgeAtAdmission'] < 80)]
    val_y_a2 = val_x_a2['total_count_all_tf']
    val_x_a3 = val_cts[(val_cts['AgeAtAdmission'] >= 80) & (val_cts['AgeAtAdmission'] < 90)]
    val_y_a3 = val_x_a3['total_count_all_tf']
    val_x_a4 = val_cts[(val_cts['AgeAtAdmission'] >= 90)]
    val_y_a4 = val_x_a4['total_count_all_tf']
    
    val_x_a0 = val_x_a0.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a1 = val_x_a1.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a2 = val_x_a2.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a3 = val_x_a3.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a4 = val_x_a4.drop(val_lkup_cts.columns.tolist(), axis=1)
    
    all_feats = [val_x_a0, val_x_a1, val_x_a2, val_x_a3, val_x_a4]
    all_lbs = [val_y_a0, val_y_a1, val_y_a2, val_y_a3, val_y_a4]
    group_df = pd.DataFrame()
    ### Eval procedures
    for i in range(len(all_lbs)):
        print(f'Evaluating performance for group: {gr_labels[i]}')
        res_dict = evaluate_model(all_feats[i], np.array(all_lbs[i].values), model, evals_result=None, 
                  task=task, tgt=tgt, tp=tp)
        disc_dict = evaluate_disc_performance_multi(model, all_feats[i], all_lbs[i])
        perf_dict = pd.DataFrame(res_dict | disc_dict, index=[i])
        perf_dict['Group'] = gr_labels[i]
        group_df = pd.concat([group_df, perf_dict], axis=0)

    return group_df

In [None]:
age_dict = evaluate_age_groups(val_data, xgb_cts, val_lkup_cts)

In [None]:
age_dict

In [None]:
def evaluate_simd_groups(val_cts, model, val_lkup_cts,
                  task='Total health contacts', tgt='Total health contacts', tp='Hospital admission',
                       gr_labels=['1 - most deprived', '2-4', '5 - least deprived']):
    print('Evaluating model by SIMD group for target: ' + task)
    #if 'age_gr' in val_cts:
        #val_cts = val_cts.drop('age_gr', axis=1)
        
    val_y = val_cts['total_count_all_tf']
    val_x_a0 = val_cts[(val_cts['simd_dec'] >= 1) & (val_cts['simd_dec'] < 3)]
    val_y_a0 = val_x_a0['total_count_all_tf']
    val_x_a1 = val_cts[(val_cts['simd_dec'] >= 3) & (val_cts['simd_dec'] < 9)]
    val_y_a1 = val_x_a1['total_count_all_tf']
    val_x_a2 = val_cts[(val_cts['simd_dec'] >= 9)]
    val_y_a2 = val_x_a2['total_count_all_tf']
    
    val_x_a0 = val_x_a0.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a1 = val_x_a1.drop(val_lkup_cts.columns.tolist(), axis=1)
    val_x_a2 = val_x_a2.drop(val_lkup_cts.columns.tolist(), axis=1)
    
    all_feats = [val_x_a0, val_x_a1, val_x_a2]
    all_lbs = [val_y_a0, val_y_a1, val_y_a2]
    group_df = pd.DataFrame()
    ### Eval procedures
    for i in range(len(all_lbs)):
        print(f'Evaluating performance for group: {gr_labels[i]}')
        res_dict = evaluate_model(all_feats[i], np.array(all_lbs[i].values), model, evals_result=None, 
                  task=task, tgt=tgt, tp=tp)
        disc_dict = evaluate_disc_performance_multi(model, all_feats[i], all_lbs[i])
        perf_dict = pd.DataFrame(res_dict | disc_dict, index=[i])
        perf_dict['Group'] = gr_labels[i]
        group_df = pd.concat([group_df, perf_dict], axis=0)

    return group_df

In [None]:
simd_dict = evaluate_simd_groups(val_data, xgb_cts, val_lkup_cts)

In [None]:
simd_dict

In [None]:
group_dict = pd.concat([age_dict, simd_dict], axis=0)
group_dict

##### Export results

In [None]:
perf_dict_og = pd.read_csv('')
perf_dict_f = pd.concat([perf_dict_og, perf_dict], axis=0)
multiv_dict_og = pd.read_csv('')
multiv_dict_f = pd.concat([multiv_dict_og, multiv_dict], axis=0)
group_dict_og = pd.read_csv('')
group_dict_f = pd.concat([group_dict_og, group_dict], axis=0)

In [None]:
perf_dict_f.to_csv('', index=False)
multiv_dict_f.to_csv('', index=False)
group_dict_f.to_csv('', index=False)