In [1]:
import os
import sys
import torch
from tabulate import tabulate
from IPython.display import display, HTML
sys.path.append("../")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


from giung2.config import get_cfg
from giung2.data.build import build_dataloaders
from giung2.modeling.build import build_model
from giung2.evaluation import (
    evaluate_acc, evaluate_nll, evaluate_bs, evaluate_ece,
    get_optimal_temperature,
)

In [2]:
def get_de_predictions(model, dataloader, ensemble_size, get_ith_weight_file):
    true_labels = [] # [num_examples, ens_size, num_classes]
    pred_logits = [] # [num_examples,]
    for images, labels in dataloader:
        images = images.cuda()
        for idx in range(ensemble_size):
            model.load_state_dict(
                torch.load(
                    get_ith_weight_file(idx), map_location="cpu"
                )["model_state_dict"]
            )
            if idx == 0:
                logits = model(images)["logits"][:, None, :].cpu()
            else:
                logits = torch.cat([
                    logits, model(images)["logits"][:, None, :].cpu()
                ], dim=1)
        pred_logits.append(logits.cpu())
        true_labels.append(labels.cpu())
    return torch.cat(pred_logits), torch.cat(true_labels)


def get_be_predictions(model, dataloader, ensemble_size):
    true_labels = [] # [num_examples, ens_size, num_classes]
    pred_logits = [] # [num_examples,]
    for images, labels in dataloader:
        images = images.cuda().repeat(ensemble_size, 1, 1, 1)
        logits = model(images)["logits"]
        logits = torch.stack(
            torch.split(logits, logits.size(0) // ensemble_size), dim=1
        )
        pred_logits.append(logits.cpu())
        true_labels.append(labels.cpu())
    return torch.cat(pred_logits), torch.cat(true_labels)

# WRN28x10 on CIFAR-100

In [3]:
DATA = []

## DeepEns-4

In [4]:
ensemble_size = 4

# load config file
cfg = get_cfg()
cfg.merge_from_file("../configs/C100_WRN28x10_SGD.yaml", allow_unsafe=True)
cfg.NUM_GPUS = 1

# build model
model = build_model(cfg).cuda().eval()

# build dataloaders
dataloaders = build_dataloaders(cfg, root="../datasets")

# configure path for deep ensembles
get_ith_weight_file = lambda idx: os.path.join(f"../outputs/C100_WRN28x10_SGD_{idx}", "best_acc1.pth.tar")

# disable grad
torch.set_grad_enabled(False)

# make predictions on valid split
val_pred_logits, val_true_labels = get_de_predictions(
    model, dataloaders["val_loader"], ensemble_size, get_ith_weight_file
)
val_confidences = torch.softmax(val_pred_logits, dim=2)

# make predictions on test split
tst_pred_logits, tst_true_labels = get_de_predictions(
    model, dataloaders["tst_loader"], ensemble_size, get_ith_weight_file
)
tst_confidences = torch.softmax(tst_pred_logits, dim=2)

# make evaluation results
for e in range(ensemble_size):
    t_opt = get_optimal_temperature(val_confidences[:, :e+1, :].mean(1), val_true_labels)
    DATA.append([
        f"DeepEns-{e+1}",
        evaluate_acc(                  tst_confidences[:, :e+1, :].mean(1),                             tst_true_labels) * 100,
        evaluate_nll(                  tst_confidences[:, :e+1, :].mean(1),                             tst_true_labels),
        evaluate_bs(                   tst_confidences[:, :e+1, :].mean(1),                             tst_true_labels),
        evaluate_ece(                  tst_confidences[:, :e+1, :].mean(1),                             tst_true_labels),
        evaluate_nll(torch.log_softmax(tst_confidences[:, :e+1, :].mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
        evaluate_bs( torch.log_softmax(tst_confidences[:, :e+1, :].mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
        evaluate_ece(torch.log_softmax(tst_confidences[:, :e+1, :].mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    ])

print(tabulate(DATA, headers=["Label", "ACC", "NLL", "BS", "ECE", "cNLL", "cBS", "cECE",], floatfmt=["", ".2f",] + [".3f"] * 6))
print()

Label        ACC    NLL     BS    ECE    cNLL    cBS    cECE
---------  -----  -----  -----  -----  ------  -----  ------
DeepEns-1  80.22  0.789  0.282  0.042   0.789  0.282   0.041
DeepEns-2  81.90  0.713  0.261  0.033   0.708  0.260   0.031
DeepEns-3  82.46  0.684  0.253  0.032   0.673  0.251   0.027
DeepEns-4  82.54  0.670  0.249  0.033   0.655  0.246   0.026



## BatchEns-4

In [5]:
ensemble_size = 4

# load config file
cfg = get_cfg()
cfg.merge_from_file("../configs/C100_WRN28x10_BE4.yaml", allow_unsafe=True)
cfg.NUM_GPUS = 1

# build model
model = build_model(cfg).cuda().eval()
model.load_state_dict(torch.load("../outputs/C100_WRN28x10_BE4_KD_0/best_acc1.pth.tar", map_location="cpu")["model_state_dict"])

# build dataloaders
dataloaders = build_dataloaders(cfg, root="../datasets")

# disable grad
torch.set_grad_enabled(False)

# make predictions on valid split
val_pred_logits, val_true_labels = get_be_predictions(model, dataloaders["val_loader"], ensemble_size)
val_confidences = torch.softmax(val_pred_logits, dim=2)

t_opt = get_optimal_temperature(val_confidences.mean(1), val_true_labels)

# make predictions on test split
tst_pred_logits, tst_true_labels = get_be_predictions(model, dataloaders["tst_loader"], ensemble_size)
tst_confidences = torch.softmax(tst_pred_logits, dim=2)

DATA.append([
    "BatchEns-4 (KD)",
    evaluate_acc(                  tst_confidences.mean(1),                             tst_true_labels) * 100,
    evaluate_nll(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_bs(                   tst_confidences.mean(1),                             tst_true_labels),
    evaluate_ece(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_nll(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_bs( torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_ece(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
])

print(tabulate(DATA, headers=["Label", "ACC", "NLL", "BS", "ECE", "cNLL", "cBS", "cECE",], floatfmt=["", ".2f",] + [".3f"] * 6))
print()

Label              ACC    NLL     BS    ECE    cNLL    cBS    cECE
---------------  -----  -----  -----  -----  ------  -----  ------
DeepEns-1        80.22  0.789  0.282  0.042   0.789  0.282   0.041
DeepEns-2        81.90  0.713  0.261  0.033   0.708  0.260   0.031
DeepEns-3        82.46  0.684  0.253  0.032   0.673  0.251   0.027
DeepEns-4        82.54  0.670  0.249  0.033   0.655  0.246   0.026
BatchEns-4 (KD)  80.40  0.804  0.286  0.072   0.750  0.277   0.021



In [6]:
ensemble_size = 4

# load config file
cfg = get_cfg()
cfg.merge_from_file("../configs/C100_WRN28x10_BE4.yaml", allow_unsafe=True)
cfg.NUM_GPUS = 1

# build model
model = build_model(cfg).cuda().eval()
model.load_state_dict(torch.load("../outputs/C100_WRN28x10_BE4_KDGaussian_0/best_acc1.pth.tar", map_location="cpu")["model_state_dict"])

# build dataloaders
dataloaders = build_dataloaders(cfg, root="../datasets")

# disable grad
torch.set_grad_enabled(False)

# make predictions on valid split
val_pred_logits, val_true_labels = get_be_predictions(model, dataloaders["val_loader"], ensemble_size)
val_confidences = torch.softmax(val_pred_logits, dim=2)

t_opt = get_optimal_temperature(val_confidences.mean(1), val_true_labels)

# make predictions on test split
tst_pred_logits, tst_true_labels = get_be_predictions(model, dataloaders["tst_loader"], ensemble_size)
tst_confidences = torch.softmax(tst_pred_logits, dim=2)

DATA.append([
    "BatchEns-4 (KD + Gaussian)",
    evaluate_acc(                  tst_confidences.mean(1),                             tst_true_labels) * 100,
    evaluate_nll(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_bs(                   tst_confidences.mean(1),                             tst_true_labels),
    evaluate_ece(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_nll(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_bs( torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_ece(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
])

print(tabulate(DATA, headers=["Label", "ACC", "NLL", "BS", "ECE", "cNLL", "cBS", "cECE",], floatfmt=["", ".2f",] + [".3f"] * 6))
print()

Label                         ACC    NLL     BS    ECE    cNLL    cBS    cECE
--------------------------  -----  -----  -----  -----  ------  -----  ------
DeepEns-1                   80.22  0.789  0.282  0.042   0.789  0.282   0.041
DeepEns-2                   81.90  0.713  0.261  0.033   0.708  0.260   0.031
DeepEns-3                   82.46  0.684  0.253  0.032   0.673  0.251   0.027
DeepEns-4                   82.54  0.670  0.249  0.033   0.655  0.246   0.026
BatchEns-4 (KD)             80.40  0.804  0.286  0.072   0.750  0.277   0.021
BatchEns-4 (KD + Gaussian)  80.04  0.816  0.288  0.075   0.760  0.277   0.020



In [7]:
ensemble_size = 4

# load config file
cfg = get_cfg()
cfg.merge_from_file("../configs/C100_WRN28x10_BE4.yaml", allow_unsafe=True)
cfg.NUM_GPUS = 1

# build model
model = build_model(cfg).cuda().eval()
model.load_state_dict(torch.load("../outputs/C100_WRN28x10_BE4_KDODS_0/best_acc1.pth.tar", map_location="cpu")["model_state_dict"])

# build dataloaders
dataloaders = build_dataloaders(cfg, root="../datasets")

# disable grad
torch.set_grad_enabled(False)

# make predictions on valid split
val_pred_logits, val_true_labels = get_be_predictions(model, dataloaders["val_loader"], ensemble_size)
val_confidences = torch.softmax(val_pred_logits, dim=2)

t_opt = get_optimal_temperature(val_confidences.mean(1), val_true_labels)

# make predictions on test split
tst_pred_logits, tst_true_labels = get_be_predictions(model, dataloaders["tst_loader"], ensemble_size)
tst_confidences = torch.softmax(tst_pred_logits, dim=2)

DATA.append([
    "BatchEns-4 (KD + ODS)",
    evaluate_acc(                  tst_confidences.mean(1),                             tst_true_labels) * 100,
    evaluate_nll(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_bs(                   tst_confidences.mean(1),                             tst_true_labels),
    evaluate_ece(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_nll(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_bs( torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_ece(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
])

print(tabulate(DATA, headers=["Label", "ACC", "NLL", "BS", "ECE", "cNLL", "cBS", "cECE",], floatfmt=["", ".2f",] + [".3f"] * 6))
print()

Label                         ACC    NLL     BS    ECE    cNLL    cBS    cECE
--------------------------  -----  -----  -----  -----  ------  -----  ------
DeepEns-1                   80.22  0.789  0.282  0.042   0.789  0.282   0.041
DeepEns-2                   81.90  0.713  0.261  0.033   0.708  0.260   0.031
DeepEns-3                   82.46  0.684  0.253  0.032   0.673  0.251   0.027
DeepEns-4                   82.54  0.670  0.249  0.033   0.655  0.246   0.026
BatchEns-4 (KD)             80.40  0.804  0.286  0.072   0.750  0.277   0.021
BatchEns-4 (KD + Gaussian)  80.04  0.816  0.288  0.075   0.760  0.277   0.020
BatchEns-4 (KD + ODS)       81.92  0.685  0.258  0.026   0.682  0.258   0.026



In [8]:
ensemble_size = 4

# load config file
cfg = get_cfg()
cfg.merge_from_file("../configs/C100_WRN28x10_BE4.yaml", allow_unsafe=True)
cfg.NUM_GPUS = 1

# build model
model = build_model(cfg).cuda().eval()
model.load_state_dict(torch.load("../outputs/C100_WRN28x10_BE4_KDConfODS_0/best_acc1.pth.tar", map_location="cpu")["model_state_dict"])

# build dataloaders
dataloaders = build_dataloaders(cfg, root="../datasets")

# disable grad
torch.set_grad_enabled(False)

# make predictions on valid split
val_pred_logits, val_true_labels = get_be_predictions(model, dataloaders["val_loader"], ensemble_size)
val_confidences = torch.softmax(val_pred_logits, dim=2)

t_opt = get_optimal_temperature(val_confidences.mean(1), val_true_labels)

# make predictions on test split
tst_pred_logits, tst_true_labels = get_be_predictions(model, dataloaders["tst_loader"], ensemble_size)
tst_confidences = torch.softmax(tst_pred_logits, dim=2)

DATA.append([
    "BatchEns-4 (KD + ConfODS)",
    evaluate_acc(                  tst_confidences.mean(1),                             tst_true_labels) * 100,
    evaluate_nll(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_bs(                   tst_confidences.mean(1),                             tst_true_labels),
    evaluate_ece(                  tst_confidences.mean(1),                             tst_true_labels),
    evaluate_nll(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_bs( torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
    evaluate_ece(torch.log_softmax(tst_confidences.mean(1).log() / t_opt, dim=1).exp(), tst_true_labels),
])

print(tabulate(DATA, headers=["Label", "ACC", "NLL", "BS", "ECE", "cNLL", "cBS", "cECE",], floatfmt=["", ".2f",] + [".3f"] * 6))
print()

Label                         ACC    NLL     BS    ECE    cNLL    cBS    cECE
--------------------------  -----  -----  -----  -----  ------  -----  ------
DeepEns-1                   80.22  0.789  0.282  0.042   0.789  0.282   0.041
DeepEns-2                   81.90  0.713  0.261  0.033   0.708  0.260   0.031
DeepEns-3                   82.46  0.684  0.253  0.032   0.673  0.251   0.027
DeepEns-4                   82.54  0.670  0.249  0.033   0.655  0.246   0.026
BatchEns-4 (KD)             80.40  0.804  0.286  0.072   0.750  0.277   0.021
BatchEns-4 (KD + Gaussian)  80.04  0.816  0.288  0.075   0.760  0.277   0.020
BatchEns-4 (KD + ODS)       81.92  0.685  0.258  0.026   0.682  0.258   0.026
BatchEns-4 (KD + ConfODS)   82.25  0.670  0.253  0.023   0.665  0.252   0.023

