In [1]:
import torch
from torchvision import models
from torchvision import datasets, transforms
from datasets import Split_Dataset
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np
import torch.nn.functional as F
from torchvision.datasets import ImageFolder, CIFAR10, CIFAR100
from datasets_v08 import Flowers102
from datasets import INaturalist
import os

torch.cuda.set_device('cuda:4')

inat_norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
num_classes = 1010

inat_val_transforms = transforms.Compose([
        transforms.Resize(224+32),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        inat_norm
    ])

inat_dataset = INaturalist('/data/scratch/swhan/data/inat-1k/', version='2019', transform=inat_val_transforms)
split_idx_path = './misc/inat-1k-train-val-split-idx.pth'

# creating/loading train-val (90 - 10) split
if os.path.exists(split_idx_path):
    split_idx = torch.load(split_idx_path)
inat_dataset = Subset(inat_dataset, split_idx['val'])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
inat_loader = torch.utils.data.DataLoader(
            inat_dataset, batch_size=256, shuffle=False,
            num_workers=16, pin_memory=True,
        )

In [3]:

def load_1_model(ckpt_path, full_path=False, num_classes=1000):
    model1 = models.resnet50(num_classes=num_classes).cuda()
    if not full_path:
        sd = torch.load(f"./dist_models/{ckpt_path}/checkpoint_best.pth", map_location="cpu")
    else:
        sd = torch.load(ckpt_path, map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model1.load_state_dict(ckpt)
    print(f"loaded {ckpt_path}")
    model1.eval()
    return model1

def rollout_loader(model, loader):
    targets = []
    outputs = []
    for it, (img, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        img = img.cuda(non_blocking=True)
        with torch.no_grad():
            output1 = model(img)
#             ece_1 = ece_criterion(output1.softmax(-1), target)
            targets.append(target)
            outputs.append(output1)
    return torch.cat(outputs), torch.cat(targets)

import torch.nn.functional as F
import inspect
from netcal.metrics import ECE

cecriterion = torch.nn.CrossEntropyLoss().cuda()
nll_criterion = torch.nn.CrossEntropyLoss().cuda()
# ece_criterion = _ECELoss().cuda()
ece_netcal = ECE(15)

def get_metrics(outs, tars, names, printing=True, input_softmax=False, num_classes=1000):

    for out, tar,name in zip(outs,tars,names):
        correct_per_class = torch.zeros(num_classes).to(tar.device)
        total_per_class = torch.zeros(num_classes).to(tar.device)

        if not input_softmax:
            out = out.softmax(-1)
        ece1 = ece_netcal.measure(out.cpu().numpy(), tar.cpu().numpy())
#         ece2 = ece_criterion(out, tar)
        loss = F.nll_loss(torch.log(out), tar)
        _, pred = out.max(-1)
        correct_vec = (pred == tar)
        ind_per_class = (tar.unsqueeze(1) == torch.arange(num_classes).to(tar.device)) # indicator variable for each class
        correct_per_class = (correct_vec.unsqueeze(1) * ind_per_class).sum(0)
        total_per_class = ind_per_class.sum(0)

        acc = (correct_vec.sum()) / len(tar)
        acc_per_class = correct_per_class / total_per_class
        if printing:
            print(name)
            print(f"NLL: {loss.item()} | ECE: {ece1}")
            print("Acc:", acc.item())
    return loss.item(), ece1, acc.item(), acc_per_class


In [4]:
class KLD(torch.nn.Module):
    def __init__(self):
        super(KLD, self).__init__()
        self.kl = torch.nn.KLDivLoss(reduction='sum', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p = F.log_softmax(p, dim=-1)
        q = F.log_softmax(q, dim=-1)
        return self.kl(p,q)

kl_div = KLD()

def compute_pair_consensus(pair_preds, target):
    agree = (pair_preds[0] == pair_preds[1])
    agree_correct = agree & (pair_preds[0] == target)
    agree_wrong = agree & (pair_preds[0] != target)
    disagree = (pair_preds[0] != pair_preds[1])
    disagree_both_wrong = disagree & (pair_preds[0] != target) & (pair_preds[1] != target)
    disagree_one_correct = disagree & (pair_preds[0] != target) & (pair_preds[1] == target) 
    disagree_one_correct2 = disagree & (pair_preds[1] != target) & (pair_preds[0] == target) 
    return agree.sum(), disagree.sum(), agree_correct.sum(), agree_wrong.sum(), disagree_both_wrong.sum(), disagree_one_correct.sum()+disagree_one_correct2.sum()

def get_div_metrics(output1,output2,output3,target):
    preds = torch.stack([output1,output2,output3])
    avg_std_logits = torch.std(preds, dim=0).mean(dim=-1).mean() # std over members, mean over classes, sum over samples (mean taken later))
    avg_std = torch.std(preds.softmax(-1), dim=0).mean(dim=-1).mean() # std over members, mean over classes, sum over samples (mean taken later))
    _, all_preds = preds.max(-1)
    ag_p, dag_p, ag_c_p, ag_w_p, dag_w_p, dag_c_p = 0, 0, 0, 0, 0, 0
    kld = 0.
    pairs = ([0,1], [0,2], [1,2])
    for p in pairs:
        ag, dag, ag_c, ag_w, dag_w, dag_c = compute_pair_consensus(all_preds[p,:], target)
        ag_p += ag
        dag_p += dag
        ag_c_p += ag_c
        ag_w_p += ag_w
        dag_c_p += dag_c
        dag_w_p += dag_w
        kld += kl_div(preds[p[0]], preds[p[1]])

    ag_sum = ag_p/len(pairs)
    dag_sum = dag_p/len(pairs)
    ag_c_sum = ag_c_p/len(pairs)
    ag_w_sum = ag_w_p/len(pairs)
    dag_c_sum = dag_c_p/len(pairs)
    dag_w_sum = dag_w_p/len(pairs)
    kld_sum = kld/len(pairs)
    print(f"Diversity agree: {ag_sum/len(output1)} | disagree: {dag_sum/len(output1)}") 
#     print(f"Ensemble Variance Logits: {avg_std_logits} ") 
#     print(f"Ensemble Variance: {avg_std}") 
#     print(f"KL div: {kld_sum/len(output1)}") 
    return ag_sum/len(output1), dag_sum/len(output1), kld_sum/len(output1), avg_std_logits, avg_std

In [5]:
## get classwise stats
def get_classwise(acc_base, acc_rotinv, acc_roteq, num_classes=1000):
    print("use order B, I, E")
    y = torch.stack([v for v in [acc_base, acc_rotinv, acc_roteq]], dim=-1)

    fac = num_classes/100
    # all 3 models equally good
    best_base_inv_eq = (y[:,0] == y[:,1]) & (y[:,1] == y[:,2])
    # 2 models equally good and is better
    best_base_inv = (y[:,0] == y[:,1]) & (y[:,0] > y[:,2])
    best_base_eq = (y[:,0] == y[:,2]) & (y[:,0] > y[:,1])
    best_inv_eq = (y[:,1] == y[:,2]) & (y[:,1] > y[:,0])
    # 2 models equally good and is worse
    worse_base_inv = (y[:,0] == y[:,1]) & (y[:,0] < y[:,2]) # best eq
    worse_base_eq = (y[:,0] == y[:,2]) & (y[:,0] < y[:,1]) # best inv
    worse_inv_eq = (y[:,1] == y[:,2]) & (y[:,1] < y[:,0]) # best base
    all_diff = (y[:,0] != y[:,1]) & (y[:,1] != y[:,2]) & (y[:,0] != y[:,2])
    print(f"all equal best: {(best_base_inv_eq.sum())/fac:.1f}%")
    print(f"B,I equal best: {(best_base_inv.sum())/fac:.1f}%")
    print(f"B,E equal best: {(best_base_eq.sum())/fac:.1f}%")
    print(f"I,E equal best: {(best_inv_eq.sum())/fac:.1f}%")
    # print(f"all diff: {all_diff.sum()}")
    total = best_base_inv_eq.sum() + best_base_inv.sum() + best_base_eq.sum() + best_inv_eq.sum() + all_diff.sum() + worse_inv_eq.sum() + worse_base_eq.sum() + worse_base_inv.sum()
    assert total == num_classes

    # for all diff 
    best_base = (y[:,0] > y[:,1]) & (y[:,0] > y[:,2]) & all_diff
    best_inv = (y[:,1] > y[:,0]) & (y[:,1] > y[:,2]) & all_diff
    best_eq = (y[:,2] > y[:,0]) & (y[:,2] > y[:,1]) & all_diff
    total_unique = best_base.sum()+best_inv.sum()+best_eq.sum()
    assert total_unique == all_diff.sum()

    # single model uniquely best
    b_uniq = best_base | worse_inv_eq
    i_uniq = best_inv | worse_base_eq
    e_uniq = best_eq | worse_base_inv
    print(f"B uniquely best: {b_uniq.sum()/fac:.1f}%")
    print(f"I uniquely best: {(best_inv.sum() + worse_base_eq.sum())/fac:.1f}%")
    print(f"E uniquely best: {(best_eq.sum() + worse_base_inv.sum())/fac:.1f}%")

    B_good = b_uniq | best_base_inv_eq | best_base_inv | best_base_eq
    I_good = i_uniq | best_base_inv_eq | best_base_inv | best_inv_eq
    E_good = e_uniq | best_base_inv_eq | best_base_eq | best_inv_eq
    
def get_classwise_ei(acc_rotinv, acc_roteq, num_classes=1000):
    y = torch.stack([v for v in [acc_rotinv, acc_roteq]], dim=-1)
    best_inv = (y[:,0] > y[:,1])
    best_eq = (y[:,0] < y[:,1])
    best_same = (y[:,0] == y[:,1])
    
    print(f"I uniquely best: {best_inv.sum() / num_classes}")
    print(f"E uniquely best: {best_eq.sum() / num_classes}")
    print(f"I and E equal: {best_same.sum() / num_classes}")

In [6]:
# INat
dataset_name = 'inat-1k'
num_classes = 1010
lr = '5.0'
loader = inat_loader

b31 = load_1_model('./checkpoints/inat-rot-base31-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)

e24 = load_1_model('./checkpoints/inat-roteq-seed24-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)
e31 = load_1_model('./checkpoints/inat-roteq-base31-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)
e69 = load_1_model('./checkpoints/inat-roteq-seed69-lp-lr5.0-cosine//checkpoint_best.pth', full_path=True, num_classes=num_classes)

i31 = load_1_model('./checkpoints/inat-rotinv-base31-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)
i24 = load_1_model('./checkpoints/inat-rotinv-seed24-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)
i69 = load_1_model('./checkpoints/inat-rotinv-seed69-lp-lr5.0-cosine/checkpoint_best.pth', full_path=True, num_classes=num_classes)

b31_out_inat, b31_tar = rollout_loader(b31, loader)
e24_out_inat, e24_tar = rollout_loader(e24, loader)
e31_out_inat, e31_tar = rollout_loader(e31, loader)
e69_out_inat, e69_tar = rollout_loader(e69, loader)
i24_out_inat, i24_tar = rollout_loader(i24, loader)
i31_out_inat, i31_tar = rollout_loader(i31, loader)
i69_out_inat, i69_tar = rollout_loader(i69, loader)

assert(torch.equal(b31_tar, e24_tar))

tar_inat = b31_tar

loaded ./checkpoints/inat-rot-base31-lp-lr5.0-cosine/checkpoint_best.pth
loaded ./checkpoints/inat-roteq-seed24-lp-lr5.0-cosine/checkpoint_best.pth
loaded ./checkpoints/inat-roteq-base31-lp-lr5.0-cosine/checkpoint_best.pth
loaded ./checkpoints/inat-roteq-seed69-lp-lr5.0-cosine//checkpoint_best.pth
loaded ./checkpoints/inat-rotinv-base31-lp-lr5.0-cosine/checkpoint_best.pth
loaded ./checkpoints/inat-rotinv-seed24-lp-lr5.0-cosine/checkpoint_best.pth
loaded ./checkpoints/inat-rotinv-seed69-lp-lr5.0-cosine/checkpoint_best.pth


In [7]:
from eval_metrics import get_metrics as get_new_metrics

def ensem_BEI(all_eq,all_base,all_inv,same_tar,num_E=0, num_B=0, num_I=0, num_comb=5, err='std'):
    ee_nll = []
    ee_ece = []
    ee_acc = []
    for i in range(num_comb):
        eq_list = np.random.choice(all_eq, num_E, replace=False)
        base_list = np.random.choice(all_base, num_B, replace=False)
        inv_list = np.random.choice(all_inv, num_I, replace=False)
        out_list = list(eq_list) + list(base_list) + list(inv_list)
        out_list = [torch.Tensor(x.cpu()) for x in out_list]
        ee_out = torch.stack(out_list).softmax(-1).mean(dim=0).cuda()
        nll, ece, acc, _ = get_new_metrics([ee_out],[same_tar],[f'EE_comb{i}'], printing=False, input_softmax=True)    
        ee_nll.append(nll)
        ee_ece.append(ece)  
        ee_acc.append(acc)
    print("E"*num_E + "B"*num_B + "I"*num_I)
    if err=='std':
        print(f"NLL: {np.mean(ee_nll):.4f} +/- {np.std(ee_nll):.4f}")
        print(f"ECE: {np.mean(ee_ece):.4f} +/- {np.std(ee_ece):.4f}")
        print(f"Acc: {np.mean(ee_acc):.4f} +/- {np.std(ee_acc):.4f}")
    elif err=='var':
        print(f"NLL: {np.mean(ee_nll):.4f} +/- {np.var(ee_nll):.4f}")
        print(f"ECE: {np.mean(ee_ece):.4f} +/- {np.var(ee_ece):.4f}")
        print(f"Acc: {np.mean(ee_acc):.4f} +/- {np.var(ee_acc):.4f}")



In [8]:
# 1 models
all_eq_f = [e69_out_inat.cpu(), e31_out_inat.cpu(), e24_out_inat.cpu()]
all_base_f = [b31_out_inat.cpu(), b31_out_inat.cpu()] #duplicated to prevent err
all_inv_f = [i24_out_inat.cpu(), i31_out_inat.cpu(), i69_out_inat.cpu()] #duplicated to prevent err

ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=3)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_B=1, num_comb=3)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_I=1, num_comb=3)

  
  
  if __name__ == "__main__":
  if __name__ == "__main__":
  # Remove the CWD from sys.path while we load stuff.
  # Remove the CWD from sys.path while we load stuff.


E
NLL: 1.8833 +/- 0.0060
ECE: 0.0450 +/- 0.0024
Acc: 0.5504 +/- 0.0019
B
NLL: 1.8776 +/- 0.0000
ECE: 0.0308 +/- 0.0000
Acc: 0.5489 +/- 0.0000
I
NLL: 1.8389 +/- 0.0112
ECE: 0.0435 +/- 0.0013
Acc: 0.5633 +/- 0.0021


In [9]:
# 2 models
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=2)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_I=1)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_I=2)

# 3 models
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=3)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=2, num_I=1)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_I=2)
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat, num_I=3)

  
  
  if __name__ == "__main__":
  if __name__ == "__main__":
  # Remove the CWD from sys.path while we load stuff.
  # Remove the CWD from sys.path while we load stuff.


EE
NLL: 1.7280 +/- 0.0041
ECE: 0.0256 +/- 0.0022
Acc: 0.5834 +/- 0.0007
EI
NLL: 1.6762 +/- 0.0078
ECE: 0.0376 +/- 0.0013
Acc: 0.5973 +/- 0.0028
II
NLL: 1.6951 +/- 0.0041
ECE: 0.0208 +/- 0.0010
Acc: 0.5937 +/- 0.0006
EEE
NLL: 1.6726 +/- 0.0000
ECE: 0.0494 +/- 0.0000
Acc: 0.5976 +/- 0.0000
EEI
NLL: 1.6237 +/- 0.0034
ECE: 0.0620 +/- 0.0015
Acc: 0.6098 +/- 0.0008
EII
NLL: 1.6184 +/- 0.0050
ECE: 0.0599 +/- 0.0017
Acc: 0.6117 +/- 0.0015
III
NLL: 1.6410 +/- 0.0000
ECE: 0.0419 +/- 0.0000
Acc: 0.6061 +/- 0.0000


In [None]:
# 1 models
all_eq_f = [e69_out_inat.cpu(), e69_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

all_eq_f = [e31_out_inat.cpu(),e31_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

all_eq_f = [e24_out_inat.cpu(), e24_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

all_eq_f = [i31_out_inat.cpu(), i31_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

all_eq_f = [i24_out_inat.cpu(), i24_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

all_eq_f = [i69_out_inat.cpu(), i69_out_inat.cpu()]
ensem_BEI(all_eq_f, all_base_f, all_inv_f,tar_inat,num_E=1, num_comb=1)

In [10]:
same_tar = b31_tar.cpu()
num_comb = 1
# all_eq_exR = [eq24_out, eq69_out, eq31_out]
# all_base_exR = [b24_out, b24_out, b31_out]
# all_inv_exR = [inv24_out, inv69_out, inv31_out]
all_eq_exR = all_eq_f
all_base_exR = all_base_f
all_inv_exR = all_inv_f
for i in range(num_comb):
    [eq1] = np.random.choice(all_eq_exR, 1)
    # [b1] = np.random.choice(all_base_exR, 1)
    [inv1] = np.random.choice(all_inv_exR, 1) 
    # _,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
    _,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
    _,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
    get_classwise_ei(acc_pc_inv, acc_pc_eq, num_classes=num_classes)
    


  # Remove the CWD from sys.path while we load stuff.
  # Remove the CWD from sys.path while we load stuff.
  if sys.path[0] == "":
  if sys.path[0] == "":


eq
NLL: 1.8738282918930054 | ECE: 0.04541622946110569
Acc: 0.5531028509140015
inv
NLL: 1.8309305906295776 | ECE: 0.04260874650000008
Acc: 0.5648025274276733
I uniquely best: 0.42178216576576233
E uniquely best: 0.3316831588745117
I and E equal: 0.24653466045856476


In [None]:
# CIFAR-100
dataset_name = 'cifar100'
num_classes = 100
lr = '0.2'
loader = c100_loader

b69 = load_1_model(f"trans_{dataset_name}_base69_cos_lr{lr}_bs256", num_classes=num_classes)
b69_out, b69_tar = rollout_loader(b69, loader)
same_tar = b69_tar
get_metrics([b69_out], [same_tar],['b69'], num_classes=num_classes)
eq69 = load_1_model(f"trans_{dataset_name}_eq69_cos_lr{lr}_bs256", num_classes=num_classes)
eq69_out, eq69_tar = rollout_loader(eq69, loader)
get_metrics([eq69_out], [same_tar],['eq69'], num_classes=num_classes)

inv69 = load_1_model(f"trans_{dataset_name}_inv69_cos_lr{lr}_bs256", num_classes=num_classes)
inv69_out, inv69_tar = rollout_loader(inv69, loader)
get_metrics([inv69_out], [same_tar],['inv69'], num_classes=num_classes)

assert(torch.equal(b69_tar,eq69_tar))

b24 = load_1_model(f"trans_{dataset_name}_base24_cos_lr{lr}_bs256", num_classes=num_classes)
b24_out, _ = rollout_loader(b24, loader)
get_metrics([b24_out], [same_tar],['b24'], num_classes=num_classes)

eq24 = load_1_model(f"trans_{dataset_name}_eq24_cos_lr{lr}_bs256", num_classes=num_classes)
eq24_out, _ = rollout_loader(eq24, loader)
get_metrics([eq24_out], [same_tar],['eq24'], num_classes=num_classes)

inv24 = load_1_model(f"trans_{dataset_name}_inv24_cos_lr{lr}_bs256", num_classes=num_classes)
inv24_out, _ = rollout_loader(inv24, loader)
get_metrics([inv24_out], [same_tar],['inv24'], num_classes=num_classes)


b31 = load_1_model(f"trans_{dataset_name}_base31_cos_lr{lr}_bs256", num_classes=num_classes)
b31_out, _ = rollout_loader(b31, loader)
get_metrics([b31_out], [same_tar],['b31'], num_classes=num_classes)

eq31 = load_1_model(f"trans_{dataset_name}_eq31_cos_lr{lr}_bs256", num_classes=num_classes)
eq31_out, _ = rollout_loader(eq31, loader)
get_metrics([eq31_out], [same_tar],['eq31'], num_classes=num_classes)

inv31 = load_1_model(f"trans_{dataset_name}_inv31_cos_lr{lr}_bs256", num_classes=num_classes)
inv31_out, _ = rollout_loader(inv31, loader)
get_metrics([inv31_out], [same_tar],['inv31'], num_classes=num_classes)



In [None]:
dataset_name = 'cifar100'
num_classes = 100
lr = '0.2'
loader = c100_loader

eq42 = load_1_model(f"trans_{dataset_name}_eq42_cos_lr{lr}_bs256", num_classes=num_classes)
eq42_out, eq42_tar = rollout_loader(eq42, loader)
get_metrics([eq42_out], [eq42_tar],['eq42'], num_classes=num_classes)


In [None]:
same_tar = eq42_tar
all_eq = [eq69_out, eq24_out, eq31_out, eq42_out]
all_base = [b69_out, b24_out, b31_out]
all_inv = [inv69_out, inv24_out, inv31_out]

# 1 models
ensem_BEI(all_eq, all_base, all_inv, same_tar, num_E=1)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_B=1)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_I=1)

# 2 models
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=2)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=1, num_I=1)

# 3 models
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=3)
ensem_BEI(all_eq, all_base, all_inv, same_tar,num_E=2, num_I=1)


In [None]:
(91.2+91.9)/2

In [None]:
(91.2+91.9+91.9)/3

In [None]:
87.03-(85.5+84.0+85.5)/3

In [None]:
86.54-(85.5+84.0)/2

In [None]:
b1 = b24_out
eq1 = eq42_out
inv1 = inv69_out

_,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
_,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
_,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
get_classwise(acc_pc_b, acc_pc_inv, acc_pc_eq, num_classes=num_classes)
  

In [None]:
same_tar = b69_tar
num_comb = 1
all_eq_exR = [eq24_out, eq69_out, eq31_out]
all_base_exR = [b24_out, b69_out, b31_out]
all_inv_exR = [inv24_out, inv69_out, inv31_out]
for i in range(num_comb):
    [eq1] = np.random.choice(all_eq_exR, 1)
    [b1] = np.random.choice(all_base_exR, 1)
    [inv1] = np.random.choice(all_inv_exR, 1) 
    _,_,_,acc_pc_b = get_metrics([b1], [same_tar],['base'], num_classes=num_classes)
    _,_,_,acc_pc_eq = get_metrics([eq1], [same_tar],['eq'], num_classes=num_classes)
    _,_,_,acc_pc_inv = get_metrics([inv1], [same_tar],['inv'], num_classes=num_classes)
    get_classwise(acc_pc_b, acc_pc_inv, acc_pc_eq, num_classes=num_classes)
    


In [None]:
num_comb =1 
eee_nll = []
eee_ece = []
eee_acc = []
for i in range(num_comb):
    [eq1, eq2, eq3] = np.random.choice(all_eq_exR, 3, replace=False)
    eee_out = (eq1.softmax(-1) + eq2.softmax(-1) + eq3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([eee_out],[same_tar],[f'EEE_comb{i}'], input_softmax=True, num_classes=num_classes)    
    ag, dag, kld, std_logits, std = get_div_metrics(eq1,eq2,eq3, same_tar)

    eee_nll.append(nll)
    eee_ece.append(ece)  
    eee_acc.append(acc)   

bbb_nll = []
bbb_ece = []
bbb_acc = []
for i in range(num_comb):
    [b1, b2, b3] = np.random.choice(all_base_exR, 3, replace=False)
    bbb_out = (b1.softmax(-1) + b2.softmax(-1) + b3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([bbb_out],[same_tar],[f'BBB_comb{i}'], input_softmax=True, num_classes=num_classes)      
    ag, dag, kld, std_logits, std = get_div_metrics(b1,b2,b3, same_tar)

    bbb_nll.append(nll)
    bbb_ece.append(ece)  
    bbb_acc.append(acc) 

iii_nll = []
iii_ece = []
iii_acc = []
for i in range(num_comb):
    [i1, i2, i3] = np.random.choice(all_inv_exR, 3, replace=False)
    iii_out = (i1.softmax(-1) + i2.softmax(-1) + i3.softmax(-1))/3
    nll, ece, acc, _ = get_metrics([iii_out],[same_tar],[f'III_comb{i}'], input_softmax=True, num_classes=num_classes)      
    ag, dag, kld, std_logits, std = get_div_metrics(i1,i2,i3, same_tar)
    
    iii_nll.append(nll)
    iii_ece.append(ece)  
    iii_acc.append(acc) 
    
bei_nll = []
bei_ece = []
bei_acc = []
for i in range(num_comb):
#     [eq1] = np.random.choice(all_eq_exR, 1)
#     [base1] = np.random.choice(all_base_exR, 1)
#     [inv1] = np.random.choice(all_inv_exR, 1)
    [eq1] = np.random.choice([eq31_out], 1)
    [base1] = np.random.choice([b31_out], 1)
    [inv1] = np.random.choice([inv24_out], 1)
    bei_out = (eq1.softmax(-1) + base1.softmax(-1) + inv1.softmax(-1))/3
    
    nll, ece, acc, _ = get_metrics([bei_out],[same_tar],[f'BEI_comb{i}'], input_softmax=True, num_classes=num_classes)     
    ag, dag, kld, std_logits, std = get_div_metrics(eq1,base1,inv1, same_tar)
    
    bei_nll.append(nll)
    bei_ece.append(ece)  
    bei_acc.append(acc) 

print(f"\nEEE Acc: {np.mean(eee_acc)} +/- {np.std(eee_acc)}")
print(f"EEE ECE: {np.mean(eee_ece)} +/- {np.std(eee_ece)}")
print(f"EEE NLL: {np.mean(eee_nll)} +/- {np.std(eee_nll)}")

print(f"\nBBB Acc: {np.mean(bbb_acc)} +/- {np.std(bbb_acc)}")
print(f"BBB ECE: {np.mean(bbb_ece)} +/- {np.std(bbb_ece)}")
print(f"BBB NLL: {np.mean(bbb_nll)} +/- {np.std(bbb_nll)}")

print(f"\nIII Acc: {np.mean(iii_acc)} +/- {np.std(iii_acc)}")
print(f"III ECE: {np.mean(iii_ece)} +/- {np.std(iii_ece)}")
print(f"III NLL: {np.mean(iii_nll)} +/- {np.std(iii_nll)}")

print(f"\nBEI Acc: {np.mean(bei_acc)} +/- {np.std(bei_acc)}")
print(f"BEI ECE: {np.mean(bei_ece)} +/- {np.std(bei_ece)}")
print(f"BEI NLL: {np.mean(bei_nll)} +/- {np.std(bei_nll)}")


In [None]:
ag, dag, kld, std_logits, std = get_div_metrics(b24_out,eq24_out,inv24_out, same_tar)


In [None]:
ag, dag, kld, std_logits, std = get_div_metrics(eq69_out,eq24_out,eq31_out, same_tar)


In [None]:
ag, dag, kld, std_logits, std = get_div_metrics(inv69_out,inv24_out,inv31_out, same_tar)


In [None]:
ag, dag, kld, std_logits, std = get_div_metrics(b69_out,b24_out,b31_out, same_tar)
