In [27]:
import os
from glob import glob

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.datasets import CIFAR10

from evaluate.calibration import nll, ece_loss
from evaluate.divergence import kld, disagreement
from evaluate.similarity import cka, gram_linear, gram_rbf

In [48]:
EXP_PATH = "./empirical/swa_20220922_185019"
TRAIN_SIZE = 8000

logit_pts = glob(os.path.join(EXP_PATH, f"logits_size={TRAIN_SIZE}_*.pt"))
print(f"{len(logit_pts)} predictions exist.")

5 predictions exist.


In [49]:
all_klds = np.zeros((len(logit_pts), len(logit_pts)))

for i in range(len(logit_pts)):
    logit_p = torch.load(logit_pts[i], map_location='cpu')
    prob_p  = F.softmax(logit_p, dim=-1).numpy()
    for j in range(len(logit_pts)):
        logit_q = torch.load(logit_pts[j], map_location='cpu')
        prob_q  = F.softmax(logit_q, dim=-1).numpy()
        all_klds[i, j] = kld(prob_p, prob_q).mean()

all_klds.mean() * (len(logit_pts)) / (len(logit_pts)-1)

1.050155845284462

In [50]:
all_disagrees = np.zeros((len(logit_pts), len(logit_pts)))

for i in range(len(logit_pts)):
    logit_p = torch.load(logit_pts[i], map_location='cpu')
    for j in range(len(logit_pts)):
        logit_q = torch.load(logit_pts[j], map_location='cpu')
        all_disagrees[i, j] = disagreement(logit_p, logit_q)

all_disagrees.mean() * (len(logit_pts)) / (len(logit_pts)-1)

0.45255999565124516

In [51]:
all_accs = np.zeros((len(logit_pts),))
test_set = CIFAR10("/opt/datasets/cifar10", train=False)
all_targets = np.asarray(test_set.targets)

for i in range(len(logit_pts)):
    logit_p = torch.load(logit_pts[i], map_location='cpu')
    pred_p  = torch.argmax(logit_p, dim=-1).numpy()
    acc = (all_targets == pred_p).mean()

    all_accs[i] = acc

all_accs.mean()

0.6235799999999999

In [52]:
all_nlls = np.zeros((len(logit_pts),))

for i in range(len(logit_pts)):
    logit_p = torch.load(logit_pts[i], map_location='cpu')
    prob_p  = F.softmax(logit_p, dim=1).numpy()
    nll_p = nll(prob_p, all_targets)

    all_nlls[i] = nll_p

all_nlls.mean()

1.162368929386139

In [None]:
all_eces = np.zeros((len(logit_pts),))

for i in range(len(logit_pts)):
    logit_p = torch.load(logit_pts[i], map_location='cpu')
    prob_p  = F.softmax(logit_p, dim=1).numpy()
    nll_p = nll(prob_p, all_targets)

    all_nlls[i] = nll_p

all_nlls.mean()