In [10]:
# Eval on same base model, direct strategy, and different input and output types
# 30 days vs 1 month, num_ctx 8k vs 16k

In [11]:
import pandas as pd
from pathlib import Path

In [12]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'medical-cbr':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud/unsync/medical-cbr


In [13]:
# Relevant paths
mimicpath = pjpath / 'datasets/mimiciv'

In [14]:
# Globals

SEED = 42

# Controls which data to load
samp_size = 20000
balanced_data = True

#base_model = 'iv_ll3_direct'

base_models = [
    'iv_ll3_direct',
    'medgenius32b',
    'dsmedical8b',
    'biomistral7b'
]

# Controls which llm to fire
input_type = ['S', 'R', 'RC']  # R, RC
output_type = ['M', 'PM']
mortality_span = ['30days', 'amonth']   # 30days, amonth; LLM might behave better by not relying on numeric semantics
num_ctx = [2, 4, 8, 16, 32] # (X 1024) LLM should behave better allowing bigger context. We cannot test anything above 32k due to GPU constraints
temperatures = [0.0, 0.1, 0.3, 0.7, 1.0] #[0.0, 0.1, 0.3, 0.7, 1.0]   # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)


# TODO: Experiments with different values for temeprature (0, *0.1, 0.3, 0.7, 1.0) ~ creativity

In [15]:
# Load precomputed dataframe. Keeps only hadm_id and delta_days_dod (to find patients that died after n days discharge)
# Transform to boolean (patient died within 30 days after discharge)
#df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}.csv.gz')[['hadm_id', 'delta_days_dod']]
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')

# Resolves deaths within a month
df['DIES'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

In [16]:
df.head()

Unnamed: 0,hadm_id,note_id,subject_id,charttime,text,gender,dod,anchor_age,anchor_year,admittime,admission_type,insurance,marital_status,race,diagnose_group_description,drg_mortality,diagnose_group_mortality,age,delta_days_dod,DIES
0,26871283,16456098-DS-16,16456098,2204-09-26,\nName: ___ Unit No: ___\n ...,F,,46,2194,2204-09-21,EW EMER.,Medicare,SINGLE,BLACK/AFRICAN AMERICAN,"CELLULITIS & OTHER SKIN INFECTIONS, MAJOR SKIN...",2.0,MODERATE,56,-1,False
1,25555836,13307227-DS-8,13307227,2182-02-12,\nName: ___ Unit No: ___\...,F,,71,2182,2182-02-06,ELECTIVE,Other,SINGLE,WHITE,CARDIAC VALVE PROCEDURES W/O AMI OR COMPLEX PD...,2.0,MODERATE,71,-1,False
2,22932284,15254575-DS-8,15254575,2112-12-05,\nName: ___ Unit No: ___\n \nA...,F,2113-03-27,86,2111,2112-12-01,EW EMER.,Medicare,WIDOWED,WHITE - RUSSIAN,"CONTUSION, OPEN WOUND & OTHER TRAUMA TO SKIN &...",2.0,MODERATE,87,112,False
3,29518468,16428221-DS-19,16428221,2139-06-16,\nName: ___ Unit No: ___\n...,F,,53,2137,2139-06-11,EW EMER.,Medicare,WIDOWED,BLACK/AFRICAN AMERICAN,"HEART FAILURE, HEART FAILURE & SHOCK W/O CC/MCC",2.0,MODERATE,55,-1,False
4,28741084,16025083-DS-18,16025083,2131-05-26,\nName: ___ Unit No: ___\...,M,,37,2131,2131-05-23,OBSERVATION ADMIT,Other,SINGLE,OTHER,"DRUG & ALCOHOL ABUSE OR DEPENDENCE, LEFT AGAIN...",1.0,LOW,37,-1,False


In [None]:
# Load results

from sklearn.metrics import precision_recall_fscore_support as prf
from sklearn.metrics import matthews_corrcoef as mcc

#results = pd.DataFrame()

base_models = ['_']

results = []
for base_model in base_models:
    for i_type in input_type:
        for o_type in output_type:
            for m_span in mortality_span:
                for nctx in num_ctx:
                    for temp in temperatures:
                        
                        fname = f'{i_type}_{o_type}_direct_{m_span}_{nctx}k_t{str(temp).replace('.', '')}.csv'
                        #fname = f'{base_model}_{i_type}_{o_type}_{m_span}_{nctx}k_t{str(temp).replace('.', '')}.csv'
                        
                        fpath = f'results/ex1_bak4/{fname}'

                        # Skips missing results
                        if not Path(fpath).is_file():
                            #print(fpath)
                            #print(f'(i) No results for {fname}, skipping...')
                            continue

                        #df_res = pd.read_csv(fpath, index_col=0).iloc[:200]
                        df_res = pd.read_csv(fpath, index_col=0)

                        y_true = df.set_index('hadm_id').loc[df_res.index].DIES
                        #print(fname)
                        y_pred = df_res.DIES.apply(lambda x: True if x =='YES' else False)

                        res_uf1 = prf(y_true, y_pred, average='micro')[2]
                        res_Mf1 = prf(y_true, y_pred, average='macro')[2]
                        res_mcc = mcc(y_true, y_pred)
                        
                        res = {
                            #'BASE MODEL': base_model,
                            'BASE MODEL': 'llm3',
                            'INPUT_TYPE': i_type,
                            'OUTPUT_TYPE': o_type,
                            #'span': m_span,
                            'num_ctx': nctx * 1024,
                            'temp': temp,
                            'uf1': res_uf1,
                            'Mf1': res_Mf1,
                            'mcc': res_mcc
                        }
                        results.append(res)

                        #print(df_res.index[26:31])

pd.DataFrame(results)

Unnamed: 0,BASE MODEL,INPUT_TYPE,OUTPUT_TYPE,num_ctx,temp,uf1,Mf1,mcc
0,_,S,M,4096,0.1,0.88,0.84,0.681139
1,_,S,M,8192,0.1,0.88,0.84,0.681139
2,_,S,M,16384,0.1,0.88,0.84,0.681139
3,_,S,M,32768,0.1,0.88,0.84,0.681139
4,_,S,M,4096,0.1,0.86,0.813333,0.627758
5,_,S,M,16384,0.1,0.86,0.813333,0.627758
6,_,S,M,32768,0.1,0.87,0.828925,0.658123
7,_,S,PM,4096,0.1,0.67,0.654776,0.421155
8,_,S,PM,8192,0.1,0.67,0.654776,0.421155
9,_,S,PM,16384,0.1,0.67,0.654776,0.421155


In [9]:
results

[]

In [9]:
from itertools import product

num_ctx = [4, 8, 16]
mortality_span = ['30days', 'amonth']

for n_ctx, m_span in product(num_ctx, mortality_span):
    print(n_ctx)
    print(m_span)
    print('.')

4
30days
.
4
amonth
.
8
30days
.
8
amonth
.
16
30days
.
16
amonth
.


In [10]:
# 1. Seleccionar columnas que queremos comunicar en json
# 2. Arreglar formato nombre columnas
#   to upper case
#   remove underscore
#   rename
# 3. DF to dict
# 4. dict to json
# 5. save to new column (input_query) so we can reuse the rest of the code

In [11]:
df.head()

Unnamed: 0,hadm_id,note_id,subject_id,charttime,text,gender,dod,anchor_age,anchor_year,admittime,admission_type,insurance,marital_status,race,diagnose_group_description,drg_mortality,diagnose_group_mortality,age,delta_days_dod,DIES
0,26871283,16456098-DS-16,16456098,2204-09-26,\nName: ___ Unit No: ___\n ...,F,,46,2194,2204-09-21,EW EMER.,Medicare,SINGLE,BLACK/AFRICAN AMERICAN,"CELLULITIS & OTHER SKIN INFECTIONS, MAJOR SKIN...",2,MODERATE,56,-1,False
1,25555836,13307227-DS-8,13307227,2182-02-12,\nName: ___ Unit No: ___\...,F,,71,2182,2182-02-06,ELECTIVE,Other,SINGLE,WHITE,CARDIAC VALVE PROCEDURES W/O AMI OR COMPLEX PD...,2,MODERATE,71,-1,False
2,22932284,15254575-DS-8,15254575,2112-12-05,\nName: ___ Unit No: ___\n \nA...,F,2113-03-27,86,2111,2112-12-01,EW EMER.,Medicare,WIDOWED,WHITE - RUSSIAN,"CONTUSION, OPEN WOUND & OTHER TRAUMA TO SKIN &...",2,MODERATE,87,112,False
3,29518468,16428221-DS-19,16428221,2139-06-16,\nName: ___ Unit No: ___\n...,F,,53,2137,2139-06-11,EW EMER.,Medicare,WIDOWED,BLACK/AFRICAN AMERICAN,"HEART FAILURE, HEART FAILURE & SHOCK W/O CC/MCC",2,MODERATE,55,-1,False
4,28741084,16025083-DS-18,16025083,2131-05-26,\nName: ___ Unit No: ___\...,M,,37,2131,2131-05-23,OBSERVATION ADMIT,Other,SINGLE,OTHER,"DRUG & ALCOHOL ABUSE OR DEPENDENCE, LEFT AGAIN...",1,LOW,37,-1,False
