In [5]:
import torch
from torchvision import models
from torchvision import datasets, transforms
from datasets import Split_Dataset
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np
import torch.nn.functional as F
from torchvision.datasets import ImageFolder, CIFAR10, CIFAR100
from datasets_v08 import Flowers102

c100_norm = transforms.Normalize([0.50707516,  0.48654887,  0.44091784], [0.26733429,  0.25643846,  0.27615047])
flowers_norm = transforms.Normalize([0.5153, 0.4172, 0.3444], [0.2981, 0.2516, 0.2915])

num_classes = 100

c100_val_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            c100_norm
        ])

flowers_val_transforms = transforms.Compose([
        transforms.Resize(224+32),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        flowers_norm
    ])

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

test_dataset = datasets.ImageFolder('/gpfs/u/locker/200/CADS/datasets/ImageNet/val', transform=val_transforms)

val_dataset = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)

test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )
val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )
c100_dataset = CIFAR100("/gpfs/u/home/BNSS/BNSSlhch/scratch/datasets/", train=False, transform=c100_val_transforms,download=False)
flowers_dataset = Flowers102("/gpfs/u/home/BNSS/BNSSlhch/scratch/datasets/", split='test', transform=flowers_val_transforms)

In [6]:
c100_loader = torch.utils.data.DataLoader(
            c100_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )
flowers_loader = torch.utils.data.DataLoader(
            flowers_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

In [7]:

def load_1_model(ckpt_path, full_path=False, num_classes=1000):
    model1 = models.resnet50(num_classes=num_classes).cuda()
    if not full_path:
        sd = torch.load(f"./dist_models/{ckpt_path}/checkpoint_best.pth", map_location="cpu")
    else:
        sd = torch.load(ckpt_path, map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model1.load_state_dict(ckpt)
    print(f"loaded {ckpt_path}")
    model1.eval()
    return model1

def rollout_loader(model, loader):
    targets = []
    outputs = []
    for it, (img, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        img = img.cuda(non_blocking=True)
        with torch.no_grad():
            output1 = model(img)
#             ece_1 = ece_criterion(output1.softmax(-1), target)
            targets.append(target)
            outputs.append(output1)
    return torch.cat(outputs), torch.cat(targets)

import torch.nn.functional as F
import inspect
from netcal.metrics import ECE

cecriterion = torch.nn.CrossEntropyLoss().cuda()
nll_criterion = torch.nn.CrossEntropyLoss().cuda()
# ece_criterion = _ECELoss().cuda()
ece_netcal = ECE(15)

def get_metrics(outs, tars, names, printing=True, input_softmax=False, num_classes=1000):

    for out, tar,name in zip(outs,tars,names):
        correct_per_class = torch.zeros(num_classes).to(tar.device)
        total_per_class = torch.zeros(num_classes).to(tar.device)

        if not input_softmax:
            out = out.softmax(-1)
        ece1 = ece_netcal.measure(out.cpu().numpy(), tar.cpu().numpy())
#         ece2 = ece_criterion(out, tar)
        loss = F.nll_loss(torch.log(out), tar)
        _, pred = out.max(-1)
        correct_vec = (pred == tar)
        ind_per_class = (tar.unsqueeze(1) == torch.arange(num_classes).to(tar.device)) # indicator variable for each class
        correct_per_class = (correct_vec.unsqueeze(1) * ind_per_class).sum(0)
        total_per_class = ind_per_class.sum(0)

        acc = (correct_vec.sum()) / len(tar)
        acc_per_class = correct_per_class / total_per_class
        if printing:
            print(name)
            print(f"NLL: {loss.item()} | ECE: {ece1}")
            print("Acc:", acc.item())
    return loss.item(), ece1, acc.item(), acc_per_class


In [8]:
class KLD(torch.nn.Module):
    def __init__(self):
        super(KLD, self).__init__()
        self.kl = torch.nn.KLDivLoss(reduction='sum', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p = F.log_softmax(p, dim=-1)
        q = F.log_softmax(q, dim=-1)
        return self.kl(p,q)

kl_div = KLD()

def compute_pair_consensus(pair_preds, target):
    agree = (pair_preds[0] == pair_preds[1])
    agree_correct = agree & (pair_preds[0] == target)
    agree_wrong = agree & (pair_preds[0] != target)
    disagree = (pair_preds[0] != pair_preds[1])
    disagree_both_wrong = disagree & (pair_preds[0] != target) & (pair_preds[1] != target)
    disagree_one_correct = disagree & (pair_preds[0] != target) & (pair_preds[1] == target) 
    disagree_one_correct2 = disagree & (pair_preds[1] != target) & (pair_preds[0] == target) 
    return agree.sum(), disagree.sum(), agree_correct.sum(), agree_wrong.sum(), disagree_both_wrong.sum(), disagree_one_correct.sum()+disagree_one_correct2.sum()

def get_div_metrics(output1,output2,output3,target):
    preds = torch.stack([output1,output2,output3])
    avg_std_logits = torch.std(preds, dim=0).mean(dim=-1).mean() # std over members, mean over classes, sum over samples (mean taken later))
    avg_std = torch.std(preds.softmax(-1), dim=0).mean(dim=-1).mean() # std over members, mean over classes, sum over samples (mean taken later))
    _, all_preds = preds.max(-1)
    ag_p, dag_p, ag_c_p, ag_w_p, dag_w_p, dag_c_p = 0, 0, 0, 0, 0, 0
    kld = 0.
    pairs = ([0,1], [0,2], [1,2])
    for p in pairs:
        ag, dag, ag_c, ag_w, dag_w, dag_c = compute_pair_consensus(all_preds[p,:], target)
        ag_p += ag
        dag_p += dag
        ag_c_p += ag_c
        ag_w_p += ag_w
        dag_c_p += dag_c
        dag_w_p += dag_w
        kld += kl_div(preds[p[0]], preds[p[1]])

    ag_sum = ag_p/len(pairs)
    dag_sum = dag_p/len(pairs)
    ag_c_sum = ag_c_p/len(pairs)
    ag_w_sum = ag_w_p/len(pairs)
    dag_c_sum = dag_c_p/len(pairs)
    dag_w_sum = dag_w_p/len(pairs)
    kld_sum = kld/len(pairs)
    print(f"Diversity agree: {ag_sum/len(output1)} | disagree: {dag_sum/len(output1)}") 
#     print(f"Ensemble Variance Logits: {avg_std_logits} ") 
#     print(f"Ensemble Variance: {avg_std}") 
#     print(f"KL div: {kld_sum/len(output1)}") 
    return ag_sum/len(output1), dag_sum/len(output1), kld_sum/len(output1), avg_std_logits, avg_std

In [9]:
## get classwise stats
def get_classwise(acc_base, acc_rotinv, acc_roteq, num_classes=1000):
    print("use order B, I, E")
    y = torch.stack([v for v in [acc_base, acc_rotinv, acc_roteq]], dim=-1)

    fac = num_classes/100
    # all 3 models equally good
    best_base_inv_eq = (y[:,0] == y[:,1]) & (y[:,1] == y[:,2])
    # 2 models equally good and is better
    best_base_inv = (y[:,0] == y[:,1]) & (y[:,0] > y[:,2])
    best_base_eq = (y[:,0] == y[:,2]) & (y[:,0] > y[:,1])
    best_inv_eq = (y[:,1] == y[:,2]) & (y[:,1] > y[:,0])
    # 2 models equally good and is worse
    worse_base_inv = (y[:,0] == y[:,1]) & (y[:,0] < y[:,2]) # best eq
    worse_base_eq = (y[:,0] == y[:,2]) & (y[:,0] < y[:,1]) # best inv
    worse_inv_eq = (y[:,1] == y[:,2]) & (y[:,1] < y[:,0]) # best base
    all_diff = (y[:,0] != y[:,1]) & (y[:,1] != y[:,2]) & (y[:,0] != y[:,2])
    print(f"all equal best: {(best_base_inv_eq.sum())/fac:.1f}%")
    print(f"B,I equal best: {(best_base_inv.sum())/fac:.1f}%")
    print(f"B,E equal best: {(best_base_eq.sum())/fac:.1f}%")
    print(f"I,E equal best: {(best_inv_eq.sum())/fac:.1f}%")
    # print(f"all diff: {all_diff.sum()}")
    total = best_base_inv_eq.sum() + best_base_inv.sum() + best_base_eq.sum() + best_inv_eq.sum() + all_diff.sum() + worse_inv_eq.sum() + worse_base_eq.sum() + worse_base_inv.sum()
    assert total == num_classes

    # for all diff 
    best_base = (y[:,0] > y[:,1]) & (y[:,0] > y[:,2]) & all_diff
    best_inv = (y[:,1] > y[:,0]) & (y[:,1] > y[:,2]) & all_diff
    best_eq = (y[:,2] > y[:,0]) & (y[:,2] > y[:,1]) & all_diff
    total_unique = best_base.sum()+best_inv.sum()+best_eq.sum()
    assert total_unique == all_diff.sum()

    # single model uniquely best
    b_uniq = best_base | worse_inv_eq
    i_uniq = best_inv | worse_base_eq
    e_uniq = best_eq | worse_base_inv
    print(f"B uniquely best: {b_uniq.sum()/fac:.1f}%")
    print(f"I uniquely best: {(best_inv.sum() + worse_base_eq.sum())/fac:.1f}%")
    print(f"E uniquely best: {(best_eq.sum() + worse_base_inv.sum())/fac:.1f}%")

    B_good = b_uniq | best_base_inv_eq | best_base_inv | best_base_eq
    I_good = i_uniq | best_base_inv_eq | best_base_inv | best_inv_eq
    E_good = e_uniq | best_base_inv_eq | best_base_eq | best_inv_eq

In [22]:
# Flowers
dataset_name = 'flowers102'
num_classes = 102
lr = '0.5'
loader = flowers_loader

b69 = load_1_model(f"trans_{dataset_name}_base69_cos_lr0.8_bs256", num_classes=num_classes)
b69_out_f, b69_tar = rollout_loader(b69, loader)
eq69 = load_1_model(f"trans_{dataset_name}_eq69_cos_lr0.8_bs256", num_classes=num_classes)
eq69_out_f, eq69_tar = rollout_loader(eq69, loader)
inv69 = load_1_model(f"trans_{dataset_name}_inv69_cos_lr0.8_bs256", num_classes=num_classes)
inv69_out_f, inv69_tar = rollout_loader(inv69, loader)
assert(torch.equal(b69_tar,eq69_tar))

b24 = load_1_model(f"trans_{dataset_name}_base24_cos_lr{lr}_bs256", num_classes=num_classes)
b24_out_f, _ = rollout_loader(b24, loader)
eq24 = load_1_model(f"trans_{dataset_name}_eq24_cos_lr0.8_bs256", num_classes=num_classes)
eq24_out_f, _ = rollout_loader(eq24, loader)
inv24 = load_1_model(f"trans_{dataset_name}_inv24_cos_lr{lr}_bs256", num_classes=num_classes)
inv24_out_f, _ = rollout_loader(inv24, loader)

b31 = load_1_model(f"trans_{dataset_name}_base31_cos_lr{lr}_bs256", num_classes=num_classes)
b31_out_f, _ = rollout_loader(b31, loader)
eq31 = load_1_model(f"trans_{dataset_name}_eq31_cos_lr{lr}_bs256", num_classes=num_classes)
eq31_out_f, _ = rollout_loader(eq24, loader)
inv31 = load_1_model(f"trans_{dataset_name}_inv31_cos_lr0.3_bs256", num_classes=num_classes)
inv31_out_f, _ = rollout_loader(inv31, loader)

tar_f = b69_tar

loaded trans_flowers102_base69_cos_lr0.8_bs256
loaded trans_flowers102_eq69_cos_lr0.8_bs256
loaded trans_flowers102_inv69_cos_lr0.8_bs256
loaded trans_flowers102_base24_cos_lr0.5_bs256
loaded trans_flowers102_eq24_cos_lr0.8_bs256
loaded trans_flowers102_inv24_cos_lr0.5_bs256
loaded trans_flowers102_base31_cos_lr0.5_bs256
loaded trans_flowers102_eq31_cos_lr0.5_bs256
loaded trans_flowers102_inv31_cos_lr0.3_bs256


In [35]:
eq42 = load_1_model(f"trans_{dataset_name}_eq42_cos_lr0.8_bs256", num_classes=num_classes)
eq42_out_f, _ = rollout_loader(eq24, loader)

loaded trans_flowers102_eq42_cos_lr0.8_bs256


In [51]:
from eval_metrics import get_metrics as get_new_metrics

def ensem_BEI(all_eq,all_base,all_inv,same_tar,num_E=0, num_B=0, num_I=0, num_comb=5, err='std'):
    ee_nll = []
    ee_ece = []
    ee_acc = []
    for i in range(num_comb):
        eq_list = np.random.choice(all_eq, num_E, replace=False)
        base_list = np.random.choice(all_base, num_B, replace=False)
        inv_list = np.random.choice(all_inv, num_I, replace=False)
        out_list = list(eq_list) + list(base_list) + list(inv_list)
        out_list = [torch.Tensor(x.cpu()) for x in out_list]
        ee_out = torch.stack(out_list).softmax(-1).mean(dim=0).cuda()
        nll, ece, acc, _ = get_new_metrics([ee_out],[same_tar],[f'EE_comb{i}'], printing=False, input_softmax=True)    
        ee_nll.append(nll)
        ee_ece.append(ece)  
        ee_acc.append(acc)
    print("E"*num_E + "B"*num_B + "I"*num_I)
    if err=='std':
        print(f"NLL: {np.mean(ee_nll):.4f} +/- {np.std(ee_nll):.4f}")
        print(f"ECE: {np.mean(ee_ece):.4f} +/- {np.std(ee_ece):.4f}")
        print(f"Acc: {np.mean(ee_acc):.4f} +/- {np.std(ee_acc):.4f}")
    elif err=='var':
        print(f"NLL: {np.mean(ee_nll):.4f} +/- {np.var(ee_nll):.4f}")
        print(f"ECE: {np.mean(ee_ece):.4f} +/- {np.var(ee_ece):.4f}")
        print(f"Acc: {np.mean(ee_acc):.4f} +/- {np.var(ee_acc):.4f}")



In [42]:
# 1 models
all_eq_f = [eq69_out_f, eq24_out_f, eq31_out_f, eq42_out_f]
all_base_f = [b69_out_f, b24_out_f, b31_out_f]
all_inv_f = [inv69_out_f, inv24_out_f, inv31_out_f]

ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_E=1)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_B=1)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_I=1)

E
NLL: 0.3570 +/- 0.0112
ECE: 0.0269 +/- 0.0020
Acc: 0.9191 +/- 0.0003
B
NLL: 0.3079 +/- 0.0156
ECE: 0.0109 +/- 0.0028
Acc: 0.9217 +/- 0.0019
I
NLL: 0.3398 +/- 0.0177
ECE: 0.0170 +/- 0.0099
Acc: 0.9132 +/- 0.0017


In [46]:
# 2 models
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_E=2)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_E=1, num_I=1)

# 3 models
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_E=3)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_f,num_E=2, num_I=1)

EE
NLL: 0.3326 +/- 0.0367
ECE: 0.0216 +/- 0.0078
Acc: 0.9225 +/- 0.0043
EI
NLL: 0.2829 +/- 0.0084
ECE: 0.0149 +/- 0.0024
Acc: 0.9276 +/- 0.0012
EEE
NLL: 0.3072 +/- 0.0277
ECE: 0.0151 +/- 0.0064
Acc: 0.9236 +/- 0.0023
EEI
NLL: 0.2778 +/- 0.0146
ECE: 0.0132 +/- 0.0040
Acc: 0.9295 +/- 0.0038


In [None]:
same_tar = b69_tar
num_comb = 1
all_eq_exR = [eq24_out, eq69_out, eq31_out]
all_base_exR = [b24_out, b69_out, b31_out]
all_inv_exR = [inv24_out, inv69_out, inv31_out]
for i in range(num_comb):
    [eq1] = np.random.choice(all_eq_exR, 1)
    [b1] = np.random.choice(all_base_exR, 1)
    [inv1] = np.random.choice(all_inv_exR, 1) 
    _,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
    _,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
    _,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
    get_classwise(acc_pc_b, acc_pc_inv, acc_pc_eq, num_classes=num_classes)
    


In [18]:
# CIFAR-100
dataset_name = 'cifar100'
num_classes = 100
lr = '0.2'
loader = c100_loader

b69 = load_1_model(f"trans_{dataset_name}_base69_cos_lr{lr}_bs256", num_classes=num_classes)
b69_out, b69_tar = rollout_loader(b69, loader)
same_tar = b69_tar
get_metrics([b69_out], [same_tar],['b69'], num_classes=num_classes)
eq69 = load_1_model(f"trans_{dataset_name}_eq69_cos_lr{lr}_bs256", num_classes=num_classes)
eq69_out, eq69_tar = rollout_loader(eq69, loader)
get_metrics([eq69_out], [same_tar],['eq69'], num_classes=num_classes)

inv69 = load_1_model(f"trans_{dataset_name}_inv69_cos_lr{lr}_bs256", num_classes=num_classes)
inv69_out, inv69_tar = rollout_loader(inv69, loader)
get_metrics([inv69_out], [same_tar],['inv69'], num_classes=num_classes)

assert(torch.equal(b69_tar,eq69_tar))

b24 = load_1_model(f"trans_{dataset_name}_base24_cos_lr{lr}_bs256", num_classes=num_classes)
b24_out, _ = rollout_loader(b24, loader)
get_metrics([b24_out], [same_tar],['b24'], num_classes=num_classes)

eq24 = load_1_model(f"trans_{dataset_name}_eq24_cos_lr{lr}_bs256", num_classes=num_classes)
eq24_out, _ = rollout_loader(eq24, loader)
get_metrics([eq24_out], [same_tar],['eq24'], num_classes=num_classes)

inv24 = load_1_model(f"trans_{dataset_name}_inv24_cos_lr{lr}_bs256", num_classes=num_classes)
inv24_out, _ = rollout_loader(inv24, loader)
get_metrics([inv24_out], [same_tar],['inv24'], num_classes=num_classes)


b31 = load_1_model(f"trans_{dataset_name}_base31_cos_lr{lr}_bs256", num_classes=num_classes)
b31_out, _ = rollout_loader(b31, loader)
get_metrics([b31_out], [same_tar],['b31'], num_classes=num_classes)

eq31 = load_1_model(f"trans_{dataset_name}_eq31_cos_lr{lr}_bs256", num_classes=num_classes)
eq31_out, _ = rollout_loader(eq31, loader)
get_metrics([eq31_out], [same_tar],['eq31'], num_classes=num_classes)

inv31 = load_1_model(f"trans_{dataset_name}_inv31_cos_lr{lr}_bs256", num_classes=num_classes)
inv31_out, _ = rollout_loader(inv31, loader)
get_metrics([inv31_out], [same_tar],['inv31'], num_classes=num_classes)



loaded trans_cifar100_base69_cos_lr0.2_bs256
b69
NLL: 0.7387407422065735 | ECE: 0.0890038620114327
Acc: 0.8535999655723572
loaded trans_cifar100_eq69_cos_lr0.2_bs256
eq69
NLL: 0.7336560487747192 | ECE: 0.09125010451376434
Acc: 0.8532999753952026
loaded trans_cifar100_inv69_cos_lr0.2_bs256
inv69
NLL: 0.8228658437728882 | ECE: 0.10278177570402622
Acc: 0.8385999798774719
loaded trans_cifar100_base24_cos_lr0.2_bs256
b24
NLL: 0.7335842847824097 | ECE: 0.08970591485351326
Acc: 0.8545999526977539
loaded trans_cifar100_eq24_cos_lr0.2_bs256
eq24
NLL: 0.7076037526130676 | ECE: 0.08781001860648402
Acc: 0.8554999828338623
loaded trans_cifar100_inv24_cos_lr0.2_bs256
inv24
NLL: 0.8170745372772217 | ECE: 0.10017117341011764
Acc: 0.8402000069618225
loaded trans_cifar100_base31_cos_lr0.2_bs256
b31
NLL: 0.7173638939857483 | ECE: 0.08865366537570948
Acc: 0.8531000018119812
loaded trans_cifar100_eq31_cos_lr0.2_bs256
eq31
NLL: 0.6678038835525513 | ECE: 0.08116961831152442
Acc: 0.85589998960495
loaded trans

(0.8248342871665955,
 0.10126517752557994,
 0.8409000039100647,
 tensor([0.9700, 0.9300, 0.7700, 0.7500, 0.7800, 0.8000, 0.9400, 0.8300, 0.9100,
         0.9000, 0.7600, 0.6300, 0.9100, 0.7300, 0.8800, 0.8700, 0.8800, 0.8300,
         0.8400, 0.8200, 0.8900, 0.9400, 0.9100, 0.9000, 0.9200, 0.8000, 0.8200,
         0.8300, 0.7800, 0.8700, 0.7600, 0.8200, 0.8000, 0.8000, 0.8300, 0.5700,
         0.9000, 0.8400, 0.8200, 0.9500, 0.8400, 0.9400, 0.8400, 0.8900, 0.8000,
         0.7400, 0.7000, 0.6900, 0.9700, 0.9200, 0.7100, 0.8500, 0.7100, 0.9900,
         0.9000, 0.7000, 0.9400, 0.9000, 0.9400, 0.6800, 0.9100, 0.8400, 0.8700,
         0.7900, 0.6400, 0.7400, 0.9100, 0.8300, 0.9700, 0.9100, 0.8700, 0.8900,
         0.7100, 0.8100, 0.7200, 0.9600, 0.9500, 0.8100, 0.8000, 0.9200, 0.8200,
         0.8500, 0.9600, 0.8100, 0.8400, 0.9100, 0.8700, 0.9400, 0.8700, 0.9500,
         0.9200, 0.8900, 0.8000, 0.7400, 0.9800, 0.8200, 0.7200, 0.8700, 0.7000,
         0.8500], device='cuda:0'))

In [48]:
dataset_name = 'cifar100'
num_classes = 100
lr = '0.2'
loader = c100_loader

eq42 = load_1_model(f"trans_{dataset_name}_eq42_cos_lr{lr}_bs256", num_classes=num_classes)
eq42_out, eq42_tar = rollout_loader(eq42, loader)
get_metrics([eq42_out], [eq42_tar],['eq42'], num_classes=num_classes)


loaded trans_cifar100_eq42_cos_lr0.2_bs256
eq42
NLL: 0.6688125729560852 | ECE: 0.08455638369619847
Acc: 0.858299970626831


(0.6688125729560852,
 0.08455638369619847,
 0.858299970626831,
 tensor([0.9400, 0.9400, 0.7500, 0.7800, 0.7400, 0.8300, 0.9300, 0.8600, 0.9500,
         0.9200, 0.7500, 0.6600, 0.8900, 0.8300, 0.8900, 0.9200, 0.8900, 0.9200,
         0.8700, 0.8800, 0.9100, 0.9400, 0.9300, 0.9500, 0.9000, 0.8400, 0.7800,
         0.8400, 0.8700, 0.8900, 0.8600, 0.8500, 0.8000, 0.8000, 0.8900, 0.6500,
         0.9100, 0.9100, 0.8500, 0.9500, 0.8800, 0.9200, 0.8500, 0.9000, 0.8300,
         0.8000, 0.7000, 0.7100, 0.9800, 0.9300, 0.7400, 0.9500, 0.6800, 0.9900,
         0.9100, 0.7600, 0.9300, 0.9000, 0.9700, 0.7200, 0.8900, 0.7900, 0.8600,
         0.7800, 0.6900, 0.8300, 0.9000, 0.7900, 0.9700, 0.9200, 0.8800, 0.9100,
         0.6900, 0.8200, 0.7200, 0.9500, 0.9600, 0.8500, 0.8500, 0.9100, 0.8300,
         0.8700, 0.9700, 0.8700, 0.8300, 0.9500, 0.8400, 0.9500, 0.8400, 0.9400,
         0.9100, 0.9300, 0.7400, 0.8400, 0.9800, 0.8300, 0.7300, 0.9000, 0.7300,
         0.8800], device='cuda:0'))

In [58]:
same_tar = eq42_tar
all_eq = [eq69_out, eq24_out, eq31_out, eq42_out]
all_base = [b69_out, b24_out, b31_out]
all_inv = [inv69_out, inv24_out, inv31_out]

# 1 models
ensem_BEI(all_eq, all_base, all_inv, same_tar, num_E=1)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_B=1)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_I=1)

# 2 models
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=2)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=1, num_I=1)

# 3 models
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=3)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=2, num_I=1)


E
NLL: 0.6943 +/- 0.0321
ECE: 0.0859 +/- 0.0046
Acc: 0.8553 +/- 0.0019
B
NLL: 0.7174 +/- 0.0000
ECE: 0.0887 +/- 0.0000
Acc: 0.8531 +/- 0.0000
I
NLL: 0.8229 +/- 0.0030
ECE: 0.1013 +/- 0.0008
Acc: 0.8403 +/- 0.0009
EE
NLL: 0.5501 +/- 0.0086
ECE: 0.0489 +/- 0.0023
Acc: 0.8674 +/- 0.0023
EI
NLL: 0.5687 +/- 0.0093
ECE: 0.0446 +/- 0.0017
Acc: 0.8654 +/- 0.0011
EEE
NLL: 0.5104 +/- 0.0068
ECE: 0.0403 +/- 0.0025
Acc: 0.8705 +/- 0.0013
EEI
NLL: 0.5145 +/- 0.0026
ECE: 0.0342 +/- 0.0017
Acc: 0.8703 +/- 0.0013


In [60]:
(91.2+91.9)/2

91.55000000000001

In [61]:
(91.2+91.9+91.9)/3

91.66666666666667

In [66]:
87.03-(85.5+84.0+85.5)/3

2.030000000000001

In [65]:
86.54-(85.5+84.0)/2

1.7900000000000063

In [13]:
b1 = b24_out
eq1 = eq42_out
inv1 = inv69_out

_,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
_,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
_,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
get_classwise(acc_pc_b, acc_pc_inv, acc_pc_eq, num_classes=num_classes)
  

base
NLL: 0.7335842847824097 | ECE: 0.08970591485351326
Acc: 0.8545999526977539
eq
NLL: 0.6688125729560852 | ECE: 0.08455638369619847
Acc: 0.858299970626831
inv
NLL: 0.8228658437728882 | ECE: 0.10278177570402622
Acc: 0.8385999798774719
use order B, I, E
all equal best: 5.0%
B,I equal best: 2.0%
B,E equal best: 5.0%
I,E equal best: 5.0%
B uniquely best: 33.0%
I uniquely best: 13.0%
E uniquely best: 37.0%


In [105]:
same_tar = b69_tar
num_comb = 1
all_eq_exR = [eq24_out, eq69_out, eq31_out]
all_base_exR = [b24_out, b69_out, b31_out]
all_inv_exR = [inv24_out, inv69_out, inv31_out]
for i in range(num_comb):
    [eq1] = np.random.choice(all_eq_exR, 1)
    [b1] = np.random.choice(all_base_exR, 1)
    [inv1] = np.random.choice(all_inv_exR, 1) 
    _,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
    _,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
    _,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
    get_classwise(acc_pc_b, acc_pc_inv, acc_pc_eq, num_classes=num_classes)
    


base
NLL: 0.7387407422065735 | ECE: 0.0890038620114327
Acc: 0.8535999655723572
eq
NLL: 0.7336560487747192 | ECE: 0.09125010451376434
Acc: 0.8532999753952026
inv
NLL: 0.8170745372772217 | ECE: 0.10017117341011764
Acc: 0.8402000069618225
use order B, I, E
all equal best: 1.0%
B,I equal best: 0.0%
B,E equal best: 13.0%
I,E equal best: 8.0%
B uniquely best: 33.0%
I uniquely best: 21.0%
E uniquely best: 24.0%


In [106]:
num_comb =1 
eee_nll = []
eee_ece = []
eee_acc = []
for i in range(num_comb):
    [eq1, eq2, eq3] = np.random.choice(all_eq_exR, 3, replace=False)
    eee_out = (eq1.softmax(-1) + eq2.softmax(-1) + eq3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([eee_out],[same_tar],[f'EEE_comb{i}'], input_softmax=True, num_classes=num_classes)    
    ag, dag, kld, std_logits, std = get_div_metrics(eq1,eq2,eq3, same_tar)

    eee_nll.append(nll)
    eee_ece.append(ece)  
    eee_acc.append(acc)   

bbb_nll = []
bbb_ece = []
bbb_acc = []
for i in range(num_comb):
    [b1, b2, b3] = np.random.choice(all_base_exR, 3, replace=False)
    bbb_out = (b1.softmax(-1) + b2.softmax(-1) + b3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([bbb_out],[same_tar],[f'BBB_comb{i}'], input_softmax=True, num_classes=num_classes)      
    ag, dag, kld, std_logits, std = get_div_metrics(b1,b2,b3, same_tar)

    bbb_nll.append(nll)
    bbb_ece.append(ece)  
    bbb_acc.append(acc) 

iii_nll = []
iii_ece = []
iii_acc = []
for i in range(num_comb):
    [i1, i2, i3] = np.random.choice(all_inv_exR, 3, replace=False)
    iii_out = (i1.softmax(-1) + i2.softmax(-1) + i3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([iii_out],[same_tar],[f'III_comb{i}'], input_softmax=True, num_classes=num_classes)      
    ag, dag, kld, std_logits, std = get_div_metrics(i1,i2,i3, same_tar)
    
    iii_nll.append(nll)
    iii_ece.append(ece)  
    iii_acc.append(acc) 
    
bei_nll = []
bei_ece = []
bei_acc = []
for i in range(num_comb):
#     [eq1] = np.random.choice(all_eq_exR, 1)
#     [base1] = np.random.choice(all_base_exR, 1)
#     [inv1] = np.random.choice(all_inv_exR, 1)
    [eq1] = np.random.choice([eq31_out], 1)
    [base1] = np.random.choice([b31_out], 1)
    [inv1] = np.random.choice([inv24_out], 1)
    bei_out = (eq1.softmax(-1) + base1.softmax(-1) + inv1.softmax(-1))/3
    
    nll, ece, acc, _ = get_metrics([bei_out],[same_tar],[f'BEI_comb{i}'], input_softmax=True, num_classes=num_classes)     
    ag, dag, kld, std_logits, std = get_div_metrics(eq1,base1,inv1, same_tar)
    
    bei_nll.append(nll)
    bei_ece.append(ece)  
    bei_acc.append(acc) 

print(f"\nEEE Acc: {np.mean(eee_acc)} +/- {np.std(eee_acc)}")
print(f"EEE ECE: {np.mean(eee_ece)} +/- {np.std(eee_ece)}")
print(f"EEE NLL: {np.mean(eee_nll)} +/- {np.std(eee_nll)}")

print(f"\nBBB Acc: {np.mean(bbb_acc)} +/- {np.std(bbb_acc)}")
print(f"BBB ECE: {np.mean(bbb_ece)} +/- {np.std(bbb_ece)}")
print(f"BBB NLL: {np.mean(bbb_nll)} +/- {np.std(bbb_nll)}")

print(f"\nIII Acc: {np.mean(iii_acc)} +/- {np.std(iii_acc)}")
print(f"III ECE: {np.mean(iii_ece)} +/- {np.std(iii_ece)}")
print(f"III NLL: {np.mean(iii_nll)} +/- {np.std(iii_nll)}")

print(f"\nBEI Acc: {np.mean(bei_acc)} +/- {np.std(bei_acc)}")
print(f"BEI ECE: {np.mean(bei_ece)} +/- {np.std(bei_ece)}")
print(f"BEI NLL: {np.mean(bei_nll)} +/- {np.std(bei_nll)}")


EEE_comb0
NLL: 0.5832626223564148 | ECE: 0.06168382893055683
Acc: 0.8614999651908875
Diversity agree: 0.9251333475112915 | disagree: 0.07486666738986969
BBB_comb0
NLL: 0.5279425978660583 | ECE: 0.04105125807672736
Acc: 0.8693000078201294
Diversity agree: 0.882099986076355 | disagree: 0.11789999902248383
III_comb0
NLL: 0.5802863240242004 | ECE: 0.043603131787478946
Acc: 0.8579999804496765
Diversity agree: 0.8637666702270508 | disagree: 0.13623332977294922
BEI_comb0
NLL: 0.521447479724884 | ECE: 0.03460804491192103
Acc: 0.8698999881744385
Diversity agree: 0.8691666722297668 | disagree: 0.13083332777023315

EEE Acc: 0.8614999651908875 +/- 0.0
EEE ECE: 0.06168382893055683 +/- 0.0
EEE NLL: 0.5832626223564148 +/- 0.0

BBB Acc: 0.8693000078201294 +/- 0.0
BBB ECE: 0.04105125807672736 +/- 0.0
BBB NLL: 0.5279425978660583 +/- 0.0

III Acc: 0.8579999804496765 +/- 0.0
III ECE: 0.043603131787478946 +/- 0.0
III NLL: 0.5802863240242004 +/- 0.0

BEI Acc: 0.8698999881744385 +/- 0.0
BEI ECE: 0.0346080449

In [68]:
ag, dag, kld, std_logits, std = get_div_metrics(b24_out,eq24_out,inv24_out, same_tar)


agree: 0.9224806427955627 | disagree: 0.07751938700675964
Ensemble Variance Logits: 1.2094038724899292
Ensemble Variance: 0.0014868759317323565
KL div: 0.22746145725250244


In [69]:
ag, dag, kld, std_logits, std = get_div_metrics(eq69_out,eq24_out,eq31_out, same_tar)


agree: 0.9501274228096008 | disagree: 0.049872610718011856
Ensemble Variance Logits: 0.9655865430831909
Ensemble Variance: 0.0009998121531680226
KL div: 0.15915538370609283


In [70]:
ag, dag, kld, std_logits, std = get_div_metrics(inv69_out,inv24_out,inv31_out, same_tar)


agree: 0.9201496243476868 | disagree: 0.07985038310289383
Ensemble Variance Logits: 1.0752143859863281
Ensemble Variance: 0.0014482238329946995
KL div: 0.21660903096199036


In [72]:
ag, dag, kld, std_logits, std = get_div_metrics(b69_out,b24_out,b31_out, same_tar)


agree: 0.9288774132728577 | disagree: 0.0711226761341095
Ensemble Variance Logits: 1.0374211072921753
Ensemble Variance: 0.0013705750461667776
KL div: 0.19209891557693481
