In [1]:
import json
import requests
import pandas as pd
from pathlib import Path

from itertools import product


In [None]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/medical-cbr


In [None]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [None]:
# Globals

SEED = 42

# Controls which data to load
samp_size = 20000
balanced_data = True

base_models = [
    'iv_ll3_direct',
    'medgenius32b',
    'dsmedical8b',
    'biomistral7b'
]

# Controls which llm to fire
input_type = ['R', 'RC']    # (R)eport, (R)eport and (C)hart data as json, (S)ummary (this needs to be precomputed!)
output_type = {     # Maps type of output with the actual format that will be requested in the query
    'M': {  # Just mortality      
            'type': 'object',
            'properties': {
                'DIES': {
                    'enum': ['YES', 'NO']
                }
            },
            'required': [
                'DIES'
            ]
           },
    'PM': {  # Prognosis and mortality      
            'type': 'object',
            'properties': {
                'PROGNOSIS': {
                    'type': 'string'
                },
                'DIES': {
                    'enum': ['YES', 'NO']
                }
            },
            'required': [
                'PROGNOSIS',
                'DIES'
            ]
            }
    }

# This is the collection of columns that contains the relevant patient info that will be provided to the LLM with the text report
# Remapping some column names to make them more significant in the prompt
pdc_remap = {
    'age': 'AGE',
    'gender': 'GENDER',
    'marital_status': 'MARITAL STATUS',
    'race': 'RACE',
    'diagnose_group_description': 'BROAD DIAGNOSIS',
    'diagnose_group_mortality': 'MORTALITY RISK',
    'insurance': 'INSURANCE',
    #'text': 'REPORT'
}

num_ctx = [16]   # Context length (x 1024)
mortality_span = ['30days'] # 30days, amonth How to ask the LLM for mortality (count days vs month)
temps = [0.1] #[0.0, *0.1, 0.3, 0.7, 1.0]   # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)


# Fixed model params
top_k = 20  # *20
top_p = 0.5 # *.5

# Slice long input: Just keep up to max_words of each text
max_chars = 22000
subsamp_size = 100  # 100 Number of entries to test model with. This number will be used to load summaries from disks

# NOTE use this only to limit the number of entries to process in the experiment
#   Smaller than subsamo_size
minisample = -1

# If set, attempts to load note summaries from disk and use those instead of the raw reports.
summaries_dir = f'{mimicpath}/summaries/'

prepend_columns = ['age', 'gender', 'insurance', 'marital_status', 'race']   # Columns in base dataframe to prepend to text

In [None]:
# Load precomputed dataframe. Keeps only hadm_id and delta_days_dod (to find patients that died after n days discharge)
# Transform to boolean (patient died within 30 days after discharge)
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')
df['delta_days_dod'] = df['delta_days_dod'].apply(lambda x: x > 0 and x <= 30)

# Load precomputed splits
with open(mimicpath / f'hadmid_splits_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    splits_hadmids = json.load(ifile)

# Load sorted hadm_ids from disk
with open(mimicpath / f'hadmid_sorted_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
    emb_hadmids = json.load(ifile)['HADM_ID']

In [None]:
# Loads test split
df_test = df.set_index('hadm_id').loc[splits_hadmids['test']]

# Gets subsample
if subsamp_size:
    print(f'> Subsampling test data to {subsamp_size}...')
    df_test = df_test.sample(subsamp_size, random_state=SEED)
    
# Middle-truncates long texts...
df_test['text'] = df_test['text'].apply(lambda x: x[:(max_chars//2)] + x[-(max_chars//2):] if len(x) > max_chars else x)

print(df_test.shape)
df_test.head()

In [7]:
if minisample > 0:
    df_test = df_test.iloc[:minisample]

In [None]:
# Define ollama instance and run in batch

tests = {}  # Will keep raw responses
instance = 'http://localhost:11434/api/generate'
auth_cookie = ''

for base_model in base_models:
    print(f'> [BASE MODEL]: {base_model}')

    for i_type in input_type:

        i_type_id = i_type

        if i_type == 'R':
            # Raw report notes as input
            # Preprends patient data to text. Replaces underscore with spaces in both, feature name and value
            df_test_i = df_test.copy()
            #df_test_i['input_query'] = df_test_i.apply(lambda x: ''.join([f'{p_cremap}: {str(x[p_cname]).replace('_', ' ')}\n' for p_cname, p_cremap in pdc_remap.items()]) + x['text'], axis=1)
            df_test_i['input_query'] = df_test_i.apply(lambda x: json.dumps({'REPORT': ''.join([f'{p_cremap}: {str(x[p_cname]).replace('_', ' ')}\n' for p_cname, p_cremap in pdc_remap.items()]) + x['text']}), axis=1)

        elif i_type == 'S':
            # NOTE: Assumes summary has already been loaded in text column
            
            ## Load summary
            summaries_path = f'{mimicpath}/summaries/summary_{max_chars}mc_ss{subsamp_size}.csv'

            df_summaries = pd.read_csv(summaries_path, index_col=0)

            missing_summaries = set(df_test.index) - set(df_test.index)
            if len(missing_summaries) > 0:
                print(f'(!) The following hadm_ids have no summary: {missing_summaries}')
                print(f'\tOriginal text will be used insted...')

            # Updates dataframe with summaries (only with the ones available)
            df_test = pd.merge(df_test, df_summaries, left_index=True, right_index=True, how='left')
            df_test['SUMMARY'] = df_test.apply(lambda x: x['text'] if x['SUMMARY'] != x['SUMMARY'] else x['SUMMARY'], axis=1)

            # Middle-truncates long texts...
            df_test['SUMMARY'] = df_test['SUMMARY'].apply(lambda x: x[:(max_chars//2)] + x[-(max_chars//2):] if len(x) > max_chars else x)
            ## End load summary



            df_test_i = df_test.copy()
            df_test_i['input_query'] = df_test_i['SUMMARY'].apply(lambda x: json.dumps({'REPORT': x}))

        elif i_type == 'RC':
            # Input as json with relevant patient data in separate fields
            df_test_i = df_test.copy()
            df_test_i['input_query'] = df_test_i.apply(lambda x: json.dumps({p_cremap: str(x[p_cname]).replace('_', ' ') for p_cname, p_cremap in pdc_remap.items()} | {'REPORT': x['text']}), axis=1)

        else:
            print(i_type)
            raise NotImplementedError

        print(f'> [INPUT_TYPE]: {i_type_id}')

        for o_type_id, o_type_format in output_type.items():
            for n_ctx, m_span, temp in product(num_ctx, mortality_span, temps):

                # NOTE: Input type R and S use the same base model:
                i_type_id_model = i_type_id

                # Resolves model
                # # Models need to exists in the online instance!
                model = f'{base_model}_{i_type_id_model}_{o_type_id}_{m_span}'

                print(f'> [OUTPUT_TYPE]: {o_type_id}, ({n_ctx} ctx, {m_span} span, {temp} temp)')

                responses = {}
                i=1
                for index, row in df_test_i.iterrows():
                    # Iterates every query in the test set
                    print(f'> Processing input {i} out of {len(df_test_i)}...', end='\r')

                    #patient_input = json.dumps({row.name: row['input_query'][:20000]})
                    fromatted_input = row['input_query']

                    data = {'model': model,  # Explicit model to use
                            #'options': {'num_ctx': n_ctx * 1024, 'temperature': temp},
                            'options': {
                                'num_ctx': n_ctx * 1024,
                                'temperature': temp, # 0?
                                'seed': SEED,
                                'top_k': top_k,
                                'top_p': top_p
                                },
                            #'options': {'seed': 55},
                            'keep-alive': -1,  # Keep connection open
                            'prompt': fromatted_input,
                            'stream': False,  # Wait and return all the result at once
                            'format': o_type_format,
                            #'seed': SEED
                        }
                    # Prepares query
                    data = json.dumps(data)
                    cookies = {
                        '_oauth2_proxy': auth_cookie}
                    headers = {
                        'Content-Type': 'application/x-www-form-urlencoded',
                    }

                    response = requests.post(instance, cookies=cookies, headers=headers, data=data)
                    json_response = json.loads(response.text)['response']
                    responses[index] = json.loads(json_response) # Keeps the dictionary version of the json response
                    #summarized_texts.append(json.loads(response)['summary'])
                    i+=1

                # Export results
                df_responses = pd.DataFrame(responses).T
                test_id = f'{base_model}_{i_type_id}_{o_type_id}_{m_span}_{n_ctx}k_t{str(temp).replace('.', '')}'
                df_responses.to_csv(f'results/ex1/{test_id}.csv')
                tests[test_id] = df_responses
                print(f'\n')