## Package loading

In [1]:
import pandas as pd
import numpy as np
import sys
import re
import pickle

## Test gen_data_df.py

### Function definitions

In [2]:
def read_patients_table():
    p = pd.read_csv("data/PATIENTS.csv.gz")
    p = p[
        [
            "SUBJECT_ID",
            "GENDER",
            "DOB",
            "DOD",
        ]
    ]
    p["DOB"] = pd.to_datetime(p["DOB"])
    p["DOD"] = pd.to_datetime(p["DOD"])
    return p

def read_icd_diagnoses_table():
    codes = pd.read_csv("data/D_ICD_DIAGNOSES.csv.gz")
    codes = codes[["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"]]
    diagnoses = pd.read_csv("data/DIAGNOSES_ICD.csv.gz")
    diagnoses = diagnoses.merge(
        codes, how="inner", left_on="ICD9_CODE", right_on="ICD9_CODE"
    )
    diagnoses[["SUBJECT_ID", "HADM_ID", "SEQ_NUM"]] = diagnoses[
        ["SUBJECT_ID", "HADM_ID", "SEQ_NUM"]
    ].astype(int)
    return diagnoses

def read_icd_procedures_table():
    codes = pd.read_csv("data/D_ICD_PROCEDURES.csv.gz")
    codes = codes[["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"]]
    procedures = pd.read_csv("data/PROCEDURES_ICD.csv.gz")
    procedures = procedures.merge(
        codes, how="inner", left_on="ICD9_CODE", right_on="ICD9_CODE"
    )
    procedures[["SUBJECT_ID", "HADM_ID", "SEQ_NUM"]] = procedures[
        ["SUBJECT_ID", "HADM_ID", "SEQ_NUM"]
    ].astype(int)
    return procedures

def read_cptevents_table():
    cpt = pd.read_csv("data/CPTEVENTS.csv.gz")
    cpt = cpt[
        [
            "SUBJECT_ID",
            "HADM_ID",
            "CPT_CD",
        ]
    ]
    return cpt

def read_prescriptions_table():
    prescription = pd.read_csv("data/PRESCRIPTIONS.csv.gz")
    prescription = prescription[~prescription["NDC"].isna()]
    prescription = prescription[["SUBJECT_ID", "HADM_ID", "NDC"]].astype(int)
    prescription = prescription.dropna()
    return prescription

def read_icustays_table():
    icu = pd.read_csv("data/ICUSTAYS.csv.gz")
    icu["INTIME"] = pd.to_datetime(icu["INTIME"])
    icu["OUTTIME"] = pd.to_datetime(icu["OUTTIME"])
    return icu

DataFrame = pd.DataFrame

def filter_codes(df, code:str, min_=5, max_=np.inf) -> DataFrame:
    t = df.groupby(code)[code].transform(len) > min_
    num_codes = len(set(df[code]))
    num_codes_after = len(set(df.loc[t, code]))
    print(
        "removing {} codes occuring less than {} times. \n num codes before filter: {} after filtering: {}".format(
            code, min_, num_codes, num_codes_after
        )
    )
    return df[t]

def remove_min_admissions(t, min_admits=1):
    tt = t.groupby("SUBJECT_ID").SUBJECT_ID.transform(len) >= min_admits
    t = t[tt]
    print(
        "num of subjects with min_admits of {} is {}".format(
            min_admits, len(set(t["SUBJECT_ID"]))
        )
    )
    return t

def group_by_return_col_list(t, groupby, col, col_name=""):
    if col_name == "":
        col_name = col
    return (
        t.groupby(groupby)
        .apply(lambda x: x[col].values.tolist())
        .reset_index(name=col_name)
    )

def merge_on_subject(
    t1,
    t2,
    how="left",
    left_on=["SUBJECT_ID", "HADM_ID"],
    right_on=["SUBJECT_ID", "HADM_ID"],
):
    return t1.merge(
        t2,
        how=how,
        left_on=left_on,
        right_on=right_on,
    )

def add_age_to_icustays(stays):
    dob = pd.to_datetime(stays["DOB"])
    dob = dob.values.astype("datetime64[s]")
    intime = pd.to_datetime(stays["INTIME"])
    intime = intime.values.astype("datetime64[s]")
    age = intime - dob
    g = lambda x: x / np.timedelta64(1, "s") / 60 / 60 / 24 / 365
    stays["AGE"] = np.asarray(list(map(g, age)))
    idxs = stays.AGE < 0
    stays.loc[idxs, "AGE"] = 90
    return stays

### If Less than n days on admission notes (Early notes)
def less_n_days_data(df_adm_notes, n):
    df_less_n = df_adm_notes[
        (
            (
                df_adm_notes["CHARTDATE"] - df_adm_notes["ADMITTIME_C"]
            ).dt.total_seconds()
            / (24 * 60 * 60)
        )
        < n
    ]
    df_less_n = df_less_n[df_less_n["TEXT"].notnull()]
    return df_less_n

def preprocess1(x):
    y = re.sub("\\[(.*?)\\]", "", x)  # remove de-identified brackets
    y = re.sub(
        "[0-9]+\.", "", y
    )  # remove 1.2. since the segmenter segments based on this
    y = re.sub("dr\.", "doctor", y)
    y = re.sub("m\.d\.", "md", y)
    y = re.sub("admission date:", "", y)
    y = re.sub("discharge date:", "", y)
    y = re.sub("--|__|==", "", y)
    return y

def preprocessing(df, col = 'TEXT'):
    df[col] = df[col].fillna(" ")
    df[col] = df[col].str.replace("\n", " ")
    df[col] = df[col].str.replace("\r", " ")
    df[col] = df[col].apply(str.strip)
    df[col] = df[col].str.lower()
    df[col] = df[col].apply(lambda x: preprocess1(x))
    df[col] = df[col].apply(lambda x: " ".join(x.split()))
    return df

def append_text(df):
    group = df.groupby('HADM_ID')
    df_list = []
    for idx in group.groups.values():
        tmp = df.loc[idx]
        discharge = tmp[tmp["CATEGORY"] == "Discharge summary"]
        rest = tmp[tmp["CATEGORY"] != "Discharge summary"]
        discharge = " ".join(discharge["TEXT"])
        rest = " ".join(rest["TEXT"])
        tmp = tmp.iloc[0]
        tmp['TEXT_DISCHARGE'] = discharge
        tmp['TEXT_REST'] = rest
        df_list.append(tmp)
    
    df = pd.concat(df_list, axis = 1)
    return df.T

def compute_time_delta(df):
    df["TIMEDELTA"] = (
        df.sort_values(["SUBJECT_ID", "ADMITTIME"])
        .groupby(["SUBJECT_ID"])["ADMITTIME"]
        .diff()
    )

    return df

### Generate all data including text

In [39]:
df_adm = pd.read_csv("data/ADMISSIONS.csv.gz")
df_adm.ADMITTIME = pd.to_datetime(
  df_adm.ADMITTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
 )
df_adm.DISCHTIME = pd.to_datetime(
  df_adm.DISCHTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
 )
df_adm.DEATHTIME = pd.to_datetime(
  df_adm.DEATHTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
 )

In [40]:
df_adm = df_adm.sort_values(["SUBJECT_ID", "ADMITTIME"])
df_adm = df_adm.reset_index(drop=True)

In [None]:
ids = df_adm['SUBJECT_ID'].values
print(len(ids), len(set(ids)))

In [41]:
df_adm["NEXT_ADMITTIME"] = df_adm.groupby("SUBJECT_ID").ADMITTIME.shift(-1)
df_adm["NEXT_ADMISSION_TYPE"] = df_adm.groupby("SUBJECT_ID").ADMISSION_TYPE.shift(-1)

In [42]:
rows = df_adm.NEXT_ADMISSION_TYPE == "ELECTIVE"
df_adm.loc[rows, "NEXT_ADMITTIME"] = pd.NaT
df_adm.loc[rows, "NEXT_ADMISSION_TYPE"] = np.NaN
df_adm = df_adm.sort_values(["SUBJECT_ID", "ADMITTIME"])

In [43]:
df_adm["DAYS_NEXT_ADMIT"] = (
        df_adm.NEXT_ADMITTIME - df_adm.DISCHTIME
    ).dt.total_seconds() / (24 * 60 * 60)
df_adm["readmission_label"] = (df_adm.DAYS_NEXT_ADMIT < 30).astype("int")
### filter out newborn and death
df_adm = df_adm[df_adm["ADMISSION_TYPE"] != "NEWBORN"]
df_adm["DURATION"] = (
        df_adm["DISCHTIME"] - df_adm["ADMITTIME"]
    ).dt.total_seconds() / (24 * 60 * 60)

In [None]:
df_notes = pd.read_csv("data/NOTEEVENTS.csv.gz")
df_notes = df_notes.sort_values(by=["SUBJECT_ID", "HADM_ID", "CHARTDATE"])

df_adm_notes = pd.merge(
        df_adm[
            [
                "SUBJECT_ID",
                "HADM_ID",
                "ADMITTIME",
                "DISCHTIME",
                "DAYS_NEXT_ADMIT",
                "NEXT_ADMITTIME",
                "ADMISSION_TYPE",
                "DEATHTIME",
                "readmission_label",
                "DURATION",
                "DIAGNOSIS",
                "MARITAL_STATUS",
                "ETHNICITY",
                "DISCHARGE_LOCATION",
            ]
        ],
        df_notes[["SUBJECT_ID", "HADM_ID", "CHARTDATE", "TEXT", "CATEGORY"]],
        on=["SUBJECT_ID", "HADM_ID"],
        how="left",
    )

In [None]:
# Adding clinical codes to dataset

# add diagnoses
code = "ICD9_CODE"
diagnoses = read_icd_diagnoses_table()
diagnoses = filter_codes(diagnoses, code = code)
diagnoses = group_by_return_col_list(
            diagnoses, ["SUBJECT_ID", "HADM_ID"], code
            )

# add procedures
procedures = read_icd_procedures_table()
procedures = filter_codes(procedures, code = code)
procedures = group_by_return_col_list(
            procedures, ["SUBJECT_ID", "HADM_ID"], code, "ICD9_CODE_PROCEDURE"
            )

# add cptevents
code = "CPT_CD"
cptevents = read_cptevents_table()
cptevents = filter_codes(cptevents, code = code)
cptevents = group_by_return_col_list(cptevents, ["SUBJECT_ID", "HADM_ID"], code)

# add prescriptions
code = "NDC"
prescriptions = read_prescriptions_table()
prescriptions = filter_codes(prescriptions, code = code)
prescriptions = group_by_return_col_list(
                prescriptions, ["SUBJECT_ID", "HADM_ID"], code
                )


In [None]:
patients = read_patients_table()
stays = read_icustays_table()
stays = merge_on_subject(
        stays, patients, how="inner", left_on=["SUBJECT_ID"], right_on=["SUBJECT_ID"]
)
stays = merge_on_subject(stays, diagnoses)
stays = merge_on_subject(stays, cptevents)
stays = merge_on_subject(stays, prescriptions)
stays = merge_on_subject(stays, procedures)
stays = add_age_to_icustays(stays)

In [None]:
filters=[
            "Discharge summary",
            "ECG",
            "Pharmacy",
            "Physician",
            "Radiology",
            "Respiratory",
        ]

df_adm_notes = pd.merge(
                df_adm_notes, stays, on=["SUBJECT_ID", "HADM_ID"], how="left"
            )
filt = df_adm_notes["ICD9_CODE"].isna() & df_adm_notes["CPT_CD"].isna()
df_adm_notes = df_adm_notes[~filt]

df_adm_notes["ADMITTIME_C"] = df_adm_notes.ADMITTIME.apply(lambda x: str(x).split(" ")[0])
df_adm_notes["ADMITTIME_C"] = pd.to_datetime(
                                df_adm_notes.ADMITTIME_C, 
                                format="%Y-%m-%d", 
                                errors="coerce"
                            )
df_adm_notes["CHARTDATE"] = pd.to_datetime(
                            df_adm_notes.CHARTDATE, 
                            format="%Y-%m-%d", 
                            errors="coerce"
                            )

filt = df_adm_notes["CATEGORY"].apply(lambda x: x in filters)
df_adm_notes = df_adm_notes[filt]

In [None]:
### If Discharge Summary
df_discharge = df_adm_notes[df_adm_notes["CATEGORY"] == "Discharge summary"]

# multiple discharge summary for one admission -> after examination -> replicated summary -> replace with the last one
df_discharge = (
    df_discharge.groupby(["SUBJECT_ID", "HADM_ID"]).nth(-1)
).reset_index()
df_discharge = df_discharge[df_discharge["TEXT"].notnull()]
df_discharge = remove_min_admissions(df_discharge, min_admits=2)
df_adm_notes = df_adm_notes[df_adm_notes["CATEGORY"] != "Discharge summary"]

In [None]:
df_less_1 = less_n_days_data(df_adm_notes, 1)
df_less_2 = less_n_days_data(df_adm_notes, 2)

In [None]:
df_less_1 = df_less_1.append(df_discharge).reset_index()
df_less_1 = preprocessing(df_less_1)

In [None]:
df_adm_notes.columns

In [None]:
df_adm_notes.TEXT

## Data without text 

In [None]:
# format date time
df_adm = pd.read_csv("data/ADMISSIONS.csv.gz")
df_adm.ADMITTIME = pd.to_datetime(
    df_adm.ADMITTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
)
df_adm.DISCHTIME = pd.to_datetime(
    df_adm.DISCHTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
)
df_adm.DEATHTIME = pd.to_datetime(
    df_adm.DEATHTIME, format="%Y-%m-%d %H:%M:%S", errors="coerce"
)

df_adm = df_adm.sort_values(["SUBJECT_ID", "ADMITTIME"])
df_adm = df_adm.reset_index(drop=True)
# one task in the paper is to predict re-admission within 30 days
df_adm["NEXT_ADMITTIME"] = df_adm.groupby("SUBJECT_ID").ADMITTIME.shift(periods=-1)
df_adm["NEXT_ADMISSION_TYPE"] = df_adm.groupby("SUBJECT_ID").ADMISSION_TYPE.shift(
    periods=-1
)

In [None]:
rows = df_adm.NEXT_ADMISSION_TYPE == "ELECTIVE"
df_adm.loc[rows, "NEXT_ADMITTIME"] = pd.NaT
df_adm.loc[rows, "NEXT_ADMISSION_TYPE"] = np.NaN

df_adm = df_adm.sort_values(["SUBJECT_ID","ADMITTIME"])

# When we filter out the "ELECTIVE",
# we need to correct the next admit time
# for these admissions since there might
# be 'emergency' next admit after "ELECTIVE"
df_adm[["NEXT_ADMITTIME", "NEXT_ADMISSION_TYPE"]] = df_adm.groupby(["SUBJECT_ID"])[
    ["NEXT_ADMITTIME", "NEXT_ADMISSION_TYPE"]
].fillna(method="bfill")
df_adm["DAYS_NEXT_ADMIT"] = (
    df_adm.NEXT_ADMITTIME - df_adm.DISCHTIME
).dt.total_seconds() / (24 * 60 * 60)
df_adm["readmission_label"] = (df_adm.DAYS_NEXT_ADMIT < 30).astype("int")
### filter out newborn and death
df_adm = df_adm[df_adm["ADMISSION_TYPE"] != "NEWBORN"]
df_adm["DURATION"] = (
    df_adm["DISCHTIME"] - df_adm["ADMITTIME"]
).dt.total_seconds() / (24 * 60 * 60)

In [None]:
# Adding clinical codes to dataset
# add diagnoses
code = "ICD9_CODE"
diagnoses = read_icd_diagnoses_table()
diagnoses = filter_codes(diagnoses, code=code)
diagnoses = group_by_return_col_list(diagnoses, ["SUBJECT_ID", "HADM_ID"], code)

# add procedures
procedures = read_icd_procedures_table()
procedures = filter_codes(procedures, code=code)
procedures = group_by_return_col_list(
    procedures, ["SUBJECT_ID", "HADM_ID"], code, "ICD9_CODE_PROCEDURE"
)

# add cptevents
code = "CPT_CD"
cptevents = read_cptevents_table()
cptevents = filter_codes(cptevents, code=code)
cptevents = group_by_return_col_list(cptevents, ["SUBJECT_ID", "HADM_ID"], code)

# add prescriptions
code = "NDC"
prescriptions = read_prescriptions_table()
prescriptions = filter_codes(prescriptions, code=code)
prescriptions = group_by_return_col_list(
    prescriptions, ["SUBJECT_ID", "HADM_ID"], code
)

In [None]:
patients = read_patients_table()
stays = read_icustays_table()
stays = stays.merge(patients, how='inner', left_on=['SUBJECT_ID'], right_on=["SUBJECT_ID"])
cols = ["SUBJECT_ID", "HADM_ID"]
stays = stays.merge(diagnoses, how="inner", left_on=cols, right_on=cols)
stays = stays.merge(procedures, how="inner", left_on=cols, right_on=cols)

In [None]:
s1 = set(stays.SUBJECT_ID)

In [None]:
d1, c1, p1, m1 = (set(diagnoses.SUBJECT_ID), set(cptevents.SUBJECT_ID), 
                  set(procedures.SUBJECT_ID), set(prescriptions.SUBJECT_ID))

In [None]:
len(d1 & s1), len(p1 & s1), len(p1 & s1 & d1), len((p1 & s1) & (d1 & d))

In [None]:
code = "ICD9_CODE"
diagnoses = read_icd_diagnoses_table()
diagnoses = filter_codes(diagnoses, code=code)
diagnoses['ICD9_SHORT'] = diagnoses['ICD9_CODE'].apply(lambda x: x[:3])

In [None]:
d1 = group_by_return_col_list(diagnoses, ["SUBJECT_ID", "HADM_ID"], code)
d2 = group_by_return_col_list(diagnoses, ["SUBJECT_ID", "HADM_ID"], 'ICD9_SHORT')

In [None]:
d2

In [None]:
# add procedures
procedures = read_icd_procedures_table()
procedures = filter_codes(procedures, code=code)

In [None]:
procedures['ICD9_PROC_SHORT'] = procedures['ICD9_CODE'].apply(lambda x: str(x)[:2])

In [None]:
stays.ICD9_CODE.isna().sum(), stays.CPT_CD.isna().sum(), stays.NDC.isna().sum(), stays.ICD9_CODE_PROCEDURE.isna().sum()

In [None]:
stays = add_age_to_icustays(stays)

In [None]:
len(set(df_adm.SUBJECT_ID) & set(procedures.SUBJECT_ID))

In [None]:
len(set(df_adm.SUBJECT_ID) & set(stays.SUBJECT_ID))

In [None]:
df_adm = pd.merge(
        df_adm, stays, on=["SUBJECT_ID", "HADM_ID"], how="inner"
    )
df_adm["ADMITTIME_C"] = df_adm.ADMITTIME.apply(
    lambda x: str(x).split(" ")[0]
)
df_adm["ADMITTIME_C"] = pd.to_datetime(
    df_adm.ADMITTIME_C, format="%Y-%m-%d", errors="coerce"
)

In [None]:
stays.shape

In [None]:
df_adm.shape

In [None]:
df_adm.ICD9_CODE.isna().sum(), df_adm.ICD9_CODE_PROCEDURE.isna().sum()

In [None]:
import itertools
def flatten(x):
    return itertools.chain.from_iterable(x)

In [None]:
len(set(flatten(df_adm["ICD9_CODE"].dropna())))

In [None]:
demographic_cols = {
        "AGE": [],
        "GENDER": [],
        "LAST_CAREUNIT": [],
        "MARITAL_STATUS": [],
        "ETHNICITY": [],
        "DISCHARGE_LOCATION": [],
    }
df["GENDER"], demographic_cols["GENDER"] = pd.factorize(df["GENDER"])

## Code embedding

In [124]:
code = "ICD9_CODE"
diagnoses = read_icd_diagnoses_table()
diagnoses = filter_codes(diagnoses, code=code)
diagnoses['ICD9_SHORT'] = diagnoses['ICD9_CODE'].apply(lambda x: x[:3])

removing ICD9_CODE codes occuring less than 5 times. 
 num codes before filter: 6841 after filtering: 3511


In [160]:
a=set(diagnoses.columns)

In [159]:
b=set(['ROW_ID'])

In [162]:
list(a-b)

['HADM_ID',
 'ICD9_SHORT',
 'ICD9_CODE',
 'SEQ_NUM',
 'LONG_TITLE',
 'SHORT_TITLE',
 'SUBJECT_ID']

In [None]:
procedures = read_icd_procedures_table()
procedures = filter_codes(procedures, code=code)
procedures['ICD9_PROC_SHORT'] = procedures['ICD9_CODE'].apply(lambda x: str(x)[:2])

In [17]:
proc_codes, diag_codes = list(set(procedures['ICD9_PROC_SHORT'])), list(set(diagnoses['ICD9_SHORT']))

In [18]:
# establish a mapping from codes to integers
map_proc = {}
for i, key in enumerate(proc_codes):
    map_proc[key] = i
procedures['ICD9_PROC_SHORT'] = procedures['ICD9_PROC_SHORT'].apply(lambda x: map_proc[x])

In [22]:
map_diag = {}
mapping_shift = len(proc_codes) # make sure the mapping will not mix
for i, key in enumerate(diag_codes):
    map_diag[key] = i + mapping_shift
diagnoses['ICD9_SHORT'] = diagnoses['ICD9_SHORT'].apply(lambda x: map_diag[x])

In [25]:
mapping_shift, len(diag_codes)

(87, 722)

In [23]:
diagnoses['ICD9_SHORT'].describe()

count    627575.000000
mean        452.904784
std         219.516260
min          87.000000
25%         247.000000
50%         442.000000
75%         685.000000
max         808.000000
Name: ICD9_SHORT, dtype: float64

In [21]:
len(p1), len(set(procedures['ICD9_CODE'])), len(proc_codes)

(480, 1055, 87)

In [None]:
from data_loader.utils.vocab import Vocab
cpt_vocab = Vocab()
diag_vocab = Vocab()
med_vocab = Vocab()
proc_vocab = Vocab()
diag_vocab._build_from_file("vocab/diag.vocab")
cpt_vocab._build_from_file("vocab/cpt.vocab")

In [None]:
cpt = list(set(cptevents['CPT_CD']))
diag = list(set(diagnoses["ICD9_CODE"]))

In [None]:
ctok = [cpt_vocab.convert_to_ids(str(c), "C", False) for c in cpt]

In [None]:
dtok = [diag_vocab.convert_to_ids(d, "D", True) for d in diag]

In [None]:
pids = list(set(df["SUBJECT_ID"]))

In [None]:
pid_df = df[df["SUBJECT_ID"] == pids[0]]

In [None]:
cptevents = group_by_return_col_list(cptevents, ["SUBJECT_ID", "HADM_ID"], code)

In [None]:
p1, d1 = set(procedures['ICD9_PROC_SHORT']), set(diagnoses['ICD9_SHORT'])
total = list(p1 | d1)
dic_map = {}
for i, key in enumerate(total):
    dic_map[key] = i

In [None]:
len(dic_map.keys())

In [None]:
procedures = group_by_return_col_list(
        procedures, ["SUBJECT_ID", "HADM_ID"], 'ICD9_PROC_SHORT'
    )
diagnoses = group_by_return_col_list(diagnoses, ["SUBJECT_ID", "HADM_ID"], 'ICD9_SHORT')

In [None]:
patients = read_patients_table()
stays = read_icustays_table()
stays = stays.merge(patients, how='inner', left_on=['SUBJECT_ID'], right_on=["SUBJECT_ID"])
cols = ["SUBJECT_ID", "HADM_ID"]
stays = stays.merge(diagnoses, how="inner", left_on=cols, right_on=cols)
stays = stays.merge(procedures, how="inner", left_on=cols, right_on=cols)
stays = add_age_to_icustays(stays)

In [None]:
df_adm = pd.merge(
        df_adm, stays, on=["SUBJECT_ID", "HADM_ID"], how="inner"
    )

In [None]:
p1 = tmp['ICD9_PROC_SHORT'].values
d1 = tmp['ICD9_SHORT'].values

In [None]:
c2, u2 = pd.factorize(procedures['ICD9_PROC_SHORT'])

In [None]:
cols = ["SUBJECT_ID", "HADM_ID"]
tmp = procedures.merge(diagnoses, how='inner', left_on=cols, right_on=cols)

In [None]:
t = stays[['ICD9_PROC_SHORT', 'ICD9_SHORT']].sum(axis=1)

In [None]:
df = pd.read_pickle('adm.pkl')

In [None]:
df.shape, len(set(df['SUBJECT_ID'])), len(set(df['HADM_ID']))

In [None]:
p1=procedures['ICD9_PROC_SHORT'].apply(lambda x: dic_map[x])

In [None]:
d1=diagnoses['ICD9_SHORT'].apply(lambda x: dic_map[x])

In [None]:
set(p1) & set(d1)

In [None]:
p1[['ICD9_PROC_SHORT', 'ICD9_SHORT']].values.sum(axis=1)

In [None]:
diagnoses.ICD9_SHORT

In [None]:
procedures.ICD9_PROC_SHORT

In [91]:
import torch
import torch.utils.data as data
import os
import pickle
import numpy as np
import itertools

class SeqCodeDataset(data.Dataset):
    def __init__(
        self,
        data_path,
        batch_size,
        train=True,
        med=False,
        diag=False,
        proc=False,
        split_num=2,
    ):
        self.proc = proc
        self.med = med
        self.diag = diag
        
        self.train = train
        self.batch_size = batch_size

        self.data = pickle.load(open(os.path.join(data_path, "data_icd.pkl"), "rb"))
        self.data_info = self.data["info"]
        self.data = self.data["data"]

        data_split_path = os.path.join(
            data_path, "splits", "split_{}.pkl".format(split_num)
        )
        if os.path.exists(data_split_path):
            self.train_idx, self.valid_idx = pickle.load(open(data_split_path, "rb"))

        self.keys = self._get_keys()

        self.max_len = self._findmax_len()

        self.num_dcodes = self.data_info['num_icd9_codes']
        self.num_pcodes = self.data_info['num_proc_codes']
    
        self.num_codes = (
            self.diag * self.num_dcodes
            + self.proc * self.num_pcodes
        )

        self.demographics_shape = self.data_info["demographics_shape"]

    def _gen_idx(self, keys, min_adm=2):
        idx = []
        for k in keys:
            v = self.data[k]
            if len(v) < min_adm:
                continue
            for i, _ in enumerate(v):
                idx.append((k, i))
        return idx

    def _get_keys(self, min_adm=2):
        keys = []
        for k, v in self.data.items():
            if len(v) < min_adm:
                continue
            keys.append(k)
        return keys

    def _findmax_len(self):
        m = 0
        for v in self.data.values():
            if len(v) > m:
                m = len(v)
        return m

    def __len__(self):
        if self.train:
            return len(self.keys)
        else:
            return 0

    def __getitem__(self, k):
        x = self.preprocess(self.data[k])
        return x

    
    def preprocess(self, seq):
        """create one hot vector of idx in seq, with length self.num_codes

        Args:
            seq: list of ideces where code should be 1

        Returns:
            x: one hot vector
            ivec: vector for learning code representation
            jvec: vector for learning code representation
        """

        icd_one_hot = torch.zeros((self.num_codes, self.max_len), dtype=torch.long)
        demo_one_hot = torch.zeros((self.demographics_shape, self.max_len), dtype=torch.long)
        mask = torch.zeros((self.max_len,), dtype=torch.long)
        ivec = []
        jvec = []
        for i, s in enumerate(seq):
            icd = s['icd']
            demo = s["demographics"]
            l = [
                 s["diagnoses"] * self.diag, 
                 s["procedures"] * self.proc
            ]
            icd = list(itertools.chain.from_iterable(l))
            
            icd_one_hot[icd, i] = 1
            demo_one_hot[:, i] = torch.Tensor(demo)
            mask[i] = 1
            for j in icd:
                for k in icd:
                    if j == k:
                        continue
                    ivec.append(j)
                    jvec.append(k)
            print(icd)
                
        return icd_one_hot.t(), mask, torch.LongTensor(ivec), torch.LongTensor(jvec), demo_one_hot.t()

In [92]:
data_path = "./data/output"
batch_size = 32
train = True

d = SeqCodeDataset(data_path,
                      batch_size,
                      diag = True,
                      proc = True
                  )

In [93]:
x, m, i, j, dx = d.preprocess(d.data[17])

[680, 734, 325, 407, 39, 56, 29]
[165, 674, 734, 782, 680, 352, 273, 764, 74, 56, 54]


In [96]:
dx.shape

torch.Size([29, 54])

In [80]:
with open("data/output/splits/split_1.pkl", "rb") as f:
    splits = pickle.load(f)

In [89]:
len(splits[1])/(len(splits[0])+len(splits[0]))

0.08337452628110067

In [116]:
with open("data/output/data_icd.pkl", "rb") as f:
    data = pickle.load(f)

In [None]:
def flatten(x):
    return itertools.chain.from_iterable(x)
len(set(flatten(df1["ICD9_PROC_SHORT"].dropna()))), len(set(flatten(df1["ICD9_SHORT"].dropna())))

## Sequence dataloader for classification

In [3]:
import torch
import torch.utils.data as data
import os
import pickle
import numpy as np
import itertools

class SeqClassificationDataset(data.Dataset):
    def __init__(
        self,
        data_path,
        batch_size,
        y_label="los",
        train=True,
        balanced_data=False,
        validation_split=0.0,
        split_num=1,
        med=False,
        diag=True,
        proc=True,
        cptcode=False
    ):
        super(SeqClassificationDataset).__init__()
        self.proc = proc
        self.med = med
        self.diag = diag
        self.cpt = cptcode

        self.data_path = data_path
        self.batch_size = batch_size
        self.train = train
        self.y_label = y_label
        self.validation_split = validation_split
        self.balanced_data = balanced_data
        self.data = pickle.load(open(os.path.join(self.data_path, "data_icd.pkl"), "rb"))
        self.data_info = self.data["info"]
        self.data = self.data["data"]

        self.demographics_shape = self.data_info["demographics_shape"]

        self.keys = list(map(int, self.data.keys()))
        self.max_len = self._findmax_len()

        self.num_dcodes = self.data_info['num_icd9_codes']
        self.num_pcodes = self.data_info['num_proc_codes']
        self.num_mcodes = self.data_info['num_med_codes']
        self.num_ccodes = self.data_info['num_cpt_codes']
        
        self.num_codes = (
            self.diag * self.num_dcodes
            + self.cpt * self.num_ccodes
            + self.proc * self.num_pcodes
            + self.med * self.num_mcodes
        )  

        data_split_path = os.path.join(
            self.data_path, "splits", "split_{}.pkl".format(split_num)
        )
        if os.path.exists(data_split_path):
            self.train_idx, self.valid_idx = pickle.load(open(data_split_path, "rb"))
            # select patients with at least two admissions
            self.train_indices = self._gen_indices(self.train_idx)
            self.valid_indices = self._gen_indices(self.valid_idx)
            # re-label the patient ID!
            # only patients with at least two visits are kept
            self.train_idx = np.arange(len(self.train_indices))
            self.valid_idx = len(self.train_indices) + np.arange(len(self.valid_indices))

            if self.balanced_data:
                self.train_idx = self._gen_balanced_indices(self.train_idx)
                #self.valid_idx = self._gen_balanced_indices(self.valid_idx)
        else:
            # TODO: data index logic if train, validation splits are not provided
            pass

    def _gen_balanced_indices(self, indices):
        """Generate a balanced set of indices"""
        ind_idx = {}

        for idx in indices:
            label = self.get_label(idx)
            if label not in ind_idx:
                ind_idx[label] = [idx]
            else:
                ind_idx[label].append(idx)

        tr = []
        te = []

        lens = sorted([len(v) for v in ind_idx.values()])

        if len(lens) > 3:
            num_samples = lens[-2]
        else:
            num_samples = lens[0]

        for v in ind_idx.values():
            v = np.asarray(v)

            if len(v) > num_samples:
                v = v[np.random.choice(np.arange(len(v)), num_samples)]

            # train, test = train_test_split(v, test_size=self.validation_split, random_state=1)
            # te.append(test)

            tr.append(v)

        train = np.concatenate(tr)
        # test = np.concatenate(te)
        return train  # , test

    def _gen_indices(self, keys):
        indices = []
        for k in keys:
            v = self.data[k]
            for j in range(len(v)):
                if (j + 1) == len(v):
                    continue
                indices.append([k, j + 1])
        return indices
    
    def _findmax_len(self):
        """Find the max number of visits of any patients

        Returns:
            [int]: the max number of visits
        """        
        m = 0
        for v in self.data.values():
            if len(v) > m:
                m = len(v)
        return m

    def __getitem__(self, index):
        if index in self.train_idx:
            idx = self.train_indices[index]
        else:
            idx = self.valid_indices[index - len(self.train_indices)]
        x = self.preprocess(idx)
        return x

    def preprocess(self, idx):
        """n: total # of visits per each patients minus one
            it's also the index for the last visits for extracting label y[n]
        Args:
            idx ([type]): [description]

        Returns:
            [type]: [description]
        """        
        seq = self.data[idx[0]]
        n = idx[1]
        x_codes = torch.zeros((self.num_codes, self.max_len), dtype=torch.float)
        demo = torch.Tensor(seq[n]["demographics"])
        for i in range(n):
            if (i + 1) == len(seq):
                continue
            s = seq[i]   
            codes = [
                 s["diagnoses"] * self.diag, 
                 s["procedures"] * self.proc
            ]
            codes = list(itertools.chain.from_iterable(codes))
            x_codes[codes, i] = 1

        x_cl = torch.Tensor(
            [
                n,
            ]
        )
       
        if self.y_label == "los":
            los = seq[n]["los"]
            if los != los:
                los = 9
            y = torch.Tensor([los - 1])
        elif self.y_label == "readmission":
            y = torch.Tensor([seq[n]["readmission"]])
        else:
            y = torch.Tensor([seq[n]["mortality"]])

        return (x_codes.t(), x_cl, demo, y)

    def get_label(self, idx):
        if idx in self.train_idx:
            idx = self.train_indices[idx]
        else:
            idx = self.valid_indices[idx - len(self.train_indices)]
        seq = self.data[idx[0]]
        n = idx[1]
        if self.y_label == "los":
            los = seq[n]["los"]
            if los != los:
                los = 9
            y = torch.Tensor([los - 1])
        elif self.y_label == "readmission":
            y = torch.Tensor([seq[n]["readmission"]])
        else:
            y = torch.Tensor([seq[n]["mortality"]])
        y = y.item()
        return y

    def __len__(self):
        l = 0
        if self.train:
            l = len(self.train_idx)
        else:
            l = len(self.valid_idx)

        return l

def collate_fn(data):
    x_codes, x_cl,  demo, y_code = zip(*data)
    x_codes = torch.stack(x_codes, dim=1)
    demo = torch.stack(demo, dim=0)
    y_code = torch.stack(y_code, dim=1).long()
    x_cl = torch.stack(x_cl, dim=0).long()
    b_is = torch.arange(x_cl.shape[0]).reshape(tuple(x_cl.shape)).long()
    return (
        x_codes,
        x_cl.squeeze(),
        b_is.squeeze(),
        demo,
    ), y_code.squeeze()

In [30]:
data_path = "./data/output"
data_split_path
batch_size = 32
train = True

d = SeqClassificationDataset(data_path,
                      batch_size,
                      y_label="mortality",
                      diag = True,
                      proc = True
                  )

In [10]:
split_num = 0
data_split_path = os.path.join(
            data_path, "splits", "split_{}.pkl".format(split_num))

In [11]:
data_split_path

'./data/output/splits/split_0.pkl'

In [12]:
train_idx, valid_idx = pickle.load(open(data_split_path, "rb"))

In [13]:
train_idx

array([54994, 93381, 23990, ...,  7059,   376, 82202])

In [16]:
d.train_idx

array([   0,    1,    2, ..., 7630, 7631, 7632])

In [23]:
d.y_label

'mortality'

In [28]:
np.array([d.data[x][0][d.y_label] for x in train_idx]).sum()

144

In [29]:
train_idx.shape

(6069,)

In [119]:
x_codes, x_cl,  demo, y_code = d[21]

In [120]:
x_codes.shape

torch.Size([29, 809])

In [121]:
x_cl

tensor([1.])

In [122]:
demo.shape

torch.Size([54])

## Sequence classification trainning

In [31]:
import numpy as np
import torch
from base import BaseTrainer
from model.metric import roc_auc_1, pr_auc_1, pr_auc, roc_auc


class ClassificationTrainer(BaseTrainer):
    """
    Trainer class

    Note:
        Inherited from BaseTrainer.
    """

    def __init__(
        self,
        model,
        loss,
        metrics,
        optimizer,
        resume,
        config,
        data_loader,
        valid_data_loader=None,
        lr_scheduler=None,
        train_logger=None,
    ):
        super(ClassificationTrainer, self).__init__(
            model, loss, metrics, optimizer, resume, config, train_logger
        )
        self.config = config
        self.data_loader = data_loader
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = config["trainer"].get(
            "log_step", int(np.sqrt(data_loader.batch_size))
        )

        if self.config["model"]["args"]["num_classes"] == 1:
            weight_0 = self.config["trainer"].get("class_weight_0", 1.0)
            weight_1 = self.config["trainer"].get("class_weight_1", 1.0)
            self.weight = [weight_0, weight_1]
            self.loss = lambda output, target: loss(output, target, self.weight)
        self.prauc_flag = pr_auc in self.metrics and roc_auc in self.metrics

    def _eval_metrics(self, output, target, **kwargs):
        acc_metrics = np.zeros(len(self.metrics))
        for i, metric in enumerate(self.metrics):
            acc_metrics[i] += metric(output, target, **kwargs)
            self.writer.add_scalar(f"{metric.__name__}", acc_metrics[i])
        return acc_metrics

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.

        Note:
            If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log

            The metrics in log must have the key 'metrics'.
        """
        self.model.train()
        total_loss = 0
        total_metrics = np.zeros(len(self.metrics))
        all_t = []
        all_o = []
        for batch_idx, (data, target) in enumerate(self.data_loader):

            all_t.append(target.numpy())

            target = target.to(self.device)
            self.optimizer.zero_grad()
            output, _ = self.model(data, device=self.device)
            loss = self.loss(output, target)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx)
            self.writer.add_scalar("loss", loss.item())

            total_loss += loss
            total_metrics += self._eval_metrics(output, target)
            all_o.append(output.detach().cpu().numpy())

            if self.verbosity >= 2 and batch_idx % self.log_step == 0:
                self.logger.info(
                    "Train Epoch: {} [{}/{} ({:.0f}%)] {}: {:.6f}".format(
                        epoch,
                        batch_idx * self.data_loader.batch_size,
                        self.data_loader.n_samples,
                        100.0 * batch_idx / len(self.data_loader),
                        "loss",
                        loss,
                    )
                )

        total_metrics = total_metrics / len(self.data_loader)
        if self.prauc_flag:
            all_o = np.hstack(all_o)
            all_t = np.hstack(all_t)
            total_metrics[-2] = pr_auc_1(all_o, all_t)
            total_metrics[-1] = roc_auc_1(all_o, all_t)

        log = {
            "loss": total_loss / len(self.data_loader),
            "metrics": total_metrics,
        }

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log = {**log, **val_log}

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :return: A log that contains information about validation

        Note:
            The validation metrics in log must have the key 'val_metrics'.
        """
        self.model.eval()
        total_val_loss = 0
        total_val_metrics = np.zeros(len(self.metrics))
        all_t = []
        all_o = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                all_t.append(target.numpy())
                target = target.to(self.device)

                output, _ = self.model(data, self.device)
                loss = self.loss(
                    output,
                    target.reshape(
                        -1,
                    ),
                )
                all_o.append(output.detach().cpu().numpy())

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx, "valid"
                )
                self.writer.add_scalar("loss", loss.item())
                total_val_loss += loss.item()
                total_val_metrics += self._eval_metrics(output, target)

        total_val_metrics = (total_val_metrics / len(self.valid_data_loader)).tolist()
        if self.prauc_flag:
            all_o = np.hstack(all_o)
            all_t = np.hstack(all_t)
            total_val_metrics[-2] = pr_auc_1(all_o, all_t)
            total_val_metrics[-1] = roc_auc_1(all_o, all_t)

        return {
            "val_loss": total_val_loss / len(self.valid_data_loader),
            "val_metrics": total_val_metrics,
        }

In [None]:
config = json.load(open('configs/taper/seq_mortality.json'))
data_loader = 
cl = ClassificationTrainer()