In [None]:
%load_ext autoreload
%autoreload 2

### preprocessing
- drop NaN values
- removed 3rd gender (n=1)
- merge rare classes
- drop classes with less than 25 samples

In [None]:
# relative imports
from uncertainty_aware_diagnosis import ICD10data, SimpleMLP
from torch.utils.data import DataLoader

# absolute imports
import polars as pl
import pickle
import numpy as np
from pycalib.visualisations import plot_reliability_diagram
from pycalib.models.calibrators import LogisticCalibration
from pycalib.metrics import classwise_ECE, conf_ECE

# paths
train_csv = "./data/lbz-train.csv"
val_csv = "./data/lbz-val.csv"
test_csv = "./data/lbz-test.csv"
ohe_pkl = "./data/ohe_cats.pkl"

In [None]:
# load data
target = "icd10_main_code"
categorical = [
    "zorginstellingnaam",
    "gender",
    "clinical_specialty",
    "DBC_specialty_code",
    "DBC_diagnosis_code",
    "icd10_subtraject_code",
]
numerical = ["age"]

# get one-hot encoded features of full dataset (all categories)
# ohe_df = pl.read_csv(ohe_csv).to_pandas()['ohe_cats']
# ohe_cats =[]
# for cat in ohe_df:
#     ohe_cats.append(cat)

with open(ohe_pkl, "rb") as f:
    ohe_cats = pickle.load(f)

train = ICD10data(
    csv_path=train_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
)
val = ICD10data(
    csv_path=val_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)
test = ICD10data(
    csv_path=test_csv,
    numerical=numerical,
    categorical=categorical,
    high_card=[],
    target=target,
    dropna=True,
    use_embedding=False,
    ohe_categories=ohe_cats,  # use one-hot encoded categorie of full dataset
    encoder=train.encoder,  # use encoder from training set
    scaler=train.scaler,  # use scalor from train set
)

input_dim = train.X.shape[1]
output_dim = train.classes.shape[0]

print(f"Number of icd10 classes: {len(train.classes)}")
print(f"(input_dim: {input_dim}, output_dim: {output_dim})")

In [None]:
# variables dataloader
batch_size = 32
shuffle = True

train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle)

In [None]:
# variables MLP
num_epochs = 25
early_stopping_patience = 10
learning_rate = 1e-3
dropout = 0.2
hidden_dim = 256

model = SimpleMLP(
    input_dim=input_dim, hidden_dim=hidden_dim, num_classes=output_dim, dropout=dropout
)
model.fit(
    train_loader,
    val_loader,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    early_stopping_patience=early_stopping_patience,
    verbose=True,
)

y_pred = model.predict(test.X)
y_proba = model.predict_proba(test.X)

In [None]:
y_test = test.y.numpy()

# select y labels: the most common, the most rare, and middle
y_select_table = (
    pl.DataFrame(train.y.numpy()).to_series().value_counts(sort=True)[0, 500, -1]
)
y_select = list(y_select_table["column_0"])
y_mask = np.isin(y_test, y_select)
y_test_select = y_test[y_mask]
y_proba_rows_select = y_proba[y_mask]
y_proba_select = y_proba_rows_select[:, y_select]

In [None]:
_ = plot_reliability_diagram(
    y_test_select,
    [
        y_proba_select,
    ],
    legend=[
        "MLP",
    ],
    class_names=list(test.classes[y_select]),
)

In [None]:
# variables calibrator
C = 0.002
solver = "lbfgs"

# 1) get your raw logits from the MLP
props_val = model.predict_proba(val.X)  # shape (n_val, n_classes)
probs_test = model.predict_proba(test.X)  # shape (n_test, n_classes)

# 2) instantiate & fit the pycalib logistic (Platt) calibrator
calibrator = LogisticCalibration(
    C=C, solver=solver, multi_class="multinomial", log_transform=True
)
calibrator.fit(props_val, val.y.numpy())

# 3) use it to get calibrated probabilities on your test set
probs_calibrated = calibrator.predict_proba(probs_test)  # shape (n_test, n_classes)

In [None]:
y_proba_rows_select_test = probs_test[y_mask]
y_proba_select_test = y_proba_rows_select_test[:, y_select]
y_proba_rows_select_cali = probs_calibrated[y_mask]
y_proba_select_cali = y_proba_rows_select_cali[:, y_select]

_ = plot_reliability_diagram(
    y_test_select,
    [y_proba_select_test, y_proba_select_cali],
    legend=["MLP", "+ Calibrator"],
    class_names=list(test.classes[y_select]),
)

In [None]:
for metric in conf_ECE, classwise_ECE:  # ECE,
    print(metric.__name__)
    print("Classifier = {:.3f}".format(metric(test.y.numpy(), probs_test, bins=15)))
    print(
        "Calibrator = {:.3f}".format(metric(test.y.numpy(), probs_calibrated, bins=15))
    )
    print("")