In [1]:
import json
import requests
import pandas as pd
from pathlib import Path
from itertools import product

In [2]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/llms4mortality


In [3]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [None]:
# Globals

SEED = 42

## Controls which data file to load
samp_size = 5000
balanced_data = True
max_chars = 22000       # Anything longer than this will be middle-truncated
ss_size = 200

## Controls which LLM model to fire
base_models = [
    'll3',
    #'medgenius32b',
    #'dsmedical8b',
    #'biomistral7b'
]
input_type = ['S', 'R', 'RC']    # (R)eport, (R)eport and (C)hart data as json, (S)ummary (this needs to be precomputed!)
output_type = {     # Maps type of output with the actual format that will be requested in the query
    'M': {  # Just mortality      
            'type': 'object',
            'properties': {
                'DIES': {
                    'enum': ['YES', 'NO']
                }
            },
            'required': [
                'DIES'
            ]
           },
    'PM': {  # Prognosis and mortality      
            'type': 'object',
            'properties': {
                'PROGNOSIS': {
                    'type': 'string'
                },
                'DIES': {
                    'enum': ['YES', 'NO']
                }
            },
            'required': [
                'PROGNOSIS',
                'DIES'
            ]
            },
    }

## LLM params. num_ctx and temps admit multiple values, which will be tested as separate experiments in the same run
num_ctx = [32]      # Context length (x 1024)
temps = [0.1]       #[0.0, *0.1, 0.3, 0.7, 1.0]   # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)
top_k = 20
top_p = 0.5

## Additional data prepro options
# This is the collection of columns that contains the relevant patient info that will be provided to the LLM with the text report
# Remapping some column names to make them more significant in the prompt
pdc_remap = {
    'age': 'AGE',
    'gender': 'GENDER',
    'marital_status': 'MARITAL STATUS',
    'race': 'RACE',
    'diagnose_group_description': 'BROAD DIAGNOSIS',
    'diagnose_group_mortality': 'MORTALITY RISK',
    'insurance': 'INSURANCE',
    #'text': 'REPORT'
}

In [5]:
# Load precomputed dataframe. Keeps only hadm_id and delta_days_dod (to find patients that died after n days discharge)
# Transform to boolean (patient died within 30 days after discharge)
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')
df['delta_days_dod'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

# Load precomputed splits
with open(mimicpath / f'hadmid_splits_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    splits_hadmids = json.load(ifile)

# Load sorted hadm_ids from disk
with open(mimicpath / f'hadmid_sorted_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    emb_hadmids = json.load(ifile)['HADM_ID']

In [6]:
# Loads test split and samples
df_test = df.set_index('hadm_id').loc[splits_hadmids['test']]
df_test = df_test.sample(ss_size, random_state=SEED) if ss_size else df_test 

# Middle-truncates long texts...
df_test['text'] = df_test['text'].apply(lambda x: x[:(max_chars//2)] + x[-(max_chars//2):] if len(x) > max_chars else x)

print(df_test.shape)
df_test.head()

(200, 19)


Unnamed: 0_level_0,note_id,subject_id,charttime,text,gender,dod,anchor_age,anchor_year,admittime,admission_type,insurance,marital_status,race,diagnose_group_description,drg_mortality,diagnose_group_mortality,drg_code,age,delta_days_dod
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
21647706,17450544-DS-18,17450544,2161-05-11,\nName: ___ Unit No: ___\...,M,2161-05-13,76,2161,2161-05-07,EW EMER.,Other,WIDOWED,WHITE,"INTESTINAL OBSTRUCTION, G.I. OBSTRUCTION W CC",3,HIGH,"[247, 389]",76,True
25789302,16492756-DS-19,16492756,2186-02-12,\nName: ___ Unit No: ___\n...,M,2186-05-21,69,2185,2186-01-20,DIRECT EMER.,Medicare,MARRIED,WHITE,"SEPTICEMIA & DISSEMINATED INFECTIONS, SEPTICEM...",2,MODERATE,"[720, 871]",70,False
29954698,11324139-DS-25,11324139,2191-12-07,\nName: ___ Unit No: ___\n...,F,2192-01-20,60,2182,2191-11-28,OBSERVATION ADMIT,Medicare,MARRIED,WHITE,"ACUTE KIDNEY INJURY, RENAL FAILURE W CC",3,HIGH,"[469, 683]",69,False
21699201,18218643-DS-15,18218643,2158-06-27,\nName: ___ Unit No: ___\n \...,M,,18,2158,2158-06-24,ELECTIVE,Other,SINGLE,WHITE,"DEPRESSION EXCEPT MAJOR DEPRESSIVE DISORDER, D...",1,LOW,"[754, 881]",18,False
28251820,13387877-DS-19,13387877,2152-05-26,\nName: ___ Unit No: __...,M,2152-06-07,60,2151,2152-05-24,DIRECT EMER.,Other,MARRIED,WHITE - OTHER EUROPEAN,"OTHER RESPIRATORY DIAGNOSES EXCEPT SIGNS, SYMP...",2,MODERATE,"[143, 189]",61,True


In [None]:
# Sets ollama instance and run in batch

tests = {}  # Will keep raw responses
instance = 'http://localhost:11434/api/generate'
auth_cookie = ''

res_fpath = Path(f'{pjpath}/exps/results/llms/simple')
res_fpath.mkdir(parents=True, exist_ok=True)
summaries_fpath = mimicpath / 'summaries'

for base_model in base_models:
    print(f'> [BASE MODEL]: {base_model}')

    summary_mod = base_model

    for i_type in input_type:

        i_type_id = i_type

        if i_type == 'R':
            # Raw report notes as input
            # Preprends patient data to text. Replaces underscore with spaces in both, feature name and value
            df_test_i = df_test.copy()
            df_test_i['input_query'] = df_test_i.apply(lambda x: json.dumps({'REPORT': ''.join([f'{p_cremap}: {str(x[p_cname]).replace('_', ' ')}\n' for p_cname, p_cremap in pdc_remap.items()]) + x['text']}), axis=1)

        elif i_type == 'S':            
            ## Load summary
            sum_model = 'll3'
            summary_id = f'summary_S{samp_size}{'_balanced' if balanced_data else ''}_{sum_model}_mc{max_chars}'
            df_summaries = pd.read_csv(f'{summaries_fpath}/{summary_id}.csv', index_col=0)

            missing_summaries = set(df_test.index) - set(df_test.index)
            if len(missing_summaries) > 0:
                print(f'(!) The following hadm_ids have no summary: {missing_summaries}')
                print(f'\tOriginal text will be used insted...')

            # Updates dataframe with summaries (only with the ones available)
            df_test = pd.merge(df_test, df_summaries, left_index=True, right_index=True, how='left')
            df_test['SUMMARY'] = df_test.apply(lambda x: x['text'] if x['SUMMARY'] != x['SUMMARY'] else x['SUMMARY'], axis=1)

            # Loads summaries middle-truncating long texts...
            df_test['SUMMARY'] = df_test['SUMMARY'].apply(lambda x: x[:(max_chars//2)] + x[-(max_chars//2):] if len(x) > max_chars else x)

            df_test_i = df_test.copy()
            df_test_i['input_query'] = df_test_i['SUMMARY'].apply(lambda x: json.dumps({'REPORT': x}))

        elif i_type == 'RC':
            # Input as json with relevant patient data in separate fields
            df_test_i = df_test.copy()
            df_test_i['input_query'] = df_test_i.apply(lambda x: json.dumps({p_cremap: str(x[p_cname]).replace('_', ' ') for p_cname, p_cremap in pdc_remap.items()} | {'REPORT': x['text']}), axis=1)

        else:
            print(i_type)
            raise NotImplementedError

        print(f'> [INPUT_TYPE]: {i_type}')

        for o_type_id, o_type_format in output_type.items():
            for n_ctx, temp in product(num_ctx, temps):

                # Resolves model
                # # Models need to be ready in the online instance!
                model = f'{base_model}_{i_type}_{o_type_id}'

                print(f'> [OUTPUT_TYPE]: {o_type_id}, ({n_ctx} ctx, {temp} temp)')

                responses = {}
                i=1
                for index, row in df_test_i.iterrows():
                    # Iterates every query in the test set
                    print(f'> Processing input {i} out of {len(df_test_i)}...', end='\r')

                    fromatted_input = row['input_query']
                    data = {'model': model,  # Explicit model to use
                            'options': {
                                'num_ctx': n_ctx * 1024,
                                'temperature': temp, # 0?
                                'seed': SEED,
                                'top_k': top_k,
                                'top_p': top_p
                                },
                            'keep-alive': -1,  # Keep connection open
                            'prompt': fromatted_input,
                            'stream': False,  # Wait and return all the result at once
                            'format': o_type_format,
                        }
                    # Prepares query
                    data = json.dumps(data)
                    cookies = {
                        '_oauth2_proxy': auth_cookie}
                    headers = {
                        'Content-Type': 'application/x-www-form-urlencoded',
                    }

                    response = requests.post(instance, cookies=cookies, headers=headers, data=data)
                    json_response = json.loads(response.text)['response']
                    responses[index] = json.loads(json_response) # Keeps the dictionary version of the json response
                    i+=1

                # Export results
                df_responses = pd.DataFrame(responses).T
                test_id = f'{base_model}_{i_type}_{o_type_id}_{n_ctx}k_t{str(temp).replace('.', '')}'
                df_responses.to_csv(res_fpath / f'{test_id}.csv')
                tests[test_id] = df_responses
                print(f'\n')

> [BASE MODEL]: ll3
> [INPUT_TYPE]: S
> [OUTPUT_TYPE]: PM, (32 ctx, 0.1 temp)
> Processing input 16 out of 200...