In [1]:
import os
from jax import numpy as jnp
import numpy as onp
import jax
import tensorflow.compat.v2 as tf
import argparse
import time
from collections import OrderedDict

from bnn_hmc.utils import data_utils
from bnn_hmc.utils import models
from bnn_hmc.utils import losses
from bnn_hmc.utils import checkpoint_utils
from bnn_hmc.utils import cmd_args_utils
from bnn_hmc.utils import logging_utils
from bnn_hmc.utils import train_utils
from bnn_hmc.utils import precision_utils
from bnn_hmc.utils import metrics

In [2]:
def get_mean_std(arr):
    arr = onp.asarray(arr)
    return arr.mean(), arr.std()

## IMDB

In [3]:
dtype = jnp.float32
train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
    "imdb", dtype)

net_apply, net_init = models.get_model("cnn_lstm", data_info)
net_apply = precision_utils.rewrite_high_precision(net_apply)

labels = test_set[1]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz


  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])


In [4]:
(_, predict_fn, _, _,_) = train_utils.get_task_specific_fns(task, data_info)

In [5]:
all_preds = []
for seed in range(12):
    try:
        checkpoint_dict = checkpoint_utils.load_checkpoint(
            "../runs/sgd/imdb/sgd_wd_3.0_stepsize_3e-07_batchsize_80_momentum_0.9_seed_{}" \
            "/model_step_499.pt".format(seed))
        _, params, net_state, _, _ = (
            checkpoint_utils.parse_sgd_checkpoint_dict(checkpoint_dict))
        predictions = onp.asarray(predict_fn(net_apply, params, net_state, test_set))
        all_preds.append(predictions.copy())
    except:
        pass
all_preds = onp.stack(all_preds)

ValueError: need at least one array to stack

In [28]:
accs = [metrics.accuracy(pred, labels) for pred in all_preds]
nlls =  [metrics.nll(pred, labels) for pred in all_preds]
eces =  [metrics.calibration_curve(pred, labels)["ece"] for pred in all_preds]
print("{:.4f} +- {:.4f}".format(*get_mean_std(accs)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(eces)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(nlls)))

0.8294 +- 0.0058
0.1299 +- 0.0112
0.7554 +- 0.1450


In [30]:
ens_preds = all_preds.mean(axis=0)
print(metrics.accuracy(ens_preds, labels))
print(metrics.calibration_curve(ens_preds, labels)["ece"])
print(metrics.nll(ens_preds, labels))

0.85608
0.04337276110887527
0.37756163


## CIFAR-100

In [4]:
dtype = jnp.float32
train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
    "cifar100", dtype)

net_apply, net_init = models.get_model("resnet20_frn_swish", data_info)
net_apply = precision_utils.rewrite_high_precision(net_apply)

labels = test_set[1]

In [5]:
(_, predict_fn, _, _,_) = train_utils.get_task_specific_fns(task, data_info)

In [6]:
all_preds = []
for seed in range(12):
    try:
        checkpoint_dict = checkpoint_utils.load_checkpoint(
            "../runs/sgd/cifar100/sgd_wd_10.0_stepsize_1e-06_batchsize_80_momentum_0.9_seed_{}/model_step_499.pt".format(seed))
        _, params, net_state, _, _ = (
            checkpoint_utils.parse_sgd_checkpoint_dict(checkpoint_dict))
        predictions = onp.asarray(predict_fn(net_apply, params, net_state, test_set))
        all_preds.append(predictions.copy())
    except:
        pass
all_preds = onp.stack(all_preds)

In [7]:
accs = [metrics.accuracy(pred, labels) for pred in all_preds]
nlls =  [metrics.nll(pred, labels) for pred in all_preds]
eces =  [metrics.calibration_curve(pred, labels)["ece"] for pred in all_preds]
print("{:.4f} +- {:.4f}".format(*get_mean_std(accs)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(eces)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(nlls)))

0.5004 +- 0.0130
0.2068 +- 0.0107
2.3171 +- 0.0924


In [8]:
ens_preds = all_preds.mean(axis=0)
print(metrics.accuracy(ens_preds, labels))
print(metrics.calibration_curve(ens_preds, labels)["ece"])
print(metrics.nll(ens_preds, labels))

0.6464
0.117816821230948
1.3727252


## CIFAR-10

In [6]:
dtype = jnp.float32
train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(
    "cifar10", dtype)

net_apply, net_init = models.get_model("resnet20_frn_swish", data_info)
net_apply = precision_utils.rewrite_high_precision(net_apply)

labels = test_set[1]

In [7]:
(_, predict_fn, _, _,_) = train_utils.get_task_specific_fns(task, data_info)

In [9]:
all_preds = []
for seed in range(12):
    try:
        checkpoint_dict = checkpoint_utils.load_checkpoint(
            "../runs/sgd/cifar10/sgd_wd_10.0_stepsize_3e-07_batchsize_80_momentum_0.9_seed_{}/model_step_499.pt".format(seed))
        _, params, net_state, _, _ = (
            checkpoint_utils.parse_sgd_checkpoint_dict(checkpoint_dict))
        predictions = onp.asarray(predict_fn(net_apply, params, net_state, test_set))
        all_preds.append(predictions.copy())
    except:
        pass
all_preds = onp.stack(all_preds)

In [10]:
accs = [metrics.accuracy(pred, labels) for pred in all_preds]
nlls =  [metrics.nll(pred, labels) for pred in all_preds]
eces =  [metrics.calibration_curve(pred, labels)["ece"] for pred in all_preds]
print("{:.4f} +- {:.4f}".format(*get_mean_std(accs)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(eces)))
print("{:.4f} +- {:.4f}".format(*get_mean_std(nlls)))

0.8548 +- 0.0049
0.1002 +- 0.0061
0.6743 +- 0.0446


In [11]:
ens_preds = all_preds.mean(axis=0)
print(metrics.accuracy(ens_preds, labels))
print(metrics.calibration_curve(ens_preds, labels)["ece"])
print(metrics.nll(ens_preds, labels))

0.89919996
0.018649099725484845
0.31519425
