In [1]:
import pandas as pd
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from pathlib import Path
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Project root path
pjpath = ''

# Hacky way of finding the project's root path. Do not rely on this, set your own pjpath!
for p in Path.cwd().parents:
    if p.stem == 'llms4mortality':
        pjpath = p
        break

print(f'> Project path is {pjpath}')

> Project path is /home/daucco/ownCloud/unsync/_entregabledata/llms4mortality


In [3]:
# Globals

# These must match the precomputed df
samp_size = 5000
balanced_data = True

# This is only relevant if computing embedding from summaries (which need to be precomputed and available in disk)
summaries = False       # If set, tries to load summaries from summary path and update text column in original data with them
target_split = False    # Set this to the name of the split that you want to get the embeddings from. Otherwise set it to False to get emebeddings from the whole dataframe
base_summ_model = 'll3' # Summarization ollama model. Only relevant if setting "summaries"
max_chars = 22000       # Max chars used when generating summaries with ollama
subsamp_size = False    # If set uses a subsample (over samp_size)

# Model definition: Specify (model_name, batch_size, truncation_side)
# brandonhcheung04/bart                         ## fine-tuned version of facebook/bart-base for abstractive summarization of clinical notes, trained on the MIMIC-IV dataset
# Simonlee711/Clinical_ModernBERT               ## encoder-based transformer tailored specifically for biomedical and clinical text handling context length up to 8192 tokens
# xyla/Clinical-T5                              ## T5 variants on the union of MIMIC-III and MIMIC-IV - NEED TO DOWNLOAD THE MODELS SEPARATELY!
# all-distilroberta-v1                          ## Non-specific
# medicalai/ClinicalBERT                        ## Healthcare-specific
# emilyalsentzer/Bio_Discharge_Summary_BERT     ## MIMIC-III discharge notes
# nazyrova/clinicalBERT                         ## MIIMC-IV discharge notes

modnames = [
    ## Big models:
    #('Simonlee711/Clinical_ModernBERT', 4, 'left'),
    #(f'{pjpath}/src/models/Clinical-T5-Scratch', 1, 'right')            # Too big for a commercial gpu
    #('brandonhcheung04/bart', 32, 'left'),

    ## Reasonable models
    #('emilyalsentzer/Bio_Discharge_Summary_BERT', 256, 'right'),
    #('all-distilroberta-v1', 256, 'right'),
    #('medicalai/ClinicalBERT', 256, 'right'),
    ('nazyrova/clinicalBERT', 256, 'right') # Found best
]

truncation_side = 'middle'  #   left, right, middle. Where to apply truncation when encoding. If middle, applies left & right and concatenates results (2 passes)
prepend_columns = ['age', 'gender', 'insurance', 'marital_status', 'race', 'diagnose_group_description', 'diagnose_group_mortality']   # Columns in base dataframe to prepend to text

In [4]:
# Set this to your MIMIC-IV path where discharge, patients and admissions tables are located
mimicpath = pjpath / 'data/mimiciv'

In [5]:
# Load precomputed dataframe. Keeps only hadm_id and text
df = pd.read_csv(mimicpath / f'mimiciv_4_mortality_S{samp_size}{'_balanced' if balanced_data else ''}.csv.gz')[['hadm_id', 'text', *prepend_columns]]

# Expected summary name
summary_id = f'summary_S{str(samp_size)}{'_balanced' if balanced_data else ''}{'_sp' + target_split if target_split else ''}_{base_summ_model}_mc{str(max_chars)}{'_ss' + str(subsamp_size) if subsamp_size else ''}'

if summaries:
    print(f'> (!) Using summaries')
    # Use precomputed summaries instead of text. Load them from disk
    df_summ = pd.read_csv(mimicpath / f'summaries/{summary_id}.csv', index_col=0)
    df_summ.index.rename('hadm_id', inplace=True)
    df = pd.merge(df, df_summ, on='hadm_id', how='inner')
    df['text'] = df.apply(lambda x: x['SUMMARY'], axis=1)

# Preprends patient data to text column. Replaces underscore with spaces in both, feature name and value
df['text'] = df.apply(lambda x: ''.join([f'{p_cname.replace('_', ' ')}: {str(x[p_cname]).replace('_', ' ')}\n' for p_cname in prepend_columns]) + x['text'], axis=1)

df.head()

Unnamed: 0,hadm_id,text,age,gender,insurance,marital_status,race,diagnose_group_description,diagnose_group_mortality
0,21891113,age: 68\ngender: F\ninsurance: Other\nmarital ...,68,F,Other,DIVORCED,WHITE,"DISORDERS OF GALLBLADDER & BILIARY TRACT, DISO...",MODERATE
1,29643114,age: 96\ngender: M\ninsurance: Medicare\nmarit...,96,M,Medicare,WIDOWED,WHITE,"CARDIAC ARRHYTHMIA & CONDUCTION DISORDERS, ACU...",HIGH
2,26747385,age: 47\ngender: F\ninsurance: Other\nmarital ...,47,F,Other,MARRIED,ASIAN,"DIGESTIVE MALIGNANCY, DIGESTIVE MALIGNANCY W MCC",HIGH
3,23932127,age: 64\ngender: M\ninsurance: Other\nmarital ...,64,M,Other,MARRIED,WHITE,"PERCUTANEOUS CORONARY INTERVENTION W AMI, PERC...",LOW
4,27210508,age: 72\ngender: M\ninsurance: Other\nmarital ...,72,M,Other,MARRIED,WHITE,"INTRACRANIAL HEMORRHAGE, INTRACRANIAL HEMORRHA...",HIGH


In [None]:
# Getes sentence embeddings for each model
encoder_kwargs = {
    #'batch_size': encoding_bsize,
    'output_value': 'sentence_embedding',
    'show_progress_bar': True,
    'convert_to_numpy': True
}

m_count = 1
for modname, mod_bsize, mod_truncation in modnames:
    print(f'> Processing model {m_count}/{len(modnames)} - ({modname})')
    if mod_truncation == 'middle':
        # Truncates right, gets left-side embeddings
        model = SentenceTransformer(modname, tokenizer_kwargs={'truncation_side': 'right'})
        embeddings_l = model.encode(df.text.to_list(), batch_size=mod_bsize, **encoder_kwargs)

        # Truncates right, gets left-side embeddings
        model = SentenceTransformer(modname, tokenizer_kwargs={'truncation_side': 'left'})
        embeddings_r = model.encode(df.text.to_list(), batch_size=mod_bsize, **encoder_kwargs)

        # Concatenates tensors horizontally
        embeddings = np.concatenate((embeddings_l, embeddings_r), axis=1)

    else:
        model = SentenceTransformer(modname, tokenizer_kwargs={'truncation_side': mod_truncation})
        embeddings = model.encode(df.text.to_list(), batch_size=mod_bsize, **encoder_kwargs)

    # Exports result to disk
    print(f'> Exporting resulting embeddings to {mimicpath}...')
    with open(mimicpath / f'embeddings_{re.sub('[^a-zA-Z0-9]+', '', modname)}_{'summary_' if summaries else ''}S{samp_size}_T{mod_truncation}{'_balanced' if balanced_data else ''}{'_PR' if len(prepend_columns)>0 else ''}.npy', 'wb') as ofile:
        np.save(ofile, embeddings)

    m_count += 1