In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

from fiber.cohort import Cohort
from fiber.condition import Procedure, Diagnosis, Drug, VitalSign, Patient, LabValue
from fiber.database.hana import engine, Session, print_sqla

from fiber.database.table import fact

In [None]:
trim_func = lambda x: x.split('.')[0]

# Cohort Definition

### TODO: 
- Patient in ICU or not

In [None]:
min_age = Patient.age_in_days > 365 * 18
heart_surgery_condition = Procedure(code='35.%').with_(min_age) | Procedure(code='36.1%').with_(min_age)

In [None]:
heart_surgery_cohort = Cohort(heart_surgery_condition)

In [None]:
len(heart_surgery_cohort)

# Demographics

In [None]:
demographics = heart_surgery_cohort.demographics

In [None]:
demographics["age"]["figure"]

In [None]:
demographics["gender"]["figure"]

# Onsets
- Mortality 0, 7, 14 and 28 days
- Rehospitalization 7, 14 and 28 days
- Acute Kidney Injury,  ICD9 code `%584%` or AKI phenotype
- Stroke (Cerebrovascular event), 0, 7, 14 and 28 days
	- Occlusion and stenosis of precerebral arteries,	ICD9 code `433%`
	- Occlusion of cerebral arteries,	ICD9 code `434%`
	- Acute but ill-defined cerebrovascular disease, ICD9 code `436%`


### TODO: 
- mortality
- rehospitalization
- AKI phenotype -> AKIN, KDIGO (more important)

In [None]:
aki = heart_surgery_cohort.has_onset(
    name="aki",
    condition=Diagnosis(code="584.9", context="ICD-9"), 
    time_deltas=[1, 7, 14, 28],
    trim_func=trim_func
)
aki

In [None]:
stroke = heart_surgery_cohort.has_onset(
    name="stroke",
    condition=Diagnosis(code='433.%') | Diagnosis(code='434.%') | Diagnosis(code='436.%'), 
    time_deltas=[1, 7, 14, 28],
    trim_func=trim_func
)
stroke

# Preconditions

In [None]:
preconditions = {}

In [None]:
diagnoses = [
    "congestive heart failure",
    "fluid and electrolyte disorders",
    "liver disease",
    "rheumatoid arthritis/collagen vascular diseases",
   # "AIDS/HIV",
    "alcohol abuse",
    "blood loss anemia",
    "cardiac arrhythmia",
    "chronic pulmonary disease",
    "coagulopathy",
    "deficiency anemia",
    "depression",
    "diabetes complicated",
    "diabetes uncomplicated",
    "drug abuse",
    "hypertension complicated",
    "hypertension uncomplicated",
    "lymphoma",
    "metastatic cancer",
    "obesity",
    "other neurological disorders",
    "paralysis",
    "peptic ulcer disease excluding bleeding",
    "peripheral vascular disorders",
    "psychoses",
    "pulmonary circulation disorders",
    "renal failure",
    "solid tumor without metastasis",
    "valvular disease",
    "weight loss"
]

In [None]:
for cond in diagnoses:
    preconditions[cond] = heart_surgery_cohort.has_precondition(
        condition=Diagnosis.from_condition_store(name=cond),
        trim_func=trim_func
    )

# Lab Values - Example for Glucose

In [None]:
print(LabValue(test_name="GLUCOSE%").value_counts("test_name"), "\n\n", \
LabValue(test_name="GLUCOSE%").get_data(inclusion_mrns=heart_surgery_cohort).test_name.value_counts().head())

In [None]:
lv_glucose = heart_surgery_cohort.results_for(LabValue(test_name="GLUCOSE"), before=heart_surgery_condition)

In [None]:
lv_glucose[lv_glucose.occurs_x_days_before < 4].occurs_x_days_before.value_counts()

## Other lab values

In [None]:
def search_lab(search_string, rows=5):
    return LabValue(test_name="%" + search_string + "%").value_counts("test_name").head(rows)

In [None]:
print(search_lab("NITROGEN"), "\n")
print(search_lab("CREATININE", 6), "\n")
print(search_lab("ANION", 6), "\n")
print(search_lab("BILIRUBIN"), "\n")
print(search_lab("ALBUMIN"), "\n")
print(search_lab("CHLORIDE"), "\n")
print(search_lab("GLUCOSE", 10), "\n")
print(search_lab("HEMATOCRIT"), "\n")
print(search_lab("HEMOGLOBIN"), "\n")
print(search_lab("LACTATE", 6), "\n")
print(search_lab("PLATELET"), "\n")
print(search_lab("POTASSIUM"), "\n")
print(search_lab("SODIUM"), "\n")
print(search_lab("WBC"), "\n", search_lab("WHITE BLOOD", 1), "\n")
print(search_lab("PT"), "\n")
print(search_lab("PTT"), "\n")
print(search_lab("INR"), "\n")

In [None]:
lab_values = {}

In [None]:
lv_cond = {
    "Blood Urea Nitrogen": LabValue("UREA NITROGEN-BLD"),
    "Blood Creatinine": LabValue("CREATININE-SERUM"),
    "Anion Gap": LabValue("ANION GAP"),
    "Bilirubin": LabValue("BILIRUBIN TOTAL"),
    "Albumin": LabValue("ALBUMIN, BLD"),
    "Chloride": LabValue("CHLORIDE-BLD"),
    "Glucose": LabValue("GLUCOSE"),
    "Hematocrit": LabValue("HEMATOCRIT"),
    "Hemoglobin": LabValue("HEMOGLOBIN"),
    "Platelet Count": LabValue("PLATELET"),
    "Potassium": LabValue("POTASSIUMBLD"),
    "Sodium": LabValue("SODIUM-BLD"),
    "White Blood Cell Count": LabValue("WHITE BLOOD CELL") | LabValue("WBC"),
    "INR": LabValue("INR"),
    "PTT": LabValue("APTT"),
    "PT": LabValue("PRO TIME"),
    "Lactate": LabValue("WB LACTATE-ART (POCT)"),
}

In [None]:
for name, cond in lv_cond.items():
    with Timer() as t:
        results_for_test = heart_surgery_cohort.results_for(cond, before=heart_surgery_condition)
        lab_values[name] = results_for_test[results_for_test.occurs_x_days_before < 4]
    print(f'Aggregating {name} day-wise done in {t.elapsed}s')

In [None]:
# TODO: lt or eq
def pivot_lab_values(
    name,
    x,
    limit,
    comparator="eq",
):
    if comparator == "lt":
        mask = x.occurs_x_days_before <= limit
    else:
        mask = x.occurs_x_days_before == limit

    return pd.Series({
        f"{name}_value_{limit}_days": x[mask]["numeric_value"].mean(),
        f"{name}_result_flag_{limit}_days": x[mask]["result_flag"].value_counts().index[0] if x[mask]["result_flag"].any() else np.nan,
        f"{name}_abnormal_{limit}_days": x[mask]["abnormal"].median(),
    })

In [None]:
pivoted_values = {}
for name, lab_df in lab_values.items():
    pivoted_values[name] = lab_df.groupby(
        ["medical_record_number", "age_in_days"]
    ).apply(
        lambda x: pivot_lab_values(name, x, 1, "lt").append([pivot_lab_values(name, x, 2), pivot_lab_values(name, x, 3)])
    ).reset_index()

In [None]:
merged_and_pivoted_lab_values = heart_surgery_cohort.merge_lab_values(*pivoted_values.values())

In [None]:
final_df = heart_surgery_cohort.build_data(
    aki,
    stroke,
    *preconditions.values(),
)

In [None]:
pd.merge(final_df, merged_and_pivoted_lab_values, on=['medical_record_number', 'age_in_days']).describe()