In [1]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re
from tqdm import tqdm

In [2]:
# Load ADMISSIONS data
df_adm = pd.read_csv('ADMISSIONS_sorted.csv')
# Display the first few rows of the dataframe
# print(df_adm.head())
df_adm.sample(5)
df_adm.info()
df_adm.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12911 entries, 0 to 12910
Data columns (total 19 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   ROW_ID                12911 non-null  int64 
 1   SUBJECT_ID            12911 non-null  int64 
 2   HADM_ID               12911 non-null  int64 
 3   ADMITTIME             12911 non-null  object
 4   DISCHTIME             12911 non-null  object
 5   DEATHTIME             1287 non-null   object
 6   ADMISSION_TYPE        12911 non-null  object
 7   ADMISSION_LOCATION    12911 non-null  object
 8   DISCHARGE_LOCATION    12911 non-null  object
 9   INSURANCE             12911 non-null  object
 10  LANGUAGE              3445 non-null   object
 11  RELIGION              12728 non-null  object
 12  MARITAL_STATUS        9714 non-null   object
 13  ETHNICITY             12911 non-null  object
 14  EDREGTIME             5932 non-null   object
 15  EDOUTTIME             5932 non-null 

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,HOSPITAL_EXPIRE_FLAG,HAS_CHARTEVENTS_DATA
count,12911.0,12911.0,12911.0,12911.0,12911.0
mean,6456.0,5294.025482,149468.268453,0.099682,0.969096
std,3727.228998,3043.092031,28803.637575,0.299588,0.173064
min,1.0,2.0,100006.0,0.0,0.0
25%,3228.5,2664.0,124511.5,0.0,1.0
50%,6456.0,5308.0,149196.0,0.0,1.0
75%,9683.5,7899.5,174239.5,0.0,1.0
max,12911.0,10566.0,199986.0,1.0,1.0


In [3]:
# Convert admission, discharge, and death times to datetime format
df_adm.ADMITTIME = pd.to_datetime(df_adm.ADMITTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
df_adm.DISCHTIME = pd.to_datetime(df_adm.DISCHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
df_adm.DEATHTIME = pd.to_datetime(df_adm.DEATHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')

# Sort by SUBJECT_ID and ADMITTIME
df_adm = df_adm.sort_values(['SUBJECT_ID', 'ADMITTIME']).reset_index(drop=True)

# Shift the ADMITTIME and ADMISSION_TYPE columns for the next admission
df_adm['NEXT_ADMITTIME'] = df_adm.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1)
df_adm['NEXT_ADMISSION_TYPE'] = df_adm.groupby('SUBJECT_ID')['ADMISSION_TYPE'].shift(-1)

# Remove 'ELECTIVE' next admission types
rows = df_adm.NEXT_ADMISSION_TYPE == 'ELECTIVE'
df_adm.loc[rows, 'NEXT_ADMITTIME'] = pd.NaT
df_adm.loc[rows, 'NEXT_ADMISSION_TYPE'] = np.nan

# Fill missing values in NEXT_ADMITTIME and NEXT_ADMISSION_TYPE by backward fill
df_adm[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']] = df_adm.groupby('SUBJECT_ID')[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']].fillna(method='bfill')

# Calculate the number of days until next admission
df_adm['DAYS_NEXT_ADMIT'] = (df_adm.NEXT_ADMITTIME - df_adm.DISCHTIME).dt.total_seconds() / (24 * 60 * 60)

# Output label for whether the next admission is within 30 days
df_adm['OUTPUT_LABEL'] = (df_adm.DAYS_NEXT_ADMIT < 30).astype('int')

# Filter out newborn admissions and deaths
df_adm = df_adm[df_adm['ADMISSION_TYPE'] != 'NEWBORN']
df_adm = df_adm[df_adm.DEATHTIME.isnull()]

# Calculate the duration of admission in days
df_adm['DURATION'] = (df_adm['DISCHTIME'] - df_adm['ADMITTIME']).dt.total_seconds() / (24 * 60 * 60)

  df_adm[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']] = df_adm.groupby('SUBJECT_ID')[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']].fillna(method='bfill')
  df_adm[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']] = df_adm.groupby('SUBJECT_ID')[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']].fillna(method='bfill')
  df_adm[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']] = df_adm.groupby('SUBJECT_ID')[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']].fillna(method='bfill')


In [4]:
import csv

with open('NOTEEVENTS_sorted.csv', 'r', newline='', encoding='utf-8') as infile, \
     open('fixed_NOTEEVENTS_sorted.csv', 'w', newline='', encoding='utf-8') as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile, quotechar='"', quoting=csv.QUOTE_MINIMAL)
    
    for row in reader:
        writer.writerow(row)
df_notes = pd.read_csv('fixed_NOTEEVENTS_sorted.csv')
df_notes.info()
df_notes.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 269674 entries, 0 to 269673
Data columns (total 11 columns):
 #   Column       Non-Null Count   Dtype  
---  ------       --------------   -----  
 0   ROW_ID       269674 non-null  int64  
 1   SUBJECT_ID   269674 non-null  int64  
 2   HADM_ID      241986 non-null  float64
 3   CHARTDATE    269674 non-null  object 
 4   CHARTTIME    232921 non-null  object 
 5   STORETIME    173805 non-null  object 
 6   CATEGORY     269674 non-null  object 
 7   DESCRIPTION  269674 non-null  object 
 8   CGID         173805 non-null  float64
 9   ISERROR      26 non-null      float64
 10  TEXT         269674 non-null  object 
dtypes: float64(3), int64(2), object(6)
memory usage: 22.6+ MB


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CGID,ISERROR
count,269674.0,269674.0,241986.0,173805.0,26.0
mean,1135995.0,2926.876206,149700.346334,17702.098898,1.0
std,539469.7,1716.323612,28926.696032,2151.30381,0.0
min,26.0,2.0,100009.0,14020.0,1.0
25%,819240.2,1402.0,125058.0,15805.0,1.0
50%,1286760.0,2900.0,149339.0,17593.0,1.0
75%,1693897.0,4474.0,174406.0,19528.0,1.0
max,1761315.0,5841.0,199971.0,21570.0,1.0


In [5]:
# Load NOTEEVENTS data
# df_notes = pd.read_csv('NOTEEVENTS_sorted.csv')
# df_notes.info()
# Sort notes by SUBJECT_ID, HADM_ID, and CHARTDATE
df_notes = df_notes.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE'])

# Merge ADMISSIONS data with NOTEEVENTS data
df_adm_notes = pd.merge(
    df_adm[
        ['SUBJECT_ID', 'HADM_ID', 'ADMITTIME', 'DISCHTIME', 'DAYS_NEXT_ADMIT', 'NEXT_ADMITTIME', 'ADMISSION_TYPE', 'DEATHTIME', 'OUTPUT_LABEL', 'DURATION']
    ],
    df_notes[['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'TEXT', 'CATEGORY']],
    on=['SUBJECT_ID', 'HADM_ID'],
    how='left',
)

# Preprocess admission times and chart dates
df_adm_notes['ADMITTIME_C'] = pd.to_datetime(
    df_adm_notes.ADMITTIME.apply(lambda x: str(x).split(' ')[0]), format='%Y-%m-%d', errors='coerce'
)
df_adm_notes['CHARTDATE'] = pd.to_datetime(df_adm_notes['CHARTDATE'], format='%Y-%m-%d', errors='coerce')


In [6]:
# Extract discharge summaries
df_discharge = df_adm_notes[df_adm_notes['CATEGORY'] == 'Discharge summary']
df_discharge = df_discharge.groupby(['SUBJECT_ID', 'HADM_ID']).nth(-1).reset_index()
df_discharge = df_discharge[df_discharge['TEXT'].notnull()]


In [7]:
# Function to extract early notes (less than n days)
def less_n_days_data(df_adm_notes, n):
    df_less_n = df_adm_notes[
        ((df_adm_notes['CHARTDATE'] - df_adm_notes['ADMITTIME_C']).dt.total_seconds() / (24 * 60 * 60)) < n
    ]
    df_less_n = df_less_n[df_less_n['TEXT'].notnull()]
    df_concat = pd.DataFrame(df_less_n.groupby('HADM_ID')['TEXT'].apply(lambda x: "%s" % ' '.join(x))).reset_index()
    df_concat['OUTPUT_LABEL'] = df_concat['HADM_ID'].apply(
        lambda x: df_less_n[df_less_n['HADM_ID'] == x].OUTPUT_LABEL.values[0]
    )
    return df_concat

df_less_2 = less_n_days_data(df_adm_notes, 2)
df_less_3 = less_n_days_data(df_adm_notes, 3)


In [8]:
# Function to preprocess text data
def preprocess1(x):
    y = re.sub(r'\[\*\*(.*?)\*\*\]', '', x)  # Remove de-identified brackets
    y = re.sub(r'[0-9]+\.', '', y)  # Remove 1.2. segments
    y = re.sub(r'dr\.', 'doctor', y)
    y = re.sub(r'm\.d\.', 'md', y)
    y = re.sub(r'admission date:', '', y)
    y = re.sub(r'discharge date:', '', y)
    y = re.sub(r'--|__|==', '', y)
    return y

# Function to preprocess and chunk text data
def preprocessing(df_less_n):
    df_less_n['TEXT'] = df_less_n['TEXT'].fillna(' ')
    df_less_n['TEXT'] = df_less_n['TEXT'].str.replace('\n', ' ').str.replace('\r', ' ').str.strip().str.lower()
    df_less_n['TEXT'] = df_less_n['TEXT'].apply(preprocess1)

    chunks = []
    for i in tqdm(range(len(df_less_n))):
        x = df_less_n.TEXT.iloc[i].split()
        n = len(x) // 318
        for j in range(n):
            chunks.append({'TEXT': ' '.join(x[j * 318 : (j + 1) * 318]), 'Label': df_less_n.OUTPUT_LABEL.iloc[i], 'ID': df_less_n.HADM_ID.iloc[i]})
        if len(x) % 318 > 10:
            chunks.append({'TEXT': ' '.join(x[-(len(x) % 318) :]), 'Label': df_less_n.OUTPUT_LABEL.iloc[i], 'ID': df_less_n.HADM_ID.iloc[i]})
    return pd.DataFrame(chunks)

df_discharge = preprocessing(df_discharge)
df_less_2 = preprocessing(df_less_2)
df_less_3 = preprocessing(df_less_3)


100%|██████████| 4772/4772 [00:01<00:00, 2790.72it/s]
100%|██████████| 4845/4845 [00:01<00:00, 2527.46it/s]
100%|██████████| 4897/4897 [00:02<00:00, 2429.35it/s]


In [9]:
# Split data into train, validation, and test sets
readmit_ID = df_adm[df_adm.OUTPUT_LABEL == 1].HADM_ID
not_readmit_ID = df_adm[df_adm.OUTPUT_LABEL == 0].HADM_ID

# Equalize the number of not readmit and readmit IDs
not_readmit_ID_use = not_readmit_ID.sample(n=len(readmit_ID), random_state=1)

# Split into validation and test sets
id_val_test_t = readmit_ID.sample(frac=0.2, random_state=1)
id_val_test_f = not_readmit_ID_use.sample(frac=0.2, random_state=1)

id_train_t = readmit_ID.drop(id_val_test_t.index)
id_train_f = not_readmit_ID_use.drop(id_val_test_f.index)

id_val_t = id_val_test_t.sample(frac=0.5, random_state=1)
id_test_t = id_val_test_t.drop(id_val_t.index)

id_val_f = id_val_test_f.sample(frac=0.5, random_state=1)
id_test_f = id_val_test_f.drop(id_val_f.index)

id_test = pd.concat([id_test_t, id_test_f])
id_val = pd.concat([id_val_t, id_val_f])
id_train = pd.concat([id_train_t, id_train_f])


In [10]:
# Final dataset preparation
discharge_train = df_discharge[df_discharge.ID.isin(id_train)]
discharge_val = df_discharge[df_discharge.ID.isin(id_val)]
discharge_test = df_discharge[df_discharge.ID.isin(id_test)]

# Save datasets to CSV files
discharge_train.to_csv('./discharge/train.csv', index=False)
discharge_val.to_csv('./discharge/val.csv', index=False)
discharge_test.to_csv('./discharge/test.csv')
