In [1]:
import os
import numpy as np
import pandas as pd
import pickle
from operator import itemgetter
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
import inflect
import base64
import tiktoken
import json
import re
import warnings
warnings.filterwarnings('ignore')

## Load Data

In [2]:
# Office PC
icus = pickle.load(open('./Data/processed_icu_24h.pkl','rb'))
icus_metadata = pd.read_csv('./Data/metadata_24h.csv',index_col=0)

In [3]:
# Train-val-test split
idx_train_val, idx_test = train_test_split(icus_metadata.index.values, test_size=0.2, random_state=0)
idx_train,idx_val = train_test_split(idx_train_val, test_size=0.125, random_state=0)

In [4]:
# Test
icus_test = list(itemgetter(*idx_test)(icus))
metadata_test = icus_metadata.loc[idx_test,:]
metadata_test = metadata_test.reindex(idx_test)

## Generate patient profiles in text

In [5]:
# Prompt Generation
class Generator:
    def __init__(self,icus,icus_metadata):
        self.icus = icus
        self.icus_metadata = icus_metadata
        self.inflect_engine = inflect.engine()

    def generate_demographic_input(self, patient_demographics):
        # Patient demographics
        demographics_text = "Patient Demographics:"
        demographics_text += f"""
        - Age: {patient_demographics['age'].item()}
        - Gender: {patient_demographics['gender'].item()}
        - Race: {patient_demographics['race'].item()}
        - Marital Status: {patient_demographics['marital_status'].item()}
        - Language: {patient_demographics['language'].item()}
        - Insurance: {patient_demographics['insurance'].item()}
        """
        return demographics_text

    def generate_vital_signs_input(self, vital_signs):
        vital_signs_text = "Vital Signs:\n"
        for _,content in vital_signs.iterrows():
            vital_signs_text += f"- Time from ICU Admission: {content['frac_charttime']} hours\n"
            if not np.isnan(content['Heart Rate']):
                vital_signs_text += f"    - Heart Rate: {content['Heart Rate']}\n"
            if not np.isnan(content['Respiratory Rate']):
                vital_signs_text += f"    - Respiratory Rate: {content['Respiratory Rate']}\n"
            if not np.isnan(content['O2 saturation pulseoxymetry']):
                vital_signs_text += f"    - Peripheral Oxygen Saturation: {content['O2 saturation pulseoxymetry']}\n"
            if not (np.isnan(content['Non Invasive Blood Pressure systolic']) and np.isnan(content['Non Invasive Blood Pressure diastolic']) and np.isnan(content['Non Invasive Blood Pressure mean'])):
                vital_signs_text += f"    - Non-invasive Blood Pressure (Systolic/Diastolic/Mean): {content['Non Invasive Blood Pressure systolic']} / {content['Non Invasive Blood Pressure diastolic']} / {content['Non Invasive Blood Pressure mean']}\n"
            if not (np.isnan(content['Arterial Blood Pressure systolic']) and np.isnan(content['Arterial Blood Pressure diastolic']) and np.isnan(content['Arterial Blood Pressure mean'])):
                vital_signs_text += f"    - Arterial Blood Pressure (Systolic/Diastolic/Mean): {content['Arterial Blood Pressure systolic']} / {content['Arterial Blood Pressure diastolic']} / {content['Arterial Blood Pressure mean']}\n"
            if not (np.isnan(content['GCS - Eye Opening']) and np.isnan(content['GCS - Verbal Response']) and np.isnan(content['GCS - Motor Response'])):
                vital_signs_text += f"    - GCS: Eye Opening: {content['GCS - Eye Opening']}, Verbal Response: {content['GCS - Verbal Response']}, Motor Response: {content['GCS - Motor Response']}\n"
            if not np.isnan(content['Temperature Fahrenheit']):
                vital_signs_text += f"    - Temperature Fahrenheit: {content['Temperature Fahrenheit']}\n"

        return vital_signs_text

    def generate_cxr_input(self, images,icu_intime):
        image_text = "Chest X-Ray Images:\n"
        for idx,content in images.iterrows():
            # Calculate time difference between image recording time and ICU admission time
            frac_charttime = round((content['StudyDatetime']-icu_intime).total_seconds()/3600,2)
            image_text += f"- The uploaded image is the most recent chest X-ray image.\n"

        return image_text

    def generate_report_input(self, reports, icu_intime):
        reports = reports[['charttime','text']].drop_duplicates()
        reports_text = "Radiology Reports:\n"
        for _,content in reports.iterrows():
            # Calculate time difference between report recording time and ICU admission time
            frac_charttime = round((content['charttime']-icu_intime).total_seconds()/3600,2)
            reports_text += f"- Time from ICU Admission: {frac_charttime} hours\n"
            reports_text += f"    - Report Text: {content['text']}\n"

        return reports_text

    def generate_patient_profiles(self):
        profiles = []
        nfiles = len(self.icus)

        with tqdm(total=nfiles) as pbar:
            for i in range(nfiles):
                # Extract ICU stay and metadata
                icu = self.icus[i]
                icu_metadata = self.icus_metadata.iloc[i]
                if icu.tabular['stay_id'].item()!=icu_metadata['stay_id']:
                  print('Error in aligning metadata: stay ID mismatch')
                  break
                icu_metadata['intime'] = pd.to_datetime(icu_metadata['intime'])
                icu_intime = icu_metadata['intime']

                # Generate textual input for each modality
                text_list = []
                # Demographics
                demo_text = self.generate_demographic_input(icu.tabular)
                text_list.append(demo_text)
                # Vital signs
                if not icu.time_series.empty:
                  ts = icu.time_series
                  # Time for last measurement
                  last_time = ts.iloc[-1,-1]
                  # Select records within 2 hours
                  selected_ts = ts[ts['frac_charttime']>=last_time-2]
                  vs_text = self.generate_vital_signs_input(selected_ts)
                  text_list.append(vs_text)
                # CXR
                if not icu.images['metadata'].empty:
                  image = pd.merge(icu.images['metadata'], icu.images['image_path'], how='left', on=['dicom_id', 'study_id', 'subject_id'])
                  img_text = self.generate_cxr_input(image.iloc[-1:], icu_intime)
                  text_list.append(img_text)
                # Radiology reports
                if not icu.notes['radiology'].empty:
                    reports_text = self.generate_report_input(icu.notes['radiology'], icu_intime)
                    text_list.append(reports_text)

                # Prompt
                icu_text = "\n".join(text_list)

                # user_prompt = (
                #     f"Based on the provided patient hospital profile: \n{icu_text}\n"
                #     f"Answer the question using only a probability between 0 and 1\n"
                #     f"Question: Will the patient die during current hospital admission?.\n"
                #     # Will the patient stay in ICU for more than three days?
                #     # f"Answer (probability between 0 and 1):"
                # )
                profiles.append(icu_text)
                # Update
                pbar.update(1)

        return profiles

In [6]:
generator = Generator(icus_test,metadata_test)
patient_profiles = generator.generate_patient_profiles()

100%|██████████| 14637/14637 [00:28<00:00, 519.48it/s]


## Generate image path

In [7]:
def extract_path(metadata):
    """Extract image path from ICU profiles"""
    p_id = str(int(metadata['subject_id'].item()))
    s_id = str(int(metadata['study_id'].item()))
    d_id = str(metadata['dicom_id'].item())
    # Path for JPG files in google cloud
    image_path = f"files/p{p_id[:2]}/p{p_id}/s{s_id}/{d_id}.jpg"
    # Path for downloaded files
    # local_path =  os.path.join(jpg_dir, os.path.basename(remote_path))
    return image_path

In [8]:
with tqdm(total=len(icus_test)) as pbar:
  images = []
  for icu in icus_test:
    if not icu.images['metadata'].empty:
      # Download
      image_path = extract_path(icu.images['metadata'].iloc[-1:])
      # download_blob(bucket_name, remote_path, local_path, project_id)
      # # Open
      # image = Image.open(local_path)
      # resized_image = image.resize((512, 512))
      images.append(image_path)
      pbar.update(1)
    else:
      images.append(None)
      pbar.update(1)
      continue

100%|██████████| 14637/14637 [00:00<00:00, 67698.33it/s]


## Create Dataset

In [24]:
qa_dataset_test = []
nfiles = len(icus_test)
with tqdm(total=nfiles) as pbar:
    for idx in range(nfiles):
        patient_info = {}
        # Create Dataset
        patient_info['id'] = int(idx+1)
        patient_info['context'] = patient_profiles[idx]
        patient_info['image'] = images[idx]
        patient_info['question'] = {'hospital_expire_flag':  "Will the patient die during current hospital admission?",
                                    'los_binary': "Will this patient's ICU stay exceed three days?"}
        patient_info['answer'] = {'hospital_expire_flag': int(metadata_test.iloc[idx]['hospital_expire_flag']),
                                'los_binary': int(metadata_test.iloc[idx]['los_binary'])}
        # Append
        qa_dataset_test.append(patient_info)
        # Update
        pbar.update(1) 

100%|██████████| 14637/14637 [00:00<00:00, 20974.37it/s]


In [25]:
output_path = 'Data/qa_dataset_test.json'
with open(output_path, "w") as f:
    json.dump(qa_dataset_test, f, indent=4)