In [1]:
# Baseline model using CBR based on PRR recovery

In [2]:
import pandas as pd
import numpy as np
import scipy
import json
from pathlib import Path
import re

from sklearn.metrics import precision_recall_fscore_support as prf
from sklearn.metrics import matthews_corrcoef as mcc

from src.minicbr import MiniCBR

In [3]:
# Globals

SEED = 42

# Data info. Will be used to load the appropriate precomputed data form disk
samp_size = 5000
balanced_data = True
max_chars = 22000       # Character limit on text notes applied in embedding generation
asdummies = True

# Load by name and truncation side
modnames = [
    ('nazyrova/clinicalBERT', 'middle'),
    #('emilyalsentzer/Bio_Discharge_Summary_BERT', 'right'),
    #('all-distilroberta-v1', 'right'),
    #('medicalai/ClinicalBERT', 'right'),

    # Big models:
    #('brandonhcheung04/bart', 'right'),
    #('brandonhcheung04/bart', 'middle'),
    #('Simonlee711/Clinical_ModernBERT', 'right'),
    #('Simonlee711/Clinical_ModernBERT', 'middle'),
]

# Default feature map. Weights will be adjusted dynamically
feature_map = {
    'emb': ('cosine', False, 0.33),
    'age': ('euclidean', True, 0.33),
    #'gender': ('jaccard', False, .07),
    #'admission_type': ('jaccard', False, .07),
    #'insurance': ('jaccard', False, .07),
    #'marital_status': ('jaccard', False, .07),
    #'race': ('jaccard', False, .07),
    'drg_mortality': ('euclidean', True, 0.33),
    #'drg_code': ('jaccard', False, .07),
}
fit_features = False    # Handle to fit prr features

# PatientRecordRecovery fixed params
prr_delta_weight = 0.1      # .2, .1
prr_score_metric = 'mcc'    # uf1, mf1, mcc
# NOTE: k & case weight is inherited from each experiment (ie, we use the same values to fit the recovery)

In [4]:
# CBR hyperparams
knn = range(1, 200, 2)  # 1, 21 TODO
weighted = [True]
summaries = [False]
withprepended = [True]   # Wherther to use embeddings that were generated from text with prepended patient data

In [5]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/llms4mortality


In [6]:
# Set this to your MIMIC-IV path where discharge, patients and admissions tables are located
mimicpath = pjpath / 'data/mimiciv'

In [7]:
# Load precomputed dataframe.
df = pd.read_csv(mimicpath / f'{'d_' if asdummies else ''}mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')

# Resolve target (ie, mortality within 30 days of discharge)
df['DIES'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

# Load precomputed splits
with open(mimicpath / f'hadmid_splits_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    splits_hadmids = json.load(ifile)

# Load sorted hadm_ids from disk
with open(mimicpath / f'hadmid_sorted_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    emb_hadmids = json.load(ifile)['HADM_ID']

In [8]:
df.head()

Unnamed: 0,hadm_id,note_id,subject_id,charttime,text,dod,anchor_age,anchor_year,admittime,diagnose_group_description,...,drg_code_975,drg_code_976,drg_code_977,drg_code_981,drg_code_982,drg_code_983,drg_code_987,drg_code_988,drg_code_989,DIES
0,21891113,19147811-DS-5,19147811,2148-09-22,\nName: ___ Unit No: ___\n ...,2149-04-22,68,2148,2148-09-19,"DISORDERS OF GALLBLADDER & BILIARY TRACT, DISO...",...,0,0,0,0,0,0,0,0,0,False
1,29643114,15193172-DS-9,15193172,2129-06-28,\nName: ___ Unit No: __...,2129-07-17,91,2124,2129-06-25,"CARDIAC ARRHYTHMIA & CONDUCTION DISORDERS, ACU...",...,0,0,0,0,0,0,0,0,0,True
2,26747385,16281465-DS-16,16281465,2144-08-25,\nName: ___ Unit No: ___\n \nAdm...,2144-09-06,39,2136,2144-07-13,"DIGESTIVE MALIGNANCY, DIGESTIVE MALIGNANCY W MCC",...,0,0,0,0,0,0,0,0,0,True
3,23932127,15966914-DS-9,15966914,2155-04-26,\nName: ___ Unit No: ___...,,57,2148,2155-04-25,"PERCUTANEOUS CORONARY INTERVENTION W AMI, PERC...",...,0,0,0,0,0,0,0,0,0,False
4,27210508,15484986-DS-28,15484986,2167-11-21,\nName: ___ Unit No: ___\...,2167-12-11,72,2167,2167-11-16,"INTRACRANIAL HEMORRHAGE, INTRACRANIAL HEMORRHA...",...,0,0,0,0,0,0,0,0,0,True


In [9]:
# Initializes results file in disk
respath = f'_res_baseline_cbr_prr.csv'
respath = f'results/cbr_embeddings_S{samp_size}{'_balanced' if balanced_data else ''}.csv'
res = pd.DataFrame(columns=['modname', 'is_summary', 'has_prepended', 'k', 'weighted', 'f1_micro', 'f1_macro', 'mcc', 'prr_delta', 'prr_fweights'])
res.to_csv(respath, mode='w', header=True)

# Path where embeddings are located
embpath = mimicpath / 'embeddings'

e_n = len(modnames) * len(knn) * len(weighted)  # TODO: In addition to these experiments we will try: modstrat (R, S, PR(chartdata embedded)), 
e_count = 1
for modname, mod_truncation in modnames:

    for summary_mode in summaries:
        for preprend_mode in withprepended:
            # Loads embeddings
            modname = re.sub('[^a-zA-Z0-9]+', '', modname)
            mod_fname = f'embeddings_{modname}_{'summary_' if summary_mode else ''}S{samp_size}_T{mod_truncation}{'_balanced' if balanced_data else ''}{'_PR' if withprepended else ''}.npy'
            print(f'> Loading embeddings from {mod_fname}...')
            embeddings = np.load(embpath / mod_fname)

            # hadmid-index mappings (and back)
            hadm2idx = {hadm: i for i, hadm in enumerate(emb_hadmids)}
            idx2hadm = {i: hadm for hadm, i in hadm2idx.items()}

            # Put embeddings in df iteratively taking their hadm_id into account
            df_embs = pd.DataFrame(columns=['emb'])
            df_embs.index.name = 'hadm_id'
            for i, emb in enumerate(embeddings):
                #df_embs.at[len(df_embs), 'emb'] = emb
                df_embs.at[idx2hadm[i], 'emb'] = emb

            # Merge df_embs with actual data df
            _df = df.copy()
            _df = pd.merge(_df, df_embs, on='hadm_id', how='inner')

            # Columns to keep in main dataframe for the experiment
            target_prefixes = ['age', 'gender', 'admission_type', 'insurance', 'marital_status', 'race', 'drg_mortality', 'drg_code', 'emb']
            sol_prefixes = ['DIES']

            # Resolve target columns from target prefixes
            target_columns = [c for c in _df.columns if c.startswith(tuple(target_prefixes))]

            df_cb_X = _df.set_index('hadm_id').loc[splits_hadmids['cb']][target_columns]
            df_test_X = _df.set_index('hadm_id').loc[splits_hadmids['test']][target_columns]
            df_train_X = _df.set_index('hadm_id').loc[splits_hadmids['train']][target_columns]

            df_cb_y = _df.set_index('hadm_id').loc[splits_hadmids['cb']][sol_prefixes]
            df_test_y = _df.set_index('hadm_id').loc[splits_hadmids['test']][sol_prefixes]
            df_train_y = _df.set_index('hadm_id').loc[splits_hadmids['train']][sol_prefixes]

            for k in knn:
                for w in weighted:
                    print(f'>> EX IDX: {e_count}/{e_n} - {modname} - {k} - {w}')

                    # CBR initialization and fit
                    mcbr = MiniCBR(df_X=df_cb_X, df_y=df_cb_y, feature_map=feature_map)
                    if fit_features:
                        mcbr.fit(df_train_X, df_train_y, k=k, delta_weight=prr_delta_weight, weighted_voting=w, score_metric=prr_score_metric, verbose=True)

                    # Call find with test set to get neighbour distances and indices
                    ndists, nidxs = mcbr.find(df_test_X, k, verbose=True)

                    # Aggregate neighbours into a single solution
                    # NOTE TODO: Using count strat, but avg would be interesting, but to test this we must first implement it in prr.find()
                    y_candidate_sols = mcbr.aggregate(ndists, nidxs, weighted_voting=w, strat='count')

                    # Eval
                    res = pd.Series([
                        modname,
                        summary_mode,
                        preprend_mode,
                        k,
                        w,
                        prf(df_test_y, y_candidate_sols, average='micro')[2],
                        prf(df_test_y, y_candidate_sols, average='macro')[2],
                        mcc(df_test_y, y_candidate_sols),
                        prr_delta_weight,
                        str([f'{k}: {v[-1]}' for k, v in mcbr.feature_map.items()])
                    ]
                    ).to_frame().T

                    # Save to disk
                    res.to_csv(respath, mode='a', header=False)
                    e_count += 1      

> Loading embeddings from embeddings_nazyrovaclinicalBERT_S5000_Tmiddle_balanced_PR.npy...
>> EX IDX: 1/100 - nazyrovaclinicalBERT - 1 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 2/100 - nazyrovaclinicalBERT - 3 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 3/100 - nazyrovaclinicalBERT - 5 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 4/100 - nazyrovaclinicalBERT - 7 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 5/100 - nazyrovaclinicalBERT - 9 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 6/100 - nazyrovaclinicalBERT - 11 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 7/100 - nazyrovaclinicalBERT - 13 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 8/100 - nazyrovaclinicalBERT - 15 - True
>> Finding neighbours for query 1004 out of 1005...

>> EX IDX: 9/100 - nazyrovaclinicalBERT - 17 - True
>> Finding neighbours 