In [1]:
#DATA PREPROCESSING

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import umap
import hdbscan
import gc
import torch
from transformers import LEDTokenizer, LEDModel
from bertopic import BERTopic

# Global config
PATH_X1 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/data_Hsp12.feather"
PATH_X2 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvd_metric.csv"
PATH_X3 = "/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvd_metric_raw.csv"
PATH_METRIC_DEF = "/Users/jingyi/Desktop/Trauma_LLM/metric_def.xlsx"

EMB_MODEL_NAME = "allenai/led-base-16384"
MAX_TOKENS = 16000
N_COMPONENTS_CLUST = 10   # dimension need to be changed
N_NEIGHBORS = 15

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print("Device:", DEVICE)

Device: mps


In [2]:
# Read data
X2 = pd.read_csv(PATH_X2)
X3 = pd.read_csv(PATH_X3)

# Compare column names
x2_cols = set(X2.columns)
x3_cols = set(X3.columns)

only_in_X2 = sorted(x2_cols - x3_cols)
only_in_X3 = sorted(x3_cols - x2_cols)
common_cols = sorted(x2_cols & x3_cols)

print("Columns only in X2:")
print(only_in_X2 if only_in_X2 else "None")

print("\nColumns only in X3:")
print(only_in_X3 if only_in_X3 else "None")

#print("\nCommon columns:")
#print(common_cols if common_cols else "None")


Columns only in X2:
['QI']

Columns only in X3:
['mtrc102', 'mtrc106', 'mtrc14', 'mtrc166', 'mtrc168', 'mtrc233', 'mtrc242', 'mtrc242001', 'mtrc25', 'mtrc29', 'mtrc80', 'mtrc85', 'mtrc96', 'mtrcNA']


In [3]:
cols_only_in_x3 = [
    "mtrc102", "mtrc106", "mtrc14", "mtrc166", "mtrc168", "mtrc233",
    "mtrc242", "mtrc242001", "mtrc25", "mtrc29", "mtrc80", "mtrc85",
    "mtrc96", "mtrcNA"
]

X3 = X3.drop(columns=cols_only_in_x3, errors="ignore")
X2 = X3


In [4]:
X1 = pd.read_feather(PATH_X1)
print("Original X1 shape:", X1.shape)

cols_to_drop = [
    'mpp_121', 'mpp_125', 'mpp_16', 'mpp_168', 'mpp_236', 'mpp_242', 'mpp_71',
    'proc_01_icd', 'proc_02_icd', 'proc_03_icd', 'proc_04_icd', 'proc_05_icd',
    'proc_06_icd', 'proc_07_icd', 'proc_08_icd', 'proc_09_icd', 'proc_10_icd',
    'proc_11_icd', 'proc_12_icd', 'proc_13_icd', 'proc_14_icd', 'proc_15_icd',
    'proc_16_icd', 'proc_17_icd', 'proc_18_icd', 'proc_19_icd', 'proc_20_icd',
    'proc_21_icd', 'proc_22_icd', 'proc_23_icd', 'proc_24_icd', 'proc_25_icd',
    'proc_26_icd', 'proc_27_icd', 'proc_28_icd', 'proc_29_icd', 'proc_30_icd',
    'proc_31_icd', 'proc_32_icd', 'proc_33_icd', 'proc_34_icd', 'proc_35_icd',
    'proc_36_icd', 'proc_37_icd', 'proc_38_icd', 'proc_39_icd', 'proc_40_icd',
    'proc_41_icd', 'proc_42_icd', 'proc_43_icd', 'proc_44_icd', 'proc_45_icd',
    'proc_46_icd', 'proc_47_icd', 'proc_48_icd', 'proc_49_icd', 'proc_50_icd',
    'proc_51_icd', 'proc_52_icd', 'proc_53_icd', 'proc_54_icd', 'proc_55_icd',
    'proc_56_icd', 'proc_57_icd', 'proc_58_icd', 'proc_59_icd', 'proc_60_icd',
    'proc_61_icd', 'proc_62_icd', 'proc_63_icd', 'proc_64_icd', 'proc_65_icd',
    'proc_66_icd', 'proc_67_icd', 'proc_68_icd', 'proc_69_icd', 'proc_70_icd',
    'proc_71_icd', 'proc_72_icd', 'proc_73_icd', 'proc_74_icd', 'proc_75_icd',
    'proc_76_icd', 'proc_77_icd', 'proc_78_icd', 'proc_79_icd', 'proc_80_icd',
    'proc_81_icd', 'proc_82_icd', 'proc_83_icd', 'proc_84_icd', 'fac_key', 'disp_tx', 'scene_tx', 'leave_tx',
    'gcs40eye_s',  'gcs40ver_s', 'gcs40mot_s',  'gcs40eye_r', 'gcs40ver_r','gcs40mot_r'
]
cols_to_drop += [f"ais_sev_{i:02d}" for i in range(1, 28)]
cols_to_drop += [f"icd9_{i:02d}" for i in range(1, 28)]

# predot_01 ... predot_27
cols_to_drop += [f"predot_{i:02d}" for i in range(1, 28)]

# proc_01 ... proc_84
cols_to_drop += [f"proc_{i:02d}" for i in range(1, 85)]

# ais_01 ... ais_27
cols_to_drop += [f"ais_{i:02d}" for i in range(1, 28)]

X1 = X1.drop(columns=cols_to_drop, errors="ignore")
print("X1 shape after dropping cols:", X1.shape)

# 1b. Read X2 csv
#X2 = pd.read_csv(PATH_X2)
#X2 = X2.drop(columns="QI", errors="ignore")
# Align X2 rows so that inc_key order matches X1
if "inc_key" not in X1.columns or "inc_key" not in X2.columns:
    raise ValueError("Both X1 and X2 must contain 'inc_key' column.")

X2 = X2.set_index("inc_key").reindex(X1["inc_key"]).reset_index()
print("Aligned X2 shape:", X2.shape)

# Assumes X1 and X2 are already loaded and aligned as in your existing code.
X2_features_only = X2.drop(columns=["inc_key"])
X4 = pd.concat([X1, X2_features_only], axis=1)
print("X4 shape:", X4.shape)

Original X1 shape: (103362, 1571)
X1 shape after dropping cols: (103362, 1278)
Aligned X2 shape: (103362, 116)
X4 shape: (103362, 1393)


In [5]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize

# Load your data (Your existing code)
X1_emb = np.load("/Users/jingyi/Desktop/Trauma_LLM/all_patient/originaldata_LED_emb_103362.npy")
X2_emb = np.load("/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvdmetric_LED_emb.npy")

X1_inc_key = pd.read_csv("/Users/jingyi/Desktop/Trauma_LLM/all_patient/originaldata_inc_key.csv")
X2_inc_key = pd.read_csv("/Users/jingyi/Desktop/Trauma_LLM/all_patient/indvdmetric_inc_key.csv")

print("Original shapes:")
print(X1_emb.shape, X1_inc_key.shape)
print(X2_emb.shape, X2_inc_key.shape)

# --- NORMALIZATION STEP ---

# L2 Normalization maps all vectors to the unit sphere
# This makes Euclidean distance equivalent to Cosine distance
print("\nNormalizing embeddings...")
X1_emb = normalize(X1_emb, norm='l2')
X2_emb = normalize(X2_emb, norm='l2')

print("Normalization complete.")

Original shapes:
(103362, 768) (103362, 1)
(103362, 768) (103362, 1)

Normalizing embeddings...
Normalization complete.


In [6]:
import numpy as np

data = np.load("X1_umap_hdbscan_outputs.npz")

X1_umap_d = data["X1_umap_d"]
X1_labels = data["X1_labels"]

print("Loaded X1_umap_d shape:", X1_umap_d.shape)
print("Loaded X1_labels shape:", X1_labels.shape)
print("Unique clusters (excluding -1):", set(X1_labels) - {-1})


Loaded X1_umap_d shape: (103362, 10)
Loaded X1_labels shape: (103362,)
Unique clusters (excluding -1): {np.int32(0), np.int32(1), np.int32(2), np.int32(3), np.int32(4), np.int32(5), np.int32(6), np.int32(7)}


In [7]:
# Divide metric data into cat and num
import re
import pandas as pd
metric_def = pd.read_excel(PATH_METRIC_DEF)

metric_def["Metric_ID"] = metric_def["Metric_ID"].astype(str).str.strip()
metric_def["Variable_type"] = metric_def["Variable_type"].astype(str).str.strip().str.lower()
metric_def["Description"] = metric_def["Description"].astype(str).str.strip()

type_by_id = dict(zip(metric_def["Metric_ID"], metric_def["Variable_type"]))
desc_by_id = dict(zip(metric_def["Metric_ID"], metric_def["Description"]))

num_cols = []
cat_cols = []
rename_map = {}

for col in X2.columns:
    m = re.search(r"mtrc(\d+)", col, flags=re.IGNORECASE)
    if not m:
        continue

    metric_id = m.group(1)
    var_type = type_by_id.get(metric_id, "")

    desc = desc_by_id.get(metric_id, "")
    new_col = f"{col}:{desc}" if desc else col
    rename_map[col] = new_col

    if var_type == "numeric":
        num_cols.append(new_col)
    elif var_type in {"binary", "count"}:
        cat_cols.append(new_col)

X2 = X2.rename(columns=rename_map)

In [8]:
  cell = """import os
  from pathlib import Path

  # Clean env that can break PyCall
  os.environ.pop("PYTHONHOME", None)
  os.environ.pop("PYTHONPATH", None)

  # Tell interpretableai where Julia + sysimage are
  os.environ["IAI_JULIA"] = "/Users/jingyi/Library/Application Support/InterpretableAI/julia/1.12.2/julia-1.12.2/bin/julia"
  os.environ["IAI_SYSTEM_IMAGE"] = "/Users/jingyi/Library/Application Support/InterpretableAI/sysimage/v3.2.2/sys.dylib"

  # Important: let IAI start Julia, but disable compiled modules
  os.environ["IAI_DISABLE_COMPILED_MODULES"] = "1"

  # License
  os.environ["IAI_LICENSE_FILE"] = "/Users/jingyi/Desktop/Trauma_LLM/iai.lic"

  # Now import
  from interpretableai import iai
  print("IAI import OK")
  """
  from pathlib import Path
  Path("iai_import_cell.py").write_text(cell)
  print("Saved: iai_import_cell.py  (run in notebook with: %run iai_import_cell.py)")



Saved: iai_import_cell.py  (run in notebook with: %run iai_import_cell.py)


In [9]:
%run iai_import_cell.py



IAI import OK


In [13]:
  # Toy model
  import numpy as np
  import pandas as pd
  from interpretableai import iai

  # Synthetic binary classification dataset
  rng = np.random.default_rng(1)
  X = pd.DataFrame(
      rng.normal(size=(300, 4)),
      columns=["x1", "x2", "x3", "x4"],
  )
  y = ((X["x1"] + 0.5 * X["x2"] - 0.2 * X["x3"]) > 0).astype(int)

  (train_X, train_y), (test_X, test_y) = iai.split_data(
      "classification", X, y, seed=1
  )

  grid = iai.GridSearch(
      iai.OptimalTreeClassifier(random_seed=1),
      max_depth=range(1, 4),
  )
  grid.fit(train_X, train_y)

  print(grid.get_learner())
  print("AUC:", grid.score(test_X, test_y, criterion="auc"))




Fitted OptimalTreeClassifier:
  1) Split: x2 < 0.2299
    2) Split: x1 < 0.39
      3) Predict: 0 (96.47%), [82,3], 85 points, error 0.03529
      4) Predict: 1 (88.89%), [4,32], 36 points, error 0.1111
    5) Split: x1 < -0.4499
      6) Predict: 0 (92.59%), [25,2], 27 points, error 0.07407
      7) Predict: 1 (93.55%), [4,58], 62 points, error 0.06452
AUC: 0.9216027874564459


In [8]:
  # check the work
  # show first 10 rows
  display(X2.head(10))

Unnamed: 0,inc_key,"mtrc3:Time to first medical contact, min","mtrc4:Prehospital time, min","mtrc46:ICU length of stay, day","mtrc47:Length of stay, day","mtrc17:Time to cranial CT for patients with GCS < 14, min","mtrc28:Time to first emergent surgery, min","mtrc31:Time to surgery for patients in shock, min","mtrc172:Time to tracheostomy in SCI patients, min","mtrc255:Time to tracheostomy, min",...,mtrc71:MRI - 2 hours,mtrc250:EVD placement,mtrc16:Antibiotics for open fractures,mtrc16001:Antibiotics for open fractures within 24 hours,mtrc23:Activation of massive transfusion protovocl,"mtrc67:Convential radiology - in 15 min, level I/II; in 30 min, level III/IV","mtrc68:CT - in 15 min, level I/II; in 30 min, level III/IV",mtrc177:Percentage of severe TBI with other injury,mtrc187:Transfer rate of children with severe TBI,mtrc65:Orthopedic non-emergent availability
0,1,,,2.0,4.0,,4.0,,,,...,,,,,,0.0,0.0,0,,0
1,2,7.0,,,,,44.0,,,,...,,0.0,,,,0.0,0.0,1,,0
2,111,11.0,,,,,107.0,,,,...,,,,,,,,0,,0
3,112,11.0,,,,,81.0,,,,...,,,,,,,,0,,0
4,113,11.0,,3.0,4.0,,26.0,,,,...,,,,,,,,0,,0
5,114,,,,,,,,,,...,,,,,,,,0,,0
6,115,,,2.0,2.0,,188.0,,,,...,,0.0,,,,,,1,,0
7,116,13.0,,,,,133.0,,,,...,,,,,,,,0,,0
8,117,7.0,,1.0,5.0,,28.0,28.0,,,...,,0.0,,,,,,1,,0
9,118,,,2.0,2.0,,74.0,74.0,,,...,,,,,,,,0,,0


In [9]:
labels = np.asarray(X1_labels)

# Remove noise (-1) from cluster set
unique_clusters = sorted(set(labels) - {-1})
print("Number of clusters (excluding -1):", len(unique_clusters))
print("Cluster IDs (first 20):", unique_clusters[:20])

Number of clusters (excluding -1): 8
Cluster IDs (first 20): [np.int32(0), np.int32(1), np.int32(2), np.int32(3), np.int32(4), np.int32(5), np.int32(6), np.int32(7)]


In [10]:
  # Decision Tree for X1
  #--------------------------
  # 0) Inputs
  # --------------------------
  MISSING_TOKENS = {
      "", " ", "  ", "\t", "\n", "\r",
      "na", "n/a", "nan", "null", "none", "nil",
      ".", "..", "...",
      "<unk>", "unk", "unknown", "missing", "nan", "na"
  }

  # X1, cat_cols, num_cols already defined
  # X1_labels loaded from npz

  # --------------------------
  # 1) Drop inc_key + align columns
  # --------------------------
  X = X2.drop(columns=["inc_key"], errors="ignore").copy()

  cat_cols = [c for c in cat_cols if c in X.columns]
  num_cols = [c for c in num_cols if c in X.columns]

  # --------------------------
  # 2) Encode missingness + impute
  # --------------------------
  missing_indicator_cols = []

  for col in cat_cols + num_cols:
      s = X[col]

      # detect missing tokens (string-aware)
      s_str = s.astype(str).str.strip().str.lower()
      miss = s.isna() | s_str.isin(MISSING_TOKENS)

      # add indicator
      #miss_col = f"{col}__missing"
      #X[miss_col] = miss.astype(int)
      #missing_indicator_cols.append(miss_col)

      # impute
      if col in cat_cols:
          X.loc[miss, col] = "MISSING"
          X[col] = X[col].astype(str)
      else:
          X[col] = pd.to_numeric(s, errors="coerce")
          miss2 = miss | X[col].isna()
          X.loc[miss2, col] = 0

  # update numeric columns to include missing indicators
  #num_cols = num_cols + missing_indicator_cols

In [11]:
  # --------------------------
  # 3) Prepare labels (drop noise cluster -1 if present)
  # --------------------------
  y = X1_labels
  mask = y != -1
  X = X.loc[mask].reset_index(drop=True)
  y = y[mask]

In [12]:

  # --- fix feature types before training ---
  cat_cols = [c for c in cat_cols if c in X.columns]
  num_cols = [c for c in num_cols if c in X.columns]

  # remove any overlap
  overlap = set(cat_cols) & set(num_cols)
  if overlap:
      print("Removing from num_cols (categorical):", overlap)
      num_cols = [c for c in num_cols if c not in overlap]

  # force categorical dtype
  for c in cat_cols:
      X[c] = X[c].astype("category")

  # force numeric dtype
  for c in num_cols:
      X[c] = pd.to_numeric(X[c], errors="coerce")

In [15]:
  # --------------------------
  # 4) Train Optimal  (with raw feature dataset)
  # --------------------------
  (train_X, train_y), (test_X, test_y) = iai.split_data(
      "classification", X, y, seed=1
  )

  grid = iai.GridSearch(
      iai.OptimalTreeClassifier(random_seed=1),
      max_depth=7,
  )
  grid.fit(train_X, train_y)

  learner = grid.get_learner()
  print(learner)
  print("Test accuracy:", grid.score(test_X, test_y))

[33m[1m│ [22m[39m- mtrc81:Craniofacial expertise
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mWe recommend extreme caution when using categoric features with many levels inside Optimal Trees, for more information and advice on how to handle such features, please refer to this link:
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mhttps://docs.interpretable.ai/dev/OptimalTrees/tips/#Categorical-Variables-with-Many-Levels
[33m[1m│ [22m[39m
[33m[1m│ [22m[39m- mtrc81:Craniofacial expertise
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mWe recommend extreme caution when using categoric features with many levels inside Optimal Trees, for more information and advice on how to handle such features, please refer to this link:
[33m[1m│ [22m[39m
[33m[1m│ [22m[39mhttps://docs.interpretable.ai/dev/OptimalTrees/tips/#Categorical-Variables-with-Many-Levels
[33m[1m│ [22m[39m


Fitted OptimalTreeClassifier:
  1) Split: mtrc58:Blunt trauma mortality - multisystem in [0.0,1.0] or is missing
    2) Split: mtrc16:Antibiotics for open fractures in [0.0,1.0] or is missing
      3) Split: mtrc46:ICU length of stay, day < 1.5
        4) Predict: 0 (49.64%), [349,189,70,0,1,27,1,66], 703 points, error 0.5036
        5) Predict: 2 (62.06%), [68,131,337,0,0,5,0,2], 543 points, error 0.3794
      6) Split: mtrc3:Time to first medical contact, min < 1
        7) Split: mtrc21:ED stay < 1 hour for patients with GCS < 9 or intubated is MISSING
          8) Split: mtrc47:Length of stay, day < 7.5
            9) Split: mtrc138:Frequency of BCVI screeening is 1 or is missing
              10) Predict: 0 (52.45%), [107,18,1,0,0,65,2,11], 204 points, error 0.4755
              11) Split: mtrc4:Prehospital time, min < 8
                12) Predict: 5 (59.55%), [1483,131,5,0,187,3475,164,390], 5835 points, error 0.4045
                13) Predict: 0 (62.00%), [31,8,2,0,0,0,2,7], 5

In [13]:
import re

def _lgb_safe_name(name: str) -> str:
    # LightGBM forbids JSON special chars in feature names
    return re.sub(r'[\\\"{}\[\]:,]', "_", name)

def _dedupe(names):
    counts = {}
    out = []
    for n in names:
        k = counts.get(n, 0)
        out.append(n if k == 0 else f"{n}__{k}")
        counts[n] = k + 1
    return out

# sanitize + dedupe column names
safe_names = [_lgb_safe_name(c) for c in X.columns]
safe_names = _dedupe(safe_names)
rename_map = dict(zip(X.columns, safe_names))

X_lgb = X.rename(columns=rename_map)

# update cat_cols / num_cols to match sanitized names
cat_cols_lgb = [rename_map[c] for c in cat_cols if c in rename_map]
num_cols_lgb = [rename_map[c] for c in num_cols if c in rename_map]


In [14]:
from sklearn.model_selection import train_test_split
import lightgbm as lgb
import numpy as np

train_X, test_X, train_y, test_y = train_test_split(
    X_lgb, y, test_size=0.2, random_state=1, stratify=y
)

num_classes = len(np.unique(train_y))
is_multiclass = num_classes > 2

lgbm = lgb.LGBMClassifier(
    objective="multiclass" if is_multiclass else "binary",
    num_class=num_classes if is_multiclass else None,
    n_estimators=500,
    learning_rate=0.05,
    num_leaves=63,
    subsample=0.9,
    colsample_bytree=0.9,
    random_state=1,
)

lgbm.fit(train_X, train_y, categorical_feature=cat_cols_lgb if cat_cols_lgb else "auto")

[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.012143 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 2257
[LightGBM] [Info] Number of data points in the train set: 82689, number of used features: 114
[LightGBM] [Info] Start training from score -1.561725
[LightGBM] [Info] Start training from score -3.133042
[LightGBM] [Info] Start training from score -3.742142
[LightGBM] [Info] Start training from score -4.850496
[LightGBM] [Info] Start training from score -2.522276
[LightGBM] [Info] Start training from score -1.317430
[LightGBM] [Info] Start training from score -3.149266
[LightGBM] [Info] Start training from score -1.126610


0,1,2
,boosting_type,'gbdt'
,num_leaves,63
,max_depth,-1
,learning_rate,0.05
,n_estimators,500
,subsample_for_bin,200000
,objective,'multiclass'
,class_weight,
,min_split_gain,0.0
,min_child_weight,0.001


In [15]:
from sklearn.metrics import accuracy_score

pred = lgbm.predict(test_X)
test_acc = accuracy_score(test_y, pred)
print("Test accuracy:", test_acc)

Test accuracy: 0.6569438397910318


In [16]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# predictions
pred = lgbm.predict(test_X)

# micro accuracy = overall accuracy
micro_accuracy = accuracy_score(test_y, pred)

# per-class precision/recall/F1
labels = np.unique(test_y)
prec, rec, f1, support = precision_recall_fscore_support(
    test_y, pred, labels=labels, average=None
)

# macro accuracy = macro-averaged recall
macro_accuracy = rec.mean()

# pack per-class F1
f1_per_class = dict(zip(labels, f1))

print("micro_accuracy:", micro_accuracy)
print("macro_accuracy:", macro_accuracy)
print("f1_per_class:", f1_per_class)

micro_accuracy: 0.6569438397910318
macro_accuracy: 0.4628230219828732
f1_per_class: {np.int32(0): np.float64(0.5317235431809781), np.int32(1): np.float64(0.3202797202797203), np.int32(2): np.float64(0.5892857142857143), np.int32(3): np.float64(0.2943396226415094), np.int32(4): np.float64(0.4116548375046694), np.int32(5): np.float64(0.7396007923205851), np.int32(6): np.float64(0.25301204819277107), np.int32(7): np.float64(0.7809551671176772)}


In [30]:
  import json

  test_acc = grid.score(test_X, test_y)

  # save learner
  learner.write_json("learner.json")

  # (optional) save grid search too
  grid.write_json("grid.json")

  # save metrics
  with open("metrics.json", "w") as f:
      json.dump({"test_accuracy": float(test_acc)}, f, indent=2)

  print("Saved: learner.json, grid.json (optional), metrics.json")

Saved: learner.json, grid.json (optional), metrics.json


In [16]:
  from IPython.display import IFrame, display

  html_path = "optimal_tree.html"
  learner.write_html(html_path)

  display(IFrame(src=html_path, width=1100, height=800))

In [33]:
  #reload
  import json
  from interpretableai.iaibase import read_json

  learner = read_json("learner.json")

  with open("metrics.json", "r") as f:
      metrics = json.load(f)

  print("Reloaded test_accuracy:", metrics["test_accuracy"])


Reloaded test_accuracy: 0.5818633300009675


In [18]:
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    precision_recall_fscore_support,
)

y_pred = learner.predict(test_X)

# Per-class precision/recall/F1
labels = np.unique(test_y)
prec, rec, f1, support = precision_recall_fscore_support(
    test_y, y_pred, labels=labels, average=None
)

f1_per_class = dict(zip(labels, f1))
recall_per_class = dict(zip(labels, rec))
precision_per_class = dict(zip(labels, prec))

# Micro accuracy = overall accuracy
micro_accuracy = accuracy_score(test_y, y_pred)

# Macro accuracy = macro-averaged recall
macro_accuracy = np.mean(rec)

metrics = {
    "micro_accuracy": micro_accuracy,
    "macro_accuracy": macro_accuracy,
    "f1_per_class": f1_per_class,
    "recall_per_class": recall_per_class,
    "precision_per_class": precision_per_class,
}

print(metrics)

{'micro_accuracy': 0.6327195330387951, 'macro_accuracy': np.float64(0.35721654826217025), 'f1_per_class': {np.int32(0): np.float64(0.4685260340212417), np.int32(1): np.float64(0.11104020421186982), np.int32(2): np.float64(0.502491103202847), np.int32(3): np.float64(0.0), np.int32(4): np.float64(0.21052631578947367), np.int32(5): np.float64(0.7436428705970992), np.int32(6): np.float64(0.0), np.int32(7): np.float64(0.7600349441353625)}, 'recall_per_class': {np.int32(0): np.float64(0.41706379707916985), np.int32(1): np.float64(0.06439674315321983), np.int32(2): np.float64(0.48027210884353744), np.int32(3): np.float64(0.0), np.int32(4): np.float64(0.12294094013660105), np.int32(5): np.float64(0.9507525586995785), np.int32(6): np.float64(0.0), np.int32(7): np.float64(0.8223062381852552)}, 'precision_per_class': {np.int32(0): np.float64(0.5344759653270291), np.int32(1): np.float64(0.4027777777777778), np.int32(2): np.float64(0.5268656716417911), np.int32(3): np.float64(0.0), np.int32(4): np.

In [32]:
  html_path = "/Users/jingyi/Desktop/Trauma_LLM/iai_X1/optimal_tree_metric7.html"
  learner.write_html(html_path)
  print("Saved to:", html_path)

Saved to: /Users/jingyi/Desktop/Trauma_LLM/iai_X1/optimal_tree_metric7.html


In [None]:
## Use metric data to run iai again
