In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import tensorflow_datasets as tfds

import os
import pickle
import numpy as np
from tqdm import tqdm

from src.pmi_estimators import train_critic_model, neural_pmi
from src.psi_estimators import psi_gaussian_train, psi_gaussian_val_class
from src.pvi_estimators import train_pvi_null_model, neural_pvi_class, neural_pvi_ensemble_class
import src.utils as utils
import src.metrics as metrics
import src.methods as methods
import src.temp_scaling as temp_scaling

2025-06-05 10:13:49.193148: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-05 10:13:49.193226: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-05 10:13:49.194566: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-05 10:13:49.201809: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
model_name = 'mlp'
dataset_name = 'mnist'

(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train[:85%]', 'train[85%:]', 'test'],
    data_dir = '../tensorflow_datasets/',
    shuffle_files=False,
    as_supervised=True,
    with_info=True
)

num_classes = 10
def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.one_hot(label, depth=num_classes)
    return image, label

ds_train = ds_train.map(preprocess)
ds_val = ds_val.map(preprocess)
ds_test = ds_test.map(preprocess)

# batch_size = 128
# ds_train = ds_train.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
# ds_val = ds_val.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# ds_test = ds_test.batch(batch_size).prefetch(tf.data.AUTOTUNE)

true_y_train = np.argmax([y for x,y in ds_train], axis=1)
true_y_val = np.argmax([y for x,y in ds_val], axis=1)
true_y_test = np.argmax([y for x,y in ds_test], axis=1)

2025-06-05 10:13:52.469347: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0


In [3]:
def create_model():
    model = tf.keras.Sequential()
    model.add(Flatten(input_shape=(28,28,1)))
    for _ in range(3):
        model.add(Dense(512, activation='relu'))
    model.add(Dense(10, activation='linear'))
    return model

### Train Model

In [4]:
for run in range(10):
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
    
    model = create_model()
    
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])

    lr_scheduler = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5, verbose=1)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True, verbose=1)
    history = model.fit(ds_train, validation_data=ds_val, epochs=100, callbacks=[lr_scheduler, early_stop])
    
    if not os.path.exists(exp_name+'/saved_models'):
        print("Making directory", exp_name+'/saved_models')
        os.makedirs(exp_name+'/saved_models')

    model.save_weights(f'{exp_name}/saved_models/trained_weights.h5')
    with open(f'{exp_name}/history.pickle', 'wb') as f:
        pickle.dump(history, f, protocol=pickle.HIGHEST_PROTOCOL)

Epoch 1/100


2025-06-05 08:44:08.732839: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f1dcd2ed4d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-06-05 08:44:08.732884: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0
2025-06-05 08:44:08.738574: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-06-05 08:44:08.777874: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100
I0000 00:00:1749113048.873410   10741 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 30: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 35: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Restoring model weights from the end of the best epoch: 25.
Epoch 35: early stopping
Making directory ../results/PI_Explainability/mlp_mnist/run_1/saved_models
Making directory ../results/PI_Explainability/mlp_mnist/run_2
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoc

In [5]:
train_acc = []
val_acc = []
test_acc = []
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])
    train_acc.append(model.evaluate(ds_train, verbose=1)[1])
    val_acc.append(model.evaluate(ds_val, verbose=1)[1])
    test_acc.append(model.evaluate(ds_test, verbose=1)[1])
print(f'Average train error: {(100-np.mean(train_acc)*100):.2f} ({(np.std(train_acc)*100):.2f})')
print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f} ({(np.std(val_acc)*100):.2f})')
print(f'Average test error: {(100-np.mean(test_acc)*100):.2f} ({(np.std(test_acc)*100):.2f})')

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10
Average train error: 0.00 (0.00)
Average validation error: 1.94 (0.10)
Average test error: 1.97 (0.07)


### PMI

In [5]:
from src.pmi_estimators import train_critic_model, neural_pmi
from tqdm import tqdm

for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/separable_variational_f_js'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)

    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)

    ##############################################################
    #
    # Train PMI Model
    #
    # #############################################################

    print(f'Training PMI model...')
    ds_activity_trn = ds_train.batch(128).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)
    ds_activity_val = ds_val.batch(128).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)
    train_critic_model(ds_activity_trn, ds_activity_val, critic='separable', estimator='variational_f_js', epochs=200, save_path=f'{exp_name}/pmi_output_model')

    ##############################################################
    #
    # Compute PMI for all validation and test samples
    #
    # #############################################################

    pmi_model = tf.keras.models.load_model(f'{exp_name}/pmi_output_model')
    n_classes = 10

    print(f'Computing PMI for all validation samples and for all classes...')
    encoded_x = []
    for x, _ in ds_val.batch(128):
        encoded_x.append(int_model(x).numpy())
    encoded_x = np.concatenate(encoded_x)
    num_samples = encoded_x.shape[0]
    
    pmi_class = []
    batch_size = 1024
    for k in range(n_classes):
        num_samples = encoded_x.shape[0]
        y_k = tf.one_hot(tf.fill([num_samples], k), depth=n_classes)
        pmi_list = []
        for i in tqdm(range(0, len(encoded_x), batch_size), desc=f"Computing PMI for class {k+1}"):
            x_batch = encoded_x[i:i+batch_size]
            y_batch = y_k[i:i+batch_size]
            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator='variational_f_js')
            pmi_list += np.array(pmi).tolist()
        pmi_class.append(pmi_list)
    np.save(f'{exp_name}/pmi_output_class_val.npy', np.array(pmi_class).T)
    
    print(f'Computing PMI for all test samples and for all classes...')
    encoded_x = []
    for x, _ in ds_test.batch(128):
        encoded_x.append(int_model(x).numpy())
    encoded_x = np.concatenate(encoded_x)
    num_samples = encoded_x.shape[0]
    
    pmi_class = []
    batch_size = 1024
    for k in range(n_classes):
        num_samples = encoded_x.shape[0]
        y_k = tf.one_hot(tf.fill([num_samples], k), depth=n_classes)
        pmi_list = []
        for i in tqdm(range(0, len(encoded_x), batch_size), desc=f"Computing PMI for class {k+1}"):
            x_batch = encoded_x[i:i+batch_size]
            y_batch = y_k[i:i+batch_size]
            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator='variational_f_js')
            pmi_list += np.array(pmi).tolist()
        pmi_class.append(pmi_list)
    np.save(f'{exp_name}/pmi_output_class_test.npy', np.array(pmi_class).T)

Run: 1
Making directory ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:04<13:54,  4.19s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:05<08:27,  2.56s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:07<06:56,  2.11s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:08<06:00,  1.84s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:10<05:28,  1.69s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   5%|▌         | 10/200 [00:15<03:29,  1.10s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▌         | 11/200 [00:16<03:47,  1.20s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▌         | 12/200 [00:18<03:58,  1.27s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▋         | 13/200 [00:19<04:05,  1.31s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  12%|█▏        | 23/200 [00:30<03:57,  1.34s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 133.03it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 603.29it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 627.52it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 637.94it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 637.09it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 627.59it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 636.56it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 634.98it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 651.41it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 649.84it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 255.47it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 618.65it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 633.93it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 637.08it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 672.23it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 675.76it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 652.69it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 585.62it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 658.41it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 663.56it/s]


Run: 2
Making directory ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:05<07:24,  2.25s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:06<06:36,  2.01s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:09<04:52,  1.50s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:10<04:44,  1.46s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:11<04:39,  1.45s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 8/200 [00:13<04:37,  1.45s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [00:21<03:04,  1.00s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 17/200 [00:22<03:26,  1.13s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   9%|▉         | 18/200 [00:24<03:39,  1.21s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|▉         | 19/200 [00:25<03:49,  1.27s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|█         | 21/200 [00:28<03:34,  1.20s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  16%|█▌        | 31/200 [00:39<03:33,  1.26s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 137.99it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 624.25it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 644.00it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 642.17it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 650.58it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 592.91it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 679.64it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 664.51it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 677.02it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 675.93it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 41.41it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 643.41it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 659.21it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 666.66it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 661.34it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 666.04it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 669.18it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 664.31it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 670.60it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 664.44it/s]


Run: 3
Making directory ../results/PI_Explainability/mlp_mnist/run_3/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   5%|▌         | 10/200 [00:14<04:29,  1.42s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 151.16it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 641.55it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 664.84it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 664.78it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 675.96it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 685.92it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 664.79it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 688.64it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 681.89it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 680.00it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 258.02it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 656.11it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 682.78it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 689.10it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 688.03it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 683.23it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 672.54it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 683.96it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 686.85it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 680.67it/s]


Run: 4
Making directory ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:05<07:25,  2.25s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:07<04:55,  1.51s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:08<04:45,  1.47s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:10<04:40,  1.45s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:11<04:36,  1.43s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 8/200 [00:13<04:32,  1.42s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   9%|▉         | 18/200 [00:24<04:05,  1.35s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 149.20it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 620.95it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 639.95it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 654.81it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 660.01it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 670.61it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 655.58it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 660.24it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 667.13it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 663.88it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 263.43it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 613.35it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 649.92it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 641.69it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 652.86it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 651.80it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 634.93it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 668.36it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 663.26it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 664.32it/s]


Run: 5
Making directory ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:05<05:21,  1.63s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:07<05:00,  1.54s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:09<04:12,  1.30s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:11<04:16,  1.33s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 17/200 [00:21<03:54,  1.28s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 147.95it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 642.67it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 677.33it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 685.43it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 677.57it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 678.97it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 668.30it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 686.93it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 662.10it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 697.84it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 252.77it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 668.36it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 688.73it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 677.44it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 691.59it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 688.61it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 686.29it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 694.18it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 697.99it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 612.67it/s]


Run: 6
Making directory ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:04<07:19,  2.22s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:06<06:30,  1.98s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:09<04:43,  1.45s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:10<04:38,  1.43s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:11<04:34,  1.42s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 8/200 [00:13<04:32,  1.42s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 9/200 [00:14<04:30,  1.42s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   5%|▌         | 10/200 [00:16<04:28,  1.41s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|▉         | 19/200 [00:24<02:55,  1.03it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|█         | 20/200 [00:26<03:18,  1.10s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  12%|█▏        | 23/200 [00:29<03:08,  1.07s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  12%|█▏        | 24/200 [00:31<03:23,  1.16s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  13%|█▎        | 26/200 [00:33<03:19,  1.15s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  14%|█▍        | 29/200 [00:37<03:13,  1.13s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  20%|█▉        | 39/200 [00:47<03:18,  1.23s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 150.99it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 627.55it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 654.01it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 654.28it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 658.25it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 662.92it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 670.79it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 662.86it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 665.93it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 658.76it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 248.85it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 646.55it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 657.58it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 666.84it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 671.01it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 668.83it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 668.68it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 672.61it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 661.85it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 679.45it/s]


Run: 7
Making directory ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:04<13:18,  4.01s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:06<05:52,  1.79s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:07<05:20,  1.63s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:09<05:01,  1.55s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:10<04:49,  1.49s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 15/200 [00:19<03:03,  1.01it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  12%|█▎        | 25/200 [00:30<03:33,  1.22s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 144.57it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 631.64it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 614.54it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 653.62it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 648.68it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 660.55it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 658.21it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 657.95it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 657.79it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 653.96it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 256.67it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 646.86it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 619.54it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 666.84it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 670.09it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 674.31it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 673.33it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 629.01it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 658.33it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 651.41it/s]


Run: 8
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:05<08:01,  2.43s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:06<06:25,  1.96s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:10<04:07,  1.28s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:11<04:13,  1.31s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 8/200 [00:12<04:15,  1.33s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▌         | 12/200 [00:17<03:26,  1.10s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▋         | 13/200 [00:18<03:43,  1.19s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   7%|▋         | 14/200 [00:20<03:55,  1.26s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 15/200 [00:21<04:03,  1.31s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [00:22<04:05,  1.33s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  13%|█▎        | 26/200 [00:34<03:47,  1.31s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 145.73it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 630.78it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 655.77it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 661.37it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 675.48it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 673.96it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 675.02it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 676.36it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 677.06it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 680.27it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 263.61it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 654.22it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 659.10it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 678.51it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 681.93it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 681.61it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 675.65it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 678.47it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 672.77it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 669.05it/s]


Run: 9
Making directory ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:04<07:19,  2.22s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:06<06:00,  1.83s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:07<05:26,  1.67s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   7%|▋         | 14/200 [00:18<04:06,  1.33s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 149.82it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 650.89it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 685.38it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 680.61it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 686.68it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 687.27it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 682.68it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 670.33it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 672.81it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 676.14it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 251.72it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 652.84it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 634.67it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 670.71it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 683.22it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 671.60it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 675.40it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 675.38it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 677.54it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 677.69it/s]


Run: 10
Making directory ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:07<03:56,  1.21s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:09<04:06,  1.27s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:10<04:11,  1.30s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   7%|▋         | 14/200 [00:17<03:03,  1.01it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 15/200 [00:19<03:44,  1.21s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [00:20<03:52,  1.27s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/mlp_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  13%|█▎        | 26/200 [00:31<03:30,  1.21s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 149.68it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 636.60it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 637.08it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 661.42it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 667.01it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 676.68it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 670.64it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 672.25it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 680.92it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 672.74it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 266.52it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 661.42it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 677.63it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 679.15it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 680.98it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 692.98it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 680.54it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 689.09it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 555.54it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 661.69it/s]


### PSI

In [6]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)

    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)
    
    ##############################################################
    #
    # Train PSI Model
    #
    # #############################################################
    
    x_logits_list = []
    y_labels_list = []

    for x_batch, y_batch in ds_train.batch(256):
        logits = int_model(x_batch)
        labels = tf.argmax(y_batch, axis=1)
        x_logits_list.append(logits)
        y_labels_list.append(labels)

    x = tf.concat(x_logits_list, axis=0).numpy()
    y = tf.concat(y_labels_list, axis=0).numpy()
    
    print(f'Training PSI model (gaussian)...')
    psi_data = psi_gaussian_train(x, y, n_projs=500)
    np.save(f'{exp_name}/gaussian_output_model_500_projs.npy', psi_data)

    ##############################################################
    #
    # Compute PSI for all validation and test samples
    #
    # #############################################################

    psi_data = np.load(f'{exp_name}/gaussian_output_model_500_projs.npy', allow_pickle=True).item()

    print(f'Computing PSI for all validation samples...')
    x_logits_list = []

    for x_batch, y_batch in ds_val.batch(256):
        logits = int_model(x_batch)
        x_logits_list.append(logits)
    
    x = tf.concat(x_logits_list, axis=0).numpy()
    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)
    np.save(f'{exp_name}/psi_output_class_500_projs_val.npy', np.array(psi_class))

    print(f'Computing PSI for all test samples...')
    x_logits_list = []

    for x_batch, y_batch in ds_test.batch(256):
        logits = int_model(x_batch)
        x_logits_list.append(logits)
    
    x = tf.concat(x_logits_list, axis=0).numpy()
    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)
    np.save(f'{exp_name}/psi_output_class_500_projs_test.npy', np.array(psi_class))

Run: 1
Making directory ../results/PI_Explainability/mlp_mnist/run_1/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 267.07it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 253.20it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 234.78it/s]


Run: 2
Making directory ../results/PI_Explainability/mlp_mnist/run_2/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 266.59it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 252.07it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 234.32it/s]


Run: 3
Making directory ../results/PI_Explainability/mlp_mnist/run_3/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 269.75it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 251.00it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.55it/s]


Run: 4
Making directory ../results/PI_Explainability/mlp_mnist/run_4/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 269.43it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 251.22it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 233.41it/s]


Run: 5
Making directory ../results/PI_Explainability/mlp_mnist/run_5/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 268.66it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 252.26it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 232.83it/s]


Run: 6
Making directory ../results/PI_Explainability/mlp_mnist/run_6/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 267.78it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 251.58it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 233.55it/s]


Run: 7
Making directory ../results/PI_Explainability/mlp_mnist/run_7/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 265.04it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 253.09it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 231.53it/s]


Run: 8
Making directory ../results/PI_Explainability/mlp_mnist/run_8/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 266.58it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 252.56it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 233.98it/s]


Run: 9
Making directory ../results/PI_Explainability/mlp_mnist/run_9/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 266.50it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 251.52it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 233.23it/s]


Run: 10
Making directory ../results/PI_Explainability/mlp_mnist/run_10/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 265.22it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 252.08it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.07it/s]


### PVI

In [12]:
random_runs = list(range(10))
while any(random_runs[i] == i for i in range(10)):
    np.random.shuffle(random_runs)
    
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{random_runs[run]+1}/saved_models/trained_weights.h5')
    pvi_model.save_weights(f'{exp_name}/pvi_model_weights.h5')
    
    untrained_model = create_model()
    train_pvi_null_model(ds_train, untrained_model, epochs=10, save_path=f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')

    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

Run: 1
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 2
Making directory ../results/PI_Explainability/mlp_mnist/run_2/calibration/pvi/training_from_scratch
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 3
Making directory ../results/PI_Explainability/mlp_mnist/run_3/calibration/pvi/training_from_scratch
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 4
Making directory ../results/PI_Explainability/mlp_mnist/run_4/calibration

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/finetuned'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    pvi_model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])

    lr_scheduler = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5, verbose=1)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True, verbose=1)
    pvi_model.fit(ds_train.batch(256), validation_data=ds_val.batch(256), epochs=100, callbacks=[lr_scheduler, early_stop])
    
    pvi_model.save_weights(f'{exp_name}/pvi_model_weights.h5')
    
    untrained_model = create_model()
    untrained_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    untrained_model.save_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

In [13]:
pvi_runs = [4 if i == 6 else 6 for i in range(10)]
    
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{pvi_runs[run]+1}/saved_models/trained_weights.h5')
    pvi_model.save_weights(f'{exp_name}/pvi_model_best_weights.h5')
    
#     untrained_model = create_model()
#     train_pvi_null_model(ds_train, untrained_model, epochs=10, save_path=f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_best_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_best_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_best_test.npy', np.array(pvi_class))

Run: 1
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 2
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 3
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 4
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 5
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 6
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 7
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 8
Computing PVI for all validation samples and for all classes...
Computing PVI for all test

### Ensemble PVI

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/ensemble_no_training_training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model_1 = create_model()
    pvi_model_1.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    null_model_1 = create_model()
    null_model_1.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    pvi_model_2 = create_model()
    pvi_model_2.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_model_weights.h5')
    null_model_2 = create_model()
    null_model_2.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    
#     true_y_val = np.argmax([y for x,y in ds_val], axis=1)
#     opt_temp_pvi_1 = utils.temp_scaling_nll(pvi_model_1.predict(ds_val.batch(128), verbose=0), true_y_val)
#     opt_temp_pvi_2 = utils.temp_scaling_nll(pvi_model_2.predict(ds_val.batch(128), verbose=0), true_y_val)
#     ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
#     opt_temp_null = utils.temp_scaling_nll(null_model_1.predict(ds_null.batch(128), verbose=0), true_y_val)
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = []
    for (x_batch, y_batch) in ds_val.batch(256):
        pvi = neural_pvi_ensemble_class([x_batch, x_batch], [pvi_model_1, pvi_model_2], [null_model_1, null_model_2])
        pvi_class += np.array(pvi).tolist()
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = []
    for (x_batch, y_batch) in ds_test.batch(256):
        pvi = neural_pvi_ensemble_class([x_batch, x_batch], [pvi_model_1, pvi_model_2], [null_model_1, null_model_2])
        pvi_class += np.array(pvi).tolist()
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

### Temp Scaling

In [14]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    if not os.path.exists(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'):
        print("Making directory", f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')
        os.makedirs(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')                                  
  
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

#     if not os.path.exists(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'):
#         print("Making directory", f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')
#         os.makedirs(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')                                  
  
    
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_weights_ets_nll.npy', opt_weights)

#     opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)
#     np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ece.npy', opt_temp)

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    pts = temp_scaling.PTSCalibrator(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=64,
    nlayers=2,
    n_nodes=32,
    length_logits=10,
    top_k_logits=5
)

    pts.tune(logits=scores, labels=pred_y_val)
    pts.save(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')

In [15]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/separable_variational_f_js'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pmi_output_class_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                            
    np.save(f'{exp_name}/pmi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                
    np.save(f'{exp_name}/pmi_opt_temp_nll.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                
    np.save(f'{exp_name}/pmi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [16]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/psi_output_class_500_projs_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/psi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                            
    np.save(f'{exp_name}/psi_opt_temp_nll.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                            
    np.save(f'{exp_name}/psi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [17]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/pvi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                          
    np.save(f'{exp_name}/pvi_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                          
    np.save(f'{exp_name}/pvi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [18]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_best_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/pvi_best_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                          
    np.save(f'{exp_name}/pvi_best_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                          
    np.save(f'{exp_name}/pvi_best_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'{exp_name}/pvi_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'{exp_name}/pvi_opt_weights_ets_nll.npy', opt_weights)

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'{exp_name}/pvi_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'{exp_name}/pvi_opt_weights_ets_nll.npy', opt_weights)
    
    pts = temp_scaling.PTSCalibrator(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=64,
    nlayers=2,
    n_nodes=128,
    length_logits=10,
    top_k_logits=5
)

    pts.tune(logits=scores, labels=pred_y_val)
    pts.save(path=f'{exp_name}/calibration_model/')

### Failure Detection

In [20]:
def get_confidence_scores(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name):
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'
    metric = conf_method.split('_')[-1] if 'temp_scaling' in conf_method else None
    method_key = conf_method.replace(f'_temp_scaling_{metric}', '') if metric else conf_method

    if method_key == 'softmax':
        if metric:
            opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
            return methods.max_softmax_prob(model, ds_test, opt_temp)
        else:
            return methods.max_softmax_prob(model, ds_test)

    elif method_key in ['pmi', 'psi', 'pvi', 'pvi_best']:
        if method_key == 'pmi':
            exp_path = f'{base_path}/pmi/separable_variational_f_js'
            class_file = 'pmi_output_class_test.npy'
        elif method_key == 'psi':
            exp_path = f'{base_path}/psi/gaussian'
            class_file = 'psi_output_class_500_projs_test.npy'
        elif method_key == 'pvi':
            exp_path = f'{base_path}/pvi/training_from_scratch'
            class_file = 'pvi_class_test.npy'
        elif method_key == 'pvi_best':
            exp_path = f'{base_path}/pvi/training_from_scratch'
            class_file = 'pvi_class_best_test.npy'

        opt_temp = np.load(f'{exp_path}/{method_key}_opt_temp_{metric}.npy')
        scores_class = np.load(f'{exp_path}/{class_file}')
        scores_class = np.array([utils.softmax(x / opt_temp) for x in scores_class])
        return np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])

    elif method_key == 'softmax_margin':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.softmax_margin(model, ds_test, opt_temp)

    elif method_key == 'max_logits':
        return methods.max_logits(model, ds_test)

    elif method_key == 'logits_margin':
        return methods.logits_margin(model, ds_test)

    elif method_key == 'negative_entropy':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.negative_entropy(model, ds_test, opt_temp)

    elif method_key == 'negative_gini':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.negative_gini(model, ds_test, opt_temp)

    elif method_key == 'isotonic_regression':
        return methods.isotonic_reg(model, ds_val, ds_test, true_y_val)

    else:
        raise ValueError(f"Unknown confidence method: {conf_method}")


def evaluate_failure_pred(ds_test, true_y_test, conf_method, n_runs=10):
    results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": [],
        "naurc": []
    }

    for run in range(n_runs):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        scores_test = get_confidence_scores(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name)

        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        results["naurc"].append(metrics.compute_naurc(scores_test, pred_y_test, true_y_test))

    return results

In [21]:
methods_list = ['softmax_temp_scaling_aurc','pmi_temp_scaling_aurc','psi_temp_scaling_aurc','pvi_temp_scaling_aurc',
                'softmax_margin_temp_scaling_aurc', 'max_logits', 'logits_margin', 'negative_entropy_temp_scaling_aurc',
                'negative_gini_temp_scaling_aurc']
for method in methods_list:
    print(f'Method: {method}')
    results = evaluate_failure_pred(ds_test, true_y_test, conf_method=f'{method}', n_runs=10)
    print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
    print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
    print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
    print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
    print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
    print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")
    print(f"NAURC           : {utils.format_ci(results['naurc'], scale=1000)}")

Method: softmax_temp_scaling_aurc
AUROC           : 94.99 (1.13)
AUPRC (success) : 99.84 (0.05)
AUPRC (error)   : 40.84 (2.75)
FPR at 95% TPR  : 16.07 (4.52)
AURC            : 1.72 (0.51)
EAURC           : 1.53 (0.52)
NAURC           : 79.10 (27.28)
Method: pmi_temp_scaling_aurc
AUROC           : 61.97 (0.73)
AUPRC (success) : 98.50 (0.05)
AUPRC (error)   : 16.61 (1.80)
FPR at 95% TPR  : N/A
AURC            : 15.04 (0.45)
EAURC           : 14.85 (0.45)
NAURC           : 762.00 (14.63)
Method: psi_temp_scaling_aurc
AUROC           : 77.70 (0.98)
AUPRC (success) : 99.12 (0.04)
AUPRC (error)   : 23.05 (2.09)
FPR at 95% TPR  : N/A
AURC            : 8.82 (0.43)
EAURC           : 8.62 (0.42)
NAURC           : 442.53 (19.85)
Method: pvi_temp_scaling_aurc
AUROC           : 93.02 (5.90)
AUPRC (success) : 99.79 (0.26)
AUPRC (error)   : 41.60 (4.64)
FPR at 95% TPR  : 23.12 (3.27)
AURC            : 2.29 (2.60)
EAURC           : 2.09 (2.59)
NAURC           : 105.11 (126.64)
Method: softmax_margin_t

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        logits = model.predict(ds_test.batch(512), verbose=0)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')
        
        scores_class = apply_ets(logits,opt_temp,opt_weights,num_classes)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'PVI ETS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
        opt_temp = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_weights_ets_nll.npy')
        
        scores_class = apply_ets(pvi,opt_temp,opt_weights,num_classes)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
method = 'softmax PTS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        logits = model.predict(ds_test.batch(512), verbose=0)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')
        
        pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
        pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
        scores_class = pts_loaded.calibrate(logits)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
method = 'PVI PTS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
        
        pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
        pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
        scores_class = pts_loaded.calibrate(pvi)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

### Calibration

In [24]:
def get_scores_for_calibration(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name):
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'

    def softmax_scaled(scores, temp=1.0):
        return np.array([utils.softmax(x / temp) for x in scores])

    if conf_method == 'softmax':
        scores_class = methods.softmax_prob(model, ds_test)
        scores_test = methods.max_softmax_prob(model, ds_test)
        return scores_class, scores_test

    if conf_method.startswith('softmax_temp_scaling'):
        metric = conf_method.split('_')[-1]
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        scores_class = methods.softmax_prob(model, ds_test, opt_temp)
        scores_test = methods.max_softmax_prob(model, ds_test, opt_temp)
        return scores_class, scores_test

    if conf_method in ['pmi', 'psi', 'pvi', 'pvi_best']:
        method = conf_method
        metric = None
        temp = 1.0
    elif conf_method.startswith(('pmi_temp_scaling', 'psi_temp_scaling', 'pvi_temp_scaling', 'pvi_best_temp_scaling')):
        parts = conf_method.split('_')
        method = '_'.join(parts[:2]) if 'best' in parts else parts[0]
        metric = parts[-1]
        method_dir = {
            'pmi': 'pmi/separable_variational_f_js',
            'psi': 'psi/gaussian',
            'pvi': 'pvi/training_from_scratch',
            'pvi_best': 'pvi/training_from_scratch'
        }[method]
        temp = float(np.load(f'{base_path}/{method_dir}/{method}_opt_temp_{metric}.npy'))
    else:
        raise ValueError(f"Unknown confidence method: {conf_method}")

    method_paths = {
        'pmi': (f'{base_path}/pmi/separable_variational_f_js', 'pmi_output_class_test.npy'),
        'psi': (f'{base_path}/psi/gaussian', 'psi_output_class_500_projs_test.npy'),
        'pvi': (f'{base_path}/pvi/training_from_scratch', 'pvi_class_test.npy'),
        'pvi_best': (f'{base_path}/pvi/training_from_scratch', 'pvi_class_best_test.npy'),
    }

    method_path, class_file = method_paths[method]
    scores_class = np.load(f'{method_path}/{class_file}')
    scores_class = softmax_scaled(scores_class, temp)
    scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
    return scores_class, scores_test

def evaluate_calibration(ds_test, true_y_test, conf_method, n_runs=10):
    results = {
        "ece": [],
        "cc_ece": [],
        "mce": [],
        "ace": [],
        "sce": [],
        "ada_ece": [],
        "ada_sce": [],
        "cc_ada_ece": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "cw_ece": [],
        "cw_sce": [],
        "cw_ada_ece": [],
        "cw_ada_sce": [],
        "cw_ada_ece_rms": [],
        "cw_ada_sce_rms": [],
        "nll": [],
        "bs": [],
        "sharpness": [],
    }

    for run in range(n_runs):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        scores_class, scores_test = get_scores_for_calibration(
            conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name
        )

        results["ece"].append(metrics.compute_ece(scores_test, pred_y_test, true_y_test, 15))
        results["cc_ece"].append(metrics.compute_cc_ece(scores_test, pred_y_test, true_y_test, 15))
        results["mce"].append(metrics.compute_mce(scores_test, pred_y_test, true_y_test, 15))
        results["ace"].append(metrics.compute_ace(scores_test, pred_y_test, true_y_test, 15))
        results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
        results["ada_ece"].append(metrics.compute_adaece(scores_test, pred_y_test, true_y_test, 15))
        results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
        results["cc_ada_ece"].append(metrics.compute_cc_adaece(scores_test, pred_y_test, true_y_test, 15))
        results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
        results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
        results["cw_ece"].append(metrics.compute_cw_ece(scores_class, true_y_test, num_classes, 15))
        results["cw_sce"].append(metrics.compute_cw_sce(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_ece"].append(metrics.compute_cw_adaece(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_sce"].append(metrics.compute_cw_adasce(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_ece_rms"].append(metrics.compute_cw_adaece_rms(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_sce_rms"].append(metrics.compute_cw_adaece_rms(scores_class, true_y_test, num_classes, 15))
        results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
        results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        results["sharpness"].append(metrics.compute_sharpness(scores_class))

    return results

In [25]:
methods_list = ['softmax','pmi','psi','pvi','pvi_best',
                'softmax_temp_scaling_nll','pmi_temp_scaling_nll','psi_temp_scaling_nll','pvi_temp_scaling_nll','pvi_best_temp_scaling_nll']
for method in methods_list:
    print(f'Method: {method}')
    results = evaluate_calibration(ds_test, true_y_test, conf_method=f'{method}', n_runs=10)
    print(f"ECE:            {utils.format_ci(results['ece'], scale=100)}")
    print(f"CC-ECE:         {utils.format_ci(results['cc_ece'], scale=100)}")
    print(f"MCE:            {utils.format_ci(results['mce'], scale=100)}")
    print(f"ACE:            {utils.format_ci(results['ace'], scale=100)}")
    print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
    print(f"Ada-ECE:        {utils.format_ci(results['ada_ece'], scale=100)}")
    print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
    print(f"CC-Ada-ECE:     {utils.format_ci(results['cc_ada_ece'], scale=100)}")
    print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
    print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
    print(f"CW-ECE:         {utils.format_ci(results['cw_ece'], scale=100)}")
    print(f"CW-SCE:         {utils.format_ci(results['cw_sce'], scale=100)}")
    print(f"CW-Ada-ECE:     {utils.format_ci(results['cw_ada_ece'], scale=100)}")
    print(f"CW-Ada-SCE:     {utils.format_ci(results['cw_ada_sce'], scale=100)}")
    print(f"CW-Ada-ECE-RMS: {utils.format_ci(results['cw_ada_ece_rms'], scale=100)}")
    print(f"CW-Ada-SCE-RMS: {utils.format_ci(results['cw_ada_sce_rms'], scale=100)}")
    print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
    print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")
    print(f"Sharpness:      {utils.format_ci(results['sharpness'], scale=100)}")

Method: softmax
ECE:            1.26 (0.08)
CC-ECE:         1.46 (0.09)
MCE:            0.78 (0.08)
ACE:            17.37 (1.82)
SCE:            0.30 (0.02)
Ada-ECE:        1.22 (0.09)
Ada-SCE:        0.16 (0.01)
CC-Ada-ECE:     1.23 (0.09)
CC-Ada-SCE:     0.15 (0.01)
CC-Ada-SCE-RMS: 1.79 (0.11)
CW-ECE:         0.30 (0.02)
CW-SCE:         0.30 (0.02)
CW-Ada-ECE:     0.08 (0.01)
CW-Ada-SCE:     0.08 (0.01)
CW-Ada-ECE-RMS: 0.24 (0.04)
CW-Ada-SCE-RMS: 0.24 (0.04)
NLL:            9.19 (0.65)
Brier Score:    3.26 (0.09)
Sharpness:      1.97 (0.21)
Method: pmi
ECE:            0.49 (0.12)
CC-ECE:         1.01 (0.08)
MCE:            0.21 (0.09)
ACE:            8.11 (1.36)
SCE:            0.19 (0.01)
Ada-ECE:        0.37 (0.14)
Ada-SCE:        0.11 (0.01)
CC-Ada-ECE:     0.60 (0.08)
CC-Ada-SCE:     0.25 (0.02)
CC-Ada-SCE-RMS: 2.76 (0.15)
CW-ECE:         0.19 (0.01)
CW-SCE:         0.19 (0.01)
CW-Ada-ECE:     0.07 (0.01)
CW-Ada-SCE:     0.07 (0.01)
CW-Ada-ECE-RMS: 0.21 (0.03)
CW-Ada-SCE-RMS: 0.2

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)

    logits = model.predict(ds_test.batch(512), verbose=0)

    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
    opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')

    scores_class = apply_ets(logits,opt_temp,opt_weights,num_classes)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'PVI ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
    opt_temp = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_temp_ets_nll.npy')
    opt_weights = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_weights_ets_nll.npy')

    scores_class = apply_ets(pvi,opt_temp,opt_weights,num_classes)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)

    logits = model.predict(ds_test.batch(512), verbose=0)

    pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
    pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
    scores_class = pts_loaded.calibrate(logits)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
method = 'PVI PTS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')

    pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
    pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
    scores_class = pts_loaded.calibrate(pvi)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")