In [3]:
import json
import requests
import pandas as pd
from pathlib import Path

from itertools import product


In [4]:
# NOTE: (On report length and limitations of LLMs)
"""
    While we can control the context size for most ollama models tampering with num_ctx, this is not
    the full story. num_ctx includes the size of the answer given by the model, so
    when dealing with complex and long prompts this might lead to incomplete, short or incoherent answers.
    We opted for truncating long reports in favor of avoiding this issues. max_chars parameter controls this truncation
"""

'\n    While we can control the context size for most ollama models tampering with num_ctx, this is not\n    the full story. num_ctx includes the size of the answer given by the model, so\n    when dealing with complex and long prompts this might lead to incomplete, short or incoherent answers.\n    We opted for truncating long reports in favor of avoiding this issues. max_chars parameter controls this truncation\n'

In [5]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/llms4mortality


In [6]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [7]:
# Globals

SEED = 42

# Controls which data to load
samp_size = 5000
balanced_data = True
target_split = False  # Split set from loaded dataframe to generate summaries from. If None, we disregard inner splits and take entries directly from the main loaded dataframe

base_models = [
    'll3',
    #'medgenius32b',
    #'dsmedical8b',
    #'biomistral7b'
]

# Slice long input: Just keep up to max_words of each text
max_chars = 22000
subsamp_size = False  # 200, 100 Number of entries to test model with. Or False to disregard it

# This is the collection of columns that contains the relevant patient info
#   Values remap column name to an alternative and more readable name (might be useful if using LLMs)
pdc_remap = {
    'age': 'AGE',
    'gender': 'GENDER',
    'marital_status': 'MARITAL STATUS',
    'race': 'RACE',
    'diagnose_group_description': 'BROAD DIAGNOSIS',
    'diagnose_group_mortality': 'MORTALITY RISK',
    'insurance': 'INSURANCE',
    #'text': 'REPORT'
}

prepend_extended_patient_data = True    # If set True, preprends additional categorical data from pdf_remap to the beginning of the text entry for each patient

# Ollama hyperparams
n_ctx = 32   # Context length (x 1024)
temp = 0.0 # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)
top_k = 20
top_p = 0.5


In [8]:
# Load precomputed dataframe.
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')

In [9]:
# Load precomputed dataframe.
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')

if target_split:
    print(f'>> (i) Performing summarization only on entries from {target_split} split...')
    # Load precomputed splits
    with open(mimicpath / f'hadmid_splits_S{samp_size}{'_balanced' if balanced_data else ''}.json', 'r') as ifile:
        splits_hadmids = json.load(ifile)

    # Loads target split
    df = df.set_index('hadm_id').loc[splits_hadmids['test']]

# Do further subsamplig (do this just to speed up computations)
if subsamp_size:
    print(f'>> (!) Subsampling total entries to {subsamp_size}!')
    print(f'> Subsampling test data to {subsamp_size}...')
    df = df.sample(subsamp_size, random_state=SEED)

# Preprends additional data to the text entries
if prepend_extended_patient_data:
    print(f'>> (!) Preprending additional patient data to text!')
    df['text'] = df.apply(lambda x: ''.join([f'{p_cremap}: {str(x[p_cname]).replace('_', ' ')}\n' for p_cname, p_cremap in pdc_remap.items()]) + x['text'], axis=1)

# We are only interested in texts (and hadm_id ad index)
df = df.set_index('hadm_id')['text'].to_frame()

>> (!) Preprending additional patient data to text!


In [10]:
df.head()

Unnamed: 0_level_0,text
hadm_id,Unnamed: 1_level_1
21891113,AGE: 68\nGENDER: F\nMARITAL STATUS: DIVORCED\n...
29643114,AGE: 96\nGENDER: M\nMARITAL STATUS: WIDOWED\nR...
26747385,AGE: 47\nGENDER: F\nMARITAL STATUS: MARRIED\nR...
23932127,AGE: 64\nGENDER: M\nMARITAL STATUS: MARRIED\nR...
27210508,AGE: 72\nGENDER: M\nMARITAL STATUS: MARRIED\nR...


In [None]:
instance = 'http://localhost:11434/api/generate'
auth_cookie = ''

for base_model in base_models:
    print(f'> [BASE MODEL]: {base_model}')

    model = f'{base_model}_summarizer'

    responses = {}
    i=1

    # Resolves responses disk path
    summary_id = f'summary_S{str(samp_size)}{'_balanced' if balanced_data else ''}{'_sp' + target_split if target_split else ''}_{base_model}_mc{str(max_chars)}{'_ss' + str(subsamp_size) if subsamp_size else ''}'
    summary_path = Path(f'{mimicpath}/{summary_id}.csv')

    if summary_path.is_file():
        # Loads existing file and assumes it contains the same data structure as the generated output (ie, is a compatible dataframe)
        print(f'>> (i) Target file already exists. Parsing contents and updating entries to process...')
        
        _df_existing_responses = pd.read_csv(summary_path, index_col=0)
        assert list(_df_existing_responses.columns) == ['SUMMARY']

        precomputed_indices = _df_existing_responses.index
        print(f'>> (i) {len(precomputed_indices)} indices were found in precomputed results file and will be ommitted from current execution...')
        df = df.loc[list(set(df.index) - set(precomputed_indices))]


    else:
        # Initializes empty df where summaries will be saved online:
        pd.DataFrame(columns=['SUMMARY']).to_csv(summary_path, mode='w', header=True)

    for index, row in df.iterrows():
        print(f'>> Processing row {i} out of {len(df)}', end='\r')

        # Get text from entry
        text = row['text']

        # Truncate middle if resulting text is longer than max_chars
        if len(text) > max_chars:
            print(f'>> (!) Text exceeds the max char limit ({len(text)}) in entry {index}. Middle-truncating to {max_chars}...')
            text = text[:(max_chars//2)] + text[-(max_chars//2):]
            print(f'\t... Result truncate: {len(text)}')

        formatted_input = json.dumps({'REPORT': text})
        
        data = {'model': model,  # Explicit model to use
                'options': {
                                'num_ctx': n_ctx * 1024,
                                'temperature': temp, # 0?
                                'seed': SEED,
                                'top_k': top_k,
                                'top_p': top_p
                                },
                'keep-alive': -1,  # Keep connection open
                'prompt': formatted_input,
                'stream': False,  # Wait and return all the result at once
                'format': {  # Prognosis and mortality      
                'type': 'object',
                'properties': {
                    'SUMMARY': {
                        'type': 'string'
                    }
                },
                'required': [
                    'SUMMARY'
                ]
                }
            }
        # Prepares query
        data = json.dumps(data)
        cookies = {
            '_oauth2_proxy': auth_cookie}
        headers = {
            'Content-Type': 'application/x-www-form-urlencoded',
        }

        response = requests.post(instance, cookies=cookies, headers=headers, data=data)
        response = json.loads(response.text)['response']
        #responses[index] = json.loads(response) # Keeps the dictionary version of the json response

        # Save online
        df_response = pd.Series({index: json.loads(response)['SUMMARY']}).to_frame()
        df_response.to_csv(summary_path, mode='a', header=False)
        i+=1