In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, Dropout, MaxPooling2D, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import tensorflow_datasets as tfds

import os
import pickle
import numpy as np
from tqdm import tqdm

from src.pmi_estimators import train_critic_model, neural_pmi
from src.psi_estimators import psi_gaussian_train, psi_gaussian_val_class
from src.pvi_estimators import train_pvi_null_model, neural_pvi_class, neural_pvi_ensemble_class
import src.utils as utils
import src.metrics as metrics
import src.methods as methods
import src.temp_scaling as temp_scaling

2025-06-05 11:08:55.098565: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-05 11:08:55.098634: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-05 11:08:55.099863: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-05 11:08:55.107218: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        memory_limit = 6 * 1024  # 6GB in MB
        tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=memory_limit)]
        )
        print("GPU memory configuration set successfully.")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU found.")

GPU memory configuration set successfully.


In [3]:
model_name = 'cnn'
dataset_name = 'fashion_mnist'

(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'fashion_mnist',
    split=['train[:85%]', 'train[85%:]', 'test'],
    data_dir = '../tensorflow_datasets/',
    shuffle_files=False,
    as_supervised=True,
    with_info=True
)

num_classes = 10
def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.one_hot(label, depth=num_classes)
    return image, label

ds_train = ds_train.map(preprocess)
ds_val = ds_val.map(preprocess)
ds_test = ds_test.map(preprocess)

# batch_size = 128
# ds_train = ds_train.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
# ds_val = ds_val.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# ds_test = ds_test.batch(batch_size).prefetch(tf.data.AUTOTUNE)

true_y_train = np.argmax([y for x,y in ds_train], axis=1)
true_y_val = np.argmax([y for x,y in ds_val], axis=1)
true_y_test = np.argmax([y for x,y in ds_test], axis=1)

2025-06-05 11:09:01.793159: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6144 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:0f:00.0, compute capability: 8.0


In [4]:
def create_model():
    model = tf.keras.Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(28,28,1)))
    model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.3))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.3))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(10, activation='linear'))
    return model

### Train Model

In [5]:
for run in range(10):
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
    
    model = create_model()
    
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])

    lr_scheduler = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=20, verbose=1)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=20, restore_best_weights=True, verbose=1)
    history = model.fit(ds_train, validation_data=ds_val, epochs=300, callbacks=[lr_scheduler, early_stop])
    
    if not os.path.exists(exp_name+'/saved_models'):
        print("Making directory", exp_name+'/saved_models')
        os.makedirs(exp_name+'/saved_models')

    model.save_weights(f'{exp_name}/saved_models/trained_weights.h5')
    with open(f'{exp_name}/history.pickle', 'wb') as f:
        pickle.dump(history, f, protocol=pickle.HIGHEST_PROTOCOL)

Epoch 1/300


2025-06-05 10:26:55.910869: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2025-06-05 10:26:56.137040: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100
2025-06-05 10:26:57.176758: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fa6dd2eb240 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-06-05 10:26:57.176842: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0
2025-06-05 10:26:57.182794: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1749119217.316375 1051235 device_compiler.h:186] Compiled cluster using XL

Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:31:33.190408: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_1/dropout_3/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:35:40.499919: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_2/dropout_6/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:40:08.681334: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_3/dropout_9/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:46:02.415305: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_4/dropout_12/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 71: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Restoring model weigh

2025-06-05 10:48:31.430775: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_5/dropout_15/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:51:22.560858: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_6/dropout_18/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 10:56:28.505825: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_7/dropout_21/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

2025-06-05 11:00:36.558291: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_8/dropout_24/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 60: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Restoring model weights from the end of the best epoch: 40.
Epoch 60: early stopping
Epoch 1/300


2025-06-05 11:02:41.173816: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_9/dropout_27/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 7

In [6]:
train_acc = []
val_acc = []
test_acc = []
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])
    train_acc.append(model.evaluate(ds_train, verbose=1)[1])
    val_acc.append(model.evaluate(ds_val, verbose=1)[1])
    test_acc.append(model.evaluate(ds_test, verbose=1)[1])
print(f'Average train error: {(100-np.mean(train_acc)*100):.2f} ({(np.std(train_acc)*100):.2f})')
print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f} ({(np.std(val_acc)*100):.2f})')
print(f'Average test error: {(100-np.mean(test_acc)*100):.2f} ({(np.std(test_acc)*100):.2f})')

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10
Average train error: 1.80 (1.49)
Average validation error: 6.07 (0.19)
Average test error: 6.99 (0.36)


### PMI

In [5]:
from src.pmi_estimators import train_critic_model, neural_pmi
from tqdm import tqdm

for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/separable_variational_f_js'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)

    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)

    ##############################################################
    #
    # Train PMI Model
    #
    # #############################################################

    print(f'Training PMI model...')
    ds_activity_trn = ds_train.batch(128).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)
    ds_activity_val = ds_val.batch(128).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)
    train_critic_model(ds_activity_trn, ds_activity_val, critic='separable', estimator='variational_f_js', epochs=200, save_path=f'{exp_name}/pmi_output_model')

    ##############################################################
    #
    # Compute PMI for all validation and test samples
    #
    # #############################################################

    pmi_model = tf.keras.models.load_model(f'{exp_name}/pmi_output_model')
    n_classes = 10

    print(f'Computing PMI for all validation samples and for all classes...')
    encoded_x = []
    for x, _ in ds_val.batch(128):
        encoded_x.append(int_model(x).numpy())
    encoded_x = np.concatenate(encoded_x)
    num_samples = encoded_x.shape[0]
    
    pmi_class = []
    batch_size = 1024
    for k in range(n_classes):
        num_samples = encoded_x.shape[0]
        y_k = tf.one_hot(tf.fill([num_samples], k), depth=n_classes)
        pmi_list = []
        for i in tqdm(range(0, len(encoded_x), batch_size), desc=f"Computing PMI for class {k+1}"):
            x_batch = encoded_x[i:i+batch_size]
            y_batch = y_k[i:i+batch_size]
            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator='variational_f_js')
            pmi_list += np.array(pmi).tolist()
        pmi_class.append(pmi_list)
    np.save(f'{exp_name}/pmi_output_class_val.npy', np.array(pmi_class).T)
    
    print(f'Computing PMI for all test samples and for all classes...')
    encoded_x = []
    for x, _ in ds_test.batch(128):
        encoded_x.append(int_model(x).numpy())
    encoded_x = np.concatenate(encoded_x)
    num_samples = encoded_x.shape[0]
    
    pmi_class = []
    batch_size = 1024
    for k in range(n_classes):
        num_samples = encoded_x.shape[0]
        y_k = tf.one_hot(tf.fill([num_samples], k), depth=n_classes)
        pmi_list = []
        for i in tqdm(range(0, len(encoded_x), batch_size), desc=f"Computing PMI for class {k+1}"):
            x_batch = encoded_x[i:i+batch_size]
            y_batch = y_k[i:i+batch_size]
            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator='variational_f_js')
            pmi_list += np.array(pmi).tolist()
        pmi_class.append(pmi_list)
    np.save(f'{exp_name}/pmi_output_class_test.npy', np.array(pmi_class).T)

Run: 1
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]2025-06-05 11:09:14.524985: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fc7f6886f50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-06-05 11:09:14.525038: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0
2025-06-05 11:09:14.530623: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-06-05 11:09:14.572753: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100
I0000 00:00:1749121754.662716 1996598 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:57<12:33,  3.89s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [01:07<13:01,  4.25s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 148.76it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 679.28it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 663.46it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 668.91it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 666.31it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 670.41it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 668.97it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 674.24it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 668.43it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 590.77it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 276.32it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 664.62it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 698.26it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 688.02it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 700.79it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 703.19it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 699.87it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 701.87it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 695.22it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 712.03it/s]


Run: 2
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:55<17:24,  5.36s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:57<12:54,  3.99s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [01:07<12:57,  4.23s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 153.87it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 556.11it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 664.43it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 666.22it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 676.94it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 671.95it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 674.70it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 675.01it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 675.60it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 672.45it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 276.85it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 676.93it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 691.08it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 695.71it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 687.21it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 706.39it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 708.86it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 696.31it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 703.97it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 712.01it/s]


Run: 3
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 9/200 [00:59<05:52,  1.85s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|▉         | 19/200 [01:10<11:12,  3.72s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 155.57it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 659.72it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 670.99it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 679.63it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 668.30it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 677.23it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 678.47it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 689.87it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 684.46it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 681.84it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 269.97it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 662.14it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 670.51it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 679.96it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 683.61it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 678.85it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 683.27it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 684.25it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 678.80it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 682.99it/s]


Run: 4
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:52<2:54:41, 52.67s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:56<17:47,  5.47s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:58<13:12,  4.08s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 16/200 [01:09<13:14,  4.32s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 152.41it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 644.36it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 666.29it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 676.57it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 676.54it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 675.39it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 678.67it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 677.40it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 666.08it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 656.64it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 271.13it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 654.34it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 660.70it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 659.84it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 666.50it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 664.87it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 648.38it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 674.03it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 673.29it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 670.38it/s]


Run: 5
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:52<2:52:29, 52.01s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:53<1:13:19, 22.22s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 4/200 [00:55<26:19,  8.06s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   7%|▋         | 14/200 [01:06<14:46,  4.76s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 144.89it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 651.08it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 661.89it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 680.40it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 671.36it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 675.63it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 679.33it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 676.57it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 678.43it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 680.01it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 270.15it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 658.74it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 681.00it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 670.68it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 682.29it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 679.65it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 683.14it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 684.85it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 687.47it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 685.22it/s]


Run: 6
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:52<2:53:18, 52.25s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:53<1:13:39, 22.32s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▏         | 3/200 [00:54<41:53, 12.76s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▋         | 13/200 [01:05<15:45,  5.06s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 158.58it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 662.07it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 683.93it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 676.42it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 695.02it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 691.62it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 693.26it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 692.23it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 699.10it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 701.78it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 271.51it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 670.50it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 695.90it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 696.15it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 689.17it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 695.99it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 703.42it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 700.02it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 696.95it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 699.54it/s]


Run: 7
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   3%|▎         | 6/200 [00:57<12:36,  3.90s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:58<09:55,  3.09s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   8%|▊         | 17/200 [01:09<12:29,  4.09s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 149.13it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 651.84it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 679.26it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 673.81it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 673.58it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 671.75it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 674.04it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 675.36it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 672.70it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 676.90it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 273.88it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 656.47it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 675.10it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 677.28it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 676.95it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 683.84it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 674.18it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 680.53it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 673.85it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 681.29it/s]


Run: 8
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   2%|▎         | 5/200 [00:55<17:28,  5.37s/it]  

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▎         | 7/200 [00:58<09:40,  3.01s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 8/200 [00:59<07:56,  2.48s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   4%|▍         | 9/200 [01:00<06:48,  2.14s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:  10%|▉         | 19/200 [01:12<11:28,  3.80s/it]






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 146.85it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 660.57it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 681.94it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 681.75it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 683.41it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 688.85it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 688.21it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 688.47it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 684.80it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 691.58it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 249.12it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 663.14it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 679.76it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 675.07it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 686.24it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 686.85it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 689.15it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 685.37it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 690.30it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 692.69it/s]


Run: 9
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:51<2:52:20, 51.96s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   1%|          | 2/200 [00:53<1:13:17, 22.21s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▌         | 12/200 [01:04<16:44,  5.34s/it] 






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 148.40it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 626.31it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 648.44it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 647.62it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 652.57it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 663.52it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 624.17it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 651.31it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 650.57it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 649.34it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 262.17it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 642.03it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 654.44it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 654.88it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 667.40it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 668.67it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 650.91it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 650.94it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 659.97it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 630.02it/s]


Run: 10
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/pmi/separable_variational_f_js
Training PMI model...


Epochs:   0%|          | 0/200 [00:00<?, ?it/s]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   0%|          | 1/200 [00:52<2:52:39, 52.06s/it]

INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets


INFO:tensorflow:Assets written to: ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/pmi/separable_variational_f_js/pmi_output_model/assets
Epochs:   6%|▌         | 11/200 [01:03<18:05,  5.74s/it] 






Computing PMI for all validation samples and for all classes...


Computing PMI for class 1: 100%|██████████| 9/9 [00:00<00:00, 153.14it/s]
Computing PMI for class 2: 100%|██████████| 9/9 [00:00<00:00, 662.60it/s]
Computing PMI for class 3: 100%|██████████| 9/9 [00:00<00:00, 698.43it/s]
Computing PMI for class 4: 100%|██████████| 9/9 [00:00<00:00, 679.51it/s]
Computing PMI for class 5: 100%|██████████| 9/9 [00:00<00:00, 696.18it/s]
Computing PMI for class 6: 100%|██████████| 9/9 [00:00<00:00, 693.12it/s]
Computing PMI for class 7: 100%|██████████| 9/9 [00:00<00:00, 685.95it/s]
Computing PMI for class 8: 100%|██████████| 9/9 [00:00<00:00, 687.29it/s]
Computing PMI for class 9: 100%|██████████| 9/9 [00:00<00:00, 694.29it/s]
Computing PMI for class 10: 100%|██████████| 9/9 [00:00<00:00, 686.04it/s]


Computing PMI for all test samples and for all classes...


Computing PMI for class 1: 100%|██████████| 10/10 [00:00<00:00, 272.26it/s]
Computing PMI for class 2: 100%|██████████| 10/10 [00:00<00:00, 660.37it/s]
Computing PMI for class 3: 100%|██████████| 10/10 [00:00<00:00, 694.88it/s]
Computing PMI for class 4: 100%|██████████| 10/10 [00:00<00:00, 684.06it/s]
Computing PMI for class 5: 100%|██████████| 10/10 [00:00<00:00, 697.42it/s]
Computing PMI for class 6: 100%|██████████| 10/10 [00:00<00:00, 692.52it/s]
Computing PMI for class 7: 100%|██████████| 10/10 [00:00<00:00, 695.61it/s]
Computing PMI for class 8: 100%|██████████| 10/10 [00:00<00:00, 693.19it/s]
Computing PMI for class 9: 100%|██████████| 10/10 [00:00<00:00, 689.34it/s]
Computing PMI for class 10: 100%|██████████| 10/10 [00:00<00:00, 696.94it/s]


### PSI

In [6]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)

    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)
    
    ##############################################################
    #
    # Train PSI Model
    #
    # #############################################################
    
    x_logits_list = []
    y_labels_list = []

    for x_batch, y_batch in ds_train.batch(256):
        logits = int_model(x_batch)
        labels = tf.argmax(y_batch, axis=1)
        x_logits_list.append(logits)
        y_labels_list.append(labels)

    x = tf.concat(x_logits_list, axis=0).numpy()
    y = tf.concat(y_labels_list, axis=0).numpy()
    
    print(f'Training PSI model (gaussian)...')
    psi_data = psi_gaussian_train(x, y, n_projs=500)
    np.save(f'{exp_name}/gaussian_output_model_500_projs.npy', psi_data)

    ##############################################################
    #
    # Compute PSI for all validation and test samples
    #
    # #############################################################

    psi_data = np.load(f'{exp_name}/gaussian_output_model_500_projs.npy', allow_pickle=True).item()

    print(f'Computing PSI for all validation samples...')
    x_logits_list = []

    for x_batch, y_batch in ds_val.batch(256):
        logits = int_model(x_batch)
        x_logits_list.append(logits)
    
    x = tf.concat(x_logits_list, axis=0).numpy()
    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)
    np.save(f'{exp_name}/psi_output_class_500_projs_val.npy', np.array(psi_class))

    print(f'Computing PSI for all test samples...')
    x_logits_list = []

    for x_batch, y_batch in ds_test.batch(256):
        logits = int_model(x_batch)
        x_logits_list.append(logits)
    
    x = tf.concat(x_logits_list, axis=0).numpy()
    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)
    np.save(f'{exp_name}/psi_output_class_500_projs_test.npy', np.array(psi_class))

Run: 1
Training PSI model (gaussian)...


Projections: 500it [00:01, 257.68it/s]


Computing PSI for all validation samples...


Projections: 500it [00:02, 217.71it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.64it/s]


Run: 2
Training PSI model (gaussian)...


Projections: 500it [00:01, 271.46it/s]


Computing PSI for all validation samples...


Projections: 500it [00:02, 245.95it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 234.09it/s]


Run: 3
Training PSI model (gaussian)...


Projections: 500it [00:01, 253.22it/s]


Computing PSI for all validation samples...


Projections: 500it [00:02, 245.89it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.86it/s]


Run: 4
Training PSI model (gaussian)...


Projections: 500it [00:01, 273.26it/s]


Computing PSI for all validation samples...


Projections: 500it [00:02, 246.42it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 229.72it/s]


Run: 5
Training PSI model (gaussian)...


Projections: 500it [00:01, 274.79it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 255.89it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 238.43it/s]


Run: 6
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 276.05it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 254.56it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 236.72it/s]


Run: 7
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 274.29it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 254.52it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 233.48it/s]


Run: 8
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 273.53it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 254.94it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 236.92it/s]


Run: 9
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 274.51it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 255.96it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.42it/s]


Run: 10
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_10/calibration/psi/gaussian
Training PSI model (gaussian)...


Projections: 500it [00:01, 275.50it/s]


Computing PSI for all validation samples...


Projections: 500it [00:01, 252.42it/s]


Computing PSI for all test samples...


Projections: 500it [00:02, 235.19it/s]


### PVI

In [8]:
random_runs = list(range(10))
while any(random_runs[i] == i for i in range(10)):
    np.random.shuffle(random_runs)
    
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{random_runs[run]+1}/saved_models/trained_weights.h5')
    pvi_model.save_weights(f'{exp_name}/pvi_model_weights.h5')
    
    untrained_model = create_model()
    train_pvi_null_model(ds_train, untrained_model, epochs=10, save_path=f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

Run: 1
Epoch 1/10


2025-06-05 18:45:18.029283: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_45/dropout_75/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 2
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 3
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 4
Epoch 1/10


2025-06-05 18:46:31.209015: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_57/dropout_111/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 5
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 6
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_6/calibration/pvi/training_from_scratch
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 7
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_7/calibration/pvi/training_from_scratch
Epoch 1/10


2025-06-05 18:47:43.437386: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_69/dropout_147/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 8
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_8/calibration/pvi/training_from_scratch
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 9
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_9/calibration/pvi/training_from_scratch
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 10
Making directory ../results/PI_Explainability/cnn_fashion_mnist/run_10/cal

2025-06-05 18:48:56.057107: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:1021] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_81/dropout_183/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...


In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/finetuned'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    pvi_model.compile(optimizer=AdamW(learning_rate=1e-4, weight_decay=1e-4), loss=loss_fn, metrics=['accuracy'])

    lr_scheduler = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5, verbose=1)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True, verbose=1)
    pvi_model.fit(ds_train.batch(256), validation_data=ds_val.batch(256), epochs=100, callbacks=[lr_scheduler, early_stop])
    
    pvi_model.save_weights(f'{exp_name}/pvi_model_weights.h5')
    
    untrained_model = create_model()
    untrained_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    untrained_model.save_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

In [9]:
pvi_runs = [1 if i == 7 else 7 for i in range(10)]
    
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{pvi_runs[run]+1}/saved_models/trained_weights.h5')
    pvi_model.save_weights(f'{exp_name}/pvi_model_best_weights.h5')
    
#     untrained_model = create_model()
#     train_pvi_null_model(ds_train, untrained_model, epochs=10, save_path=f'{exp_name}/pvi_null_model_weights.h5')
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    pvi_model = create_model()
    pvi_model.load_weights(f'{exp_name}/pvi_model_best_weights.h5')
    null_model = create_model()
    null_model.load_weights(f'{exp_name}/pvi_null_model_weights.h5')
    
    true_y_val = np.argmax([y for x,y in ds_val], axis=1)
    opt_temp_pvi = temp_scaling.temp_scaling_nll(pvi_model.predict(ds_val.batch(128), verbose=0), true_y_val)
    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
    opt_temp_null = temp_scaling.temp_scaling_nll(null_model.predict(ds_null.batch(128), verbose=0), true_y_val)

    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = neural_pvi_class(ds_val.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_best_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = neural_pvi_class(ds_test.batch(128), pvi_model, null_model, opt_temp_pvi, opt_temp_null)
    np.save(f'{exp_name}/pvi_class_best_test.npy', np.array(pvi_class))

Run: 1
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 2
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 3
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 4
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 5
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 6
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 7
Computing PVI for all validation samples and for all classes...
Computing PVI for all test samples and for all classes...
Run: 8
Computing PVI for all validation samples and for all classes...
Computing PVI for all test

### Ensemble PVI

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/ensemble_no_training_training_from_scratch'
    if not os.path.exists(exp_name):
        print("Making directory", exp_name)
        os.makedirs(exp_name)
        
    ##############################################################
    #
    # Train PVI Model
    #
    # #############################################################
    
    pvi_model_1 = create_model()
    pvi_model_1.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    null_model_1 = create_model()
    null_model_1.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    pvi_model_2 = create_model()
    pvi_model_2.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_model_weights.h5')
    null_model_2 = create_model()
    null_model_2.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch/pvi_null_model_weights.h5')
    
#     true_y_val = np.argmax([y for x,y in ds_val], axis=1)
#     opt_temp_pvi_1 = utils.temp_scaling_nll(pvi_model_1.predict(ds_val.batch(128), verbose=0), true_y_val)
#     opt_temp_pvi_2 = utils.temp_scaling_nll(pvi_model_2.predict(ds_val.batch(128), verbose=0), true_y_val)
#     ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))
#     opt_temp_null = utils.temp_scaling_nll(null_model_1.predict(ds_null.batch(128), verbose=0), true_y_val)
    
    ##############################################################
    #
    # Compute PVI for all training and test samples
    #
    # #############################################################
    
    print(f'Computing PVI for all validation samples and for all classes...')
    pvi_class = []
    for (x_batch, y_batch) in ds_val.batch(256):
        pvi = neural_pvi_ensemble_class([x_batch, x_batch], [pvi_model_1, pvi_model_2], [null_model_1, null_model_2])
        pvi_class += np.array(pvi).tolist()
    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))

    print(f'Computing PVI for all test samples and for all classes...')
    pvi_class = []
    for (x_batch, y_batch) in ds_test.batch(256):
        pvi = neural_pvi_ensemble_class([x_batch, x_batch], [pvi_model_1, pvi_model_2], [null_model_1, null_model_2])
        pvi_class += np.array(pvi).tolist()
    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))

### Temp Scaling

In [10]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    if not os.path.exists(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'):
        print("Making directory", f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')
        os.makedirs(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')                                  
  
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

#     if not os.path.exists(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'):
#         print("Making directory", f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')
#         os.makedirs(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration')                                  
  
    
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_weights_ets_nll.npy', opt_weights)

#     opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)
#     np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp_ece.npy', opt_temp)

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')
    
    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = model.predict(ds_val.batch(512), verbose=0)
    
    pts = temp_scaling.PTSCalibrator(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=64,
    nlayers=2,
    n_nodes=32,
    length_logits=10,
    top_k_logits=5
)

    pts.tune(logits=scores, labels=pred_y_val)
    pts.save(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')

In [11]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/separable_variational_f_js'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pmi_output_class_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                            
    np.save(f'{exp_name}/pmi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                
    np.save(f'{exp_name}/pmi_opt_temp_nll.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                
    np.save(f'{exp_name}/pmi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [12]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/psi_output_class_500_projs_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/psi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                            
    np.save(f'{exp_name}/psi_opt_temp_nll.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                            
    np.save(f'{exp_name}/psi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [14]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/pvi_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                          
    np.save(f'{exp_name}/pvi_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                          
    np.save(f'{exp_name}/pvi_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [13]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_best_val.npy')
    
    opt_temp = temp_scaling.temp_scaling_aurc(scores, pred_y_val, true_y_val)                                 
    np.save(f'{exp_name}/pvi_best_opt_temp_aurc.npy', opt_temp)
    
    opt_temp = temp_scaling.temp_scaling_nll(scores, true_y_val)                                          
    np.save(f'{exp_name}/pvi_best_opt_temp_nll.npy', opt_temp)

    opt_temp = temp_scaling.temp_scaling_ece(scores, pred_y_val, true_y_val, 15)                                          
    np.save(f'{exp_name}/pvi_best_opt_temp_ece.npy', opt_temp)

Run: 1
Run: 2
Run: 3
Run: 4
Run: 5
Run: 6
Run: 7
Run: 8
Run: 9
Run: 10


In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'{exp_name}/pvi_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'{exp_name}/pvi_opt_weights_ets_nll.npy', opt_weights)

In [None]:
for run in range(10):
    print(f'Run: {run+1}')
    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow
    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_val = np.argmax(model.predict(ds_val.batch(512), verbose=1), axis=1)
    scores = np.load(f'{exp_name}/pvi_class_val.npy')
    
    opt_temp, opt_weights = temp_scaling.ensemble_temp_scaling_nll(scores, true_y_val, num_classes)
    np.save(f'{exp_name}/pvi_opt_temp_ets_nll.npy', opt_temp)
    np.save(f'{exp_name}/pvi_opt_weights_ets_nll.npy', opt_weights)
    
    pts = temp_scaling.PTSCalibrator(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=64,
    nlayers=2,
    n_nodes=128,
    length_logits=10,
    top_k_logits=5
)

    pts.tune(logits=scores, labels=pred_y_val)
    pts.save(path=f'{exp_name}/calibration_model/')

### Failure Detection

In [17]:
def get_confidence_scores(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name):
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'
    metric = conf_method.split('_')[-1] if 'temp_scaling' in conf_method else None
    method_key = conf_method.replace(f'_temp_scaling_{metric}', '') if metric else conf_method

    if method_key == 'softmax':
        if metric:
            opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
            return methods.max_softmax_prob(model, ds_test, opt_temp)
        else:
            return methods.max_softmax_prob(model, ds_test)

    elif method_key in ['pmi', 'psi', 'pvi', 'pvi_best']:
        if method_key == 'pmi':
            exp_path = f'{base_path}/pmi/separable_variational_f_js'
            class_file = 'pmi_output_class_test.npy'
        elif method_key == 'psi':
            exp_path = f'{base_path}/psi/gaussian'
            class_file = 'psi_output_class_500_projs_test.npy'
        elif method_key == 'pvi':
            exp_path = f'{base_path}/pvi/training_from_scratch'
            class_file = 'pvi_class_test.npy'
        elif method_key == 'pvi_best':
            exp_path = f'{base_path}/pvi/training_from_scratch'
            class_file = 'pvi_class_best_test.npy'

        opt_temp = np.load(f'{exp_path}/{method_key}_opt_temp_{metric}.npy')
        scores_class = np.load(f'{exp_path}/{class_file}')
        scores_class = np.array([utils.softmax(x / opt_temp) for x in scores_class])
        return np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])

    elif method_key == 'softmax_margin':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.softmax_margin(model, ds_test, opt_temp)

    elif method_key == 'max_logits':
        return methods.max_logits(model, ds_test)

    elif method_key == 'logits_margin':
        return methods.logits_margin(model, ds_test)

    elif method_key == 'negative_entropy':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.negative_entropy(model, ds_test, opt_temp)

    elif method_key == 'negative_gini':
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        return methods.negative_gini(model, ds_test, opt_temp)

    elif method_key == 'isotonic_regression':
        return methods.isotonic_reg(model, ds_val, ds_test, true_y_val)

    else:
        raise ValueError(f"Unknown confidence method: {conf_method}")


def evaluate_failure_pred(ds_test, true_y_test, conf_method, n_runs=10):
    results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": [],
        "naurc": []
    }

    for run in range(n_runs):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        scores_test = get_confidence_scores(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name)

        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        results["naurc"].append(metrics.compute_naurc(scores_test, pred_y_test, true_y_test))

    return results

In [18]:
methods_list = ['softmax_temp_scaling_aurc','pmi_temp_scaling_aurc','psi_temp_scaling_aurc','pvi_temp_scaling_aurc',
                'softmax_margin_temp_scaling_aurc', 'max_logits', 'logits_margin', 'negative_entropy_temp_scaling_aurc',
                'negative_gini_temp_scaling_aurc']
for method in methods_list:
    print(f'Method: {method}')
    results = evaluate_failure_pred(ds_test, true_y_test, conf_method=f'{method}', n_runs=10)
    print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
    print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
    print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
    print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
    print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
    print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")
    print(f"NAURC           : {utils.format_ci(results['naurc'], scale=1000)}")

Method: softmax_temp_scaling_aurc
AUROC           : 92.06 (0.45)
AUPRC (success) : 99.30 (0.10)
AUPRC (error)   : 43.61 (1.12)
FPR at 95% TPR  : 47.47 (1.50)
AURC            : 9.14 (1.06)
EAURC           : 6.64 (0.97)
NAURC           : 98.34 (13.60)
Method: pmi_temp_scaling_aurc
AUROC           : 58.65 (0.71)
AUPRC (success) : 94.16 (0.15)
AUPRC (error)   : 15.71 (1.26)
FPR at 95% TPR  : N/A
AURC            : 58.46 (1.53)
EAURC           : 55.95 (1.34)
NAURC           : 830.99 (13.78)
Method: psi_temp_scaling_aurc
AUROC           : 78.46 (0.40)
AUPRC (success) : 97.04 (0.14)
AUPRC (error)   : 29.17 (0.99)
FPR at 95% TPR  : N/A
AURC            : 30.50 (1.45)
EAURC           : 27.99 (1.26)
NAURC           : 415.27 (7.81)
Method: pvi_temp_scaling_aurc
AUROC           : 91.02 (1.81)
AUPRC (success) : 99.18 (0.33)
AUPRC (error)   : 46.06 (3.89)
FPR at 95% TPR  : 46.99 (3.99)
AURC            : 10.26 (3.06)
EAURC           : 7.76 (3.07)
NAURC           : 115.52 (46.30)
Method: softmax_margin_

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        logits = model.predict(ds_test.batch(512), verbose=0)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')
        
        scores_class = apply_ets(logits,opt_temp,opt_weights,num_classes)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'PVI ETS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
        opt_temp = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_weights_ets_nll.npy')
        
        scores_class = apply_ets(pvi,opt_temp,opt_weights,num_classes)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
method = 'softmax PTS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        logits = model.predict(ds_test.batch(512), verbose=0)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
        opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')
        
        pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
        pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
        scores_class = pts_loaded.calibrate(logits)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

In [None]:
method = 'PVI PTS'
print(f'Method: {method}')
results = {
        "auroc": [],
        "fpr_at_95tpr": [],
        "auprc_success": [],
        "auprc_error": [],
        "aurc": [],
        "eaurc": []
    }
for run in range(10):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        
        base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
        pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
        
        pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
        pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
        scores_class = pts_loaded.calibrate(pvi)
        scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
        
        results["auroc"].append(metrics.compute_auroc(scores_test, pred_y_test, true_y_test))
        results["auprc_success"].append(metrics.compute_auprc_success(scores_test, pred_y_test, true_y_test))
        results["auprc_error"].append(metrics.compute_auprc_error(scores_test, pred_y_test, true_y_test))
        results["fpr_at_95tpr"].append(metrics.compute_fpr_at_95tpr(scores_test, pred_y_test, true_y_test))
        results["aurc"].append(metrics.compute_aurc(scores_test, pred_y_test, true_y_test))
        results["eaurc"].append(metrics.compute_eaurc(scores_test, pred_y_test, true_y_test))
        
print(f"AUROC           : {utils.format_ci(results['auroc'], scale=100)}")
print(f"AUPRC (success) : {utils.format_ci(results['auprc_success'], scale=100)}")
print(f"AUPRC (error)   : {utils.format_ci(results['auprc_error'], scale=100)}")
print(f"FPR at 95% TPR  : {utils.format_ci(results['fpr_at_95tpr'], scale=100)}")
print(f"AURC            : {utils.format_ci(results['aurc'], scale=1000)}")
print(f"EAURC           : {utils.format_ci(results['eaurc'], scale=1000)}")

### Calibration

In [19]:
def get_scores_for_calibration(conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name):
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'

    def softmax_scaled(scores, temp=1.0):
        return np.array([utils.softmax(x / temp) for x in scores])

    if conf_method == 'softmax':
        scores_class = methods.softmax_prob(model, ds_test)
        scores_test = methods.max_softmax_prob(model, ds_test)
        return scores_class, scores_test

    if conf_method.startswith('softmax_temp_scaling'):
        metric = conf_method.split('_')[-1]
        opt_temp = np.load(f'{base_path}/softmax_opt_temp_{metric}.npy')
        scores_class = methods.softmax_prob(model, ds_test, opt_temp)
        scores_test = methods.max_softmax_prob(model, ds_test, opt_temp)
        return scores_class, scores_test

    if conf_method in ['pmi', 'psi', 'pvi', 'pvi_best']:
        method = conf_method
        metric = None
        temp = 1.0
    elif conf_method.startswith(('pmi_temp_scaling', 'psi_temp_scaling', 'pvi_temp_scaling', 'pvi_best_temp_scaling')):
        parts = conf_method.split('_')
        method = '_'.join(parts[:2]) if 'best' in parts else parts[0]
        metric = parts[-1]
        method_dir = {
            'pmi': 'pmi/separable_variational_f_js',
            'psi': 'psi/gaussian',
            'pvi': 'pvi/training_from_scratch',
            'pvi_best': 'pvi/training_from_scratch'
        }[method]
        temp = float(np.load(f'{base_path}/{method_dir}/{method}_opt_temp_{metric}.npy'))
    else:
        raise ValueError(f"Unknown confidence method: {conf_method}")

    method_paths = {
        'pmi': (f'{base_path}/pmi/separable_variational_f_js', 'pmi_output_class_test.npy'),
        'psi': (f'{base_path}/psi/gaussian', 'psi_output_class_500_projs_test.npy'),
        'pvi': (f'{base_path}/pvi/training_from_scratch', 'pvi_class_test.npy'),
        'pvi_best': (f'{base_path}/pvi/training_from_scratch', 'pvi_class_best_test.npy'),
    }

    method_path, class_file = method_paths[method]
    scores_class = np.load(f'{method_path}/{class_file}')
    scores_class = softmax_scaled(scores_class, temp)
    scores_test = np.array([score[pred] for score, pred in zip(scores_class, pred_y_test)])
    return scores_class, scores_test

def evaluate_calibration(ds_test, true_y_test, conf_method, n_runs=10):
    results = {
        "ece": [],
        "cc_ece": [],
        "mce": [],
        "ace": [],
        "sce": [],
        "ada_ece": [],
        "ada_sce": [],
        "cc_ada_ece": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "cw_ece": [],
        "cw_sce": [],
        "cw_ada_ece": [],
        "cw_ada_sce": [],
        "cw_ada_ece_rms": [],
        "cw_ada_sce_rms": [],
        "nll": [],
        "bs": [],
        "sharpness": [],
    }

    for run in range(n_runs):
        tf.keras.utils.set_random_seed(run + 10)
        model = create_model()
        model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

        pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
        scores_class, scores_test = get_scores_for_calibration(
            conf_method, model, ds_test, pred_y_test, run, model_name, dataset_name
        )

        results["ece"].append(metrics.compute_ece(scores_test, pred_y_test, true_y_test, 15))
        results["cc_ece"].append(metrics.compute_cc_ece(scores_test, pred_y_test, true_y_test, 15))
        results["mce"].append(metrics.compute_mce(scores_test, pred_y_test, true_y_test, 15))
        results["ace"].append(metrics.compute_ace(scores_test, pred_y_test, true_y_test, 15))
        results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
        results["ada_ece"].append(metrics.compute_adaece(scores_test, pred_y_test, true_y_test, 15))
        results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
        results["cc_ada_ece"].append(metrics.compute_cc_adaece(scores_test, pred_y_test, true_y_test, 15))
        results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
        results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
        results["cw_ece"].append(metrics.compute_cw_ece(scores_class, true_y_test, num_classes, 15))
        results["cw_sce"].append(metrics.compute_cw_sce(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_ece"].append(metrics.compute_cw_adaece(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_sce"].append(metrics.compute_cw_adasce(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_ece_rms"].append(metrics.compute_cw_adaece_rms(scores_class, true_y_test, num_classes, 15))
        results["cw_ada_sce_rms"].append(metrics.compute_cw_adaece_rms(scores_class, true_y_test, num_classes, 15))
        results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
        results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        results["sharpness"].append(metrics.compute_sharpness(scores_class))

    return results

In [20]:
methods_list = ['softmax','pmi','psi','pvi','pvi_best',
                'softmax_temp_scaling_nll','pmi_temp_scaling_nll','psi_temp_scaling_nll','pvi_temp_scaling_nll','pvi_best_temp_scaling_nll']
for method in methods_list:
    print(f'Method: {method}')
    results = evaluate_calibration(ds_test, true_y_test, conf_method=f'{method}', n_runs=10)
    print(f"ECE:            {utils.format_ci(results['ece'], scale=100)}")
    print(f"CC-ECE:         {utils.format_ci(results['cc_ece'], scale=100)}")
    print(f"MCE:            {utils.format_ci(results['mce'], scale=100)}")
    print(f"ACE:            {utils.format_ci(results['ace'], scale=100)}")
    print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
    print(f"Ada-ECE:        {utils.format_ci(results['ada_ece'], scale=100)}")
    print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
    print(f"CC-Ada-ECE:     {utils.format_ci(results['cc_ada_ece'], scale=100)}")
    print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
    print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
    print(f"CW-ECE:         {utils.format_ci(results['cw_ece'], scale=100)}")
    print(f"CW-SCE:         {utils.format_ci(results['cw_sce'], scale=100)}")
    print(f"CW-Ada-ECE:     {utils.format_ci(results['cw_ada_ece'], scale=100)}")
    print(f"CW-Ada-SCE:     {utils.format_ci(results['cw_ada_sce'], scale=100)}")
    print(f"CW-Ada-ECE-RMS: {utils.format_ci(results['cw_ada_ece_rms'], scale=100)}")
    print(f"CW-Ada-SCE-RMS: {utils.format_ci(results['cw_ada_sce_rms'], scale=100)}")
    print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
    print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")
    print(f"Sharpness:      {utils.format_ci(results['sharpness'], scale=100)}")

Method: softmax
ECE:            3.06 (0.56)
CC-ECE:         3.43 (0.48)
MCE:            1.54 (0.42)
ACE:            11.16 (2.42)
SCE:            0.69 (0.10)
Ada-ECE:        3.03 (0.57)
Ada-SCE:        0.52 (0.06)
CC-Ada-ECE:     3.19 (0.47)
CC-Ada-SCE:     0.70 (0.09)
CC-Ada-SCE-RMS: 5.24 (0.32)
CW-ECE:         0.69 (0.10)
CW-SCE:         0.69 (0.10)
CW-Ada-ECE:     0.42 (0.04)
CW-Ada-SCE:     0.42 (0.04)
CW-Ada-ECE-RMS: 1.01 (0.09)
CW-Ada-SCE-RMS: 1.01 (0.09)
NLL:            23.63 (1.37)
Brier Score:    10.73 (0.21)
Sharpness:      10.21 (2.13)
Method: pmi
ECE:            3.05 (0.55)
CC-ECE:         3.67 (0.45)
MCE:            1.49 (0.42)
ACE:            13.29 (2.95)
SCE:            0.73 (0.09)
Ada-ECE:        2.65 (0.51)
Ada-SCE:        0.56 (0.05)
CC-Ada-ECE:     3.31 (0.40)
CC-Ada-SCE:     0.71 (0.10)
CC-Ada-SCE-RMS: 5.48 (0.31)
CW-ECE:         0.73 (0.09)
CW-SCE:         0.73 (0.09)
CW-Ada-ECE:     0.49 (0.06)
CW-Ada-SCE:     0.49 (0.06)
CW-Ada-ECE-RMS: 1.16 (0.14)
CW-Ada-SCE-RMS:

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)

    logits = model.predict(ds_test.batch(512), verbose=0)

    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    opt_temp = np.load(f'{base_path}/softmax_opt_temp_ets_nll.npy')
    opt_weights = np.load(f'{base_path}/softmax_opt_weights_ets_nll.npy')

    scores_class = apply_ets(logits,opt_temp,opt_weights,num_classes)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
def apply_ets(logits, opt_temp, opt_weights, n_class):
    p1 = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    scaled_logits = logits / opt_temp
    p0 = np.exp(scaled_logits) / np.sum(np.exp(scaled_logits), axis=1, keepdims=True)
    p2 = np.ones_like(p0) / n_class
    w = opt_weights / np.sum(opt_weights)  # just in case
    calibrated_probs = w[0] * p0 + w[1] * p1 + w[2] * p2
    return calibrated_probs


method = 'PVI ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')
    opt_temp = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_temp_ets_nll.npy')
    opt_weights = np.load(f'{base_path}/pvi/training_from_scratch/pvi_opt_weights_ets_nll.npy')

    scores_class = apply_ets(pvi,opt_temp,opt_weights,num_classes)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
method = 'softmax ETS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)

    logits = model.predict(ds_test.batch(512), verbose=0)

    pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
    pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
    scores_class = pts_loaded.calibrate(logits)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")

In [None]:
method = 'PVI PTS'
print(f'Method: {method}')
results = {
        "sce": [],
        "ada_sce": [],
        "cc_ada_sce": [],
        "cc_ada_sce_rms": [],
        "nll": [],
        "bs": [],
    }
for run in range(10):
    tf.keras.utils.set_random_seed(run + 10)
    model = create_model()
    model.load_weights(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_weights.h5')

    pred_y_test = np.argmax(model.predict(ds_test.batch(256), verbose=0), axis=1)
    base_path = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/'
    pvi =  np.load(f'{base_path}/pvi/training_from_scratch/pvi_class_test.npy')

    pts_loaded = temp_scaling.PTSCalibrator(
        epochs=0,
        lr=1e-3,
        weight_decay=1e-4,
        batch_size=64,
        nlayers=2,
        n_nodes=32,
        length_logits=10,
        top_k_logits=5
    )
    pts_loaded.load(path=f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/calibration_model/')
    scores_class = pts_loaded.calibrate(pvi)

    results["sce"].append(metrics.compute_sce(scores_class, true_y_test, num_classes, 15))
    results["ada_sce"].append(metrics.compute_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce"].append(metrics.compute_cc_adasce(scores_class, true_y_test, num_classes, 15))
    results["cc_ada_sce_rms"].append(metrics.compute_cc_adasce_rms(scores_class, true_y_test, num_classes, 15))
    results["nll"].append(metrics.compute_nll(scores_class, true_y_test, num_classes))
    results["bs"].append(metrics.compute_brier_score(scores_class, true_y_test, num_classes))
        
print(f"SCE:            {utils.format_ci(results['sce'], scale=100)}")
print(f"Ada-SCE:        {utils.format_ci(results['ada_sce'], scale=100)}")
print(f"CC-Ada-SCE:     {utils.format_ci(results['cc_ada_sce'], scale=100)}")
print(f"CC-Ada-SCE-RMS: {utils.format_ci(results['cc_ada_sce_rms'], scale=100)}")
print(f"NLL:            {utils.format_ci(results['nll'], scale=100)}")
print(f"Brier Score:    {utils.format_ci(results['bs'], scale=100)}")