In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

from mimiciii_db import DB
from mimiciii_db.config import db_url

In [None]:
db = DB.from_url(db_url())
print("Database connected successfully!")

In [None]:
sql_multimorbidity_from_mvs = """
WITH base AS (
  SELECT
    subject_id,
    hadm_id,
    icustay_id,
    age::int AS age,
    COALESCE(morbidity_count, 0) AS morbidity_count
  FROM mimiciii.filtered_patients_with_morbidity_counts
),
binned AS (
  SELECT
    CASE
      WHEN age BETWEEN 0  AND 24 THEN '0–24'
      WHEN age BETWEEN 25 AND 44 THEN '25–44'
      WHEN age BETWEEN 45 AND 64 THEN '45–64'
      WHEN age BETWEEN 65 AND 84 THEN '65–84'
      ELSE '>85'
    END AS age_bin,
    (morbidity_count >= 2)::int AS multimorbid
  FROM base
)
SELECT age_bin,
       ROUND(100.0 * SUM(multimorbid)::numeric / COUNT(*), 1) AS pct_multimorbid,
       COUNT(*) AS n_patients
FROM binned
GROUP BY age_bin
ORDER BY CASE age_bin
  WHEN '0–24' THEN 1
  WHEN '25–44' THEN 2
  WHEN '45–64' THEN 3
  WHEN '65–84' THEN 4
  ELSE 5 END;
"""
df_multi = db.query_df(sql_multimorbidity_from_mvs)
df_multi


In [None]:
sql_heatmap_from_mvs = """
WITH cohort AS (
  SELECT
    subject_id,
    hadm_id,
    icustay_id,
    age::int AS age
  FROM mimiciii.filtered_patients
),
-- take the 30 binary Elixhauser flags (1/0) per admission
flags AS (
  SELECT
    hadm_id,
    congestive_heart_failure,
    cardiac_arrhythmias,
    valvular_disease,
    pulmonary_circulation,
    peripheral_vascular,
    hypertension,
    paralysis,
    other_neurological,
    chronic_pulmonary,
    diabetes_uncomplicated,
    diabetes_complicated,
    hypothyroidism,
    renal_failure,
    liver_disease,
    peptic_ulcer,              -- note: column is 'peptic_ulcer' in elixhauser_quan
    aids,
    lymphoma,
    metastatic_cancer,
    solid_tumor,
    rheumatoid_arthritis,
    coagulopathy,
    obesity,
    weight_loss,
    fluid_electrolyte,         -- note: column is 'fluid_electrolyte' in elixhauser_quan
    blood_loss_anemia,
    deficiency_anemias,
    alcohol_abuse,
    drug_abuse,
    psychoses,
    depression
  FROM mimiciii.elixhauser_quan
),
joined AS (
  SELECT
    c.hadm_id,
    CASE
      WHEN c.age BETWEEN 0  AND 24 THEN '0–24'
      WHEN c.age BETWEEN 25 AND 44 THEN '25–44'
      WHEN c.age BETWEEN 45 AND 64 THEN '45–64'
      WHEN c.age BETWEEN 65 AND 84 THEN '65–84'
      ELSE '>85'
    END AS age_bin,
    -- bring in flags; LEFT JOIN keeps patients even if no Elix rows (treated as 0 later)
    f.*
  FROM cohort c
  LEFT JOIN flags f USING (hadm_id)
),
-- long/tidy: (age_bin, comorbidity_name, present_flag)
long AS (
  SELECT age_bin, name AS comorbidity, present
  FROM joined j,
       unnest(
         ARRAY[
           'other neurological disorder','coagulopathy','depression','liver disease',
           'alcohol abuse','drug abuse','deficiency anemias','paralysis','weight loss',
           'rheumatoid arthritis','solid tumor','lymphoma','peptic ulcer disease',
           'blood loss anemia','psychoses','aids','metastatic cancer','diabetes complicated',
           'obesity','renal failure','valvular disease','hypothyroidism',
           'peripheral vascular disease','pulmonary circulation disorder',
           'chronic pulmonary disease','diabetes uncomplicated','congestive heart failure',
           'fluid electrolyte disorder','hypertension','cardiac arrhythmias'
         ],
         ARRAY[
           COALESCE(j.other_neurological,0), COALESCE(j.coagulopathy,0), COALESCE(j.depression,0), COALESCE(j.liver_disease,0),
           COALESCE(j.alcohol_abuse,0), COALESCE(j.drug_abuse,0), COALESCE(j.deficiency_anemias,0), COALESCE(j.paralysis,0), COALESCE(j.weight_loss,0),
           COALESCE(j.rheumatoid_arthritis,0), COALESCE(j.solid_tumor,0), COALESCE(j.lymphoma,0), COALESCE(j.peptic_ulcer,0),
           COALESCE(j.blood_loss_anemia,0), COALESCE(j.psychoses,0), COALESCE(j.aids,0), COALESCE(j.metastatic_cancer,0), COALESCE(j.diabetes_complicated,0),
           COALESCE(j.obesity,0), COALESCE(j.renal_failure,0), COALESCE(j.valvular_disease,0), COALESCE(j.hypothyroidism,0),
           COALESCE(j.peripheral_vascular,0), COALESCE(j.pulmonary_circulation,0),
           COALESCE(j.chronic_pulmonary,0), COALESCE(j.diabetes_uncomplicated,0), COALESCE(j.congestive_heart_failure,0),
           COALESCE(j.fluid_electrolyte,0), COALESCE(j.hypertension,0), COALESCE(j.cardiac_arrhythmias,0)
         ]
       ) AS u(name, present)
)
SELECT
  age_bin,
  comorbidity,
  ROUND(100.0 * AVG(present)::numeric, 1) AS prevalence_pct
FROM long
GROUP BY age_bin, comorbidity
ORDER BY comorbidity, age_bin;
"""
df_heat = db.query_df(sql_heatmap_from_mvs)
df_heat


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# pivot your df_heat (same as before)
mat = df_heat.pivot(index="comorbidity", columns="age_bin", values="prevalence_pct").fillna(0)

# manually reorder rows to match the figure layout (top to bottom)
paper_order = [
    # Group 3
    'other neurological disorder','coagulopathy','depression','liver disease',
    'alcohol abuse','drug abuse',
    # Group 2
    'deficiency anemias','paralysis','weight loss','rheumatoid arthritis',
    'solid tumor','lymphoma','peptic ulcer disease','blood loss anemia',
    'psychoses','aids','metastatic cancer','diabetes complicated','obesity',
    # Group 1
    'renal failure','valvular disease','hypothyroidism','peripheral vascular disease',
    'pulmonary circulation disorder','chronic pulmonary disease',
    'diabetes uncomplicated','congestive heart failure','fluid electrolyte disorder',
    'hypertension','cardiac arrhythmias'
]

# apply that order
mat = mat.reindex(paper_order)

# plot
g = sns.clustermap(
    mat,
    row_cluster=False,  # disable reordering since we fixed order manually
    col_cluster=False,
    cmap="inferno",
    vmin=0, vmax=80,
    figsize=(7, 9)
)

ax = g.ax_heatmap

# show outer frame
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color("black")

# add white horizontal separators for groups ①, ②, ③
separators = [6, 19]   # after row 6 (Group 3), after row 19 (Group 2)
for y in separators:
    ax.hlines(y, *ax.get_xlim(), colors="white", linewidth=3)

ax.set_xlabel("Age Bracket")
ax.set_ylabel("")
plt.show()
