In [1]:
import torch
from torchvision import models
from torchvision import datasets, transforms
from datasets import Split_Dataset
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

test_dataset = datasets.ImageFolder('/gpfs/u/locker/200/CADS/datasets/ImageNet/val', transform=val_transforms)

val_dataset = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)

test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=True,
            num_workers=20, pin_memory=True,
        )
val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

In [2]:
def load_3_models(list_ckpts):
    model1 = models.resnet50().cuda()
    model2 = models.resnet50().cuda()
    model3 = models.resnet50().cuda()
    sd = torch.load(f"./dist_models/{list_ckpts[0]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model1.load_state_dict(ckpt)
    model1.eval()

    sd = torch.load(f"./dist_models/{list_ckpts[1]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model2.load_state_dict(ckpt)
    model2.eval()

    sd = torch.load(f"./dist_models/{list_ckpts[2]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model3.load_state_dict(ckpt)
    model3.eval()

    return model1, model2, model3

In [3]:
import torch.nn.functional as F

class JSD(torch.nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = torch.nn.KLDivLoss(reduction='sum', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p = F.log_softmax(p, dim=-1)
        q = F.log_softmax(q, dim=-1)
        
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))

class KLD(torch.nn.Module):
    def __init__(self):
        super(KLD, self).__init__()
        self.kl = torch.nn.KLDivLoss(reduction='sum', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p = F.log_softmax(p, dim=-1)
        q = F.log_softmax(q, dim=-1)
        return self.kl(p,q)

kl_div = KLD()
js_div = JSD()

class _ECELoss(torch.nn.Module):

    def __init__(self, n_bins=20):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, softmaxes, labels):
#         softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=softmaxes.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

nll_criterion = torch.nn.CrossEntropyLoss().cuda()
ece_criterion = _ECELoss().cuda()

def compute_pair_consensus(pair_preds):
    agree = (pair_preds[0] == pair_preds[1])
    agree_correct = agree & (pair_preds[0] == target)
    agree_wrong = agree & (pair_preds[0] != target)
    disagree = (pair_preds[0] != pair_preds[1])
    disagree_both_wrong = disagree & (pair_preds[0] != target) & (pair_preds[1] != target)
    disagree_one_correct = disagree & (pair_preds[0] != target) & (pair_preds[1] == target) 
    disagree_one_correct2 = disagree & (pair_preds[1] != target) & (pair_preds[0] == target) 
    return agree.sum(), disagree.sum(), agree_correct.sum(), agree_wrong.sum(), disagree_both_wrong.sum(), disagree_one_correct.sum()+disagree_one_correct2.sum()


In [4]:
ckpts = ['ft_baseR_cos_lr0.003_bs256', 'ft_eqR_cos_lr0.003_bs256', 'ft_inv_cos_lr0.003_bs256']

In [5]:
model1, model2, model3 = load_3_models(ckpts)

w_acc = 0
n_acc = 0
ag_sum = 0
ag_c_sum = 0
ag_w_sum = 0
dag_sum = 0
dag_c_sum = 0
dag_w_sum = 0
avg_std_logits = 0.
avg_std = 0.

kld_sum = 0.
# js_div = 0.

ece_ensem, ece1, ece2, ece3 = 0., 0., 0., 0.
nll_ensem, nll1, nll2, nll3 = 0., 0., 0., 0.

pairs = ([0,1], [0,2], [1,2])
targets = []
for it, (img,target) in enumerate(test_loader):
    target = target.cuda(non_blocking=True)
    img = img.cuda(non_blocking=True)
    with torch.no_grad():
        output1 = model1(img)
        output2 = model2(img)
        output3 = model3(img)
        preds = torch.stack([output1,output2,output3])
        avg_std_logits += torch.std(preds, dim=0).mean(dim=-1).sum() # std over members, mean over classes, sum over samples (mean taken later))
        avg_std += torch.std(preds.softmax(-1), dim=0).mean(dim=-1).sum() # std over members, mean over classes, sum over samples (mean taken later))
        _, all_preds = preds.max(-1)
        ag_p, dag_p, ag_c_p, ag_w_p, dag_w_p, dag_c_p = 0, 0, 0, 0, 0, 0
        kld = 0.
        for p in pairs:
            ag, dag, ag_c, ag_w, dag_w, dag_c = compute_pair_consensus(all_preds[p,:])
            ag_p += ag
            dag_p += dag
            ag_c_p += ag_c
            ag_w_p += ag_w
            dag_c_p += dag_c
            dag_w_p += dag_w
            kld += kl_div(preds[p[0]], preds[p[1]])
        
        ag_sum += ag_p/len(pairs)
        dag_sum += dag_p/len(pairs)
        ag_c_sum += ag_c_p/len(pairs)
        ag_w_sum += ag_w_p/len(pairs)
        dag_c_sum += dag_c_p/len(pairs)
        dag_w_sum += dag_w_p/len(pairs)
        kld_sum += kld/len(pairs)
        
        label_matrix = (all_preds == target).float().T
        logit = label_matrix.T.unsqueeze(2).repeat(1,1,1000) * preds.softmax(dim=-1)
        weighted_ensem = logit.sum(dim=0)
        naive_ensem = preds.softmax(dim=-1).mean(dim=0)
        _, w_ensem_pred = weighted_ensem.max(-1)
        _, n_ensem_pred = naive_ensem.max(-1)
        w_acc += (w_ensem_pred == target).sum()
        n_acc += (n_ensem_pred == target).sum()
        
        ece1 += ece_criterion(output1.softmax(-1), target)
        
    targets.append(label_matrix)
print(f"UB: {w_acc/len(test_dataset)} | ensem: {n_acc/len(test_dataset)}")
print(f"agree: {ag_sum/len(test_dataset)} | disagree: {dag_sum/len(test_dataset)}") 
print(f"agree_correct: {ag_c_sum/len(test_dataset)} | agree_wrong: {ag_w_sum/len(test_dataset)}") 
print(f"disagree_1correct: {dag_c_sum/len(test_dataset)} | disagree_2wrong: {dag_w_sum/len(test_dataset)}") 
print(f"Ensemble Variance Logits: {avg_std_logits/len(test_dataset)}") 
print(f"Ensemble Variance: {avg_std/len(test_dataset)}") 
print(f"KL div: {kld_sum/len(test_dataset)}") 

UB: 0.8425399661064148 | ensem: 0.7884399890899658
agree: 0.826686680316925 | disagree: 0.17331331968307495
agree_correct: 0.7130333185195923 | agree_wrong: 0.11365331709384918
disagree_1correct: 0.10659998655319214 | disagree_2wrong: 0.06671332567930222
Ensemble Variance Logits: 0.8963876366615295
Ensemble Variance: 0.00029170859488658607
KL div: 0.34884077310562134
