In [1]:
# Eval on same base model, direct strategy, and different input and output types

In [2]:
import pandas as pd
from pathlib import Path

In [3]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/llms4mortality


In [4]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [5]:
# Globals

SEED = 42

# Controls which data to load
samp_size = 5000
balanced_data = True

base_models = [
    'll3',
    #'medgenius32b',
    #'dsmedical8b',
    #'biomistral7b'
]

# Experiment configs to try to evaluate
input_type = ['S', 'R', 'RC']  # R, RC
output_type = ['M', 'PM']
num_ctx = [2, 4, 8, 16, 32] # (X 1024) LLM should behave better allowing bigger context. We cannot test anything above 32k due to GPU constraints
temperatures = [0.0, 0.1, 0.3, 0.7, 1.0] #[0.0, 0.1, 0.3, 0.7, 1.0]   # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)

In [6]:
# Load precomputed dataframe. Keeps only hadm_id and delta_days_dod (to find patients that died after n days discharge)
# Transform to boolean (patient died within 30 days after discharge)
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')

# Resolves deaths within a month
df['DIES'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

In [7]:
df.head()

Unnamed: 0,hadm_id,note_id,subject_id,charttime,text,gender,dod,anchor_age,anchor_year,admittime,...,insurance,marital_status,race,diagnose_group_description,drg_mortality,diagnose_group_mortality,drg_code,age,delta_days_dod,DIES
0,21891113,19147811-DS-5,19147811,2148-09-22,\nName: ___ Unit No: ___\n ...,F,2149-04-22,68,2148,2148-09-19,...,Other,DIVORCED,WHITE,"DISORDERS OF GALLBLADDER & BILIARY TRACT, DISO...",2,MODERATE,"[284, 445]",68,212,False
1,29643114,15193172-DS-9,15193172,2129-06-28,\nName: ___ Unit No: __...,M,2129-07-17,91,2124,2129-06-25,...,Medicare,WIDOWED,WHITE,"CARDIAC ARRHYTHMIA & CONDUCTION DISORDERS, ACU...",3,HIGH,"[201, 281]",96,19,True
2,26747385,16281465-DS-16,16281465,2144-08-25,\nName: ___ Unit No: ___\n \nAdm...,F,2144-09-06,39,2136,2144-07-13,...,Other,MARRIED,ASIAN,"DIGESTIVE MALIGNANCY, DIGESTIVE MALIGNANCY W MCC",3,HIGH,"[240, 374]",47,12,True
3,23932127,15966914-DS-9,15966914,2155-04-26,\nName: ___ Unit No: ___...,M,,57,2148,2155-04-25,...,Other,MARRIED,WHITE,"PERCUTANEOUS CORONARY INTERVENTION W AMI, PERC...",1,LOW,"[174, 247]",64,-1,False
4,27210508,15484986-DS-28,15484986,2167-11-21,\nName: ___ Unit No: ___\...,M,2167-12-11,72,2167,2167-11-16,...,Other,MARRIED,WHITE,"INTRACRANIAL HEMORRHAGE, INTRACRANIAL HEMORRHA...",3,HIGH,"[44, 64]",72,20,True


In [8]:
# Load results

from sklearn.metrics import precision_recall_fscore_support as prf
from sklearn.metrics import matthews_corrcoef as mcc

results = []

llm_exp_id = ['simple', 'cot1nn', 'cot1p1']
llm_exp_fpath = pjpath / f'exps/results/llms'

for lleid in llm_exp_id:
    for base_model in base_models:
        for i_type in input_type:
            for o_type in output_type:
                for nctx in num_ctx:
                    for temp in temperatures:
                        #fname = f'{i_type}_{o_type}_direct_{nctx}k_t{str(temp).replace('.', '')}.csv'
                        fname = f'{base_model}_{i_type}_{o_type}_{nctx}k_t{str(temp).replace('.', '')}'
                        #fpath = f'results/ex1_bak4/{fname}'
                        fpath = llm_exp_fpath / f'{lleid}/{fname}.csv'

                        # Skips missing results
                        if not Path(fpath).is_file():
                            continue

                        df_res = pd.read_csv(fpath, index_col=0)

                        y_true = df.set_index('hadm_id').loc[df_res.index]['DIES']
                        y_pred = df_res.DIES.apply(lambda x: True if x =='YES' else False)

                        res_uf1 = prf(y_true, y_pred, average='micro')[2]
                        res_Mf1 = prf(y_true, y_pred, average='macro')[2]
                        res_mcc = mcc(y_true, y_pred)
                        
                        res = {
                            'LLM STRAT': lleid,
                            'BASE MODEL': base_model,
                            'INPUT_TYPE': i_type,
                            'OUTPUT_TYPE': o_type,
                            'num_ctx': nctx * 1024,
                            'temp': temp,
                            'uf1': res_uf1,
                            'Mf1': res_Mf1,
                            'mcc': res_mcc
                        }
                        results.append(res)

pd.DataFrame(results)

Unnamed: 0,LLM STRAT,BASE MODEL,INPUT_TYPE,OUTPUT_TYPE,num_ctx,temp,uf1,Mf1,mcc
0,simple,ll3,S,M,32768,0.1,0.765,0.735293,0.470631
1,cot1nn,ll3,S,M,32768,0.1,0.715,0.713791,0.536269
2,cot1p1,ll3,S,M,32768,0.1,0.76,0.68815,0.425732
