In [None]:
%load_ext autoreload
%autoreload 2

### preprocessing
- drop NaN values
- removed 3rd gender (n=1)
- merge rare classes
- drop classes with less than 25 samples

In [None]:
# relative imports
from uncertainty_aware_diagnosis import(
    ICD10data, 
    SimpleMLP, 
    TemperatureScaling,
    BaseCalibrator,
    PlattCalibrator,
    MulticlassTemperatureScaling,
    TopLabelTemperatureScaling,
)
from torch.utils.data import DataLoader

# absolute imports
import polars as pl
import pickle
import numpy as np
import torch
from torchmetrics import F1Score, Recall
from sklearn.metrics import brier_score_loss
from pycalib.visualisations import plot_reliability_diagram
from pycalib.models.calibrators import LogisticCalibration
from pycalib.metrics import classwise_ECE, conf_ECE

# paths
train_csv = "./data/lbz-train.csv"
val_csv = "./data/lbz-val.csv"
test_csv = "./data/lbz-test.csv"
ohe_pkl = "./data/ohe_cats.pkl"

In [None]:
predict_hoofdgroepen_only = False
use_subset = False
use_single_hospital = False

In [None]:
# load data
if predict_hoofdgroepen_only:
    target = "hoofdgroep"  # simplify multi-class problem by predicting hoofdgroepen instead of specific codes
else:
    target = "icd10_main_code"

categorical = [
    "zorginstellingnaam",
    "gender",
    "clinical_specialty",
    "DBC_specialty_code",
    "DBC_diagnosis_code",
    "icd10_subtraject_code",
]
numerical = ["age"]

# get one-hot encoded features of full dataset (all categories)
# ohe_df = pl.read_csv(ohe_csv).to_pandas()['ohe_cats']
# ohe_cats =[]
# for cat in ohe_df:
#     ohe_cats.append(cat)

with open(ohe_pkl, "rb") as f:
    ohe_cats = pickle.load(f)

train = ICD10data(
    csv_path=train_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
)
val = ICD10data(
    csv_path=val_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)
test = ICD10data(
    csv_path=test_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)

input_dim = train.X.shape[1]
output_dim = train.classes.shape[0]

print(f"Number of icd10 classes: {len(train.classes)}")
print(f"(input_dim: {input_dim}, output_dim: {output_dim})")

In [None]:
# variables dataloader
batch_size = 32
shuffle = True

if use_subset:
    train.X = train.X[:6400]
    val.X = val.X[:6400]
    train.y = train.y[:6400]
    val.y = val.y[:6400]

train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle)

### select single hospital

In [None]:
if use_single_hospital:
    # load data
    target = "icd10_main_code"
    categorical = [
        "zorginstellingnaam",
        "gender",
        "clinical_specialty",
        "DBC_specialty_code",
        "DBC_diagnosis_code",
        "icd10_subtraject_code",
    ]
    numerical = ["age"]

    # get one-hot encoded features of full dataset (all categories)
    # ohe_df = pl.read_csv(ohe_csv).to_pandas()['ohe_cats']
    # ohe_cats =[]
    # for cat in ohe_df:
    #     ohe_cats.append(cat)

    with open(ohe_pkl, "rb") as f:
        ohe_cats = pickle.load(f)

    train = ICD10data(
        csv_path=train_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    )
    val = ICD10data(
        csv_path=val_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
        encoder=train.encoder,  # use encoder from training set
        scaler=train.scaler,  # use scalor from train set
    )
    test = ICD10data(
        csv_path=test_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
        encoder=train.encoder,  # use encoder from training set
        scaler=train.scaler,  # use scalor from train set
    )

    input_dim = train.X.shape[1]
    output_dim = train.classes.shape[0]

    print(f"Number of icd10 classes: {len(train.classes)}")
    print(f"(input_dim: {input_dim}, output_dim: {output_dim})")

    # variables dataloader
    batch_size = 32
    shuffle = True

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle)

# start training

In [None]:
# variables MLP
num_epochs = 150
early_stopping_patience = 20
learning_rate = 1e-3
dropout = 0.2
hidden_dim = 256

model = SimpleMLP(
    input_dim=input_dim, hidden_dim=hidden_dim, num_classes=output_dim, dropout=dropout
)
model.fit(
    train_loader,
    val_loader,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    early_stopping_patience=early_stopping_patience,
    verbose=True,
)

y_pred = model.predict(test.X)
y_proba = model.predict_proba(test.X)

In [None]:
y_test = test.y.numpy()

In [None]:
# y_test = test.y.numpy()

# # select y labels: the most common, the most rare, and middle
# y_select_table = (
#     pl.DataFrame(train.y.numpy()).to_series().value_counts(sort=True)[0, 500, -1]
# )
# y_select = list(y_select_table["column_0"])
# y_mask = np.isin(y_test, y_select)
# y_test_select = y_test[y_mask]
# y_proba_rows_select = y_proba[y_mask]
# y_proba_select = y_proba_rows_select[:, y_select]

In [None]:
# _ = plot_reliability_diagram(
#     y_test_select,
#     [
#         y_proba_select,
#     ],
#     legend=[
#         "MLP",
#     ],
#     class_names=list(test.classes[y_select]),
# )

In [None]:
# variables calibrator
C = 0.03
solver = "lbfgs"

# 1) get your raw logits from the MLP
props_val = model.predict_proba(val.X)  # shape (n_val, n_classes)
probs_test = model.predict_proba(test.X)  # shape (n_test, n_classes)

# 2) instantiate & fit the pycalib logistic (Platt) calibrator
calibrator = LogisticCalibration(
    C=C, solver=solver, multi_class="multinomial", log_transform=True
)
calibrator.fit(props_val, val.y.numpy())

# 3) use it to get calibrated probabilities on your test set
probs_calibrated = calibrator.predict_proba(probs_test)  # shape (n_test, n_classes)

In [None]:
bins = 5
for metric in conf_ECE, classwise_ECE:  # ECE,
    print(metric.__name__)
    print("Classifier = {:.3f}".format(metric(test.y.numpy(), probs_test, bins=bins)))
    print(
        "Calibrator = {:.3f}".format(metric(test.y.numpy(), probs_calibrated, bins=bins))
    )
    print("")

true_corr = (np.argmax(probs_test, axis=1) == y_test).astype(int)
print('brier_score_loss')
print("Classifier       = {:.3f}".format(brier_score_loss(true_corr, probs_test.max(axis=1))))
print("Calibrated      = {:.3f}".format(brier_score_loss(true_corr, probs_calibrated.max(axis=1))))

In [None]:
fig = plot_reliability_diagram(
    labels=y_test,
    scores=
    [
        probs_test,
        probs_calibrated
    ],
    legend=[
        "MLP (reduced label space)", "Calibrated"
    ],
    # show_gaps=True,
    # show_histogram=True,
    confidence=True,
    bins=11,
)

In [None]:
# Convert calibrated probabilities to predicted classes
y_pred_base = torch.argmax(torch.tensor(probs_test), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_base, test.y)
recall_macro.update(y_pred_base, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the basemodel")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")
print("")

# Convert calibrated probabilities to predicted classes
y_pred_calibrated = torch.argmax(torch.tensor(probs_calibrated), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_calibrated, test.y)
recall_macro.update(y_pred_calibrated, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the calibrated improvement")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")

# Temprature scaling
Multiclass platt's scaling is more flexible but potentially more data-hungry (one weight + bias per class). temperature scaling is based on a single‐parameter rescaling (Single scalar T that uniformly “softens” or “sharpens” all logits), therefore it might better in the current setting. The drawback is its low flexibility because of the single parameter it can only scale calibration globaly instead of each class seperately. Therefore also lower risk on overfitting. Therefore, it is promising when the network is systematically over- or under-confident across all classes, which is the case. Suitable when in case of a small validation set.
Platt's scaling is more suited for class-specific miscallibration (not the case given it is under-confident accross all) and when having plenty of validation data. 




In [None]:
# extract logits on val & test
logits_val = model.predict_logits(val.X)  # shape (n_val, n_classes)
logits_test = model.predict_logits(test.X)  # shape (n_test, n_classes)

# fit temperature
temp_scaler = TemperatureScaling(device=next(model.parameters()).device)
temp_scaler.fit(logits_val, val.y.numpy())

# get calibrated probabilities
probs_temp = temp_scaler.predict_proba(logits_test)

In [None]:
# # compare on your selected subset exactly as you did before
# y_proba_rows_select_test  = probs_test[y_mask]
# y_proba_select_test      = y_proba_rows_select_test[:, y_select]
# y_proba_rows_select_cali = probs_calibrated[y_mask]
# y_proba_select_cali      = y_proba_rows_select_cali[:, y_select]

# # re-plot reliability
# _ = plot_reliability_diagram(
#     y_test_select,
#     [y_proba_select_test, y_proba_select_cali],
#     legend=["MLP", "+ Temp-scale"],
#     class_names=list(test.classes[y_select]),
# )

In [None]:
# compute ECEs
for metric in (conf_ECE, classwise_ECE):
    print(metric.__name__)
    print(
        "Classifier       = {:.3f}".format(metric(test.y.numpy(), probs_test, bins=bins))
    )
    print(
        "Temp-scaled      = {:.3f}".format(metric(test.y.numpy(), probs_temp, bins=bins))
    )
    print("")

true_corr = (np.argmax(probs_test, axis=1) == y_test).astype(int)
print('brier_score_loss')
print("Classifier       = {:.3f}".format(brier_score_loss(true_corr, probs_test.max(axis=1))))
print("Temp-scaled      = {:.3f}".format(brier_score_loss(true_corr, probs_temp.max(axis=1))))

In [None]:
# Convert calibrated probabilities to predicted classes
y_pred_base = torch.argmax(torch.tensor(probs_test), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_base, test.y)
recall_macro.update(y_pred_base, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the basemodel")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")
print("")

# Convert calibrated probabilities to predicted classes
y_pred_calibrated = torch.argmax(torch.tensor(probs_temp), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_calibrated, test.y)
recall_macro.update(y_pred_calibrated, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the calibrated improvement")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")

In [None]:
fig = plot_reliability_diagram(
    labels=y_test,
    scores=
    [
        probs_test,
        probs_calibrated,
        probs_temp
    ],
    legend=[
        "MLP (original)", "Platt", "Temp"
    ],
    # show_gaps=True,
    # show_histogram=True,
    confidence=True,
    bins=15,
)

In [None]:
# after fitting temp_scaler as above
print("Temp‐scaled ECE:", conf_ECE(test.y.numpy(), probs_temp, bins=15))
print("Logistic ECE:", conf_ECE(test.y.numpy(), probs_calibrated, bins=15))

# base calibrator

In [None]:
#  a) Get validation & test data
logits_val = model.predict_logits(val.X)  # (n_val, K)
probs_val = model.predict_proba(val.X)  # (n_val, K)
y_val = val.y.numpy()

logits_test = model.predict_logits(test.X)  # (n_test, K)
probs_test = model.predict_proba(test.X)  # (n_test, K)
y_test = test.y.numpy()

device = next(model.parameters()).device

#  b) Instantiate both
calibrators = {
    "Multiclass‐TS": MulticlassTemperatureScaling(device=device, init_temp=1.0),
    "TopLabel‐TS": TopLabelTemperatureScaling(device=device, init_temp=1.0),
}

#  c) Fit & apply
calibrated = {}
for name, cal in calibrators.items():
    cal.fit(logits_val, probs_val, y_val)
    calibrated[name] = cal.predict_proba(logits_test, probs_test)

#  d) Evaluate
print(calibrated)

# For multiclass you might compute ECE or NLL; for the example below we
# just show top-label Brier for both (binary):
true_corr = (np.argmax(probs_test, axis=1) == y_test).astype(int)
print("=== Brier top-label ===")
print("Uncalibrated:", brier_score_loss(true_corr, probs_test.max(axis=1)))
print("TopLabel-TS:", brier_score_loss(true_corr, calibrated["TopLabel-TS"]))

# Group common, uncommon and rare classes

In [None]:
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

In [None]:
# 1) Get raw labels & probs
# ------------------------------------------------
# Assuming your ICD10data objects expose .y as a torch.Tensor
y_train = train.y.numpy()  # shape (n_train,)
y_test = test.y.numpy()  # shape (n_test,)
probs_test = model.predict_proba(test.X)  # shape (n_test, num_classes)
num_classes = probs_test.shape[1]

# 2) Compute class frequencies & sorted order
# ------------------------------------------------
class_counts = np.bincount(y_train, minlength=num_classes)
sorted_idxs = np.argsort(class_counts)[::-1]  # descending by count
cum_counts = class_counts[sorted_idxs].cumsum()
total = class_counts.sum()

# find where cumulative hits 33% and 66% of total examples
cut1 = np.searchsorted(cum_counts, total * 0.33)
cut2 = np.searchsorted(cum_counts, total * 0.66)

high_freq_idxs = sorted_idxs[:cut1]
mid_freq_idxs = sorted_idxs[cut1:cut2]
low_freq_idxs = sorted_idxs[cut2:]

# 3) Randomly sample up to K from each bucket
# ------------------------------------------------
rng = np.random.default_rng(42)
K_high, K_mid, K_low = 5, 5, 5

sel_high = rng.choice(
    high_freq_idxs, size=min(K_high, len(high_freq_idxs)), replace=False
)
sel_mid = rng.choice(mid_freq_idxs, size=min(K_mid, len(mid_freq_idxs)), replace=False)
sel_low = rng.choice(low_freq_idxs, size=min(K_low, len(low_freq_idxs)), replace=False)

selected_classes = np.concatenate([sel_high, sel_mid, sel_low])

# 4) Compute per-class calibration curves
# ------------------------------------------------
n_bins = 15
calib_data = {}

for cls in selected_classes:
    # binary ground-truth for “is this class?”
    y_true_bin = (y_test == cls).astype(int)
    # predicted probability for that class
    y_prob_cls = probs_test[:, cls]

    prob_pred, frac_true = calibration_curve(
        y_true_bin, y_prob_cls, n_bins=n_bins, strategy="uniform"
    )
    calib_data[cls] = (prob_pred, frac_true)

# 5) Plot them together
# ------------------------------------------------
plt.figure(figsize=(8, 8))
plt.plot([0, 1], [0, 1], "k--", label="Ideal")

for cls in selected_classes:
    prob_pred, frac_true = calib_data[cls]
    plt.plot(
        prob_pred,
        frac_true,
        marker="o",
        linestyle="-",
        label=f"class {cls} (freq={class_counts[cls]})",
    )

plt.xlabel("Mean predicted probability")
plt.ylabel("Fraction of positives")
plt.title(f"Reliability diagram for {len(selected_classes)} sample classes")
plt.legend(loc="lower right", fontsize="small", ncol=1)
plt.grid(True)
plt.tight_layout()
plt.show()