<h1>Falls in seniors: predicting their health outcomes using narratives and baseline demographic data</h1>


# Introduction

**Goals & Rationales**

Timely admission to the emergency department (ED) is an important determinant of outcomes in elderly patients. Yet, unnecessary hospital visits increase the burden on health systems and risk bringing anxiety and stress to patients and their family members and caregivers. Accordingly, we examined:
- (Ia) how the time till hospital visit (nicknamed as “delay time”) as currently captured in the narratives may be predictive of patient outcomes;
- (Ib) Once “delay times” were extracted, we also developed and validate prognostic survival models that predict time to an adverse outcome using data collected at baseline (narrative, baseline, and delay times).

**Methods**

We employed data from both primary and supplementary as shown in [Table below](#table). Our evaluation set consists of cases from 2013-2020. Note that these subsets yield sizes comparable to open survival datasets reported in recent benchmarks, e.g. SUPPORT (n=9,105), FLCHAIN (n=7,894).

As problem statement instructs, we removed parts of each narrative that were already captured in the other columns, i.e. starting characters on sex, age, and end phrases that follow the marker ```DX``` (or equivalent markers such as ```***```, ```>>```). Then, we performed preprocessing steps on the narratives. Next, we computed word embeddings to describe the textural data.

**Derivations of survival times**

We inferred the number of hours from the time of fall incident to the time of hospital visit by searching for keywords, such as “1 DAY AGO”, “YESTERDAY”, “LAST NITE”. Cases whose survival times cannot be determined were excluded in the survival analyses. Patients whose dispositions do not match the outcome definitions were treated as censored (these include patients who left the hospital before being seen).

**Model development**

We performed Bayesian optimization (```Optuna``` package) to find the optimal parameters of eXtreme Gradient Boosters (XGB) survival models that predict $P$ probabilities of experiencing an adverse outcome at times 1 to $P$. For baseline comparison, we also fitted regularized Cox’s regression models. For each outcome, we developed two survival models (XGB/Cox) under different combinations of three input types:
1. patient’s baseline data
1. raw word embeddings
1. their dimensionality reduced versions.

**Key messages**

Based on our key findings as analyzed below, we recommend that “delay time” to hospital be tracked, as our analyses showed that this additional information can be used to predict adversity in patient outcomes, with predictive accuracy reaching 0.72 in C-index. We note that “delay time” is different from the hour of the fall (morning/evening), which is also not tracked in the current NEISS coding manual but has long been suggested in literature (e.g. [SPLATT mnemonic](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7905119/)).

# Objectives of this notebook

Audience will be able to...
- Hands on experience of natural language processing (NLP) and following encoders/transformers   
- Review the practical uses of the following packages/ tasks:
    
<a href id="table"></a>


| Tasks | Package used |
|:--|:--|    
| Bayesian optimization | ```optuna```  |
| Efficient dataset querying | ```polars``` |
| Dimensionality reduction & visualization |```UMAP``` |
| Survival models | ```xgboost```, ```sksurv```, ```lifelines```|
| Evaluation metrics for survival data analysis | ```sksurv.metrics.concordance_index_censored``` <br>```sksurv.metrics.brier_score```, ```scipy.stats.spearmanr``` |


## Shortening runtime


In order to shorten the runtime, we recommend only running the key results / trials, e.g.

**Option A**: XGB Optuna loop will only try the following input settings
```
for mid in ['xgb',]:
    for input_type in [25]: # rather than other combinations of inputs [19,21,22,23,24,26]
```
**Option B**: Run Cox regression only

In [None]:
import json
from pathlib import Path

import os
import random
import pandas as pd
import numpy as np

import multiprocessing as mp
from multiprocessing import Pool

import matplotlib.pyplot as plt
import plotly.express as px

import torch

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    # tf.random.set_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(1119)

# Problem formulation

## Load data

In [None]:
DATA_FOLDER = Path.cwd().parent / "data"

In [None]:
def get_data(folder):
    with Path(folder / "variable_mapping.json").open("r") as f:
        mapping = json.load(f, parse_int=True)

    for c in mapping.keys():
        mapping[c] = {int(k): v for k, v in mapping[c].items()}

    df = pd.read_csv(folder / "primary_data.csv", parse_dates=['treatment_date'],
                   dtype={"body_part_2": "Int64", "diagnosis_2": "Int64", 'other_diagnosis_2': 'string'} )

    df2 = pd.read_csv(folder / "supplementary_data.csv",  parse_dates=['treatment_date'],
                    dtype={"body_part_2": "Int64", "diagnosis_2": "Int64", 'other_diagnosis_2': 'string' } )

    org_columns = df2.columns
    df['source']=1
    df2['source']=2

    merged_df = pd.concat( (df, df2)) #, left_on='cpsc_case_number', right_on='cpsc_case_number', how='outer' )
    merged_df.drop_duplicates( inplace=True )
    merged_df.reset_index( inplace=True )

    primary_cid = merged_df.iloc[ np.where( merged_df.source==1)[0],: ].cpsc_case_number
    supp_cid    = merged_df.iloc[ np.where( merged_df.source==2)[0],: ].cpsc_case_number

    trn_case_nums = np.intersect1d( primary_cid, supp_cid )
    tst_case_nums = np.setdiff1d( supp_cid, primary_cid )

    print( 'Trn:Tst ratio:', len(trn_case_nums)/ len(tst_case_nums),  )

    def add_cols( df2 ):

        df2['month'] = df2.treatment_date.dt.month
        df2['year'] = df2.treatment_date.dt.year

        '''
        0. Not stated
        1. White: A person having origins in any of the Europe, Middle East, or North Africa.
        2. Black/African American: A person having origins in any of the black racial groups of Africa.
        3. ED record indicates more than one race (e.g., multiracial, biracial)
        4. Asian: A person having origins in any of the original peoples of the Far East, Southeast Asia, or the Indian subcontinent
        5. American Indian/Alaska Native: A person having origins in any of the original peoples of North and South America (including Central America), and who maintains tribal affiliation or community attachment.
        6. Native Hawaiian/Pacific Islander: A person having origins in any of the original peoples of Hawaii, Guam, Samoa, or other Pacific Islands.
        7. White Hispanic 1 Race=1
        8. Black Hispanic 1 Race=2
        '''

        k = 'race_white'
        df2[k] = 0 # non-white
        q=np.where( df2['race'] == 0 )[0]
        df2.loc[ q, k] = -1 # not stated
        q=np.where( df2['race'] == 1 )[0]
        df2.loc[ q, k] = 1

        k = 'race_4'
        df2[k] = 0 # non-white
        q=np.where( df2['race'] == 0 )[0]
        df2.loc[ q, k] = -1
        q=np.where( df2['race'] == 4 )[0]
        df2.loc[ q, k] = -2
        q=np.where( df2['race'] == 1 )[0]
        df2.loc[ q, k] = 1


        df2['race_recoded'] = 0
        df2['race_recoded'] = df2['race'].copy()
        q=np.where( (df2['hispanic'] == 1 ) & (df2['race'] == 1) )[0]
        df2.loc[q, 'race_recoded'] = 7
        q=np.where( (df2['hispanic'] == 1 ) & (df2['race'] == 2) )[0]
        df2.loc[q, 'race_recoded'] = 8

        df2['severity'] = df2['disposition'].copy()
        df2['severity'].replace( {9: np.nan, 6: 1, 5:2,  1:3,  2:4 ,  4:5,  8: 6 }, inplace=True)


        df2['age_cate']= 0
        df2['age_cate']= pd.cut(
        df2.age,
        bins=[0,65,75,85,95,150],
        labels=["1: 65 or under", "2: 65-74", "3: 74-85", "4: 85-94", "5: 95+"],
        )

        df2['age_cate'] = pd.Categorical(df2.age_cate).copy()
        df2['age_cate_binned'] = df2.age_cate.cat.codes

        # drop the 3 cases of unknown sex and then map codes to English words
        df2 = df2[df2.sex != 0]
        return df2

    # add variables
    merged_df = add_cols( merged_df )

    for col in mapping.keys():
        if col != 'disposition':
            merged_df[col] = merged_df[col].map(mapping[col])

    return merged_df, org_columns, trn_case_nums, tst_case_nums, mapping

def load_decoded(folder):
    decoded_df2=pd.read_csv(folder / 'decoded_df2_unique.csv')

    decoded_df2.sex = (decoded_df2.sex == 'MALE').astype(int)
    for k in [ 'alcohol','fire_involvement','drug', ]:
        decoded_df2[k] = ( decoded_df2[k] == 'Yes').astype(int)
    dic = {}
    for k in [ 'location','product_1','product_2','product_3','body_part','body_part_2' ]:
        dic[k] = {k: { i:l for l,i in enumerate( decoded_df2[k].unique() ) } }
        decoded_df2.replace( dic[k], inplace=True )

    def add_race_categories( df2 ):
        df2.reset_index(inplace=True)
        k = 'race_white'
        df2[k] = 0 # non-white
        q=np.where( df2['race'] == 'N.S.' )[0]
        df2.loc[ q, k] = -1 # not stated
        q=np.where( df2['race'] == 'WHITE' )[0]
        df2.loc[ q, k] = 1

        k = 'race_4'
        df2[k] = 0 # non-white
        q=np.where( df2['race'] == 'N.S.' )[0]; print(len(q))
        df2.loc[ q, k] = -1
        q=np.where( df2['race'] == 'ASIAN' )[0]; print(len(q))
        df2.loc[ q, k] = -2
        q=np.where( df2['race'] == 'WHITE' )[0]; print(len(q))
        df2.loc[ q, k] = 1
        return df2

    decoded_df2 = add_race_categories(decoded_df2 )
    print( 'race=white?', decoded_df2['race_white'].unique() ) #check

    return decoded_df2

if ('decoded_df2' in globals())==False:
    _, org_columns, trn_case_nums, tst_case_nums, mapping = get_data(DATA_FOLDER)
    decoded_df2 = load_decoded(DATA_FOLDER)

att =['location','product_1','product_2','product_3','fire_involvement','body_part','drug','alcohol', 'sex', 'age_cate_binned','race_recoded','year','month']


## Add day of week to the decoded dataframe ```decoded_df2```

In [None]:
decoded_df2['weekday'] = pd.to_datetime( decoded_df2.treatment_date ).dt.weekday
decoded_df2.head(1).transpose()

## Extract time-to-event (```time2event```) data from each narrative



In [None]:
from time import time
import pickle
import pandas as pd
import polars as pol
import numpy as np
import json

from sklearn.model_selection import GridSearchCV, KFold
from sklearn.exceptions import FitFailedWarning
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from pathlib import Path

import scipy
from sksurv.metrics import *

labels = ['3h','6h','9h','12h','15h','18h', '24h','2d','3d','1w','2w','1mo']

def split_ds( df2 ):
    df = df2.filter((pol.col('narrative').str.contains(' LAST ')|
                     pol.col('narrative').str.contains('MORNING')|
                     pol.col('narrative').str.contains('A.M.',literal=True, strict=True )|
                     pol.col('narrative').str.contains('P.M.',literal=True, strict=True )|
                     pol.col('narrative').str.contains('AFTERNOON')|
                     pol.col('narrative').str.contains(' DAY')|
                     pol.col('narrative').str.contains('TODAY') |
                     pol.col('narrative').str.contains('EARLIER TONIGHT') |
                     pol.col('narrative').str.contains('AROUND') |
                     pol.col('narrative').str.contains('YESTERDAY') |
                     pol.col('narrative').str.contains(' AGO') )&
                   ~(
                    pol.col('narrative').str.contains("LAST 2 STEPS", literal=True) |
                    pol.col('narrative').str.contains("LAST FEW STEP", literal=True) |
                    pol.col('narrative').str.contains("LAST STAIR", literal=True) |
                    pol.col('narrative').str.contains("LAST SEV STEP", literal=True) |
                    pol.col('narrative').str.contains("LAST STEP", literal=True) |
                    pol.col('narrative').str.contains("YR AGO", literal=True) |
                    pol.col('narrative').str.contains("YRS AGO", literal=True) |
                    pol.col('narrative').str.contains("YEAR AGO", literal=True) |
                    pol.col('narrative').str.contains("YEARS AGO", literal=True)
                   ))
    k  = 'narrative'
    kk = 'time2hosp_1'

    # Fixed on Nov 20 2023
    # df = df2.with_columns( pol.col('treatment_date').str.to_datetime().dt.weekday().alias('weekday') )

    df = df.with_columns(pol.when(
      pol.col(k).str.contains('SEVERAL HOUR') |
      pol.col(k).str.contains('SEV HOUR', literal=True, strict=True) |
      pol.col(k).str.contains('SEV. HR', literal=True, strict=True) |
      pol.col(k).str.contains('SEV. HOUR', literal=True, strict=True)
      ).then(5).otherwise(np.nan).alias( 'a' ))
    df = df.with_columns(pol.when(
      pol.col(k).str.contains('SEVERAL DAY') | pol.col('narrative_cleaned').str.contains('several day') |
      pol.col(k).str.contains('SEV DAY') |
      pol.col(k).str.contains('SEVRAL DAY')|
      pol.col(k).str.contains('SEV. D')
      ).then(5*24).otherwise(np.nan).alias('d' ))
    df = df.with_columns(pol.when(
      pol.col(k).str.contains('SEVERAL WK') |
      pol.col(k).str.contains('SEV. WEEKS') |
      pol.col(k).str.contains('SEVERAL WEEK')|
      pol.col(k).str.contains('SEV. WK') |
      pol.col(k).str.contains('SEV WK')
      ).then(5*7*24).otherwise(np.nan).alias('y' ))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('COUPLE HOUR') | pol.col(k).str.contains('COUPLE OF HOUR')
      ).then(3).otherwise(np.nan).alias( 'c' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('COUPLE DAY') | pol.col(k).str.contains('COUPLE OF DAY') | pol.col(k).str.contains('3 NITES AGO')
      ).then(3*24).otherwise(np.nan).alias( 'f'))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('FEW HOUR')
      ).then(4).otherwise(np.nan).alias( 'b' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('FEW DAY') | pol.col(k).str.contains('FEW NIGHT') |  pol.col(k).str.contains('FEW NITE')
      ).then(4*24).otherwise(np.nan).alias( 'e' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('FEW HOUR')
      ).then(4).otherwise(np.nan).alias( 'ee' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('FEW WEEK')
      ).then(4*7*24).otherwise(np.nan).alias( 'eee' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('OTHER DAY')
      ).then(9*24).otherwise(np.nan).alias( 'eeee' ))

    df = df.with_columns(pol.when(
      pol.col(k).str.contains('LAST MTH')  | pol.col(k).str.contains('LAST MONTH')
      ).then(30*24).otherwise(np.nan).alias( 'g' ))

    df = df.with_columns(pol.when(
      pol.col(k).str.contains('PREVIOUS DAY') | pol.col(k).str.contains('PREV DAY') |
      pol.col(k).str.contains('YESTERDAY') | pol.col(k).str.contains('LAST AM')    | pol.col(k).str.contains('LAST PM') |
      pol.col(k).str.contains('LAST EVEN') | pol.col(k).str.contains('DAY BEFORE') |
      pol.col(k).str.contains('LAST NOC')  | pol.col(k).str.contains('LAST NIGHT') | pol.col(k).str.contains('LAST NITE') | pol.col(k).str.contains('LAST NGHT')
      ).then(28).otherwise(np.nan).alias( 'g' ))

    df = df.with_columns(pol.when(
      pol.col(k).str.contains('P.M.', literal=True, strict=True ) |
      pol.col(k).str.contains('THIS AFTERNOON') | pol.col(k).str.contains('THIS EVEN') | pol.col(k).str.contains('TONIGHT')
      ).then(9).otherwise(np.nan).alias( 'h' ))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('A.M.', literal=True, strict=True )|
      pol.col(k).str.contains('MORNING', literal=True, strict=True )
      ).then(12).otherwise(np.nan).alias( 'i' ))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('THIS MORNING') |pol.col(k).str.contains('THIS AM') | pol.col(k).str.contains('TDY')|
        pol.col(k).str.contains('TODAY')
      ).then(12).otherwise(np.nan).alias( 'j' ))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST EVE') | pol.col(k).str.contains('LAST NGIHT') |  pol.col(k).str.contains('AROUND MIDNIGHT') |
        pol.col(k).str.contains('LAST WK') | pol.col(k).str.contains('LAST MIGHT') | pol.col(k).str.contains('LAST WEEK') | pol.col(k).str.contains('WK AGO')
      ).then(10.5*7).otherwise(np.nan).alias( 'k' ))

    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('1 day',literal=True, strict=True)
      ).then(24).otherwise(np.nan).alias( 'x'))

    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('1+ day',literal=True, strict=True)
      ).then(36).otherwise(np.nan).alias( 'xx'))

    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('an hour') | pol.col('narrative_cleaned').str.contains('a hour')|
        pol.col('narrative_cleaned').str.contains('1 hour')
      ).then(1).otherwise(np.nan).alias( 'am'))

    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST MOND') | pol.col(k).str.contains('MONDAY')
      ).then( pol.col('weekday')+7 ).otherwise(np.nan).alias( '1' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST TUE') | pol.col(k).str.contains('TUESDAY')
      ).then(pol.col('weekday')+6).otherwise(np.nan).alias( '2' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST WED') | pol.col(k).str.contains('WEDNESDAY')
      ).then(pol.col('weekday')+5).otherwise(np.nan).alias( '3' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST THU') | pol.col(k).str.contains('THURSDAY')
      ).then(pol.col('weekday')+4).otherwise(np.nan).alias( '4' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST FRI') | pol.col(k).str.contains('FRIDAY')
      ).then(pol.col('weekday')+3).otherwise(np.nan).alias( '5' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST SAT') | pol.col(k).str.contains('SATURDAY')
      ).then(pol.col('weekday')+2).otherwise(np.nan).alias( '6' ))
    df = df.with_columns(pol.when(
        pol.col(k).str.contains('LAST SUN') | pol.col(k).str.contains('SUNDAY')
      ).then(pol.col('weekday')+1).otherwise(np.nan).alias( '7' ))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('1 mon')  | pol.col('narrative_cleaned').str.contains('a mon')
      ).then(24*30).otherwise(np.nan).alias( 'ba'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('2 month')
      ).then(24*60).otherwise(np.nan).alias( 'bb'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('3 month')
      ).then(24*90).otherwise(np.nan).alias( 'bc'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('4 month')
      ).then(24*120).otherwise(np.nan).alias( 'bd'))

    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('1 week')  | pol.col('narrative_cleaned').str.contains('a week')
      ).then(24*7).otherwise(np.nan).alias( 'be'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('2 week')
      ).then(24*14).otherwise(np.nan).alias( 'bf'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('3 week')
      ).then(24*21).otherwise(np.nan).alias( 'bg'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('4 week')
      ).then(24*28).otherwise(np.nan).alias( 'bh'))
    df = df.with_columns(pol.when(
        pol.col('narrative_cleaned').str.contains('5 week')
      ).then(24*35).otherwise(np.nan).alias( 'bi'))

    df = df.with_columns(pol.when(
      pol.col('narrative').str.contains('SEVERAL MONTH') |
      pol.col('narrative').str.contains('SEV MON') |
      pol.col('narrative').str.contains('5 MOS AGO')
      ).then(24*7*4*5).otherwise(np.nan).alias( 'bj'))

    k='narrative'
    for n in range(30):
        df = df.with_columns(pol.when(
          pol.col(k).str.contains(f'{n}D AGO') |
          pol.col(k).str.contains(f'{n} DAY') |  pol.col('narrative_cleaned').str.contains(f'{n} day') |
          pol.col(k).str.contains(f'{n}DAY')
      ).then(n*24).otherwise(np.nan).alias( f'd{n}' ))

    k='narrative'
    for n in range(24):
        df = df.with_columns(pol.when(
        pol.col(k).str.contains(f'AROUND {n}:') |pol.col(k).str.contains(f'AROUND {n}:') |
        pol.col(k).str.contains(f'AROUND {n} AM') | pol.col(k).str.contains(f'AROUND {n} PM')
      ).then(24).otherwise(np.nan).alias( f'time{n}' ))

    k='narrative_cleaned'
    for n in range(6): # ================ Month
        df = df.with_columns(pol.when(
        pol.col(k).str.contains(f'{n} mth')
      ).then(n*30*24).otherwise(np.nan).alias( f'n{n}' ))

    for n in range(10): # ================ weeks
        df = df.with_columns(pol.when(
          pol.col(k).str.contains(f'{n} wk',literal=True, strict=True ) | pol.col(k).str.contains(f'{n}wk',literal=True, strict=True ) |
          pol.col(k).str.contains(f'{n} week',literal=True, strict=True ) | pol.col(k).str.contains(f'{n}week',literal=True, strict=True ) |
          pol.col('narrative').str.contains(f'{n}WEEK',literal=True, strict=True ) |
          pol.col('narrative').str.contains(f'{n}WKS AGO',literal=True, strict=True ) |
          pol.col(k).str.contains(f'{n}weeks ago' ) |
          pol.col('narrative').str.contains(f'{n} WEEK',literal=True, strict=True )
      ).then(n*7*24).otherwise(np.nan).alias( f'n{n}' ))

    for n in range(30): # ================ DAY
        df = df.with_columns(pol.when(
      pol.col(k).str.contains(f'{n}night') |
      pol.col(k).str.contains(f'{n} night') |pol.col(k).str.contains(f'{n} d ago') |
      pol.col(k).str.contains(f'{n} day') | pol.col(k).str.contains(f'{n} dy ago') |
      pol.col(k).str.contains(f'{n}day')
      ).then(n*24).otherwise(np.nan).alias( f'n{n}' ))

    for n in range(50): # ================ HOUR
        df = df.with_columns(pol.when(
          pol.col(k).str.contains('hour ago') |
          pol.col(k).str.contains(f'{n} hour') |
          pol.col(k).str.contains(f'{n}hour') |
          pol.col('narrative').str.contains(f'{n} HOURS AGO') |
          pol.col(k).str.contains(f'{n}hrs ago') |
          pol.col(k).str.contains(f'{n}h ago')
      ).then(n).otherwise(np.nan).alias( f'h{n}' ))

    for n in range(90): # ================ minutes
        df = df.with_columns(pol.when(
        pol.col(k).str.contains(f'{n} minute') | pol.col(k).str.contains(f'{n} min ago')
      ).then(n/60).otherwise(np.nan).alias( f'm{n}' ))

    for n in range(90): # ================ minutes
        df = df.with_columns(pol.when(
        pol.col('narrative').str.contains(f'{n} MIN')
      ).then(n/60).otherwise(np.nan).alias( f'M{n}' ))

    rr=-119-90-140
    print( 'Sample size:', df.shape, df[:,rr:].head(1) )
    time2hosp=np.nanmax( df[:,rr:].to_numpy(),1 )
    df = df.with_columns(pol.lit(time2hosp).alias('time2hosp'))
    p  = df.filter( pol.col( 'time2hosp') .is_nan() )

    print( '\n\n',p.shape , 'remaining')
    for r in p.sample(5).iter_rows():
        print( r[3], )
        print( r[31], '\n' )

    p2 = df.filter( pol.col( 'time2hosp') >0 )
    trn_df = p2.filter( pol.col('cpsc_case_number').is_in( trn_case_nums ) )
    tst_df = p2.filter( pol.col('cpsc_case_number').is_in( tst_case_nums ) )
    print( trn_df.shape[0],'dev samples', tst_df.shape[0], 'test samples' )

    surv_pols = {}
    surv_pols['trn'] = trn_df[0::2,:]
    surv_pols['val'] = trn_df[1::2,:]
    surv_pols['tst'] = tst_df

    return surv_pols

if ( 'surv_pols' in globals()) ==False:
    surv_pols = split_ds( pol.DataFrame( decoded_df2 ) )


In [None]:
cnums={}
for t in ['trn','val','tst']:
    cnums[t]=np.array(surv_pols[t].select('cpsc_case_number') )

In [None]:
# example
np.array([("TEXT", 1, 1), ("XXX", 2, 2)], dtype='|S4, i4, i4')

In [None]:
decoded_df2.drop('Unnamed: 0', axis=1, inplace=True )
decoded_df2.drop('Unnamed: 0.1', axis=1, inplace=True )
decoded_df2.drop('level_0', axis=1, inplace=True )
decoded_df2['cpsc_num'] = decoded_df2.cpsc_case_number.copy()
decoded_df2.set_index('cpsc_num',inplace=True)
decoded_df2.head()

##  Calculate severity scores using the AIS model developed by Chung et al.

In [None]:
%%time

import matplotlib.pyplot as plt
%matplotlib inline

import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
import numpy as np
import scipy.io as sio
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import confusion_matrix

### Tensorflow 2.0 version
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout
from tensorflow.keras import losses, optimizers, metrics
from tensorflow.keras import regularizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback, EarlyStopping, ModelCheckpoint

### Define models
def define_model(input_size=46, learning_rate=1e-2):
    input_data = Input(shape=(input_size, ))

    ### Model (Region 46)
    x = Dense(64, kernel_regularizer=regularizers.l2(0.0001), kernel_initializer=tf.random_normal_initializer(stddev=0.01))(input_data)
    x = keras.layers.LeakyReLU(alpha=0.1)(x)
    x = Dropout(rate=0.5)(x)
    x = Dense(32, kernel_regularizer=regularizers.l2(0.0001))(x)
    x = keras.layers.LeakyReLU(alpha=0.1)(x)
    x = Dropout(rate=0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=input_data, outputs=x)
    model.compile(optimizer=optimizers.Adam(lr=learning_rate), loss=losses.binary_crossentropy)

    print('Model==============')
    model.summary()

    !wget -O Model_Region46.h5 https://raw.githubusercontent.com/HeewonChung92/AIS/main/Model_Region46.h5
    model.load_weights( 'Model_Region46.h5')

    return model


def get_ais():

    def process_dx( df, N=0 ):
        bb = pol.DataFrame( df )
        bb = bb.with_columns(pol.col("narrative").str.replace_all( r"(?i)>>|SUSTAINED| D X*|DX:|DX.|DZ:|DZ.|\b\/|\b-\b", ' DX. ').alias('processed'))
        bb = bb.with_columns(pol.col("processed").str.replace_all( '(>)(>)', ' DX. '))
        bb = bb.with_columns(pol.col("processed").str.replace_all( ',,', ' DX. '))
        bb = bb.with_columns(pol.col("processed").str.replace_all( '/', ' DX. '))

        '''
        bb.with_columns(pol.col("narrative").str.replace_all( r"\bDX*", 'DX.'))
        bb.with_columns(pol.col("narrative").str.replace_all( "\\***", 'DX.'))

        bb.with_columns(pol.col("narrative").str.replace_all( r"\bD X*", 'DX.'))
        bb.with_columns(pol.col("narrative").str.replace_all( "\/", 'DX.'))
        bb.with_columns(pol.col("narrative").str.replace_all( ',,', 'DX.'))
        ''';
        if N>0:
            a = bb.filter( ~pol.col('processed').str.contains('DX.') )
            print(a.shape[0]/bb.shape[0], 'of samples do not have Dx info;', 1- a.shape[0]/bb.shape[0]  )
            for i, r in enumerate( np.array(a.sample(N) ) ):
                print(i,r[5], '\n\t', r[-1].lower() )
        return bb

    decoded_pol = process_dx(decoded_df2.copy(), 20 )



    k='processed'
    def get_dx( df ):
        diag_pols = (
            df.with_row_count('id').with_columns(pol.col(k).str.split( "DX.").alias("split_str"))
            .explode("split_str")
            .with_columns(
                ("string_" + pol.arange(0, pol.count()).cast(pol.Utf8).str.zfill(2))
                .over("id")
                .alias("col_nm")
            )
            .pivot(
                index=['id', k],
                values='split_str',
                columns='col_nm',
            )
            .with_columns(
                pol.col('^string_.*$').fill_null("")
            )
        )
        return diag_pols

    dx_pol = get_dx( decoded_pol )
    #dx_pol.head(1)

    from tqdm import tqdm
    try:
        ais_df=pd.read_excel('AIS_codes.xlsx',sheet_name='SimpleCodingBook')
    except:
        !wget -O AIS_codes.xlsx https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-021-03024-1/MediaObjects/41598_2021_3024_MOESM2_ESM.xlsx
        ais_df=pd.read_excel('AIS_codes.xlsx',sheet_name='SimpleCodingBook')

    for i, kk in tqdm(enumerate(ais_df['Region-46']) ):
        k=kk.upper()
        dx_pol = dx_pol.with_columns(pol.when(
        pol.col('string_01').str.contains(k) |
        pol.col('string_02').str.contains(k) |
        pol.col('string_03').str.contains(k) |
        pol.col('string_04').str.contains(k) |
        pol.col('string_05').str.contains(k) |
        pol.col('string_06').str.contains(k) |
        pol.col('string_07').str.contains(k) |
        pol.col('string_08').str.contains(k) |
        pol.col('string_09').str.contains(k) |
        pol.col('string_10').str.contains(k) |
        pol.col('string_11').str.contains(k) |
        pol.col('string_12').str.contains(k)
          ).then(1).otherwise(0).alias( f'{k}' ))

    ais_features = np.array(dx_pol[:,-46:] )

    # add more
    for k,c in zip(['HEAD','NECK'], [1,10]):
        i=np.array(
            decoded_pol.filter( pol.col('processed').str.contains(f'INJURED {k}') | pol.col('processed').str.contains(f'{k} INJU') ).select('index')
                  )
        print(k, len(i))
        ais_features[i, c] = 1 # head

    #plt.imshow(  np.array(ais_features[::5000,:]) )
    ais_scores = ais_model.predict( ais_features )

    return ais_scores, ais_scores.squeeze()

ais_model = define_model()

ais_scores, ais_scores = get_ais()

ais_df = pd.DataFrame( dict( cpsc_nums= decoded_df2.index, ais_scores=ais_scores ) )

In [None]:
px.histogram( ais_df.ais_scores)

In [None]:
px.histogram( ais_df[ais_df.ais_scores > 0.04].ais_scores )

## Define t2event given "timed" cohort

In [None]:
def get_surv( dff, DEFN = 2 ):
    from sksurv.util import Surv
    ev,time,surv_inter,surv_str={},{},{},{}

    if DEFN==1:
        thres=3 # treated
    elif DEFN==2:
        thres=5 # hosp/ died


    for t in ['trn','val','tst']:
        df=dff[t].to_pandas()
        surv_inter[t]=pd.DataFrame( {'label_lower_bound': df['time2hosp'] ,
                                     'label_upper_bound': df['time2hosp'] } )

        if DEFN==3:
            q=np.where(  df['severity'] < thres )[0] # unseen + observed
            ev[t] = ( df['severity'] >= thres ).values
        else:
            q=np.where(  df['severity'] < thres )[0] # unseen + observed
            ev[t] = ( df['severity'] >= thres ).values

        surv_inter[t].iloc[q, 1] = np.inf
        print(t, np.sum(ev[t])/ len(ev[t]), 'n=',len(ev[t]) )
        time[t] = ( df['time2hosp']  ).values

        surv_str[t]=Surv.from_arrays(ev[t], time[t])
    return ev, time, surv_inter, surv_str

event_indicator, time2event, surv_inter, surv_str = get_surv( surv_pols, DEFN=2 )

In [None]:
mapping['body_part']

In [None]:
# define severity
import plotly.express as px
t='trn';
px.histogram( decoded_df2.loc[cnums[t].squeeze() ], x='body_part' )

In [None]:
# define severity !?

import plotly.express as px
t='trn';
px.histogram( decoded_df2.loc[cnums[t].squeeze() ], x='diagnosis' )

In [None]:
# define severity !?

import plotly.express as px
t='trn';
px.histogram( decoded_df2.loc[cnums[t].squeeze() ], x='disposition' )

In [None]:
'''
1: unseen
2: observed
3: treated
4: reated
5: admitted/hospitalized <-
6: died
''';

## Load pre-computed embeddings


In [None]:
%%time

src='narrative_cleaned'
EMB=[1,2,3,4]

def get_embeddings( EMB, src ):

    Embeddings={}
    if 'cleaned' in src:
        # cpsc_nums = indices not unique and hence 533518
        files = ['all_embeddings_1.pkl',
                 'all_embeddings_2.pkl',
                 'all_embeddings_3.pkl',
                 'all_embeddings_4.pkl'
                ]
        with open(DATA_FOLDER / "all_cpsc.pkl", 'rb') as handle:
            x=pickle.load(handle)
        cpcs_nums= x['cpsc']

    else:
        files =['', 'narrative_n426691_emb1_d768_2023-10-02.pkl', 'narrative_n426691_emb2_d384_2023-10-02.pkl',
            '../input/n-raw-extract-3/narrative_n426691_emb3_d768_2023-10-03.pkl', 'narrative_n426691_emb4_d512_2023-10-02.pkl' ]

    for emb in EMB:
        def read( file ):
            print( file )
            with open(DATA_FOLDER / file, 'rb') as handle:
                x=pickle.load(handle)
            if emb==4:
                x=x['embeddings'].numpy()
            else:
                x=x['embeddings']
            return x

        x = read(files[emb - 1])
        X = pd.DataFrame( np.hstack((np.expand_dims(cpcs_nums,1), x )))
        X = X.drop_duplicates(0)
        X.set_index(0,inplace=True)

        for t in ['trn','val','tst']:
            c = surv_pols[t].to_pandas()['cpsc_case_number']
            Xs = X.loc[c]
            Embeddings[emb,t] = Xs
            print(emb, t, Xs.shape)
            
    return Embeddings,cpcs_nums

Embeddings, cpcs_nums = get_embeddings( EMB, src )

In [None]:
for t in ['trn','val','tst']:
    s= np.array( surv_pols[t].select( 'year' )) - 2013
    #surv_pols[t].with_columns( (pol.col('year')-pol.lit(2013) ).alias('year') )
    surv_pols[t].replace( 'year', pol.Series(s[:,0])  )
surv_pols[t].head()

## Load the pre-computed OpenAI embeddings provided by DrivenData

In [None]:
openai_embeddings = pd.read_parquet(DATA_FOLDER / "openai_embeddings_primary_narratives.parquet.gzip")
openai_embeddings.set_index('cpsc_case_number',inplace=True)

emb=11
for t in ['trn', 'val', ]:
    s=openai_embeddings.loc[cnums[t].squeeze()]['embedding']
    Embeddings[emb,t]=np.reshape( np.concatenate(s.values).ravel(), (len(s),len(s.values[1]) ) )
    assert (s.values[0] - np.reshape( np.concatenate(s.values).ravel(), (len(s),len(s.values[1]) ) )[0,:] ).sum() ==0
    print('loaded OpenAI embeddings')

## Read in LEALLA embeddings

In [None]:
emb = 6
def read_6():
    # Read in LEALLA embeddings
    try:
        cpcs_nums= decoded_df2.cpsc_case_number
    except:
        cpcs_nums= decoded_df2.index
    emb = 6
    for d in range(20):
        files = [
        'narrative_cleaned_n426691_emb6_d0_2023-10-04.pkl',
        'narrative_cleaned_n426691_emb6_d1_2023-10-04.pkl',
        'narrative_cleaned_n426691_emb6_d2_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d3_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d4_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d5_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d6_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d7_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d8_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d9_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d10_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d11_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d12_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d13_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d14_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d15_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d16_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d17_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d18_2023-10-05.pkl',
        'narrative_cleaned_n426691_emb6_d19_2023-10-05.pkl' ]

        with open(DATA_FOLDER / "LEALLA" / files[d], 'rb') as handle:
            x=pickle.load(handle)

        print('read embeddings', files[d] )

        x=x['embeddings']

        nfeats=x.shape[1]
        try:
            all_embeddings;
        except:
            all_embeddings = np.zeros( (decoded_df2.shape[0],nfeats) )

        all_embeddings[d::20,:] = x

    print( all_embeddings.shape )
    X = pd.DataFrame( np.hstack((np.expand_dims(cpcs_nums,1), all_embeddings )))
    X.set_index(0,inplace=True)
    for t in ['trn','val','tst']:
        c = surv_pols[t].to_pandas()['cpsc_case_number']
        Xs = X.loc[c]
        Embeddings[emb,t] = Xs
        print(emb, t, Xs.shape)

read_6()

# Fit dimensionality-reduction models using a subset of the computed word embeddings (via UMAP)



In [None]:
EMB = [1,2,3,4,6,11]

## Visualize the reduced word embeddings, stratified by select categorical variables


In [None]:
%%time
import umap

import seaborn as sns
import matplotlib.pyplot as plt

t ='trn'

reducers = {}
for rdims in [4]: # use more than 2 dims per https://programminghistorian.org/en/lessons/clustering-visualizing-word-embeddings
    for e in EMB:
        reducers[e,rdims] = umap.UMAP(
        n_neighbors=25,
        min_dist=0.01,
        n_jobs=1,
        n_components=rdims,
        random_state=1119).fit( Embeddings[e,t] )
        print('Reducing word embedding', e, 'via UMAP')

In [None]:
reducers.keys()

In [None]:
%%time

# same dimensions
def show_umaps( mid, rdims ):
    t='trn'
    X_embedded = reducers[mid,rdims].transform( Embeddings[mid,t]  )
    embedded_dict = {}
    for i in range(0,X_embedded.shape[1]):
        embedded_dict[f"Dim {i+1}"] = X_embedded[:,i]

    projected = pd.DataFrame( embedded_dict )
    projected['cids']= cnums[t]

    for k in ['sex', 'age_cate_binned','body_part','location', 'diagnosis','race_white','race_4', 'severity']:
        projected[k]= surv_pols[t].to_pandas()[k]

    fname = 'DejaVu Sans'
    tfont = {'fontname':fname, 'fontsize':12} # title font attributes
    afont = {'fontname':fname, 'fontsize':10} # axis font attributes
    lfont = {'fontname':fname, 'fontsize':8}  # legend font attributes

    def tune_figure(ax, title:str='Title'):
        ax.axis('off')

        title =f'WB{mid} colored by {title}'
        ax.set_title(title, **tfont)
        ax.get_legend().set_title("")
        ax.get_legend().prop.set_family(lfont['fontname'])
        ax.get_legend().prop.set_size(lfont['fontsize'])
        ax.get_legend().get_frame().set_linewidth(0.0)

    f, axs = plt.subplots(3,2,figsize=(14,7))
    try:
        axs = axs.flatten()
    except:
        pass

    d,k=3,'location'
    sns.scatterplot(data=projected, x='Dim 1', y='Dim 2', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)

    d,k=0,'severity'
    sns.scatterplot(data=projected, x='Dim 2', y='Dim 3', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)

    d,k=1,'sex'
    sns.scatterplot(data=projected, x='Dim 3', y='Dim 1', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)
    d,k=2,'body_part'
    sns.scatterplot(data=projected, x='Dim 1', y='Dim 2', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)

    d,k=4,'race_4'
    sns.scatterplot(data=projected, x='Dim 2', y='Dim 3', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)

    d,k=5,'race_white'
    sns.scatterplot(data=projected, x='Dim 3', y='Dim 1', hue=k, s=50, alpha=0.1, ax=axs[d]);
    tune_figure(axs[d], k)

# viewing different combination (including 3rd and 4th dims)
[show_umaps( s, r) for r in [4] for s in EMB]

## Transform the rest of the (test) word embeddings using the fitted UMAP

In [None]:
word_reduced={}
for emb in EMB:
    T = ['trn', 'val','tst']
    if emb == 11:
        T = ['trn','val']
    if emb>0:
        for r in [4]:
            for t in T:
                word_reduced[emb,t] = reducers[emb,r].transform( Embeddings[emb,t])

# More preliminary analysis

## Examine lengths of narratives (400-word narrative only introduced in a later time period)

In [None]:
for t in ['trn','tst']:
    surv_pols[t]= surv_pols[t].with_columns([
        pol.col("narrative").str.len_bytes().alias("narrative_len")
    ])
    fig=px.histogram( surv_pols[t].to_pandas(),'narrative_len')
    fig.show()

# Develop and evaluate survival analysis models based on eXtreme Gradient Boost with hyperparameter optimization using the ```optuna``` package


- The loop below will walk through a series of trials to collect statistics that quantify illustrate the effects of input data types to each survival model candidate
- Within each trial, we use the ```optuna``` object, via ```run_xgb_optuna```  function wherein the optuna object will conduct Bayesian search with steps equal to the value of ```n_trials```

- **Warning**: this section will take some time. For a quick dry-run, change these values:
  ```n_trials= (100 if INTERACTIVE else 500)```
  




In [None]:
pos_ratio = 1-np.isinf( surv_inter['trn']['label_upper_bound'] ).sum() / surv_inter['trn'].shape[0]
print( 'pos-neg-ratio:', pos_ratio,  )

## Adjust ```n_trials``` hyperparamter for fine-grained (or coarse) hyperparameter search

In [None]:
n_trials = 100

## Define function that calls the ```optuna``` object performing hyperparameter optimization of model parameters involved in each trial

In [None]:
import xgboost as xgb
from sklearn.metrics import confusion_matrix
from sklearn.metrics import log_loss
from sklearn.metrics import roc_auc_score, log_loss

from datetime import date

def run_xgb_optuna( T, emb, X, surv_inter ):
    ds = {}
    base_params = {'verbosity': 0,
                  'objective': 'survival:aft',
                  'eval_metric': 'aft-nloglik',
                  'tree_method': 'hist'}  # Hyperparameters common to all trials
    samp_choices = ['uniform']

    for t in T:
        ds[t] = xgb.DMatrix( X[t])
        # see details https://xgboost.readthedocs.io/en/stable/tutorials/aft_survival_analysis.html
        ds[t].set_float_info('label_lower_bound', surv_inter[t]['label_lower_bound'] )
        ds[t].set_float_info('label_upper_bound', surv_inter[t]['label_upper_bound'] )

    t='trn1'
    print(X['trn'][0::2,:].shape )
    ds[t] = xgb.DMatrix( X['trn'][0::2,:] )
    ds[t].set_float_info('label_lower_bound', surv_inter['trn']['label_lower_bound'][0::2])
    ds[t].set_float_info('label_upper_bound', surv_inter['trn']['label_upper_bound'][1::2])
    t='trn2'
    ds[t] = xgb.DMatrix( X['trn'][1::2,:] )
    ds[t].set_float_info('label_lower_bound', surv_inter['trn']['label_lower_bound'][0::2])
    ds[t].set_float_info('label_upper_bound', surv_inter['trn']['label_upper_bound'][1::2])

    if gpus:
        base_params.update( {'tree_method': 'gpu_hist', 'device':'cuda', } )
        samp_choices = ['gradient_based','uniform']

    def tuner(trial):
        params = {'learning_rate': trial.suggest_float('learning_rate', 0.001, 1.0),
                  'aft_loss_distribution': trial.suggest_categorical('aft_loss_distribution',
                                                                      ['normal', 'logistic', 'extreme']),
                  'aft_loss_distribution_scale': trial.suggest_float('aft_loss_distribution_scale', 0.1, 10.0),
                  'max_depth': trial.suggest_int('max_depth', 3, 10),
                  'booster': trial.suggest_categorical('booster',['gbtree','dart',]),
                  'scale_pos_weight': trial.suggest_float('scale_pos_weight', pos_ratio*0.1, 10*pos_ratio ),  # L1 reg on weights
                  'alpha': trial.suggest_float('alpha', 1e-8, 10 ),  # L1 reg on weights
                  'lambda': trial.suggest_float('lambda', 1e-8, 10 ),  # L2 reg on weights
                  'eta': trial.suggest_float('eta', 0, 1.0),  # step size
                  'sampling_method': trial.suggest_categorical('sampling_method', samp_choices ),
                  'subsample': trial.suggest_float('subsample', 0.01, 1 ),
                  'gamma': trial.suggest_float('gamma', 1e-8, 10)  # larger, more conservative; min loss reduction required to make leaf
        }

        params.update(base_params)
        pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')

        bst = xgb.train(params, ds['trn1'], num_boost_round=10000,
                        evals=[(ds['trn1'], 'train'), (ds['trn2'], 'valid')],  # <---- data matrices
                        early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
        if bst.best_iteration >= 25:
            return bst.best_score
        else:
            return np.inf  # Reject models with < 25 trees

    # Run hyperparameter search
    study = optuna.create_study(direction='minimize')
    study.optimize( tuner, n_trials= n_trials )
    print('Completed hyperparameter tuning with best aft-nloglik = {}.'.format(study.best_trial.value))
    params = {}
    params.update(base_params)
    params.update(study.best_trial.params)

    print('Re-running the best trial... params = {}'.format(params))
    bst = xgb.train(params, ds['trn1'], num_boost_round=10000, verbose_eval=False,
                    evals=[(ds['trn1'], 'train'), (ds['trn2'], 'valid')],
                    early_stopping_rounds=50)

    # Explore hyperparameter search space
    fig = optuna.visualization.plot_param_importances(study)
    fig.show()

    for t in T:
        try:
            res[t]= pd.DataFrame({'Label (lower bound)': surv_inter[t]['label_lower_bound'],
                       'Label (upper bound)': surv_inter[t]['label_upper_bound'],
                       'Predicted label': bst.predict(ds[t]) } )

            sp=scipy.stats.spearmanr( res[t].iloc[:,-2], res[t].iloc[:,-1] )
            c=concordance_index_censored( event_time = time2event[t], event_indicator = event_indicator[t] , estimate=1/res[t].iloc[:,-1] )

            print(t.upper(), f'| R2:{sp[0]:.3f}; p:{sp[1]:.4} | C:{c[0]*100:.2f} ', end='|' )
            for d,h in enumerate([3,6,9,12,15,18,24,49,73, 7*24+1, 7*24*2+1, 7*24*4+1 ]):
                bs = brier_score( surv_str['trn'], surv_str[t], estimate=1/res[t].iloc[:,-1], times=[h] )
                print( end=f'{labels[d]}:{bs[1][0]:.3f} |' )
            print()
        except:
            pass

    today = date.today()
    bst.save_model(f'aft_model_{emb}_{today}.json')
    return res

## Loop over a series of trials to compare the effects of input data types

The for-loop below explores input types, depending on ```input_type```:
- ```input_type=1 # 2, 3, 4```: original word embedding
- ```input_type=21 # 22, 23, 24, 26 ```: dimensionality reduced versions of word embedding
- ```input_type=25```: all dimensionality reduced versions of word embedding, *plus* baseline variables


In [None]:
gpus = False

import optuna

In [None]:
# ====================================
# Begin XGB/ Optuna search
# ====================================
pos_ratio = 1-np.isinf( surv_inter['trn']['label_upper_bound'] ).sum() / surv_inter['trn'].shape[0]
print( 'pos-neg-ratio:', pos_ratio,  )

for input_type in [25]:  # part 1 : 6,19,20,
    X,res ={},{}
    if input_type==11:
        T=['trn','val']
    else:
        T=['trn','val','tst']
    for t in T:
        if input_type == 19:
            X[t] = np.hstack( (Embeddings[1,t],Embeddings[2,t],Embeddings[3,t], Embeddings[4,t], Embeddings[6,t] ) )
        elif input_type == 20:
            X[t] = np.hstack( (X[t], surv_pols[t][att].to_pandas()) )
        elif input_type == 25:
            X[t] = np.hstack( (word_reduced[1,t],word_reduced[2,t],word_reduced[3,t],word_reduced[4,t],word_reduced[6,t], surv_pols[t][att].to_pandas()  ) )
        elif input_type >= 21: # dim reduced versions for comparison with Cox regression to be done next code block
            rr = input_type - 20
            X[t] = np.hstack( (word_reduced[rr,t], surv_pols[t][att].to_pandas() ) )
        else: # 1-4,6,11
            try:
                X[t] = Embeddings[input_type,t].to_numpy()
            except:
                print(emb,'skip numpy')
                X[t] = Embeddings[input_type,t]
    res = run_xgb_optuna( T, input_type, X, surv_inter)

# Cox regression analysis for baseline comparison and preliminary interpretations


In [None]:
%matplotlib inline

from lifelines import AalenAdditiveFitter, CoxPHFitter

att =['location','product_1','product_2','product_3','fire_involvement','body_part','drug','alcohol', 'sex', 'age_cate_binned','race_recoded','year','month']

att1=att.copy() # without race
att1+=['time2hosp', 'severity']
att1.remove('fire_involvement')
att1.remove('race_recoded')

att2=att.copy()  # with race
att2+=['time2hosp', 'severity']
att2.remove('fire_involvement')
print( 'Attribute subsets\natt1:',att1, '\natt2:', att2 )

import sksurv
import matplotlib.pyplot as plt
from sksurv.util import Surv
from lifelines import CoxPHFitter
from IPython.core.display import HTML
from IPython.display import display

from lifelines.utils import k_fold_cross_validation

def run_cox():
    cphs={}
    for emb in [-1,0,1,2,3,4,6,11]:
        T = ['trn', 'val','tst']
        if emb == 11:
            T = ['trn','val']

        TIMES = [3,6,9,12,15,18,24,49,73, 7*24+1, 7*24*2+1, 7*24*4+1 ]
        trn_event_times = np.array(surv_pols['trn']['time2hosp'].unique())

        M = len(trn_event_times)

        a=np.broadcast_to( trn_event_times[:, np.newaxis], (M, len(TIMES)) )
        b=np.broadcast_to( TIMES, (M, len(TIMES) ) )
        Q=np.unique( np.argmin( (a - b)**2 ,axis=1))

        print('-'*100, emb, '\n\n')
        for t in T:
            if emb==-1:
                A = att2
            else:
                A = att1
            XX =surv_pols[t].to_pandas()[A] *1.

            if emb>0:
                XX = pd.DataFrame( np.hstack( (XX, word_reduced[emb,t]) ), columns=A +['w1','w2','w3', 'w4'] )

            if t == 'trn':
                penalty = np.ones( XX.shape[1]-2 )*.1

                aaf_1 = AalenAdditiveFitter(coef_penalizer=0.5)
                aaf_2 = AalenAdditiveFitter(coef_penalizer=10)
                cph = CoxPHFitter(penalizer=penalty, l1_ratio=.1)

                if 0:
                    print(np.mean(k_fold_cross_validation(cph, XX, duration_col='time2hosp', event_col='severity',)))# scoring_method="concordance_index")))
                    print(np.mean(k_fold_cross_validation(aaf_1, XX, duration_col='time2hosp', event_col='severity',)))# scoring_method="concordance_index")))
                    print(np.mean(k_fold_cross_validation(aaf_2, XX, duration_col='time2hosp', event_col='severity',)))# scoring_method="concordance_index")))

                cphs[emb] = CoxPHFitter(penalizer=penalty, l1_ratio=.1)
                cphs[emb].fit(XX, 'time2hosp', 'severity')

            preds = cphs[emb].predict_cumulative_hazard( XX ).iloc[Q,:]
            c=concordance_index_censored(event_time = time2event[t],
                                         event_indicator = event_indicator[t],
                                         estimate=preds.iloc[-1,:] )

            sp=scipy.stats.spearmanr( np.array( XX['time2hosp']), 1/ preds.iloc[-1,:].values  )

            print(f'InputSet{emb}\n', t.upper(), f'| R2:{sp[0]:.3f}; p:{sp[1]:.4} | C:{c[0]*100:.2f} ', )

            for d,(ii,h) in enumerate(zip(Q,TIMES)):
                try:
                    bs = brier_score( surv_str['trn'], surv_str[t], estimate= preds.iloc[ii,:], times=[h] )
                    print( end=f'{labels[d]}:{bs[1][0]:.3f} | ' )
                except:
                    pass
            c=cphs[emb].score( XX,'concordance_index')
            print( f'\t{c:.3f}' )
        print('\n\n')
        cphs[emb].print_summary()
        cphs[emb].plot()
        plt.figure()
        return cphs

cphs = run_cox()

## Interpreting the result outputs


Example output:
```
InputSet-1
 TRN | R2:0.184; p:1.201e-39 | C:58.23
3h:0.997 | 6h:0.996 | 9h:0.990 | 12h:0.858 | 15h:0.855 | 18h:0.853 | 24h:0.844 | 2d:0.634 | 3d:0.580 | 1w:0.387 | 	0.576
InputSet-1
 VAL | R2:0.156; p:5.188e-29 | C:57.99
3h:0.998 | 6h:0.995 | 9h:0.989 | 12h:0.852 | 15h:0.849 | 18h:0.847 | 24h:0.833 | 2d:0.616 | 3d:0.563 | 1w:0.364 | 	0.564
InputSet-1
 TST | R2:0.148; p:1.226e-104 | C:55.31
3h:1.000 | 6h:0.998 | 9h:0.995 | 12h:0.931 | 15h:0.928 | 18h:0.926 | 24h:0.917 | 2d:0.758 | 3d:0.711 | 1w:0.504 | 	0.559
```

- Overall concordant indices for the development set (```TRN``` and ```VAL```) and evaluation set (```TST```) are 0.582, 0.580, 0.553, respectively.

- Concordant indices for predicting risk of experience the outcome within 12 hours' time: 0.858, 0. 0.852, 0.931  
