In [9]:
import pickle

import torch

from utils_eenn_avcs import *

import robustness_metrics as rm

2023-12-27 11:00:18.729814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open(f'ImageNet.p', 'rb') as f:
    data = pickle.load(f)
logits, targets, ARGS = data

In [18]:
L, N, C = logits.shape

In [7]:
probs = torch.softmax(logits, dim=2)
preds = get_preds_per_exit(probs)
acc = get_acc_per_exit(preds, targets)

In [8]:
acc

[tensor(0.5663),
 tensor(0.6514),
 tensor(0.6842),
 tensor(0.6977),
 tensor(0.7134)]

In [10]:
eces = []

L = len(acc)
for l in range(L):
    ece = rm.metrics.ExpectedCalibrationError(num_bins=15)
    ece.add_batch(probs[l, :, :].numpy(), label=targets.numpy())
    eces.append(ece.result()['ece'])

2023-12-27 11:00:34.900323: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22036 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6




In [11]:
eces

[0.013025594875216484,
 0.010537293739616871,
 0.009827039204537868,
 0.008983026258647442,
 0.021733099594712257]

## 1) Temperaturce scaling

In [23]:
indices = torch.randperm(N)

# Split indices into two groups
indices1, indices2 = indices[:int(N / 2)], indices[int(N / 2):]

logits_val = logits.index_select(1, indices1)
logits_test = logits.index_select(1, indices2)

targets_val = targets.index_select(0, indices1)
targets_test = targets.index_select(0, indices2)

In [28]:
T_arr = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 1., 1.1, 1.25, 1.5, 1.75, 2., 2.25, 2.5, 5., 10.]

ECE_dict = {}
for T in T_arr:
    logits_temper = logits_val / T
    probs_temp = torch.softmax(logits_temper, dim=2)
    eces = []
    for l in range(L):
        ece = rm.metrics.ExpectedCalibrationError(num_bins=15)
        ece.add_batch(probs_temp[l, :, :].numpy(), label=targets_val.numpy())
        eces.append(ece.result()['ece'])
    ECE_dict[T] = eces

In [29]:
T_star = []
for l in range(L):
    ECE_min, T_min = 1e10, None
    for T in ECE_dict.keys():
        if ECE_dict[T][l] < ECE_min:
            ECE_min = ECE_dict[T][l]
            T_min = T
    T_star.append(T_min)

T_star = torch.tensor(T_star)

In [30]:
T_star

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.1000])

##### MSDNet for ImageNet seems well calibrated already, optimal temperature at most extis is 1

## 2) Credible sets

In [12]:
alpha = 0.05
L, N, C = probs.shape
probs_num = probs.numpy()

cred_sets = []
for n in range(N):
    cred_sets_n = []
    for l in range(L):
        sorted_classes = np.argsort(probs_num[l, n])[::-1]
        cum_sum = np.cumsum(probs_num[l, n, sorted_classes])
        num_classes_to_include = np.where(cum_sum >= 1 - alpha)[0][0] + 1
        cred_sets_n.append(sorted_classes[:num_classes_to_include].tolist())
    cred_sets.append(cred_sets_n)

In [13]:
cred_sets_intersect = []
for i in range(N):
    cred_sets_intersect.append(running_intersection_classification(cred_sets[i]))

In [14]:
sizes_cred = []
coverage_cred = []
for i in range(N):
    sizes_cred.append([len(cred_sets[i][l]) for l in range(L)])
    coverage_cred.append([targets[i] in cred_sets[i][l] for l in range(L)])

consistency_cred = np.array([consistency_classifciation(cred_sets[i]) for i in range(N)])
sizes_cred = np.array(sizes_cred)
coverage_cred = np.array(coverage_cred)

sizes_cred_intersect = []
coverage_cred_intersect = []
for i in range(N):
    sizes_cred_intersect.append([len(cred_sets_intersect[i][l]) for l in range(L)])
    coverage_cred_intersect.append([targets[i] in cred_sets_intersect[i][l] for l in range(L)])

consistency_cred_intersect = np.array([consistency_classifciation(cred_sets_intersect[i]) for i in range(N)])
sizes_cred_intersect = np.array(sizes_cred_intersect)
coverage_cred_intersect = np.array(coverage_cred_intersect)

In [15]:
sizes_cred.mean(axis=0), sizes_cred.std(axis=0), sizes_cred_intersect.mean(axis=0), sizes_cred_intersect.std(axis=0)

(array([43.6209 , 27.93284, 22.82906, 19.85656, 18.32186]),
 array([66.71544037, 53.34832452, 48.2882818 , 43.03334132, 42.71981351]),
 array([43.6209 , 18.6902 , 11.92742,  9.168  ,  7.60886]),
 array([66.71544037, 34.72678136, 24.40029574, 19.33324536, 16.71180749]))

In [16]:
coverage_cred.mean(axis=0), coverage_cred_intersect.mean(axis=0)

(array([0.9495 , 0.95552, 0.95732, 0.95792, 0.95582]),
 array([0.9495 , 0.93166, 0.9211 , 0.91474, 0.90846]))

In [17]:
consistency_cred.mean(axis=0), consistency_cred.std(axis=0), consistency_cred_intersect.mean(axis=0), consistency_cred_intersect.std(axis=0)

(array([1.        , 0.83629903, 0.77660886, 0.73972643, 0.7565029 ]),
 array([0.        , 0.21893355, 0.26666782, 0.28745661, 0.29318572]),
 array([1., 1., 1., 1., 1.]),
 array([0., 0., 0., 0., 0.]))