In [None]:
%load_ext autoreload
%autoreload 2

# Preprocessing
- drop NaN values
- removed 3rd gender (n=1)
- merge rare classes (~4000 classes --> 1000)
- drop classes with less than 3x25 samples
- stratified 70% train, 15% calibration, 15% test split


data_config: dictionary containing the following keys:
```
target: target column name
numerical_features: list of numerical feature names
categorical_features: list of categorical feature names
high_cardinality_features: list of high cardinality feature names, that must be embeded (too large feature space for one hote encoding (ohe))
use_embedding: bool whether to use embedding for high cardinality features

train_csv: str path to training csv file
val_csv: str path to validation csv file
test_csv: sre path to test csv file
ohe_pkl: str path to pickle file containing ohe categories
```

In [None]:
# import data configuration
from data.config import data_config

# relative imports
from uncertainty_aware_diagnosis import(
    ICD10data, 
    SimpleMLP, 
    PlattCalibrator,
    TemperatureScaling,
)

# absolute imports
import polars as pl
import pickle
import numpy as np
import torch
from torchmetrics import F1Score, Recall
from torch.utils.data import DataLoader
from sklearn.metrics import brier_score_loss
from pycalib.visualisations import plot_reliability_diagram
from pycalib.models.calibrators import LogisticCalibration
from pycalib.metrics import classwise_ECE, conf_ECE

# paths
train_csv = data_config['train_csv']
val_csv = data_config['val_csv']
test_csv = data_config['test_csv']
ohe_pkl = data_config['ohe_pkl']

# variables dataloader
batch_size = 32
shuffle = True

# Experiment settings

In [None]:
# different experiment settings
predict_hoofdgroepen_only = False
use_single_hospital = False
use_subset = False # use subset of data for faster training
subset_size = 6400

### Load data

In [None]:
# load data
if predict_hoofdgroepen_only:
    target = "hoofdgroep"  # simplify multi-class problem by predicting hoofdgroepen instead of specific codes
else:
    target = data_config['target'] # ICD10 principle diagnosis code

# features
numerical = data_config['numerical_features']
categorical = data_config['categorical_features']
high_cardinality_features = data_config['high_cardinality_features']
use_embedding = False

with open(ohe_pkl, "rb") as f:
    ohe_cats = pickle.load(f)

In [None]:
train = ICD10data(
    csv_path=train_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
)
val = ICD10data(
    csv_path=val_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)
test = ICD10data(
    csv_path=test_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)

input_dim = train.X.shape[1]
output_dim = train.classes.shape[0]

print(f"Number of icd10 classes: {len(train.classes)}")
print(f"(input_dim: {input_dim}, output_dim: {output_dim})")

In [None]:
if use_subset:
    train.X = train.X[:subset_size]
    val.X = val.X[:subset_size]
    train.y = train.y[:subset_size]
    val.y = val.y[:subset_size]

train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle)

### select single hospital (if True)

In [None]:
if use_single_hospital:

    train = ICD10data(
        csv_path=train_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    )
    val = ICD10data(
        csv_path=val_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
        encoder=train.encoder,  # use encoder from training set
        scaler=train.scaler,  # use scalor from train set
    )
    test = ICD10data(
        csv_path=test_csv[:-4] + "_1hosp" + ".csv",
        numerical=numerical,
        categorical=categorical,
        high_card=[],
        target=target,
        dropna=True,
        use_embedding=False,
        ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
        encoder=train.encoder,  # use encoder from training set
        scaler=train.scaler,  # use scalor from train set
    )

    input_dim = train.X.shape[1]
    output_dim = train.classes.shape[0]

    print(f"Number of icd10 classes: {len(train.classes)}")
    print(f"(input_dim: {input_dim}, output_dim: {output_dim})")

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle)

# start training

In [None]:
# variables MLP
num_epochs = 150
early_stopping_patience = 20
learning_rate = 1e-3
dropout = 0.2
hidden_dim = 256
k_folds = 3

# define model
model = SimpleMLP(
    input_dim=input_dim, hidden_dim=hidden_dim, num_classes=output_dim, dropout=dropout
)
# # fit model using validation/calibration set
# model.fit(
#     train_loader,
#     val_loader,
#     num_epochs=num_epochs,
#     learning_rate=learning_rate,
#     early_stopping_patience=early_stopping_patience,
#     verbose=True,
# )

# fit model using cross validation
model.fit_cv(
    train_loader,
    k_folds=k_folds,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    early_stopping_patience=early_stopping_patience,
    verbose=True,
)

# predict y, probs, and define y_test
y_pred = model.predict(test.X)
y_proba = model.predict_proba(test.X)
y_test = test.y.numpy()

# Fit Platt's scaling

In [None]:
# variables calibrator
C = 0.03
solver = "lbfgs"

# 1) get your raw logits from the MLP
props_val = model.predict_proba(val.X)  # shape (n_val, n_classes)
probs_test = model.predict_proba(test.X)  # shape (n_test, n_classes)

# 2) instantiate & fit the pycalib logistic (Platt) calibrator
calibrator = LogisticCalibration(
    C=C, solver=solver, multi_class="multinomial", log_transform=True
)
calibrator.fit(props_val, val.y.numpy())

# 3) use it to get calibrated probabilities on your test set
probs_calibrated = calibrator.predict_proba(probs_test)  # shape (n_test, n_classes)

In [None]:
# variables calibrator
C = 0.03
solver = "lbfgs"

# get raw logits from the MLP
props_val = model.predict_proba(val.X)  # shape (n_val, n_classes)
probs_test = model.predict_proba(test.X)  # shape (n_test, n_classes)

# instantiate & fit the Platt calibrator
platt_calibrator = PlattCalibrator(C=C, solver=solver, multi_class="multinomial", log_transform=True)
platt_calibrator.fit(props_val, val.y.numpy())

# use it to get calibrated probabilities on your test set
probs_calibrated_platt = platt_calibrator.predict_proba(probs_test)


## Compute calibration metrics

In [None]:
bins = 5 # ECE is based on bins

# compute ECEs
for metric in conf_ECE, classwise_ECE:  # ECE,
    print(metric.__name__)
    print("Classifier = {:.3f}".format(metric(test.y.numpy(), probs_test, bins=bins)))
    print(
        "Calibrator = {:.3f}".format(metric(test.y.numpy(), probs_calibrated, bins=bins))
    )
    print("")

# compute brier score loss
true_corr = (np.argmax(probs_test, axis=1) == y_test).astype(int)
print('brier_score_loss')
print("Classifier       = {:.3f}".format(brier_score_loss(true_corr, probs_test.max(axis=1))))
print("Calibrated      = {:.3f}".format(brier_score_loss(true_corr, probs_calibrated.max(axis=1))))

## Plot reliability diagram (Baseline vs. Platt's scaling)

In [None]:
# plot reliability diagram
fig = plot_reliability_diagram(
    labels=y_test,
    scores=
    [
        probs_test,
        probs_calibrated
    ],
    legend=[
        "MLP (reduced label space)", "Calibrated"
    ],
    # show_gaps=True,
    # show_histogram=True,
    confidence=True,
    bins=11,
)

# Compute test scores (comparing baseline to platt's scaling)

In [None]:
# Convert calibrated probabilities to predicted classes
y_pred_base = torch.argmax(torch.tensor(probs_test), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_base, test.y)
recall_macro.update(y_pred_base, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the basemodel")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")
print("")

# Convert calibrated probabilities to predicted classes
y_pred_calibrated = torch.argmax(torch.tensor(probs_calibrated), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_calibrated, test.y)
recall_macro.update(y_pred_calibrated, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the calibrated improvement")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")

# Temprature scaling
Multiclass platt's scaling is more flexible but potentially more data-hungry (one weight + bias per class). temperature scaling is based on a single‐parameter rescaling (Single scalar T that uniformly “softens” or “sharpens” all logits), therefore it might better in the current setting. The drawback is its low flexibility because of the single parameter it can only scale calibration globaly instead of each class seperately. Therefore also lower risk on overfitting. Therefore, it is promising when the network is systematically over- or under-confident across all classes, which is the case. Suitable when in case of a small validation set.
Platt's scaling is more suited for class-specific miscallibration (not the case given it is under-confident accross all) and when having plenty of validation data. 




In [None]:
# extract logits on val & test
logits_val = model.predict_logits(val.X)  # shape (n_val, n_classes)
logits_test = model.predict_logits(test.X)  # shape (n_test, n_classes)

# fit temperature
temp_scaler = TemperatureScaling(device=next(model.parameters()).device)
temp_scaler.fit(logits_val, val.y.numpy())

# get calibrated probabilities
probs_temp = temp_scaler.predict_proba(logits_test)

## Compute calibration metrics

In [None]:
# compute ECEs
for metric in (conf_ECE, classwise_ECE):
    print(metric.__name__)
    print(
        "Classifier       = {:.3f}".format(metric(test.y.numpy(), probs_test, bins=bins))
    )
    print(
        "Temp-scaled      = {:.3f}".format(metric(test.y.numpy(), probs_temp, bins=bins))
    )
    print("")

# compute brier score loss
true_corr = (np.argmax(probs_test, axis=1) == y_test).astype(int)
print('brier_score_loss')
print("Classifier       = {:.3f}".format(brier_score_loss(true_corr, probs_test.max(axis=1))))
print("Temp-scaled      = {:.3f}".format(brier_score_loss(true_corr, probs_temp.max(axis=1))))

## Compute test scores (comparing baseline to temperature scaling)

In [None]:
# Convert calibrated probabilities to predicted classes
y_pred_base = torch.argmax(torch.tensor(probs_test), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_base, test.y)
recall_macro.update(y_pred_base, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the basemodel")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")
print("")

# Convert calibrated probabilities to predicted classes
y_pred_calibrated = torch.argmax(torch.tensor(probs_temp), dim=1)

# Initialize metrics
f1_macro = F1Score(task="multiclass", average="macro", num_classes=len(test.classes))
recall_macro = Recall(task="multiclass", average="macro", num_classes=len(test.classes))

# Update metrics with predictions and true labels
f1_macro.update(y_pred_calibrated, test.y)
recall_macro.update(y_pred_calibrated, test.y)

# Compute final values
final_f1 = f1_macro.compute()
final_recall = recall_macro.compute()

# Print results
print("Test scores of the calibrated improvement")
print(f"F1 Macro: {final_f1:.4f}")
print(f"Recall Macro: {final_recall:.4f}")

## Plot reliability diagram (baseline vs. Platt's vs. Temperature)

In [None]:
# plot reliability diagram	
fig = plot_reliability_diagram(
    labels=y_test,
    scores=
    [
        probs_test,
        probs_calibrated,
        probs_temp
    ],
    legend=[
        "MLP (original)", "Platt", "Temp"
    ],
    # show_gaps=True,
    # show_histogram=True,
    confidence=True,
    bins=15,
)

In [None]:
# compare ECEs
print("Temp‐scaled ECE:", conf_ECE(test.y.numpy(), probs_temp, bins=15))
print("Logistic ECE:", conf_ECE(test.y.numpy(), probs_calibrated, bins=15))

# compare brier score loss
print("Temp‐scaled brier score loss:", brier_score_loss(true_corr, probs_temp.max(axis=1)))
print("Logistic brier score loss:", brier_score_loss(true_corr, probs_calibrated.max(axis=1)))