In [1]:
#DATA PREPROCESSING

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import umap
import hdbscan
import gc
import torch
from transformers import LEDTokenizer, LEDModel
from bertopic import BERTopic

# Global config
PATH_X1 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/data_Hsp12.feather"
PATH_X2 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvd_metric.csv"
PATH_X3 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvd_metric_raw.csv"
PATH_METRIC_DEF = "/Users/jingyi/Desktop/Trauma_LLM/metric_def.xlsx"

EMB_MODEL_NAME = "allenai/led-base-16384"
MAX_TOKENS = 16000
N_COMPONENTS_CLUST = 10   # dimension need to be changed
N_NEIGHBORS = 15

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print("Device:", DEVICE)

Device: mps


In [2]:
# Read data
X2 = pd.read_csv(PATH_X2)
X3 = pd.read_csv(PATH_X3)

# Compare column names
x2_cols = set(X2.columns)
x3_cols = set(X3.columns)

only_in_X2 = sorted(x2_cols - x3_cols)
only_in_X3 = sorted(x3_cols - x2_cols)
common_cols = sorted(x2_cols & x3_cols)

print("Columns only in X2:")
print(only_in_X2 if only_in_X2 else "None")

print("\nColumns only in X3:")
print(only_in_X3 if only_in_X3 else "None")

#print("\nCommon columns:")
#print(common_cols if common_cols else "None")


Columns only in X2:
['QI']

Columns only in X3:
['mtrc102', 'mtrc106', 'mtrc14', 'mtrc166', 'mtrc168', 'mtrc233', 'mtrc242', 'mtrc242001', 'mtrc25', 'mtrc29', 'mtrc80', 'mtrc85', 'mtrc96', 'mtrcNA']


In [3]:
cols_only_in_x3 = [
    "mtrc102", "mtrc106", "mtrc14", "mtrc166", "mtrc168", "mtrc233",
    "mtrc242", "mtrc242001", "mtrc25", "mtrc29", "mtrc80", "mtrc85",
    "mtrc96", "mtrcNA"
]

X3 = X3.drop(columns=cols_only_in_x3, errors="ignore")
X2 = X3


In [4]:
X1 = pd.read_feather(PATH_X1)
print("Original X1 shape:", X1.shape)

cols_to_drop = [
    'mpp_121', 'mpp_125', 'mpp_16', 'mpp_168', 'mpp_236', 'mpp_242', 'mpp_71',
    'proc_01_icd', 'proc_02_icd', 'proc_03_icd', 'proc_04_icd', 'proc_05_icd',
    'proc_06_icd', 'proc_07_icd', 'proc_08_icd', 'proc_09_icd', 'proc_10_icd',
    'proc_11_icd', 'proc_12_icd', 'proc_13_icd', 'proc_14_icd', 'proc_15_icd',
    'proc_16_icd', 'proc_17_icd', 'proc_18_icd', 'proc_19_icd', 'proc_20_icd',
    'proc_21_icd', 'proc_22_icd', 'proc_23_icd', 'proc_24_icd', 'proc_25_icd',
    'proc_26_icd', 'proc_27_icd', 'proc_28_icd', 'proc_29_icd', 'proc_30_icd',
    'proc_31_icd', 'proc_32_icd', 'proc_33_icd', 'proc_34_icd', 'proc_35_icd',
    'proc_36_icd', 'proc_37_icd', 'proc_38_icd', 'proc_39_icd', 'proc_40_icd',
    'proc_41_icd', 'proc_42_icd', 'proc_43_icd', 'proc_44_icd', 'proc_45_icd',
    'proc_46_icd', 'proc_47_icd', 'proc_48_icd', 'proc_49_icd', 'proc_50_icd',
    'proc_51_icd', 'proc_52_icd', 'proc_53_icd', 'proc_54_icd', 'proc_55_icd',
    'proc_56_icd', 'proc_57_icd', 'proc_58_icd', 'proc_59_icd', 'proc_60_icd',
    'proc_61_icd', 'proc_62_icd', 'proc_63_icd', 'proc_64_icd', 'proc_65_icd',
    'proc_66_icd', 'proc_67_icd', 'proc_68_icd', 'proc_69_icd', 'proc_70_icd',
    'proc_71_icd', 'proc_72_icd', 'proc_73_icd', 'proc_74_icd', 'proc_75_icd',
    'proc_76_icd', 'proc_77_icd', 'proc_78_icd', 'proc_79_icd', 'proc_80_icd',
    'proc_81_icd', 'proc_82_icd', 'proc_83_icd', 'proc_84_icd', 'fac_key', 'disp_tx', 'scene_tx', 'leave_tx',
    'gcs40eye_s',  'gcs40ver_s', 'gcs40mot_s',  'gcs40eye_r', 'gcs40ver_r','gcs40mot_r'
]
cols_to_drop += [f"ais_sev_{i:02d}" for i in range(1, 28)]
cols_to_drop += [f"icd9_{i:02d}" for i in range(1, 28)]

# predot_01 ... predot_27
cols_to_drop += [f"predot_{i:02d}" for i in range(1, 28)]

# proc_01 ... proc_84
cols_to_drop += [f"proc_{i:02d}" for i in range(1, 85)]

# ais_01 ... ais_27
cols_to_drop += [f"ais_{i:02d}" for i in range(1, 28)]

X1 = X1.drop(columns=cols_to_drop, errors="ignore")
print("X1 shape after dropping cols:", X1.shape)

# 1b. Read X2 csv
#X2 = pd.read_csv(PATH_X2)
#X2 = X2.drop(columns="QI", errors="ignore")
# Align X2 rows so that inc_key order matches X1
if "inc_key" not in X1.columns or "inc_key" not in X2.columns:
    raise ValueError("Both X1 and X2 must contain 'inc_key' column.")

X2 = X2.set_index("inc_key").reindex(X1["inc_key"]).reset_index()
print("Aligned X2 shape:", X2.shape)

# Assumes X1 and X2 are already loaded and aligned as in your existing code.
X2_features_only = X2.drop(columns=["inc_key"])
X4 = pd.concat([X1, X2_features_only], axis=1)
print("X4 shape:", X4.shape)

Original X1 shape: (103362, 1571)
X1 shape after dropping cols: (103362, 1278)
Aligned X2 shape: (103362, 116)
X4 shape: (103362, 1393)


In [5]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize

# Load your data (Your existing code)
X1_emb = np.load("/Users/jingyi/Desktop/Trauma_LLM/all_patient/originaldata_LED_emb_103362.npy")
X2_emb = np.load("/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvdmetric_LED_emb.npy")

X1_inc_key = pd.read_csv("/Users/jingyi/Desktop/Trauma_LLM/all_patient/originaldata_inc_key.csv")
X2_inc_key = pd.read_csv("/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvdmetric_inc_key.csv")

print("Original shapes:")
print(X1_emb.shape, X1_inc_key.shape)
print(X2_emb.shape, X2_inc_key.shape)

# --- NORMALIZATION STEP ---

# L2 Normalization maps all vectors to the unit sphere
# This makes Euclidean distance equivalent to Cosine distance
print("\nNormalizing embeddings...")
X1_emb = normalize(X1_emb, norm='l2')
X2_emb = normalize(X2_emb, norm='l2')

print("Normalization complete.")

Original shapes:
(103362, 768) (103362, 1)
(103362, 768) (103362, 1)

Normalizing embeddings...
Normalization complete.


In [6]:
import numpy as np

data = np.load("X1_umap_hdbscan_outputs.npz")

X1_umap_d = data["X1_umap_d"]
X1_labels = data["X1_labels"]

print("Loaded X1_umap_d shape:", X1_umap_d.shape)
print("Loaded X1_labels shape:", X1_labels.shape)
print("Unique clusters (excluding -1):", set(X1_labels) - {-1})


Loaded X1_umap_d shape: (103362, 10)
Loaded X1_labels shape: (103362,)
Unique clusters (excluding -1): {np.int32(0), np.int32(1), np.int32(2), np.int32(3), np.int32(4), np.int32(5), np.int32(6), np.int32(7)}


In [7]:
  import numpy as np

  data = np.load("X4_umap_hdbscan_outputs.npz")

  X4_umap_d = data["X4_umap_d"]
  X4_labels = data["X4_labels"]

  print("Loaded X4_umap_d shape:", X4_umap_d.shape)
  print("Loaded X4_labels shape:", X4_labels.shape)
  print("Unique clusters (excluding -1):", set(X4_labels) - {-1})

Loaded X4_umap_d shape: (103362, 10)
Loaded X4_labels shape: (103362,)
Unique clusters (excluding -1): {np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6)}


In [8]:
# Divide metric data into cat and num
import re
import pandas as pd
metric_def = pd.read_excel(PATH_METRIC_DEF)

metric_def["Metric_ID"] = metric_def["Metric_ID"].astype(str).str.strip()
metric_def["Variable_type"] = metric_def["Variable_type"].astype(str).str.strip().str.lower()
metric_def["Description"] = metric_def["Description"].astype(str).str.strip()

type_by_id = dict(zip(metric_def["Metric_ID"], metric_def["Variable_type"]))
desc_by_id = dict(zip(metric_def["Metric_ID"], metric_def["Description"]))

num_cols = []
cat_cols = []
rename_map = {}

for col in X2.columns:
    m = re.search(r"mtrc(\d+)", col, flags=re.IGNORECASE)
    if not m:
        continue

    metric_id = m.group(1)
    var_type = type_by_id.get(metric_id, "")

    desc = desc_by_id.get(metric_id, "")
    new_col = f"{col}:{desc}" if desc else col
    rename_map[col] = new_col

    if var_type == "numeric":
        num_cols.append(new_col)
    elif var_type in {"binary", "count"}:
        cat_cols.append(new_col)

X2 = X2.rename(columns=rename_map)

In [9]:
  cell = """import os
  from pathlib import Path

  # Clean env that can break PyCall
  os.environ.pop("PYTHONHOME", None)
  os.environ.pop("PYTHONPATH", None)

  # Tell interpretableai where Julia + sysimage are
  os.environ["IAI_JULIA"] = "/Users/jingyi/Library/Application Support/InterpretableAI/julia/1.12.2/julia-1.12.2/bin/julia"
  os.environ["IAI_SYSTEM_IMAGE"] = "/Users/jingyi/Library/Application Support/InterpretableAI/sysimage/v3.2.2/sys.dylib"

  # Important: let IAI start Julia, but disable compiled modules
  os.environ["IAI_DISABLE_COMPILED_MODULES"] = "1"

  # License
  os.environ["IAI_LICENSE_FILE"] = "/Users/jingyi/Desktop/Trauma_LLM/iai.lic"

  # Now import
  from interpretableai import iai
  print("IAI import OK")
  """
  from pathlib import Path
  Path("iai_import_cell.py").write_text(cell)
  print("Saved: iai_import_cell.py  (run in notebook with: %run iai_import_cell.py)")



Saved: iai_import_cell.py  (run in notebook with: %run iai_import_cell.py)


In [10]:
%run iai_import_cell.py



IAI import OK


In [11]:
  # check the work
  # show first 10 rows
  display(X2.head(10))

Unnamed: 0,inc_key,"mtrc3:Time to first medical contact, min","mtrc4:Prehospital time, min","mtrc46:ICU length of stay, day","mtrc47:Length of stay, day","mtrc17:Time to cranial CT for patients with GCS < 14, min","mtrc28:Time to first emergent surgery, min","mtrc31:Time to surgery for patients in shock, min","mtrc172:Time to tracheostomy in SCI patients, min","mtrc255:Time to tracheostomy, min",...,mtrc71:MRI - 2 hours,mtrc250:EVD placement,mtrc16:Antibiotics for open fractures,mtrc16001:Antibiotics for open fractures within 24 hours,mtrc23:Activation of massive transfusion protovocl,"mtrc67:Convential radiology - in 15 min, level I/II; in 30 min, level III/IV","mtrc68:CT - in 15 min, level I/II; in 30 min, level III/IV",mtrc177:Percentage of severe TBI with other injury,mtrc187:Transfer rate of children with severe TBI,mtrc65:Orthopedic non-emergent availability
0,1,,,2.0,4.0,,4.0,,,,...,,,,,,0.0,0.0,0,,0
1,2,7.0,,,,,44.0,,,,...,,0.0,,,,0.0,0.0,1,,0
2,111,11.0,,,,,107.0,,,,...,,,,,,,,0,,0
3,112,11.0,,,,,81.0,,,,...,,,,,,,,0,,0
4,113,11.0,,3.0,4.0,,26.0,,,,...,,,,,,,,0,,0
5,114,,,,,,,,,,...,,,,,,,,0,,0
6,115,,,2.0,2.0,,188.0,,,,...,,0.0,,,,,,1,,0
7,116,13.0,,,,,133.0,,,,...,,,,,,,,0,,0
8,117,7.0,,1.0,5.0,,28.0,28.0,,,...,,0.0,,,,,,1,,0
9,118,,,2.0,2.0,,74.0,74.0,,,...,,,,,,,,0,,0


In [20]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


# =========================
# 0) Inputs + missing tokens
# =========================
MISSING_TOKENS = {
    "", " ", "  ", "\t", "\n", "\r",
    "na", "n/a", "nan", "null", "none", "nil",
    ".", "..", "...",
    "<unk>", "unk", "unknown", "missing", "nan", "na"
}

df_X1 = X1
df_X2 = X2
df_X4 = X4

In [21]:
# =========================
# 1) Preprocess X2 (outside function)
# =========================
X = df_X2.drop(columns=["inc_key"], errors="ignore").copy()

cat_cols = [c for c in cat_cols if c in X.columns]
num_cols = [c for c in num_cols if c in X.columns]

for col in cat_cols + num_cols:
    s = X[col]
    s_str = s.astype(str).str.strip().str.lower()
    miss = s.isna() | s_str.isin(MISSING_TOKENS)

    if col in cat_cols:
        X.loc[miss, col] = "MISSING"
        X[col] = X[col].astype(str)
    else:
        X[col] = pd.to_numeric(s, errors="coerce")
        miss2 = miss | X[col].isna()
        X.loc[miss2, col] = 0

# fix feature types
overlap = set(cat_cols) & set(num_cols)
if overlap:
    print("Removing from num_cols (categorical):", overlap)
    num_cols = [c for c in num_cols if c not in overlap]

for c in cat_cols:
    X[c] = X[c].astype("category")

for c in num_cols:
    X[c] = pd.to_numeric(X[c], errors="coerce")

In [23]:
# =========================
# 2) Align by inc_key (outside)
# =========================
df1 = df_X1[["inc_key"]].copy()
df1["x1_cluster"] = np.asarray(X1_labels).astype(int)

df4 = df_X4[["inc_key"]].copy()
df4["x4_cluster"] = np.asarray(X4_labels).astype(int)

if df1["inc_key"].duplicated().any():
    raise ValueError("df_X1 has duplicated inc_key. Deduplicate first.")
if df4["inc_key"].duplicated().any():
    raise ValueError("df_X4 has duplicated inc_key. Deduplicate first.")
if df_X2["inc_key"].duplicated().any():
    raise ValueError("df_X2 has duplicated inc_key. Deduplicate first.")

X_with_key = df_X2[["inc_key"]].copy()
X_with_key = X_with_key.join(X)

aligned = df1.merge(X_with_key, on="inc_key", how="inner").merge(df4, on="inc_key", how="inner")

In [26]:
# =========================
# 3) IAI per X1 cluster → predict X4
# =========================
def run_iai_per_x1_cluster(
    aligned_df: pd.DataFrame,
    cat_cols: list[str],
    num_cols: list[str],
    output_dir: str = "./iai_trees",
    include_noise: bool = False,
    min_class_count: int = 2,
    test_size: float = 0.2,
    random_state: int = 42,
    max_depth: int = 5,
):
    os.makedirs(output_dir, exist_ok=True)

    df = aligned_df.copy()
    if not include_noise:
        df = df[(df["x1_cluster"] != -1) & (df["x4_cluster"] != -1)].copy()

    feature_cols = [c for c in df.columns if c not in ["inc_key", "x1_cluster", "x4_cluster"]]
    used_cols = [c for c in feature_cols if c in (cat_cols + num_cols)]

    clusters = sorted(df["x1_cluster"].unique().tolist())
    results = []

    for c in clusters:
        sub = df[df["x1_cluster"] == c].copy()
        y = sub["x4_cluster"].astype(int)

        if y.nunique() < 2:
            results.append((c, len(sub), y.nunique(), "SKIP (no split)"))
            continue

        vc = y.value_counts()
        keep = vc[vc >= min_class_count].index
        sub = sub[y.isin(keep)].copy()
        y = sub["x4_cluster"].astype(int)

        if y.nunique() < 2:
            results.append((c, len(sub), y.nunique(), f"SKIP (after filtering <{min_class_count})"))
            continue

        X_sub = sub[used_cols]

        X_train, X_test, y_train, y_test = train_test_split(
            X_sub, y, test_size=test_size, random_state=random_state, stratify=y
        )
        grid = iai.GridSearch(iai.OptimalTreeClassifier(random_seed=1),max_depth=5,)
        grid.fit(X_train, y_train)
        learner = grid.get_learner()

        test_acc = learner.score(X_test, y_test)

        model_path = os.path.join(output_dir, f"iai_x1cluster_{c}_depth{max_depth}.json")
        learner.write_json(model_path)

        results.append((c, len(sub), y.nunique(), f"OK → acc={test_acc:.4f} | {model_path}"))

    return pd.DataFrame(results, columns=["x1_cluster", "n_samples_used", "num_x4_classes", "status"])

In [27]:
# =========================
# 4) Run
# =========================
summary = run_iai_per_x1_cluster(
    aligned_df=aligned,
    cat_cols=cat_cols,
    num_cols=num_cols,
    output_dir="./iai_trees_depth5",
    include_noise=False,
    min_class_count=2,
    max_depth=5,
)

print(summary)

[33m[1m│ [22m[39m- mtrc81:Craniofacial expertise
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mWe recommend extreme caution when using categoric features with many levels inside Optimal Trees, for more information and advice on how to handle such features, please refer to this link:
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mhttps://docs.interpretable.ai/dev/OptimalTrees/tips/#Categorical-Variables-with-Many-Levels
[33m[1m│ [22m[39m
[33m[1m│ [22m[39m- mtrc81:Craniofacial expertise
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mWe recommend extreme caution when using categoric features with many levels inside Optimal Trees, for more information and advice on how to handle such features, please refer to this link:
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mhttps://docs.interpretable.ai/dev/OptimalTrees/tips/#Categorical-Variables-with-Many-Levels
[33m[1m│ [22m[39m
[33m[1m│ [22m[39m- mtrc81:Craniofacial expertise
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mWe recommend extreme

   x1_cluster  n_samples_used  num_x4_classes  \
0           0           21683               4   
1           1            4505               2   
2           2            2450               1   
3           3             809               2   
4           4            8297               3   
5           5           27681               4   
6           6            4432               4   
7           7           33502               5   

                                              status  
0  OK → acc=0.6313 | ./iai_trees_depth5/iai_x1clu...  
1  OK → acc=0.6837 | ./iai_trees_depth5/iai_x1clu...  
2                                    SKIP (no split)  
3  OK → acc=0.9877 | ./iai_trees_depth5/iai_x1clu...  
4  OK → acc=0.8928 | ./iai_trees_depth5/iai_x1clu...  
5  OK → acc=0.7226 | ./iai_trees_depth5/iai_x1clu...  
6  OK → acc=0.7971 | ./iai_trees_depth5/iai_x1clu...  
7  OK → acc=0.7297 | ./iai_trees_depth5/iai_x1clu...  


In [39]:
import json
from interpretableai.iaibase import read_json
import os
import glob

model_dir = "./iai_trees_depth5"

for json_path in glob.glob(os.path.join(model_dir, "iai_x1cluster_*_depth5.json")):
    learner = read_json(json_path)
    html_path = json_path.replace(".json", ".html")
    learner.write_html(html_path)
    print("Wrote:", html_path)

Wrote: ./iai_trees_depth5/iai_x1cluster_4_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_1_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_7_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_0_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_5_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_6_depth5.html
Wrote: ./iai_trees_depth5/iai_x1cluster_3_depth5.html


In [40]:
import os
import re
import json
import glob
from typing import Any, Dict, List, Tuple, Optional


# --------- helpers to read JSON ----------
def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return json.load(f)


def first_key(d: Dict[str, Any], keys: List[str]) -> Optional[Any]:
    for k in keys:
        if k in d:
            return d[k]
    return None


# --------- tree structure handling ----------
def get_tree_root(data: Dict[str, Any]):
    # Try common layouts
    for key in ["tree", "model", "learner", "root"]:
        if key in data:
            data = data[key]
            break

    # Flat node list with root id
    if isinstance(data, dict) and "nodes" in data:
        nodes = data["nodes"]
        node_map = {}
        for i, n in enumerate(nodes):
            node_id = n.get("id", i)
            node_map[node_id] = n
        root_id = data.get("root", data.get("root_id", 0))
        if isinstance(root_id, dict):
            return root_id
        return node_map[root_id]

    # If already hierarchical
    return data


def get_children(node: Dict[str, Any]) -> Tuple[Optional[Dict], Optional[Dict]]:
    # Common child keys
    for lk, rk in [
        ("left", "right"),
        ("left_child", "right_child"),
        ("leftChild", "rightChild"),
        ("l", "r"),
    ]:
        if lk in node and rk in node:
            return node[lk], node[rk]

    # Sometimes children are indices and stored in "nodes"
    return None, None


def is_leaf(node: Dict[str, Any]) -> bool:
    if node is None:
        return True
    if "prediction" in node or "predicted_class" in node or "class" in node:
        return True
    left, right = get_children(node)
    return left is None and right is None


def get_prediction(node: Dict[str, Any]):
    for k in ["prediction", "predicted_class", "class", "value"]:
        if k in node:
            return node[k]
    return None


def get_prob(node: Dict[str, Any]):
    for k in ["prob", "probability", "p", "class_probability", "class_probabilities"]:
        if k in node:
            return node[k]
    return None


def get_split(node: Dict[str, Any]):
    feature = first_key(node, ["feature", "split_feature", "feature_name", "var"])
    threshold = first_key(node, ["threshold", "split_value", "value"])
    operator = first_key(node, ["operator", "op", "comparison"])
    categories = first_key(node, ["categories", "cat_values", "values"])
    missing_left = first_key(node, ["missing_to_left", "default_left", "missing_left"])
    return feature, threshold, operator, categories, missing_left


def build_conditions(node: Dict[str, Any]) -> Tuple[str, str]:
    feature, threshold, operator, categories, missing_left = get_split(node)

    if feature is None:
        return "UNKNOWN_SPLIT", "UNKNOWN_SPLIT"

    # Categorical split
    if categories is not None:
        cats = categories if isinstance(categories, list) else [categories]
        left_cond = f"{feature} in {cats}"
        right_cond = f"{feature} not in {cats}"
    else:
        # Numeric split
        if operator is None:
            operator = "<="
        if threshold is None:
            threshold = "?"
        left_cond = f"{feature} {operator} {threshold}"
        right_cond = f"{feature} not({operator} {threshold})"

    if missing_left is True:
        left_cond = f"{left_cond} or missing"
    elif missing_left is False:
        right_cond = f"{right_cond} or missing"

    return left_cond, right_cond


def walk_paths(node: Dict[str, Any], conditions: List[str]):
    if node is None:
        return []

    if is_leaf(node):
        pred = get_prediction(node)
        prob = get_prob(node)
        return [(conditions, pred, prob)]

    left, right = get_children(node)
    left_cond, right_cond = build_conditions(node)

    paths = []
    if left is not None:
        paths += walk_paths(left, conditions + [left_cond])
    if right is not None:
        paths += walk_paths(right, conditions + [right_cond])
    return paths


def format_path(cluster_id: str, conds: List[str], pred, prob) -> str:
    cond_text = " and ".join(conds) if conds else "(all)"
    if isinstance(prob, dict):
        # if class prob dict, show max
        max_class = max(prob, key=prob.get)
        max_p = prob[max_class]
        prob_text = f" (p={max_p:.1%})"
    elif isinstance(prob, (float, int)):
        prob_text = f" (p={prob:.1%})"
    else:
        prob_text = ""
    return f"X1 cluster {cluster_id} patient -> when {cond_text}, then predict {pred}{prob_text}"


# --------- main: load all jsons and print paths ----------
model_dir = "./iai_trees_depth5"

for json_path in glob.glob(os.path.join(model_dir, "iai_x1cluster_*_depth5.json")):
    data = load_json(json_path)
    root = get_tree_root(data)

    cluster_match = re.search(r"x1cluster_(\d+)_depth", os.path.basename(json_path))
    cluster_id = cluster_match.group(1) if cluster_match else "?"

    paths = walk_paths(root, [])
    print(f"\n=== {json_path} ===")
    if not paths:
        print("No paths found. Top-level keys:", list(data.keys()))
        continue
    for conds, pred, prob in paths:
        print(format_path(cluster_id, conds, pred, prob))


=== ./iai_trees_depth5/iai_x1cluster_4_depth5.json ===
X1 cluster 4 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_1_depth5.json ===
X1 cluster 1 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_7_depth5.json ===
X1 cluster 7 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_0_depth5.json ===
X1 cluster 0 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_5_depth5.json ===
X1 cluster 5 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_6_depth5.json ===
X1 cluster 6 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_3_depth5.json ===
X1 cluster 3 patient -> when (all), then predict None


In [41]:
import json
import pprint

path = "./iai_trees_depth5/iai_x1cluster_4_depth5.json"
with open(path, "r") as f:
    data = json.load(f)

print("Top-level keys:", list(data.keys()))

# show 1 level down for any dict values
for k, v in data.items():
    if isinstance(v, dict):
        print(f"{k} keys:", list(v.keys()))
    elif isinstance(v, list):
        print(f"{k} list length:", len(v))

split_features keys: ['All']
hyperplane_config list length: 0
tree_ keys: ['node_count', 'nodes', 'capacity']
all_trees_ list length: 100
regression_features list length: 0
prb_ keys: ['data', 'baseline']


In [43]:
import os
import re
import json
import glob
from typing import Any, Dict, List, Tuple, Optional


def node_get(node: Dict[str, Any], keys: List[str], default=None):
    for k in keys:
        if k in node:
            return node[k]
    return default


def build_feature_names(data: Dict[str, Any]) -> List[str]:
    feats = data.get("split_features", {}).get("All", [])
    return feats if isinstance(feats, list) else []


def get_nodes(data: Dict[str, Any]) -> List[Dict[str, Any]]:
    tree = data.get("tree_", {})
    nodes = tree.get("nodes", [])
    return nodes


def get_root_idx(nodes: List[Dict[str, Any]]) -> int:
    # Try to find a root by parent == -1 / None
    for i, n in enumerate(nodes):
        parent = node_get(n, ["parent", "parent_id"], None)
        if parent in (-1, None):
            return i
    return 0


def get_children(node: Dict[str, Any]) -> Tuple[Optional[int], Optional[int]]:
    # Common child index keys
    left = node_get(node, ["left", "left_child", "left_child_id", "l"], None)
    right = node_get(node, ["right", "right_child", "right_child_id", "r"], None)

    # Some schemas use "children": [l, r]
    if left is None and right is None and "children" in node:
        children = node["children"]
        if isinstance(children, list) and len(children) >= 2:
            left, right = children[0], children[1]

    return left, right


def is_leaf(node: Dict[str, Any]) -> bool:
    if node_get(node, ["is_leaf", "leaf"], False):
        return True
    left, right = get_children(node)
    return left is None and right is None


def get_prediction(node: Dict[str, Any]):
    return node_get(node, ["prediction", "predicted_class", "class", "value"], None)


def get_prob(node: Dict[str, Any]):
    # may be a list or dict
    prob = node_get(node, ["prob", "probability", "class_probabilities", "probs"], None)
    return prob


def split_condition(node: Dict[str, Any], feature_names: List[str]) -> Tuple[str, str]:
    feat_idx = node_get(node, ["feature", "split_feature", "feature_index"], None)
    feat_name = None
    if isinstance(feat_idx, int) and 0 <= feat_idx < len(feature_names):
        feat_name = feature_names[feat_idx]
    elif isinstance(feat_idx, str):
        feat_name = feat_idx
    else:
        feat_name = "UNKNOWN_FEATURE"

    threshold = node_get(node, ["threshold", "split_value", "value"], None)
    operator = node_get(node, ["operator", "op", "comparison"], "<=")
    categories = node_get(node, ["categories", "cat_values", "values"], None)
    missing_left = node_get(node, ["missing_to_left", "default_left", "missing_left"], None)

    if categories is not None:
        cats = categories if isinstance(categories, list) else [categories]
        left_cond = f"{feat_name} in {cats}"
        right_cond = f"{feat_name} not in {cats}"
    else:
        left_cond = f"{feat_name} {operator} {threshold}"
        right_cond = f"{feat_name} not({operator} {threshold})"

    if missing_left is True:
        left_cond += " or missing"
    elif missing_left is False:
        right_cond += " or missing"

    return left_cond, right_cond


def walk_paths(nodes: List[Dict[str, Any]], idx: int, feature_names: List[str], conds: List[str]):
    node = nodes[idx]
    if is_leaf(node):
        return [(conds, get_prediction(node), get_prob(node))]

    left, right = get_children(node)
    left_cond, right_cond = split_condition(node, feature_names)

    paths = []
    if left is not None:
        paths += walk_paths(nodes, left, feature_names, conds + [left_cond])
    if right is not None:
        paths += walk_paths(nodes, right, feature_names, conds + [right_cond])
    return paths


def format_path(cluster_id: str, conds: List[str], pred, prob) -> str:
    cond_text = " and ".join(conds) if conds else "(all)"
    prob_text = ""
    if isinstance(prob, dict):
        max_class = max(prob, key=prob.get)
        prob_text = f" (p={prob[max_class]:.1%})"
    elif isinstance(prob, (list, tuple)) and len(prob) > 0:
        max_p = max(prob)
        prob_text = f" (p={max_p:.1%})"
    elif isinstance(prob, (float, int)):
        prob_text = f" (p={prob:.1%})"
    return f"X1 cluster {cluster_id} patient -> when {cond_text}, then predict {pred}{prob_text}"


model_dir = "./iai_trees_depth5"

for json_path in glob.glob(os.path.join(model_dir, "iai_x1cluster_*_depth5.json")):
    with open(json_path, "r") as f:
        data = json.load(f)

    feature_names = build_feature_names(data)
    nodes = get_nodes(data)
    if not nodes:
        print(f"\n=== {json_path} ===")
        print("No nodes found.")
        continue

    root_idx = get_root_idx(nodes)
    cluster_match = re.search(r"x1cluster_(\d+)_depth", os.path.basename(json_path))
    cluster_id = cluster_match.group(1) if cluster_match else "?"

    paths = walk_paths(nodes, root_idx, feature_names, [])
    print(f"\n=== {json_path} ===")
    for conds, pred, prob in paths:
        print(format_path(cluster_id, conds, pred, prob))


=== ./iai_trees_depth5/iai_x1cluster_4_depth5.json ===
X1 cluster 4 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_1_depth5.json ===
X1 cluster 1 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_7_depth5.json ===
X1 cluster 7 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_0_depth5.json ===
X1 cluster 0 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_5_depth5.json ===
X1 cluster 5 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_6_depth5.json ===
X1 cluster 6 patient -> when (all), then predict None

=== ./iai_trees_depth5/iai_x1cluster_3_depth5.json ===
X1 cluster 3 patient -> when (all), then predict None
