Skip to content

Commit

Permalink
feat(Multi med administrations): Support for multiple medication admi…
Browse files Browse the repository at this point in the history
…nistrations

1. Changes to MIMIC parsers to handle multiple medications
2. Medication lab test generations with multiple medications
3. Respiratory rate itemid removed because of clash with medication item id in hirid
  • Loading branch information
PavanReddy28 committed Jun 11, 2023
1 parent edaa1f4 commit f64e27d
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 100 deletions.
198 changes: 120 additions & 78 deletions src/modeling/querier.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
import pandas as pd
import os
import tqdm

from src.utils.utils import AnalysisUtils, get_normalized_trend
from src.utils import constants


class DatasetQuerier(AnalysisUtils):

def __init__(self, data, res, t_labs, t_med1, t_med2, gender="MF", age_b=0, age_a=100, ethnicity="WHITE", lab_mapping=None):
self.final = None
self.temp = None
def __init__(self, data, res, t_labs, meds, gender="MF", age_b=0, age_a=100, ethnicity="WHITE", lab_mapping=None):
self.final_pairs_data, self.interim_pairs_data = [], []
self.t_labs = t_labs
self.t_med1 = t_med1
self.t_med2 = t_med2
self.meds = meds
super().__init__(data, res, gender=gender, age_b=age_b, age_a=age_a, ethnicity=ethnicity, lab_mapping=lab_mapping)

def check_med2(self, t_med2, row):
def check_medk(self, t_medk, row):
"""
Check if a 2nd medication was administered to patients
Check if a kth medication was administered to patients
"""
if row["HADM_ID"] in t_med2["HADM_ID"].to_list():
if row["ITEMID"] in t_med2[t_med2["HADM_ID"]==row["HADM_ID"]]["ITEMID"].to_list():
if t_medk is None:
return False
if row["HADM_ID"] in t_medk["HADM_ID"].to_list():
if row["ITEMID"] in t_medk[t_medk["HADM_ID"]==row["HADM_ID"]]["ITEMID"].to_list():
return True
return False

def get_med2(self, t_med2, row):
def get_medk(self, t_medk, row):
'''
Return 2nd medication data
'''
temp = t_med2[t_med2["HADM_ID"]==row["HADM_ID"]]
temp = t_medk[t_medk["HADM_ID"]==row["HADM_ID"]]
return temp[temp["ITEMID"]==row["ITEMID"]].iloc[0]

def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
def get_vals(self, r, t_labs, t_med0, t_med2, before_windows, after_windows):
"""
Calculate the lab test values in time windows before and after medication administration. Return a dataframe with labtest values of before and after windows as a dict
Params:
Expand All @@ -42,6 +43,7 @@ def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
row = r.copy()
for b_w in before_windows:
lab_vals = t_labs[t_labs["HADM_ID"]==row["HADM_ID"]]
med0_bool = self.check_medk(t_med0, row)
lab_vals = lab_vals[lab_vals["LabTimeFromAdmit"].dt.total_seconds()<row["MedTimeFromAdmit"].total_seconds()]

b_window_start = row["MedTimeFromAdmit"].total_seconds() - (b_w[0]*3600)
Expand All @@ -50,6 +52,10 @@ def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
lab_vals = lab_vals[lab_vals["LabTimeFromAdmit"].dt.total_seconds()>b_window_end]
lab_vals["hours_from_med"] = (row["STARTTIME"]-lab_vals["CHARTTIME"]).dt.total_seconds()/3600
lab_vals = lab_vals.sort_values(["ITEMID", "hours_from_med"])

if med0_bool:
med0_val = self.get_medk(t_med0, row)
lab_vals = lab_vals[lab_vals["LabTimeFromAdmit"].dt.total_seconds()>med0_val["MedTimeFromAdmit"].total_seconds()]

t = lab_vals.groupby(["ITEMID"]).count()[["HADM_ID"]]
val_counts_m = t[t["HADM_ID"]>=1]
Expand All @@ -70,7 +76,7 @@ def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
for a_w in after_windows:

lab_vals = t_labs[t_labs["HADM_ID"]==row["HADM_ID"]]
med2_bool = self.check_med2(t_med2, row)
med2_bool = self.check_medk(t_med2, row)
lab_vals = lab_vals[lab_vals["LabTimeFromAdmit"].dt.total_seconds()>row["MedTimeFromAdmit"].total_seconds()]
a_window_start = row["MedTimeFromAdmit"].total_seconds() + (a_w[0]*3600)
a_window_end = row["MedTimeFromAdmit"].total_seconds() + (a_w[1])*3600
Expand All @@ -80,7 +86,7 @@ def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
lab_vals = lab_vals.sort_values(["ITEMID", "hours_from_med"])

if med2_bool:
med2_val = self.get_med2(t_med2, row)
med2_val = self.get_medk(t_med2, row)
lab_vals = lab_vals[lab_vals["LabTimeFromAdmit"].dt.total_seconds()<med2_val["MedTimeFromAdmit"].total_seconds()]

t = lab_vals.groupby(["ITEMID"]).count()[["HADM_ID"]]
Expand All @@ -103,41 +109,60 @@ def get_vals(self, r, t_labs, t_med1, t_med2, before_windows, after_windows):
return row

def generate_med_lab_data(self, before_windows, after_windows, lab_parts=(0,50)):
"""Generate lab test values in before and after windows of medication
Args:
before_windows (list of tuples): before windows (in hours) Ex: [(1,2), (2,3)]
after_windows (list of tuples): after windows (in hours) Ex: [(1,2), (2,3)]
use_id (bool, optional):
Returns:
list(pd.DataFrame): Med lab pair values are present in this dataframe. Each row contains the medication value and the before/after lab test value.
list(pd.DataFrame): Contains columns with dictionaries of lab values. Columns are named based on before and after window
"""
Generate lab test values in before and after windows of medication
"""

t_labs, t_med1, t_med2 = self.t_labs, self.t_med1, self.t_med2
self.final_pairs_data = []
self.interim_pairs_data = []

all_types = set(["abs", "mean", "std", "trends", "time"])
cols_b = [f"before_{t}_{b_w}" for b_w in before_windows for t in all_types]
cols_a = [f"after_{t}_{a_w}" for a_w in after_windows for t in all_types]
cols = cols_b.copy()
cols.extend(cols_a)
temp = t_med1.copy()

self.temp = temp.apply(lambda r : self.get_vals(r, t_labs, t_med1, t_med2, before_windows, after_windows), axis=1)
self.temp.to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_first_val_{self.stratify_prefix}_doc_eval_new_win_{lab_parts}.csv"))

col_vals = []
for col in cols:
col_vals.append(
temp.assign(dict=temp[col].dropna().map(lambda d: d.items())).explode("dict", ignore_index=True).assign(
LAB_ITEMID=lambda df: df.dict.str.get(0),
temp=lambda df: df.dict.str.get(1)
).drop(columns=["dict"]+cols).astype({'temp':'float64'}).rename(columns={"temp":f"{col}_sp"}).dropna(subset=["LAB_ITEMID"])
)
for i in range(1, len(col_vals)):
col_vals[i] = pd.merge(col_vals[i-1], col_vals[i], how="outer", on=list(t_med1.columns)+["LAB_ITEMID"])

final = col_vals[-1][list(t_med1.columns)+["LAB_ITEMID"]+[f"{col}_sp" for col in cols]]
final["LAB_NAME"] = final["LAB_ITEMID"]
final = final.rename(columns={"ITEMID":"MED_NAME"})
self.final = final

final.to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_trends_first_val_{self.stratify_prefix}_doc_eval_win_{lab_parts}.csv"))
t_labs = self.t_labs
for i in tqdm.tqdm(range(len(self.meds))):

if i==0 :
t_med0, t_med1, t_med2 = None, self.meds[i], self.meds[i+1]
elif i== len(self.meds)-1:
t_med0, t_med1, t_med2 = self.meds[i-1], self.meds[i], None
else:
t_med0, t_med1, t_med2 = self.meds[i-1], self.meds[i], self.meds[i+1]

all_types = set(["abs", "mean", "std", "trends", "time"])
cols_b = [f"before_{t}_{b_w}" for b_w in before_windows for t in all_types]
cols_a = [f"after_{t}_{a_w}" for a_w in after_windows for t in all_types]
cols = cols_b.copy()
cols.extend(cols_a)
temp = t_med1.copy()

return final, temp
self.interim_pairs_data.append(temp.apply(lambda r : self.get_vals(r, t_labs, t_med0, t_med2, before_windows, after_windows), axis=1))
self.interim_pairs_data[-1].to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_first_val_{self.stratify_prefix}_doc_eval_new_win_lab{lab_parts}_med({i}, {i+1}).csv"))
temp = self.interim_pairs_data[-1]

col_vals = []
for col in cols:
col_vals.append(
temp.assign(dict=temp[col].dropna().map(lambda d: d.items())).explode("dict", ignore_index=True).assign(
LAB_ITEMID=lambda df: df.dict.str.get(0),
temp=lambda df: df.dict.str.get(1)
).drop(columns=["dict"]+cols).astype({'temp':'float64'}).rename(columns={"temp":f"{col}_sp"}).dropna(subset=["LAB_ITEMID"])
)
for i in range(1, len(col_vals)):
col_vals[i] = pd.merge(col_vals[i-1], col_vals[i], how="outer", on=list(t_med1.columns)+["LAB_ITEMID"])

final = col_vals[-1][list(t_med1.columns)+["LAB_ITEMID"]+[f"{col}_sp" for col in cols]]
final["LAB_NAME"] = final["LAB_ITEMID"]
final = final.rename(columns={"ITEMID":"MED_NAME"})
self.final_pairs_data.append(final)

final.to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_trends_first_val_{self.stratify_prefix}_doc_eval_win_lab{lab_parts}_med({i}, {i+1}).csv"))

return self.final_pairs_data, self.interim_pairs_data

def query(self, med, lab, before_windows, after_windows, use_id=False):
"""Query lab test value for a given medication
Expand All @@ -150,44 +175,61 @@ def query(self, med, lab, before_windows, after_windows, use_id=False):
use_id (bool, optional): Set to True to use Original labels given in MIMIC names. Defaults to False, ie, to use MIMIC Extract labels and HIRID labels.
Returns:
pd.DataFrame: Med lab pair values are present in this dataframe. Each row contains the medication value and the before/after lab test value.
pd.DataFrame: Contains columns with dictionaries of lab values. Columns are named based on before and after window
list(pd.DataFrame): Med lab pair values are present in this dataframe. Each row contains the medication value and the before/after lab test value.
list(pd.DataFrame): Contains columns with dictionaries of lab values. Columns are named based on before and after window
"""
filter_col = constants.ID_COL if use_id else constants.NAME_ID_COL
med1_filtered = self.t_med1[self.t_med1[filter_col]==med]
med2_filtered = self.t_med2[self.t_med2[filter_col]==med]
labs_filtered = self.t_labs[self.t_labs[filter_col]==lab]

t_labs = self.t_labs[self.t_labs[filter_col]==lab]
n_meds = [med1[med1[filter_col]==med] for med1 in n_meds if med1[med1[filter_col]==med].shape[0]>0]


if med1_filtered.shape[0]==0:
if n_meds[0].shape[0]==0:
print(f"No data found for the given medication {med}")
return
if labs_filtered.shape[0]==0:
if t_labs.shape[0]==0:
print(f"No data found for the given lab test {lab}")
return

temp = med1_filtered.copy()
temp = temp.apply(lambda r : self.get_vals(r, labs_filtered, med1_filtered, med2_filtered, before_windows, after_windows), axis=1)
# temp.to_csv(os.path.join(self.res, f"before_after_windows_med_lab_first_val_{self.stratify_prefix}_{med}_{lab}_doc_eval_new_win.csv"))
final_pairs_data = []
interim_pairs_data = []

all_types = set(["abs", "mean", "std", "trends", "time"])
cols_b = [f"before_{t}_{b_w}" for b_w in before_windows for t in all_types]
cols_a = [f"after_{t}_{a_w}" for a_w in after_windows for t in all_types]
cols = cols_b.copy()
cols.extend(cols_a)

col_vals = []
for col in cols:
col_vals.append(
temp.assign(dict=temp[col].dropna().map(lambda d: d.items())).explode("dict", ignore_index=True).assign(
LAB_ITEMID=lambda df: df.dict.str.get(0),
temp=lambda df: df.dict.str.get(1)
).drop(columns=["dict"]+cols).astype({'temp':'float64'}).rename(columns={"temp":f"{col}_sp"}).dropna(subset=["LAB_ITEMID"])
)
for i in range(1, len(col_vals)):
col_vals[i] = pd.merge(col_vals[i-1], col_vals[i], how="outer", on=list(med1_filtered.columns)+["LAB_ITEMID"])

final = col_vals[-1][list(med1_filtered.columns)+["LAB_ITEMID"]+[f"{col}_sp" for col in cols]]
final["LAB_NAME"] = final["LAB_ITEMID"]
final = final.rename(columns={"ITEMID":"MED_NAME"})

return final, temp
for i in tqdm.tqdm(range(len(n_meds))):

if i==0 :
t_med0, t_med1, t_med2 = None, self.meds[i], self.meds[i+1]
elif i== len(self.meds)-1:
t_med0, t_med1, t_med2 = self.meds[i-1], self.meds[i], None
else:
t_med0, t_med1, t_med2 = self.meds[i-1], self.meds[i], self.meds[i+1]

all_types = set(["abs", "mean", "std", "trends", "time"])
cols_b = [f"before_{t}_{b_w}" for b_w in before_windows for t in all_types]
cols_a = [f"after_{t}_{a_w}" for a_w in after_windows for t in all_types]
cols = cols_b.copy()
cols.extend(cols_a)
temp = t_med1.copy()

interim_pairs_data.append(temp.apply(lambda r : self.get_vals(r, t_labs, t_med0, t_med2, before_windows, after_windows), axis=1))
interim_pairs_data[-1].to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_first_val_{self.stratify_prefix}_doc_eval_new_win_lab-{lab}_med-{med}({i}, {i+1}).csv"))
temp = interim_pairs_data[-1]

col_vals = []
for col in cols:
col_vals.append(
temp.assign(dict=temp[col].dropna().map(lambda d: d.items())).explode("dict", ignore_index=True).assign(
LAB_ITEMID=lambda df: df.dict.str.get(0),
temp=lambda df: df.dict.str.get(1)
).drop(columns=["dict"]+cols).astype({'temp':'float64'}).rename(columns={"temp":f"{col}_sp"}).dropna(subset=["LAB_ITEMID"])
)
for i in range(1, len(col_vals)):
col_vals[i] = pd.merge(col_vals[i-1], col_vals[i], how="outer", on=list(t_med1.columns)+["LAB_ITEMID"])

final = col_vals[-1][list(t_med1.columns)+["LAB_ITEMID"]+[f"{col}_sp" for col in cols]]
final["LAB_NAME"] = final["LAB_ITEMID"]
final = final.rename(columns={"ITEMID":"MED_NAME"})
final_pairs_data.append(final)

final.to_csv(os.path.join(self.res, f"before_after_windows_main_med_lab_trends_first_val_{self.stratify_prefix}_doc_eval_win_lab-{lab}_med-{med}({i}, {i+1}).csv"))

return final_pairs_data, interim_pairs_data
Loading

0 comments on commit f64e27d

Please sign in to comment.