In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

from mimiciii_db import DB
from mimiciii_db.config import db_url

In [None]:
db = DB.from_url(db_url())
print("Database connected successfully!")

In [None]:
query = """
SELECT * FROM filtered_patients_agegrouped_with_morbidities
"""

# Use your database connection
df = db.query_df(query)
df

In [None]:
df_long = df.melt(
    id_vars=["age_bin"],
    var_name="comorbidity",
    value_name="present"
)
heat = (
    df_long.groupby(["age_bin", "comorbidity"])["present"]
    .mean()
    .mul(100)
    .reset_index(name="prevalence_pct")
)

heat = heat[~heat["comorbidity"].isin(["hadm_id", "age", "subject_id"])]
heat

In [None]:
disease_order = [
    "other_neurological", "coagulopathy", "depression", "liver_disease",
    "alcohol_abuse", "drug_abuse", "deficiency_anemias", "paralysis", "weight_loss",
    "rheumatoid_arthritis", "solid_tumor", "lymphoma", "peptic_ulcer",
    "blood_loss_anemia", "psychoses", "aids", "metastatic_cancer", "diabetes_complicated",
    "obesity", "renal_failure", "valvular_disease", "hypothyroidism",
    "peripheral_vascular", "pulmonary_circulation",
    "chronic_pulmonary", "diabetes_uncomplicated", "congestive_heart_failure",
    "fluid_electrolyte", "hypertension", "cardiac_arrhythmias",
]

present_rows = [r for r in disease_order if r in heat["comorbidity"].unique()]
heat_matrix = (
    heat.pivot(index="comorbidity", columns="age_bin", values="prevalence_pct")
        .reindex(present_rows)
        [[c for c in ['16-24','25-44','45-64','65-84','≥85'] if c in heat["age_bin"].unique()]]
        .fillna(0.0)
)

plt.figure(figsize=(10, 12))
ax = sns.heatmap(
    heat_matrix,
    cmap="inferno",
    linewidths=0.75,
    annot=True,
    cbar_kws={"label": "Prevalence (%)"},
)
ax.set_xlabel("Age Bracket")
ax.set_ylabel("")
ax.set_title("Elixhauser Comorbidity Prevalence by Age Group")
plt.tight_layout()
plt.show()

In [None]:
#exporting the heatmap data to a csv file for finer reproduction of the dendrogram using R

print(heat_matrix.shape)
print(heat_matrix.index[:5])
print(heat_matrix.columns)

heat_matrix.to_csv("../data/heat_matrix_1b.csv", index=True, encoding="utf-8-sig", float_format="%.4f")