In [27]:
# --- Clustered bar chart of late cases by category and delay ranges (pivot fix + legend title) ---
import pandas as pd
import numpy as np
import plotly.graph_objects as go

# ===== CONFIG =====
DATA_PATH = "C:/Users/egtay/Downloads/Telegram Desktop/Final_Cleaned_Dataset_OPTIC_7.csv"
VALUE_COL = "KNIFE_START_DELAY"
LATE_COL  = "Is_Late"
TOP_K     = 10

CATEGORY_MAP = {
    "Delay_Category": "Delay_Category",
    "LOCATION": "LOCATION",
    "ROOM": "ROOM",
    "EQUIPMENT": "__EQUIPMENT_SPLIT__",
    "EMERGENCY_PRIORITY": "EMERGENCY_PRIORITY",
    "DISCIPLINE": "DISCIPLINE",
    "ANESTHESIA": "ANESTHESIA", 
    "ADMISSION_CLASS_TYPE": "ADMISSION_CLASS_TYPE",
    "ADMISSION_WARD": "ADMISSION_WARD",
    "ADMISSION_BED": "ADMISSION_BED",
    "AOH": "AOH",
    "BLOOD": "BLOOD",
    "IMPLANT": "__IMPLANT_BIN__",
    "CANCER_INDICATOR": "CANCER_INDICATOR",
    "TRAUMA_INDICATOR": "TRAUMA_INDICATOR",
}

# ===== LOAD =====
df = pd.read_csv(DATA_PATH)
df[VALUE_COL] = pd.to_numeric(df[VALUE_COL], errors="coerce")
df = df.dropna(subset=[VALUE_COL]).copy()

if LATE_COL not in df.columns:
    raise KeyError("Column 'Is_Late' not found in dataset.")

# Keep only LATE cases
df = df[df[LATE_COL] == 1].copy()

# ===== EQUIPMENT split =====
if "EQUIPMENT" in df.columns:
    eq = (
        df[["OPERATION_ID", "EQUIPMENT"]]
        .assign(EQUIPMENT=lambda d: d["EQUIPMENT"].astype(str).fillna("").str.split(","))
        .explode("EQUIPMENT")
    )
    eq["EQUIPMENT"] = eq["EQUIPMENT"].astype(str).str.strip()
    eq.loc[eq["EQUIPMENT"].isin(["", "nan", "None", "NULL", "none", "null"]), "EQUIPMENT"] = np.nan
    eq = eq.dropna(subset=["EQUIPMENT"])
    df_eq = df.merge(eq, on="OPERATION_ID", how="left", suffixes=("", "_x"))
    df_eq["__EQUIPMENT_SPLIT__"] = df_eq["EQUIPMENT_x"]
else:
    df_eq = df.copy()
    df_eq["__EQUIPMENT_SPLIT__"] = np.nan

# ===== IMPLANT binary =====
if "IMPLANT" in df.columns:
    def to_implant_bin(s):
        try:
            v = float(s)
            return "No implant" if v == 0 else "Has implant"
        except:
            return "Has implant"
    df["__IMPLANT_BIN__"] = df["IMPLANT"].astype(str).map(to_implant_bin)
else:
    df["__IMPLANT_BIN__"] = np.nan

# ===== Delay bins =====
bins = [0, 30, 60, 90, 120, np.inf]
bin_labels = ["0-30", "30-60", "60-90", "90-120", "120+"]

def get_working_df(label: str):
    col = CATEGORY_MAP[label]
    base = df_eq if col == "__EQUIPMENT_SPLIT__" else df
    w = base.copy()
    if col in w.columns:
        w[col] = w[col].astype(str).str.strip()
        w.loc[w[col].isin(["", "nan", "None", "NULL", "none", "null"]), col] = np.nan
        w[col] = w[col].fillna("Unspecified")
    return w, col

def select_top_by_volume(w: pd.DataFrame, col: str, k: int = TOP_K) -> pd.DataFrame:
    counts = w[col].value_counts()
    if len(counts) > k:
        w = w[w[col].isin(counts.iloc[:k].index)]
    return w

def make_clustered_pivot(label: str):
    """
    Returns:
      order_levels: list of category levels kept (top-K by volume)
      pivot_pct: DataFrame indexed by levels, columns=bin_labels, values=percentages (filled 0)
      pivot_cnt: same shape with counts
      pivot_med: same shape with medians (NaN where empty)
    """
    w, col = get_working_df(label)
    if col not in w.columns or w.empty:
        return [], pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    w = select_top_by_volume(w, col, TOP_K)

    w["Delay_Bin"] = pd.cut(w[VALUE_COL], bins=bins, labels=bin_labels, right=False, include_lowest=True)

    gb = w.groupby([col, "Delay_Bin"], observed=False)[VALUE_COL]
    agg = gb.agg(count="count", median="median").reset_index()

    totals = w.groupby(col, observed=False)[VALUE_COL].count().rename("total").reset_index()
    agg = agg.merge(totals, on=col, how="left")
    agg["pct"] = (agg["count"] / agg["total"] * 100).astype(float)

    order_levels = totals.sort_values("total", ascending=False)[col].astype(str).tolist()

    pivot_pct = agg.pivot(index=col, columns="Delay_Bin", values="pct").reindex(index=order_levels, columns=bin_labels).fillna(0.0)
    pivot_cnt = agg.pivot(index=col, columns="Delay_Bin", values="count").reindex(index=order_levels, columns=bin_labels).fillna(0).astype(int)
    pivot_med = agg.pivot(index=col, columns="Delay_Bin", values="median").reindex(index=order_levels, columns=bin_labels)
    pivot_med = pivot_med.round(1)

    return order_levels, pivot_pct, pivot_cnt, pivot_med

# Build figure with dropdown
fig = go.Figure()
buttons, masks = [], []
trace_offset = 0
bundle = {}  # label -> (order_levels, pct, cnt, med)

for label in CATEGORY_MAP.keys():
    order_levels, pivot_pct, pivot_cnt, pivot_med = make_clustered_pivot(label)
    bundle[label] = (order_levels, pivot_pct, pivot_cnt, pivot_med)

    if len(order_levels) == 0:
        masks.append([False] * trace_offset)
        continue

    for b in bin_labels:
        y_vals = pivot_pct[b].values
        c_vals = pivot_cnt[b].values
        m_vals = pivot_med[b].values  # may contain NaN
        med_disp = np.where(np.isnan(m_vals), "", np.round(m_vals, 1))

        fig.add_trace(go.Bar(
            x=order_levels,
            y=y_vals,
            customdata=np.c_[c_vals, med_disp],
            name=b,
            hovertemplate=(
                "Delayed Minutes (Knife-to-Skin): <b>" + b + "</b><br>"
                "%{x}<br>"
                "Pct late: %{y:.1f}%<br>"
                "Cases: %{customdata[0]}<br>"
                "Median: %{customdata[1]} min<extra></extra>"
            )
        ))

    n_new = len(bin_labels)
    masks.append([False] * trace_offset + [True] * n_new)
    trace_offset += n_new

# Pad masks to full trace count
total_traces = trace_offset
masks = [m + [False] * (total_traces - len(m)) for m in masks]

# Dropdown buttons
for label, mask in zip(CATEGORY_MAP.keys(), masks):
    buttons.append(dict(
        label=label,
        method="update",
        args=[
            {"visible": mask},
            {
                "title": f"Late Cases — % by Delay Range for {label} (Top 10 by volume)",
                "yaxis": {"title": "% of Late Cases"},
                "xaxis": {"title": label},
                "barmode": "group",
                "legend": {"title": {"text": "Delayed Minutes (Knife-to-Skin)"}},
            },
        ]
    ))

# ===== Initial visibility fix: show ALL 5 bins for Delay_Category on first load =====
# Find index of Delay_Category within CATEGORY_MAP and apply its mask
labels_list = list(CATEGORY_MAP.keys())
initial_idx = labels_list.index("Delay_Category") if "Delay_Category" in labels_list else 0

# If that category had no data (mask all False), fall back to the first non-empty mask
if not any(masks[initial_idx]):
    for i, m in enumerate(masks):
        if any(m):
            initial_idx = i
            break

# Apply initial visibility mask
for i in range(total_traces):
    fig.data[i].visible = masks[initial_idx][i]

# Set initial layout to match the initial category
initial_label = labels_list[initial_idx]
fig.update_layout(
    title=f"Late Cases — % by Delay Range for {initial_label} (Top 10 by volume)",
    yaxis_title="% of Late Cases",
    barmode="group",
    legend_title_text="Delayed Minutes (Knife-to-Skin)",
    updatemenus=[dict(
        type="dropdown",
        x=1.0, xanchor="right",  
        y=1.0, yanchor="top",
        buttons=buttons,
        showactive=True
    )],
    margin=dict(l=60, r=30, t=70, b=80)
)

fig.show()

In [28]:
# --- Median KNIFE_START_DELAY bar chart with dropdown, dynamic y-axis ---
import pandas as pd
import numpy as np
import plotly.graph_objects as go

# ====== CONFIG ======
DATA_PATH = "C:/Users/egtay/Downloads/Telegram Desktop/Final_Cleaned_Dataset_OPTIC_7.csv"
VALUE_COL = "KNIFE_START_DELAY"
TOP_K     = 10

CATEGORY_MAP = {
    "Delay_Category": "Delay_Category",
    "LOCATION": "LOCATION",
    "ROOM": "ROOM",
    "EQUIPMENT": "__EQUIPMENT_SPLIT__",
    "EMERGENCY_PRIORITY": "EMERGENCY_PRIORITY",
    "DISCIPLINE": "DISCIPLINE",
    "ANESTHESIA": "ANESTHESIA",
    "ADMISSION_CLASS_TYPE": "ADMISSION_CLASS_TYPE",
    "ADMISSION_WARD": "ADMISSION_WARD",
    "ADMISSION_BED": "ADMISSION_BED",
    "AOH": "AOH",
    "BLOOD": "BLOOD",
    "IMPLANT": "__IMPLANT_BIN__",
    "CANCER_INDICATOR": "CANCER_INDICATOR",
    "TRAUMA_INDICATOR": "TRAUMA_INDICATOR",
}

# ====== LOAD ======
df = pd.read_csv(DATA_PATH)
df[VALUE_COL] = pd.to_numeric(df[VALUE_COL], errors="coerce")
df = df.dropna(subset=[VALUE_COL]).copy()

# ====== EQUIPMENT split ======
if "EQUIPMENT" in df.columns:
    eq = (
        df[["OPERATION_ID", "EQUIPMENT"]]
        .assign(EQUIPMENT=lambda d: d["EQUIPMENT"].astype(str).fillna("").str.split(","))
        .explode("EQUIPMENT")
    )
    eq["EQUIPMENT"] = eq["EQUIPMENT"].astype(str).str.strip()
    eq.loc[eq["EQUIPMENT"].isin(["", "nan", "None", "NULL"]), "EQUIPMENT"] = np.nan
    eq = eq.dropna(subset=["EQUIPMENT"])
    df_eq = df.merge(eq, on="OPERATION_ID", how="left", suffixes=("", "_x"))
    df_eq["__EQUIPMENT_SPLIT__"] = df_eq["EQUIPMENT_x"]
else:
    df_eq = df.copy()
    df_eq["__EQUIPMENT_SPLIT__"] = np.nan

# ====== IMPLANT binary ======
if "IMPLANT" in df.columns:
    def to_implant_bin(s):
        try:
            v = float(s)
            return "No implant" if v == 0 else "Has implant"
        except:
            return "Has implant"
    df["__IMPLANT_BIN__"] = df["IMPLANT"].astype(str).map(to_implant_bin)
else:
    df["__IMPLANT_BIN__"] = np.nan

# Helpers
def get_working_df(label: str):
    col = CATEGORY_MAP[label]
    return (df_eq if col == "__EQUIPMENT_SPLIT__" else df).copy(), col

def select_top_by_volume(w: pd.DataFrame, col: str, k: int = TOP_K) -> pd.DataFrame:
    w = w.dropna(subset=[col]).copy()
    w[col] = w[col].astype(str).str.strip()
    counts = w[col].value_counts()
    if len(counts) > k:
        w = w[w[col].isin(counts.iloc[:k].index)]
    return w

def make_agg(label: str) -> pd.DataFrame:
    w, col = get_working_df(label)
    if col not in w.columns:
        return pd.DataFrame(columns=[col, "median_delay", "count"])
    w = select_top_by_volume(w, col, TOP_K)
    agg = (
        w.groupby(col)[VALUE_COL]
         .agg(median_delay="median", count="count")
         .reset_index()
    )
    agg = agg.sort_values("median_delay", ascending=False)
    agg[col] = agg[col].astype(str)
    return agg

# Build figure
fig = go.Figure()
buttons, masks = [], []
trace_offset = 0
all_aggs = {}

for label in CATEGORY_MAP.keys():
    agg = make_agg(label)
    all_aggs[label] = agg
    if agg.empty:
        masks.append([False] * trace_offset)
        continue

    cat_col = CATEGORY_MAP[label]
    fig.add_trace(go.Bar(
        x=agg[cat_col],
        y=agg["median_delay"],
        customdata=np.c_[agg["count"]],
        text=agg["median_delay"].round(1).astype(str),
        textposition="outside",
        hovertemplate=(
            f"<b>{label}</b>: %{{x}}<br>"
            "Median: %{y:.1f} min<br>"
            "Cases: %{customdata[0]}<extra></extra>"
        ),
        name=label
    ))

    masks.append([False] * trace_offset + [True])
    trace_offset += 1

# pad masks
total = trace_offset
masks = [m + [False] * (total - len(m)) for m in masks]

# dropdown
for label, mask in zip(CATEGORY_MAP.keys(), masks):
    agg = all_aggs[label]
    if not agg.empty:
        y_min, y_max = agg["median_delay"].min(), agg["median_delay"].max()
        # dynamic range logic
        lower = 0 if y_min >= 0 else y_min - 20
        upper = 100 if y_max <= 100 and y_min >= 0 else y_max + 20
        yr = [lower, upper]
    else:
        yr = [0, 100]

    buttons.append(dict(
        label=label,
        method="update",
        args=[
            {"visible": mask},
            {
                "title": f"Median KNIFE_START_DELAY by {label} (Top 10)",
                "yaxis": {"range": yr},
                "xaxis": {"title": label},
            },
        ]
    ))

# init view
if masks:
    for i, tr in enumerate(fig.data):
        tr.visible = masks[0][i]

fig.update_layout(
    title=f"Median KNIFE_START_DELAY by {list(CATEGORY_MAP.keys())[0]} (Top 10)",
    yaxis_title="Median KNIFE_START_DELAY (minutes)",
    bargap=0.25,
    updatemenus=[dict(
        type="dropdown",
        x=1.0, xanchor="right",    
        y=1.15, yanchor="top",     
        buttons=buttons,
        showactive=True
    )],
    margin=dict(l=60, r=30, t=70, b=80)
)


# default y-range for initial category
if not all_aggs[list(CATEGORY_MAP.keys())[0]].empty:
    y_min, y_max = all_aggs[list(CATEGORY_MAP.keys())[0]]["median_delay"].min(), all_aggs[list(CATEGORY_MAP.keys())[0]]["median_delay"].max()
    lower = 0 if y_min >= 0 else y_min - 20
    upper = 100 if y_max <= 100 and y_min >= 0 else y_max + 20
    fig.update_yaxes(range=[lower, upper])
else:
    fig.update_yaxes(range=[0, 100])

fig.show()

In [29]:
# --- Box plot for KNIFE_START_DELAY across multiple categories with dropdown ---
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ====== CONFIG ======
DATA_PATH = "C:/Users/egtay/Downloads/Telegram Desktop/Final_Cleaned_Dataset_OPTIC_7.csv"
VALUE_COL = "KNIFE_START_DELAY"
LATE_COL  = "Is_Late"
TOP_K     = 10

# Fixed y-range and filter bounds
Y_MIN, Y_MAX = -1440, 1440

CATEGORY_MAP = {
    "Delay_Category": "Delay_Category",
    "LOCATION": "LOCATION",
    "ROOM": "ROOM",
    "EQUIPMENT": "__EQUIPMENT_SPLIT__",
    "EMERGENCY_PRIORITY": "EMERGENCY_PRIORITY",
    "DISCIPLINE": "DISCIPLINE",
    "ANESTHESIA": "ANESTHESIA",
    "ADMISSION_CLASS_TYPE": "ADMISSION_CLASS_TYPE",
    "ADMISSION_WARD": "ADMISSION_WARD",
    "ADMISSION_BED": "ADMISSION_BED",
    "AOH": "AOH",
    "BLOOD": "BLOOD",
    "IMPLANT": "__IMPLANT_BIN__",
    "CANCER_INDICATOR": "CANCER_INDICATOR",
    "TRAUMA_INDICATOR": "TRAUMA_INDICATOR",
}

EMERGENCY_PRIORITY_ORDER = ["0", "P0", "P1", "P2A", "P2B", "P3B"]

# ====== LOAD ======
df = pd.read_csv(DATA_PATH)
df[VALUE_COL] = pd.to_numeric(df[VALUE_COL], errors="coerce")
df = df.dropna(subset=[VALUE_COL]).copy()
if LATE_COL not in df.columns:
    df[LATE_COL] = 0
df[LATE_COL] = pd.to_numeric(df[LATE_COL], errors="coerce").fillna(0).astype(int)

# ---- EQUIPMENT preprocessing ----
if "EQUIPMENT" in df.columns:
    eq = (
        df[["OPERATION_ID", "EQUIPMENT"]]
        .assign(EQUIPMENT=lambda d: d["EQUIPMENT"].astype(str).fillna("").str.split(","))
        .explode("EQUIPMENT")
    )
    eq["EQUIPMENT"] = eq["EQUIPMENT"].astype(str).str.strip()
    eq.loc[eq["EQUIPMENT"].isin(["", "nan", "None", "NULL"]), "EQUIPMENT"] = np.nan
    eq = eq.dropna(subset=["EQUIPMENT"])
    df_eq = df.merge(eq, on="OPERATION_ID", how="left", suffixes=("", "_x"))
    df_eq["__EQUIPMENT_SPLIT__"] = df_eq["EQUIPMENT_x"]
else:
    df_eq = df.copy()
    df_eq["__EQUIPMENT_SPLIT__"] = np.nan  # will drop as NaN later

# ---- IMPLANT preprocessing ----
if "IMPLANT" in df.columns:
    def to_implant_bin(s):
        try:
            v = float(s)
            return "No implant" if v == 0 else "Has implant"
        except:
            return "Has implant"
    df["__IMPLANT_BIN__"] = df["IMPLANT"].astype(str).map(to_implant_bin)
else:
    df["__IMPLANT_BIN__"] = np.nan

def get_category_df(label):
    col = CATEGORY_MAP[label]
    return (df_eq if col == "__EQUIPMENT_SPLIT__" else df).copy(), col

def filter_top_levels(working, cat_col, top_k=TOP_K):
    w = working.dropna(subset=[cat_col]).copy()
    w[cat_col] = w[cat_col].astype(str).str.strip()
    counts = w[cat_col].value_counts()
    if len(counts) > top_k:
        w = w[w[cat_col].isin(counts.iloc[:top_k].index)]
    return w

def clamp_value_range(frame, col, lo=Y_MIN, hi=Y_MAX):
    return frame[(frame[col] >= lo) & (frame[col] <= hi)]

fig = go.Figure()
buttons, visibility_blocks = [], []
trace_offset = 0

for label in CATEGORY_MAP.keys():
    working, cat_col = get_category_df(label)
    if cat_col not in working.columns:
        visibility_blocks.append([False]*trace_offset)
        continue

    # Exclude "No Delay" for Delay_Category BEFORE Top-K
    if label == "Delay_Category":
        mask = working[cat_col].astype(str).str.strip().str.lower() != "no delay"
        working = working[mask]

    # Choose Top-K categories (by volume)
    sub = filter_top_levels(working, cat_col)
    # Filter out-of-range values
    sub = clamp_value_range(sub, VALUE_COL, Y_MIN, Y_MAX)

    if sub.empty:
        visibility_blocks.append([False]*trace_offset)
        continue

    # Category order
    if label == "EMERGENCY_PRIORITY":
        # Normalize for matching, then keep the desired order subset
        present = set(sub[cat_col].astype(str).str.upper().str.strip().unique())
        order = [lvl for lvl in EMERGENCY_PRIORITY_ORDER if lvl.upper() in present]
    else:
        order = sub[cat_col].value_counts().index.tolist()

    # Only box plot (no dots)
    fig_box = px.box(
        sub,
        x=cat_col, y=VALUE_COL,
        points=False,
        category_orders={cat_col: order}
    )
    for tr in fig_box.data:
        fig.add_trace(tr)

    n_new = len(fig_box.data)
    visibility = [False]*trace_offset + [True]*n_new
    visibility_blocks.append(visibility)
    trace_offset += n_new

total_traces = trace_offset
visibility_blocks = [vb + [False]*(total_traces - len(vb)) for vb in visibility_blocks]

for label, vb in zip(CATEGORY_MAP.keys(), visibility_blocks):
    excl = " — excluding No Delay" if label == "Delay_Category" else ""
    buttons.append(dict(
        label=label,
        method="update",
        args=[
            {"visible": vb},
            {
                "title": f"KNIFE_START_DELAY — Box Plot by {label} (Top {TOP_K}){excl}",
                "yaxis": {"range": [Y_MIN, Y_MAX]},
            },
        ],
    ))

# Initial state: first label
initial_viz = visibility_blocks[0] if visibility_blocks else []
for idx, tr in enumerate(fig.data):
    tr.visible = initial_viz[idx] if idx < len(initial_viz) else False

fig.update_layout(
    title=f"KNIFE_START_DELAY — Box Plot by {list(CATEGORY_MAP.keys())[0]} (Top {TOP_K})",
    yaxis_title="KNIFE_START_DELAY (minutes)",
    updatemenus=[dict(
        type="dropdown", x=1, xanchor="right", y=1, yanchor="top",
        buttons=buttons, showactive=True
    )],
    margin=dict(l=60, r=30, t=70, b=80),
)

fig.update_yaxes(range=[Y_MIN, Y_MAX])
fig.show()


KeyboardInterrupt: 

In [30]:
# --- Stacked bar: total cases by category with On-time vs Late  ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

# ===== CONFIG =====
DATA_PATH = "C:/Users/egtay/Downloads/Telegram Desktop/Final_Cleaned_Dataset_OPTIC_7.csv"
VALUE_COL = "KNIFE_START_DELAY"
LATE_COL  = "Is_Late"
TOP_K     = 10
Y_MIN, Y_MAX = -1440, 1440

CATEGORY_MAP = {
    "LOCATION": "LOCATION",
    "ROOM": "ROOM",
    "EQUIPMENT": "__EQUIPMENT_SPLIT__",
    "EMERGENCY_PRIORITY": "EMERGENCY_PRIORITY",
    "DISCIPLINE": "DISCIPLINE",
    "ANESTHESIA": "ANESTHESIA",
    "ADMISSION_CLASS_TYPE": "ADMISSION_CLASS_TYPE",
    "ADMISSION_WARD": "ADMISSION_WARD",
    "ADMISSION_BED": "ADMISSION_BED",
    "AOH": "AOH",
    "BLOOD": "BLOOD",
    "IMPLANT": "__IMPLANT_BIN__",
    "CANCER_INDICATOR": "CANCER_INDICATOR",
    "TRAUMA_INDICATOR": "TRAUMA_INDICATOR",
}

# ===== LOAD =====
df = pd.read_csv(DATA_PATH)
df[VALUE_COL] = pd.to_numeric(df[VALUE_COL], errors="coerce")
df = df.dropna(subset=[VALUE_COL]).copy()
if LATE_COL not in df.columns:
    df[LATE_COL] = 0
df[LATE_COL] = pd.to_numeric(df[LATE_COL], errors="coerce").fillna(0).astype(int)

# ---- EQUIPMENT preprocessing ----
if "EQUIPMENT" in df.columns:
    eq = (
        df[["OPERATION_ID", "EQUIPMENT"]]
        .assign(EQUIPMENT=lambda d: d["EQUIPMENT"].astype(str).fillna("").str.split(","))
        .explode("EQUIPMENT")
    )
    eq["EQUIPMENT"] = eq["EQUIPMENT"].astype(str).str.strip()
    eq.loc[eq["EQUIPMENT"].isin(["", "nan", "None", "NULL"]), "EQUIPMENT"] = np.nan
    eq = eq.dropna(subset=["EQUIPMENT"])
    df_eq = df.merge(eq, on="OPERATION_ID", how="left", suffixes=("", "_x"))
    df_eq["__EQUIPMENT_SPLIT__"] = df_eq["EQUIPMENT_x"]
else:
    df_eq = df.copy()
    df_eq["__EQUIPMENT_SPLIT__"] = np.nan

# ---- IMPLANT preprocessing ----
if "IMPLANT" in df.columns:
    def to_implant_bin(s):
        try:
            v = float(s)
            return "No implant" if v == 0 else "Has implant"
        except:
            return "Has implant"
    df["__IMPLANT_BIN__"] = df["IMPLANT"].astype(str).map(to_implant_bin)
else:
    df["__IMPLANT_BIN__"] = np.nan

# ===== Helpers =====
def get_category_df(label: str):
    col = CATEGORY_MAP[label]
    return (df_eq if col == "__EQUIPMENT_SPLIT__" else df).copy(), col

def clamp_value_range(frame, col, lo=Y_MIN, hi=Y_MAX):
    return frame[(frame[col] >= lo) & (frame[col] <= hi)]

def excluded_value_for_label(label: str) -> str:
    """Return the category value to exclude/annotate for a given label."""
    if label == "ADMISSION_BED":
        return "Not Admitted"
    elif label == "BLOOD":
        return "NIL"
    else:
        return "0"

def filter_top_levels_excluding_value(working, cat_col, label_name, top_k=TOP_K):
    """Exclude special value, keep for annotation stats."""
    w = working.dropna(subset=[cat_col]).copy()
    w[cat_col] = w[cat_col].astype(str).str.strip()
    excl = excluded_value_for_label(label_name)
    w_nonexcl = w[w[cat_col] != excl].copy()
    counts = w_nonexcl[cat_col].value_counts()
    if len(counts) > top_k:
        keep = counts.iloc[:top_k].index
        w_nonexcl = w_nonexcl[w_nonexcl[cat_col].isin(keep)]
    return w_nonexcl, w  # return also the full set (incl. excluded) for annotation

def build_excluded_annotation(cat_df_incl_excl, cat_col, label_name):
    """Return annotation if excluded value exists for this label."""
    if cat_df_incl_excl.empty:
        return []

    cat_series = cat_df_incl_excl[cat_col].astype(str).str.strip()
    excl_val = excluded_value_for_label(label_name)
    mask = (cat_series == excl_val)
    excl_count = int(mask.sum())
    if excl_count == 0:
        return []

    total_x = excl_count
    late_x = int(cat_df_incl_excl.loc[mask, LATE_COL].sum())
    late_pct_x = 100.0 * late_x / total_x if total_x > 0 else 0.0

    grand_total = int(cat_df_incl_excl.shape[0])
    overall_pct_x = 100.0 * total_x / grand_total if grand_total > 0 else 0.0

    title = f"{excl_val}: Total Cases: {total_x}   Late Cases: {late_x} ({late_pct_x:.1f}%)"
    foot  = f"*{overall_pct_x:.1f}% of total cases in {excl_val}"

    ann = dict(
        text=title + "<br>" + foot,
        xref="paper", yref="paper",
        x=0, y=1.08,                 # put above the graph
        xanchor="left", yanchor="bottom",
        align="left",
        showarrow=False,
        font=dict(size=12),
        bgcolor="rgba(255,255,255,0.85)",
        bordercolor="rgba(0,0,0,0.25)",
        borderwidth=1,
        borderpad=4,
        name="excluded_note"
    )
    return [ann]

# ===== Build stacked bars =====
fig_stacked = go.Figure()
buttons_s, visibility_blocks_s, annotations_by_label = [], [], []
trace_offset_s = 0
COLOR_MAP = {"On-time": "#9ecae1", "Late": "#08519c"}

for label in CATEGORY_MAP.keys():
    working, cat_col = get_category_df(label)
    if cat_col not in working.columns:
        visibility_blocks_s.append([False]*trace_offset_s)
        annotations_by_label.append([])
        continue

    working = clamp_value_range(working, VALUE_COL, Y_MIN, Y_MAX)

    sub_nonexcl, sub_incl_excl = filter_top_levels_excluding_value(working, cat_col, label, TOP_K)

    if sub_nonexcl.empty:
        visibility_blocks_s.append([False]*trace_offset_s)
        annotations_by_label.append(build_excluded_annotation(sub_incl_excl, cat_col, label))
        continue

    summary = (
        sub_nonexcl.groupby(cat_col)
        .agg(total_cases=(VALUE_COL, "size"),
             late_cases=(LATE_COL, "sum"))
        .reset_index()
    )
    summary["ontime_cases"] = summary["total_cases"] - summary["late_cases"]
    summary["late_pct"]   = (summary["late_cases"]   / summary["total_cases"] * 100).round(1)
    summary["ontime_pct"] = (summary["ontime_cases"] / summary["total_cases"] * 100).round(1)
    grand_total = int(sub_incl_excl.shape[0])
    summary["overall_pct"] = (summary["total_cases"] / grand_total * 100).round(1)

    long_df = pd.concat([
        summary[[cat_col, "ontime_cases", "ontime_pct", "total_cases", "overall_pct"]]
            .assign(Status="On-time")
            .rename(columns={"ontime_cases": "count", "ontime_pct": "pct"}),
        summary[[cat_col, "late_cases",   "late_pct",   "total_cases", "overall_pct"]]
            .assign(Status="Late")
            .rename(columns={"late_cases": "count", "late_pct": "pct"}),
    ], ignore_index=True)

    order = summary.sort_values("total_cases", ascending=False)[cat_col].tolist()

    fig_bar = px.bar(
        long_df,
        x=cat_col, y="count",
        color="Status",
        category_orders={cat_col: order, "Status": ["On-time", "Late"]},
        color_discrete_map=COLOR_MAP,
        text=long_df["pct"].map(lambda v: f"{v:.1f}%"),
        hover_data=None,
        custom_data=["Status", "pct", "total_cases", "overall_pct"],
    )
    fig_bar.update_traces(textposition="inside")
    fig_bar.update_layout(barmode="stack")

    hovertemplate = (
        "%{x} in %{meta}<br>"
        "Status: %{customdata[0]}<br>"
        "Count: %{y}<br>"
        "Share within %{x}: %{customdata[1]:.1f}% of total in %{x}<br>"
        "Total in %{x} as % of ALL in %{meta}: %{customdata[3]:.1f}%"
        "<extra></extra>"
    )

    for tr in fig_bar.data:
        tr.hovertemplate = hovertemplate
        tr.meta = label
        fig_stacked.add_trace(tr)

    n_new = len(fig_bar.data)
    visibility = [False]*trace_offset_s + [True]*n_new
    visibility_blocks_s.append(visibility)
    trace_offset_s += n_new

    annotations_by_label.append(build_excluded_annotation(sub_incl_excl, cat_col, label))

# Pad visibilities
total_traces_s = trace_offset_s
visibility_blocks_s = [vb+[False]*(total_traces_s-len(vb)) for vb in visibility_blocks_s]

# Dropdown buttons
for label, vb, anns in zip(CATEGORY_MAP.keys(), visibility_blocks_s, annotations_by_label):
    buttons_s.append(dict(
        label=label,
        method="update",
        args=[
            {"visible": vb},
            {
                "title": f"Cases by {label} (Top {TOP_K}) — On-time vs Late",
                "yaxis": {"title": "Number of Cases"},
                "barmode": "stack",
                "legend": {"title": "Status"},
                "annotations": anns
            },
        ],
    ))

# Initial view = LOCATION
initial_label = "LOCATION"
labels_order = list(CATEGORY_MAP.keys())
initial_idx = labels_order.index(initial_label) if initial_label in labels_order else 0
initial_viz = visibility_blocks_s[initial_idx] if visibility_blocks_s else []
for i, tr in enumerate(fig_stacked.data):
    tr.visible = initial_viz[i] if i < len(initial_viz) else False

fig_stacked.update_layout(
    title=f"Cases by {initial_label} (Top {TOP_K}) — On-time vs Late",
    yaxis_title="Number of Cases",
    updatemenus=[dict(
        type="dropdown", x=1, xanchor="right", y=1.08, yanchor="top",
        buttons=buttons_s, showactive=True
    )],
    margin=dict(l=60, r=30, t=70, b=80),
    barmode="stack",
    annotations=annotations_by_label[initial_idx] if annotations_by_label else []
)

fig_stacked.show()