In [17]:
import numpy as np
import pandas as pd

# Map vocab -> index
icd_index = {code: i for i, code in enumerate(icd_vocab)}
lab_vocab_items = [int(x) for x in lab_vocab_items]  # đảm bảo int
lab_index = {it: i for i, it in enumerate(lab_vocab_items)}

def _to_list_safe(x):
    """Ép mọi kiểu (list/tuple/np.ndarray/scalar/None) thành list an toàn."""
    if x is None:
        return []
    if isinstance(x, list):
        return x
    if isinstance(x, tuple):
        return list(x)
    if isinstance(x, np.ndarray):
        return x.tolist()
    # scalar/string -> 1 phần tử? Ở đây coi như rỗng để tránh sai
    return []

def true_icd_indices_from_row(row_dict):
    # ưu tiên y_icd nếu có
    y = row_dict.get("y_icd", None)
    if y is not None:
        arr = np.asarray(y).astype(float).ravel()
        return np.where(arr >= 0.5)[0]
    # fallback từ icd_blocks
    blocks = _to_list_safe(row_dict.get("icd_blocks", []))
    idxs = [icd_index[b] for b in blocks if b in icd_index]
    return np.array(sorted(set(idxs)), dtype=int)

def true_lab_indices_from_row(row_dict):
    # ưu tiên y_lab nếu có
    y = row_dict.get("y_lab", None)
    if y is not None:
        arr = np.asarray(y).astype(float).ravel()
        return np.where(arr >= 0.5)[0]
    # fallback từ lab_items
    items = _to_list_safe(row_dict.get("lab_items", []))
    # itemid có thể là str -> ép int
    idxs = []
    for it in items:
        try:
            it_int = int(it)
            if it_int in lab_index:
                idxs.append(lab_index[it_int])
        except Exception:
            pass
    return np.array(sorted(set(idxs)), dtype=int)

def topk(arr, k=5):
    idx = np.argsort(arr)[-k:][::-1]
    return idx, arr[idx]

rows = []
for i in range(len(sampled)):
    # preds
    icd_idx_top, icd_p_top = topk(probs_icd[i], TOP_K)
    lab_idx_top, lab_p_top = topk(probs_lab[i], TOP_K)

    icd_thr_idx = np.where(probs_icd[i] >= THR_ICD)[0]
    lab_thr_idx = np.where(probs_lab[i] >= THR_LAB)[0]

    # ground truth
    row_dict = sampled.iloc[i].to_dict()
    icd_true_idx = true_icd_indices_from_row(row_dict)
    lab_true_idx = true_lab_indices_from_row(row_dict)

    # map tên
    icd_top5 = [(icd_vocab[j], float(icd_p_top[t])) for t, j in enumerate(icd_idx_top)]
    icd_thr  = [(icd_vocab[j], float(probs_icd[i][j])) for j in icd_thr_idx]
    icd_true = [icd_vocab[j] for j in icd_true_idx]

    lab_top5 = [
        (int(lab_vocab_items[j]),
         itemid_to_label.get(int(lab_vocab_items[j]), str(lab_vocab_items[j])),
         float(lab_p_top[t]))
        for t, j in enumerate(lab_idx_top)
    ]
    lab_thr  = [
        (int(lab_vocab_items[j]),
         itemid_to_label.get(int(lab_vocab_items[j]), str(lab_vocab_items[j])),
         float(probs_lab[i][j]))
        for j in lab_thr_idx
    ]
    lab_true = [
        (int(lab_vocab_items[j]),
         itemid_to_label.get(int(lab_vocab_items[j]), str(lab_vocab_items[j])))
        for j in lab_true_idx
    ]

    # hits (giao giữa dự đoán top-5 và nhãn thật)
    icd_hits = sorted(set(icd_true_idx).intersection(icd_idx_top.tolist()))
    lab_hits = sorted(set(lab_true_idx).intersection(lab_idx_top.tolist()))

    rows.append({
        "subject_id": int(sampled.iloc[i]["subject_id"]),
        "hadm_id": int(sampled.iloc[i]["hadm_id"]),
        "ICD_true": icd_true,
        "ICD_top5": icd_top5,
        f"ICD_thr>={THR_ICD:.2f}": icd_thr,
        "ICD_hits_in_top5": [icd_vocab[j] for j in icd_hits],
        "LAB_true": lab_true,
        "LAB_top5": lab_top5,
        f"LAB_thr>={THR_LAB:.2f}": lab_thr,
        "LAB_hits_in_top5": [
            (int(lab_vocab_items[j]),
             itemid_to_label.get(int(lab_vocab_items[j]), str(lab_vocab_items[j])))
            for j in lab_hits
        ],
    })

result_df = pd.DataFrame(rows)
pd.set_option("display.max_colwidth", 180)
display(result_df)


Unnamed: 0,subject_id,hadm_id,ICD_true,ICD_top5,ICD_thr>=0.50,ICD_hits_in_top5,LAB_true,LAB_top5,LAB_thr>=0.50,LAB_hits_in_top5
0,16868333,20895093,"[401, 427, 276, 428, 584]","[(276, 0.5127730965614319), (427, 0.44929900765419006), (285, 0.43124574422836304), (401, 0.42651453614234924), (428, 0.4082546830177307)]","[(276, 0.5127730965614319)]","[401, 427, 276, 428]","[(50971, Potassium), (50983, Sodium), (51221, Hematocrit), (50902, Chloride), (50912, Creatinine), (51006, Urea Nitrogen), (50882, Bicarbonate), (50868, Anion Gap), (50931, Glu...","[(50804, Calculated Total CO2, 0.8115944266319275), (50818, pCO2, 0.8100318312644958), (50802, Base Excess, 0.806288480758667), (50821, pO2, 0.798486590385437), (52033, Specime...","[(50971, Potassium, 0.7228435277938843), (50983, Sodium, 0.7052118182182312), (51221, Hematocrit, 0.7289817333221436), (50902, Chloride, 0.7522096037864685), (50912, Creatinine...","[(52033, Specimen Type), (50821, pO2), (50802, Base Excess), (50804, Calculated Total CO2), (50818, pCO2)]"
1,14268926,29880083,"[E87, 530]","[(401, 0.4859779477119446), (272, 0.28625693917274475), (530, 0.25700268149375916), (E87, 0.254139244556427), (V15, 0.2479652762413025)]",[],"[E87, 530]",[],"[(50822, Potassium, Whole Blood, 0.29035520553588867), (50820, pH, 0.27227362990379333), (52033, Specimen Type, 0.2429913431406021), (50808, Free Calcium, 0.2409135401248932), ...",[],[]
2,11684950,22257790,[D64],"[(F41, 0.16395187377929688), (Y92, 0.11136190593242645), (F32, 0.08898287266492844), (Z68, 0.08400328457355499), (E87, 0.08071349561214447)]",[],[],[],"[(51221, Hematocrit, 0.20850510895252228), (51265, Platelet Count, 0.19227702915668488), (51248, MCH, 0.19037726521492004), (51279, Red Blood Cells, 0.18588398396968842), (5125...",[],[]
3,18667552,20224691,"[Z79, Z87, Y92, I48, D64]","[(Z79, 0.3766069710254669), (E78, 0.3704931139945984), (Z87, 0.3187647759914398), (I10, 0.265815407037735), (I48, 0.22200794517993927)]",[],"[Z79, Z87, I48]",[],"[(51222, Hemoglobin, 0.2534600496292114), (52172, RDW-SD, 0.23453214764595032), (51250, MCV, 0.2269824892282486), (50971, Potassium, 0.2202129364013672), (50882, Bicarbonate, 0...",[],[]
4,19267993,27709255,"[E87, Z79, Y92, Z68, Z86, N17]","[(E87, 0.28885138034820557), (Z87, 0.2615728974342346), (Z68, 0.25958457589149475), (F41, 0.22869989275932312), (I10, 0.21826675534248352)]",[],"[E87, Z68]","[(50971, Potassium), (50983, Sodium), (51221, Hematocrit), (50902, Chloride), (50912, Creatinine), (51006, Urea Nitrogen), (50882, Bicarbonate), (50868, Anion Gap), (50931, Glu...","[(50902, Chloride, 0.2290140986442566), (51221, Hematocrit, 0.22898733615875244), (50931, Glucose, 0.22130648791790009), (51277, RDW, 0.2147906869649887), (51265, Platelet Coun...",[],"[(51221, Hematocrit), (50902, Chloride), (50931, Glucose), (51265, Platelet Count), (51277, RDW)]"
