In [1]:
# CoT mortality prediction using precomputed summaries from both, test entries and examples

In [2]:
import json
import requests
import pandas as pd
import numpy as np
from pathlib import Path
import re

from src.minicbr import MiniCBR

In [3]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/llms4mortality


In [4]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [None]:
# Globals

SEED = 42

## Controls which data file to load
samp_size = 5000
balanced_data = True
max_chars = 22000       # Anything longer than this will be middle-truncated

## CBR configs
#knn = 1
weighted = True
summary_mode = False    # Whether to use summaries in the CBR part of the query (ie, finding neighbours)
withprepended = True    # Whether to load embeddings that were computed out of texts with preprended patient info
# Load by name and truncation side
modname, mod_truncation = ('nazyrova/clinicalBERT', 'middle')
# Default feature map for dynamic weighted
feature_map = {
    'emb': ('cosine', False, 0.33),
    'age': ('euclidean', True, 0.33),
    #'gender': ('jaccard', False, .07),
    #'admission_type': ('jaccard', False, .07),
    #'insurance': ('jaccard', False, .07),
    #'marital_status': ('jaccard', False, .07),
    #'race': ('jaccard', False, .07),
    'drg_mortality': ('euclidean', True, 0.33),
    #'drg_code': ('jaccard', False, .07),
}

## Controls which LLM model to fire
base_models = [
    'llama3',
    #'medgenius32b',
    #'dsmedical8b',
    #'biomistral7b'
]

## Paths to your system prompt json files
# These should contain dictionaries where keys are short ids for each prompt type, and the values is another dictionary with at least "prompt"
sysprompt_in_fpath = pjpath / 'ollama/sysprompts/sysprompt_in.json'
sysprompt_out_fpath = pjpath / 'ollama/sysprompts/sysprompt_out.json'   # Should also contain "format" key with the expected output format

# Type of input and output system prompts. Define only a single value for each.
i_type_id = 'S'
o_type_id = 'M'

## LLM params
n_ctx = 32      # Context length (x 1024)
temp = 0.1      #[0.0, *0.1, 0.3, 0.7, 1.0]   # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)
top_k = 20
top_p = 0.5
ss_size = 200

## Additional data prepro options
# This is the collection of columns that contains the relevant patient info that will be provided to the LLM with the text report
# Remapping some column names to make them more significant in the prompt
pdc_remap = {
    'age': 'AGE',
    'gender': 'GENDER',
    'marital_status': 'MARITAL STATUS',
    'race': 'RACE',
    'diagnose_group_description': 'BROAD DIAGNOSIS',
    'diagnose_group_mortality': 'MORTALITY RISK',
    'insurance': 'INSURANCE',
    #'text': 'REPORT'
}

In [6]:
# Load data
# Load precomputed dataframe.
dfname = f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}'
df = pd.read_csv(mimicpath / f'{dfname}.csv.gz')

# We also need the dummies for CBR
df_dummies = pd.read_csv(mimicpath / f'd_{dfname}.csv.gz')

# Resolve target (ie, mortality within 30 days of discharge)
df['DIES'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)
df_dummies['DIES'] = df_dummies['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

# Load precomputed splits
with open(mimicpath / f'hadmid_splits_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    splits_hadmids = json.load(ifile)
    
# Load sorted hadm_ids from disk
with open(mimicpath / f'hadmid_sorted_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    emb_hadmids = json.load(ifile)['HADM_ID']

In [7]:
# Resolve test data

df_test = df.set_index('hadm_id').loc[splits_hadmids['test']]
df = df.set_index('hadm_id')
df['DIES'] = df['DIES'].apply(lambda x: 'YES' if x else 'NO')   # Need to do this for json
df_test = df_test.sample(ss_size, random_state=SEED) if ss_size else df_test 

In [8]:
# Loads embeddings

# Path where embeddings are located
embpath = mimicpath / 'embeddings'

modname = re.sub('[^a-zA-Z0-9]+', '', modname)
mod_fname = f'embeddings_{modname}_{'summary_' if summary_mode else ''}S{samp_size}_T{mod_truncation}{'_balanced' if balanced_data else ''}{'_PR' if withprepended else ''}.npy'
print(f'> Loading embeddings from {mod_fname}...')
embeddings = np.load(embpath / mod_fname)

# hadmid-index mappings (and back)
hadm2idx = {hadm: i for i, hadm in enumerate(emb_hadmids)}
idx2hadm = {i: hadm for hadm, i in hadm2idx.items()}

# Put embeddings in df iteratively taking their hadm_id into account
df_embs = pd.DataFrame(columns=['emb'])
df_embs.index.name = 'hadm_id'
for i, emb in enumerate(embeddings):
    #df_embs.at[len(df_embs), 'emb'] = emb
    df_embs.at[idx2hadm[i], 'emb'] = emb

# Merge df_embs with actual data df
_df = df_dummies.copy()
_df = pd.merge(_df, df_embs, on='hadm_id', how='inner')

# Columns to keep in main dataframe for the experiment
target_prefixes = ['age', 'gender', 'admission_type', 'insurance', 'marital_status', 'race', 'drg_mortality', 'drg_code', 'emb']
sol_prefixes = ['DIES']

target_columns = [c for c in _df.columns if c.startswith(tuple(target_prefixes))]

# Prepares data for CBR model
dfd_cb_X = _df.set_index('hadm_id').loc[splits_hadmids['cb']][target_columns]
dfd_test_X = _df.set_index('hadm_id').loc[df_test.index][target_columns]

dfd_cb_y = _df.set_index('hadm_id').loc[splits_hadmids['cb']][sol_prefixes]
dfd_test_y = _df.set_index('hadm_id').loc[df_test.index][sol_prefixes]

# Initializes MiniCBR and resolves 1NN neighbours for test data
print(f'> Finding neighbours for test entries...')
mcbr = MiniCBR(df_X=dfd_cb_X, df_y=dfd_cb_y, feature_map=feature_map)
ndists, nidxs = mcbr.find(dfd_test_X, 1)

# Translate natural indices into actual hadm_id from neighbours. Save as a separate df
# The index of this are the IDs of the test entries, while neigh_hadmid maps to the ID of the neighbour
neigh_hadm_id = [dfd_cb_X.iloc[idx[0]].name for idx in nidxs]
df_neighs = pd.DataFrame({'neigh_hadmid': neigh_hadm_id}, index=dfd_test_X.index)

> Loading embeddings from embeddings_nazyrovaclinicalBERT_S5000_Tmiddle_balanced_PR.npy...
> Finding neighbours for test entries...


In [None]:
# Load precomputed summaries
# This is required for CoT

summaries_fpath = mimicpath / 'summaries'
sum_model = 'llama3'
summary_id = f'summary_S{samp_size}{'_balanced' if balanced_data else ''}_{sum_model}_mc{max_chars}'
df_summaries = pd.read_csv(f'{summaries_fpath}/{summary_id}.csv', index_col=0)

In [10]:
# Sets ollama instance and run

# First we load the system prompt configs from disk.
# These will be used to build the right system prompt for each experiment
with open(sysprompt_in_fpath, 'r') as ifile:
    sprompt_data_in = json.load(ifile)

with open(sysprompt_out_fpath, 'r') as ifile:
    sprompt_data_out = json.load(ifile)

# Getting related prompts for the specified input and output mode
sprompt_in = sprompt_data_in[i_type_id]['prompt']
sprompt_out = sprompt_data_out[o_type_id]['prompt']
o_type_format = sprompt_data_out[o_type_id]['format']
# Builds full system prompt
sprompt = sprompt_in + " " + sprompt_out

instance = 'http://localhost:11434/api/generate'
auth_cookie = ''

res_fpath = Path(f'{pjpath}/exps/results/llms/cot1nn')
res_fpath.mkdir(parents=True, exist_ok=True)


for base_model in base_models:
    print(f'> [BASE MODEL]: {base_model}')
    responses = {}
    i=1
    for index, row in df_test.iterrows():
        print(f'>> Processing row {i} out of {len(df_test)}', end='\r')

        # Get summary text from entry
        text = df_summaries.loc[index].iloc[0]

        # Truncate middle if resulting text is longer than max_chars
        if len(text) > max_chars:
            print(f'(i) Text exceeds the max char limit ({len(text)}) in entry {index}. Middle-truncating to {max_chars}...')
            text = text[:(max_chars//2)] + text[-(max_chars//2):]
            print(f'... Result truncate: {len(text)}')

        # Get neighour test from entry
        neigh_hadm_id = df_neighs.loc[index]
        neigh_text = df_summaries.loc[neigh_hadm_id]['SUMMARY'].iloc[0]
        neigh_dies = df.loc[index]['DIES']

        # Truncate middle if resulting text is longer than max_chars
        if len(neigh_text) > max_chars:
            print(f'(i) Neighbour text exceeds the max char limit ({len(neigh_text)}) in entry {index}. Middle-truncating to {max_chars}...')
            neigh_text = neigh_text[:(max_chars//2)] + neigh_text[-(max_chars//2):]
            print(f'... Result truncate: {len(neigh_text)}')

        formatted_input = json.dumps({'REPORT': text})
        formatted_input_neigh = json.dumps({'REPORT': neigh_text})
        formatted_output_neigh = json.dumps({'DIES': neigh_dies})

        # Builds full CoT prompt with example
        cot1nn_prompt = "\nBelow is an example of the medical text report from which you will have to decide if the patient will die within 30 days of their medical discharge:\n"
        cot1nn_prompt += formatted_input_neigh
        cot1nn_prompt += "\nAnd here's its solution:\n"
        cot1nn_prompt += formatted_output_neigh
        cot1nn_prompt += "\nThe previous example is from a medical case that was marked as similar by another expert system. You must decide if the exemplar case is similar enough to another case you will be provided below, and if so, predict the mortality of the new case from its report and the known mortality from the similar case. If the cases are not similar enough, ignore the mortality of the example and try to predict the mortality of the case just from its text report.\n"
        cot1nn_prompt += formatted_input

        data = {'model': base_model,  # Explicit model to use
            'options': {
                'num_ctx': n_ctx * 1024,
                'temperature': temp,
                'seed': SEED,
                'top_k': top_k,
                'top_p': top_p
                },
            'keep-alive': 0,
            'system': sprompt,
            'prompt': cot1nn_prompt,
            'stream': False,  # Wait and return all the result at once
            'format': o_type_format
        }

        # Prepares query
        data = json.dumps(data)
        cookies = {
            '_oauth2_proxy': auth_cookie}
        headers = {
            'Content-Type': 'application/x-www-form-urlencoded',
        }
        cot1nn_response = requests.post(instance, cookies=cookies, headers=headers, data=data)
        cot1nn_response_json = json.loads(cot1nn_response.text)['response']
        responses[index] = json.loads(cot1nn_response_json) # Keeps the dictionary version of the json response
        i+=1

    # Export results
    df_responses = pd.DataFrame(responses).T
    test_id = f'{base_model}_{i_type_id}_{o_type_id}_{n_ctx}k_t{str(temp).replace('.', '')}'
    df_responses.to_csv(res_fpath / f'{test_id}.csv')
    print(f'\n')

> [BASE MODEL]: llama3
>> Processing row 200 out of 200

