In [1]:
# CoT mortality prediction using precomputed summaries from both, test entries and examples

In [2]:
import json
import requests
import pandas as pd
from pathlib import Path

In [None]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud-UPM/CBR/medical-cbr


In [None]:
# Relevant paths
mimicpath = pjpath / 'data/mimiciv'

In [None]:
# Globals

SEED = 42

# Controls which data to load
samp_size = 5000
balanced_data = True

#base_model = 'iv_ll3_direct'  # Needs to be available

base_models = [
    'll3',
]

# Slice long input: Just keep up to max_words of each text
max_chars = 22000
subsamp_size = 200  # Number of entries to test model with

# LLM parameters
input_type = 'S'    # (R)eport, (R)eport and (C)hart data as json
output_type = {     # Maps type of output with the actual format that will be requested in the query
    'M': {  # Just mortality      
            'type': 'object',
            'properties': {
                'DIES': {
                    'enum': ['YES', 'NO']
                }
            },
            'required': [
                'DIES'
            ]
            }
    }

# This is the collection of columns that contains the relevant patient info that will be provided to the LLM with the text report
# Remapping some column names to make them more significant in the prompt
pdc_remap = {
    'age': 'AGE',
    'gender': 'GENDER',
    'marital_status': 'MARITAL STATUS',
    'race': 'RACE',
    'diagnose_group_description': 'BROAD DIAGNOSIS',
    'diagnose_group_mortality': 'MORTALITY RISK',
    'insurance': 'INSURANCE',
    #'text': 'REPORT'
}

n_ctx = 32   # Context length (x 1024)
temp = 0.0 # Temperature option for the LLM. The greater, the more creative the answer (def 0.1)
top_k = 20
top_p = 0.5

In [6]:
# Load precomputed summaries (both from test entries and neighbours)

summary_id = f'summary_{max_chars}mc_ss{subsamp_size}'
neigh_summary_id = f'neighbour_summary_{max_chars}mc_ss{subsamp_size}'
df_test_summaries = pd.read_csv(f'{mimicpath}/summaries/{summary_id}.csv', index_col=0)
df_neigh_summaries = pd.read_csv(f'{mimicpath}/summaries/{neigh_summary_id}.csv', index_col=0)

In [None]:
# Fire instance
# NOTE: Assuming fixed input type (S) and output (M)
i_type_id = input_type
o_type_id = 'M'
o_type_format = output_type[o_type_id]

instance = 'http://localhost:11434/api/generate'
auth_cookie = ''

for base_model in base_models:
    print(f'> [BASE MODEL]: {base_model}')
    responses = {}
    i=1

    for index, row in df_test_summaries.iterrows():
        print(f'>> Processing row {i} out of {len(df_test_summaries)}', end='\r')

        # Get text from entry
        text = row['SUMMARY']

        # Truncate middle if resulting text is longer than max_chars
        if len(text) > max_chars:
            print(f'(i) Text exceeds the max char limit ({len(text)}) in entry {index}. Middle-truncating to {max_chars}...')
            text = text[:(max_chars//2)] + text[-(max_chars//2):]
            print(f'... Result truncate: {len(text)}')

        # Get neighour test from entry
        neigh_text = df_neigh_summaries.loc[index]['SUMMARY']
        neigh_dies = df_neigh_summaries.loc[index]['DIES']

        # Truncate middle if resulting text is longer than max_chars
        if len(neigh_text) > max_chars:
            print(f'(i) Neighbour text exceeds the max char limit ({len(neigh_text)}) in entry {index}. Middle-truncating to {max_chars}...')
            neigh_text = neigh_text[:(max_chars//2)] + neigh_text[-(max_chars//2):]
            print(f'... Result truncate: {len(neigh_text)}')

        formatted_input = json.dumps({'REPORT': text})
        formatted_input_neigh = json.dumps({'REPORT': neigh_text})
        formatted_output_neigh = json.dumps({'DIES': neigh_dies})

        # Builds full CoT prompt with example
        #cot1nn_prompt = "Here's an example of the INPUT and OUTPUT of the mortality prediction task using a medical that was previously deemed similar to the one to resolve:"
        
        cot1nn_prompt = "\nBelow is an example of the medical text report from which you will have to decide if the patient will die within 30 days of their medical discharge:\n"
        cot1nn_prompt += formatted_input_neigh
        cot1nn_prompt += "\nAnd here's its solution:\n"
        cot1nn_prompt += formatted_output_neigh
        cot1nn_prompt += "\nNow decide if the following text report corresponds to a patient who is likely to die within 30 days of medical discharge as per your main instructions, and considering that the previous example corresponds to a medical case that is similar to the one you have to resolve now:\n"
        cot1nn_prompt += formatted_input

        model = f'{base_model}_{i_type_id}_{o_type_id}'
        data = {'model': model,  # Explicit model to use
            'options': {
                'num_ctx': n_ctx * 1024,
                'temperature': temp,
                'seed': SEED,
                'top_k': top_k,
                'top_p': top_p
                },
            'keep-alive': -1,  # Keep connection open
            'prompt': cot1nn_prompt,
            'stream': False,  # Wait and return all the result at once
            'format': o_type_format
        }

        # Prepares query
        data = json.dumps(data)
        cookies = {
            '_oauth2_proxy': auth_cookie}
        headers = {
            'Content-Type': 'application/x-www-form-urlencoded',
        }
        cot1nn_response = requests.post(instance, cookies=cookies, headers=headers, data=data)
        cot1nn_response_json = json.loads(cot1nn_response.text)['response']
        responses[index] = json.loads(cot1nn_response_json) # Keeps the dictionary version of the json response
        #summarized_texts.append(json.loads(response)['summary'])

        i+=1

    # Export results
    df_responses = pd.DataFrame(responses).T
    test_id = f'{base_model}_{i_type_id}_{o_type_id}_{n_ctx}k_t{str(temp).replace('.', '')}'
    df_responses.to_csv(f'../results/ex2/{test_id}.csv')
    print(f'\n')

> [BASE MODEL]: iv_ll3_direct
>> Processing row 100 out of 100

> [BASE MODEL]: medgenius32b
>> Processing row 100 out of 100

> [BASE MODEL]: dsmedical8b
>> Processing row 100 out of 100

> [BASE MODEL]: biomistral7b
>> Processing row 100 out of 100

