In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
# plt.style.use('ggplot')
%matplotlib inline

In [2]:
import os
import logging
from functools import partial

from jax import random
from torchvision import transforms
import wandb
from flax.training.checkpoints import save_checkpoint, restore_checkpoint

from src.models import make_Hard_OvR_Ens_loss as make_prod_loss
from src.models import make_Hard_OvR_Ens_MNIST_plots as make_plots
from src.models import make_Cls_Ens_loss as make_ens_loss
from src.data import get_image_dataset, NumpyLoader
from src.utils.training import setup_training, train_loop
from experiments.configs.mnist_hard_ovr_classification import get_config

2022-10-07 14:45:35.232271: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-07 14:45:35.803557: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64
2022-10-07 14:45:35.803640: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64


In [3]:
os.environ['XLA_FLAGS'] = "--xla_gpu_force_compilation_parallelism=1"

In [4]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'cls_comparison_figures.ipynb'
# ^ W&B doesn't know how to handle VS Code notebooks.

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmetodj[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
config = get_config()

In [6]:
train_dataset, _, val_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    val_percent=config.val_percent,
    flatten_img=True,
    train_augmentations=[
        # transforms.RandomCrop(28, padding=2),
        # transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), shear=5),
        # transforms.RandomHorizontalFlip(),
        # transforms.
    ]
)
ens_train_loader = NumpyLoader(train_dataset, config.batch_size, num_workers=8)
ens_val_loader = NumpyLoader(val_dataset, config.batch_size, num_workers=8)

train_dataset, test_dataset, val_dataset = get_image_dataset(
    dataset_name=config.dataset_name,
    val_percent=config.val_percent,
    flatten_img=True,
    train_augmentations=[
        # transforms.RandomCrop(28, padding=2),
        # transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), shear=5),
        # transforms.RandomHorizontalFlip(),
        # transforms.
    ]
)
train_loader = NumpyLoader(train_dataset, config.batch_size, num_workers=8)
val_loader = NumpyLoader(val_dataset, config.batch_size, num_workers=8)
test_loader = NumpyLoader(test_dataset, config.batch_size, num_workers=8)

In [7]:
ens_config = config.copy_and_resolve_references()
ens_config.model_name = 'Cls_Ens'
del ens_config.β_schedule

In [8]:
best_ens_states = []
best_prod_states = []

RECOMPUTE_PROD = False
RECOMPUTE_ENS = False
RECOMPUTE_PROD_STD = True
PRETRAIN_PROD = False

for i in range(0, 3):
    rng = random.PRNGKey(i)
    setup_rng, rng = random.split(rng)
    init_x = train_dataset[0][0]
    init_y = train_dataset[0][1]

    ens_model, state = setup_training(ens_config, setup_rng, init_x, init_y)

    if RECOMPUTE_ENS:
        _, ens_state = train_loop(
            ens_model, state, ens_config, rng, make_ens_loss, make_ens_loss, ens_train_loader, ens_val_loader,
            # test_loader,
            wandb_kwargs={
                'mode': 'offline',
                # 'notes': 'Data augmentation',
                # 'tags': ['MNIST testing'],
            },
            # plot_fn=make_plots,
            # plot_freq=1,
        )

        save_checkpoint(f'dynNN_redux/ens_model_{i}', ens_state, 1, overwrite=True)

    rng = random.PRNGKey(i)
    setup_rng, rng = random.split(rng)

    prod_model, state = setup_training(config, setup_rng, init_x, init_y)

    if RECOMPUTE_PROD:
        if PRETRAIN_PROD:
            state.replace(params=ens_state.params)
            # TODO: Also replace BN (model) state?

        _, prod_state = train_loop(
            prod_model, state, config, rng, partial(make_prod_loss, per_member_loss=0.5), partial(make_prod_loss, per_member_loss=0.5), train_loader, val_loader,
            # test_loader,
            wandb_kwargs={
                'mode': 'offline',
                # 'notes': 'pre-trained',
                # 'tags': ['MNIST testing', 'pre-trained'],
            },
            # plot_fn=make_plots,
            # plot_freq=1,
        )

        save_checkpoint(f'dynNN_redux/prod_model_{i}', prod_state, 1, overwrite=True)

    stdprod_model, state = setup_training(config, setup_rng, init_x, init_y)

    if RECOMPUTE_PROD_STD:
        rng = random.PRNGKey(i)
        setup_rng, rng = random.split(rng)

        _, stdprod_state = train_loop(
            stdprod_model, state, config, rng, partial(make_prod_loss, per_member_loss=0.00), partial(make_prod_loss, per_member_loss=0.00), train_loader, val_loader,
            # test_loader,
            wandb_kwargs={
                'mode': 'offline',
                # 'notes': 'pre-trained',
                # 'tags': ['MNIST testing', 'pre-trained'],
            },
            # plot_fn=make_plots,
            # plot_freq=1,
        )

        save_checkpoint(f'dynNN_redux/stdprod_model_{i}', stdprod_state, 1, overwrite=True)

+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

2022-10-07 14:45:50.528627: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-07 14:45:51.066706: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64
2022-10-07 14:45:51.066787: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64


  0%|          | 0/50 [00:00<?, ?it/s]

epoch:   1 - train loss: 0.45753, val loss: 0.19387, train err: 0.1678, val err: 0.0805, β: 2.1141, lr: 0.00360
epoch:   2 - train loss: 0.16268, val loss: 0.14535, train err: 0.0749, val err: 0.0667, β: 2.1391, lr: 0.00420
epoch:   3 - train loss: 0.11565, val loss: 0.11977, train err: 0.0603, val err: 0.0578, β: 2.1695, lr: 0.00480
epoch:   4 - train loss: 0.08809, val loss: 0.10431, train err: 0.0509, val err: 0.0510, β: 2.2065, lr: 0.00540
epoch:   5 - train loss: 0.06621, val loss: 0.09590, train err: 0.0427, val err: 0.0458, β: 2.2513, lr: 0.00600
epoch:   6 - train loss: 0.05286, val loss: 0.09098, train err: 0.0385, val err: 0.0447, β: 2.3058, lr: 0.00660
epoch:   7 - train loss: 0.04177, val loss: 0.08554, train err: 0.0343, val err: 0.0440, β: 2.3717, lr: 0.00720
epoch:   8 - train loss: 0.03357, val loss: 0.08675, train err: 0.0307, val err: 0.0430, β: 2.4513, lr: 0.00780
epoch:   9 - train loss: 0.02604, val loss: 0.08403, train err: 0.0281, val err: 0.0425, β: 2.5473, lr: 

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
learning_rate,▄▄▅▅▆▇▇██████▇▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁
train/err,█▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/err,█▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▃▃▂▂▂▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val/loss,█▅▃▂▂▁▁▁▁▁▁▁▂▂▂▂▃▄▄▅▄▅▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
β,▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▅▅▆▆▆▇▇▇▇▇██████████

0,1
best_epoch,44.0
best_val_err,0.02233
epoch,50.0
learning_rate,0.0003
train/err,0.00333
train/loss,0.00246
val/err,0.02383
val/loss,0.12618
β,15.90612


+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

+---------------------------------------------+------------+---------+-----------+--------+
| Name                                        | Shape      | Size    | Mean      | Std    |
+---------------------------------------------+------------+---------+-----------+--------+
| batch_stats/nets_0/layer_0/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_0/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_1/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_2/BatchNorm_0/var  | (200,)     | 200     | 1.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/mean | (200,)     | 200     | 0.0       | 0.0    |
| batch_stats/nets_0/layer_3/BatchNorm_0/var  | (200,)     | 200     | 1.0      

2022-10-07 14:48:30.150530: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-07 14:48:30.677486: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64
2022-10-07 14:48:30.677570: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.8/lib64


  0%|          | 0/50 [00:00<?, ?it/s]

epoch:   1 - train loss: 0.44737, val loss: 0.19776, train err: 0.1719, val err: 0.0830, β: 2.1141, lr: 0.00360
epoch:   2 - train loss: 0.16008, val loss: 0.14825, train err: 0.0767, val err: 0.0680, β: 2.1391, lr: 0.00420
epoch:   3 - train loss: 0.11423, val loss: 0.12469, train err: 0.0616, val err: 0.0603, β: 2.1695, lr: 0.00480
epoch:   4 - train loss: 0.08682, val loss: 0.10914, train err: 0.0529, val err: 0.0540, β: 2.2065, lr: 0.00540
epoch:   5 - train loss: 0.06820, val loss: 0.09900, train err: 0.0467, val err: 0.0497, β: 2.2513, lr: 0.00600
epoch:   6 - train loss: 0.05291, val loss: 0.09347, train err: 0.0404, val err: 0.0488, β: 2.3058, lr: 0.00660
epoch:   7 - train loss: 0.04043, val loss: 0.08398, train err: 0.0356, val err: 0.0455, β: 2.3717, lr: 0.00720
epoch:   8 - train loss: 0.03240, val loss: 0.08412, train err: 0.0320, val err: 0.0428, β: 2.4513, lr: 0.00780
epoch:   9 - train loss: 0.02629, val loss: 0.08163, train err: 0.0288, val err: 0.0413, β: 2.5473, lr: 

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


KeyboardInterrupt: 

## Paper plots

In [None]:
from  functools import partial
from itertools import combinations

import matplotlib
import pandas as pd
import jax
import jax.numpy as jnp
import flax.linen as nn
import distrax
from chex import assert_rank, assert_shape, assert_equal_shape

In [None]:
restored_prod_models = [restore_checkpoint(f'dynNN_redux/prod_model_{i}', 1) for i in range(3)] # 3
restored_ens_models = [restore_checkpoint(f'dynNN_redux/ens_model_{i}', 1) for i in range(3)] # 3
restored_stdprod_models = [restore_checkpoint(f'dynNN_redux/stdprod_model_{i}', 1) for i in range(3)] # 3

In [None]:
text_width = 6.75133 # in  --> Confirmed with template explanation
line_width = 3.25063
dpi = 400

fs_m1 = 7  # for figure ticks
fs = 8  # for regular figure text
fs_p1 = 9 #  figure titles

matplotlib.rc('font', size=fs)          # controls default text sizes
matplotlib.rc('axes', titlesize=fs)     # fontsize of the axes title
matplotlib.rc('axes', labelsize=fs)    # fontsize of the x and y labels
matplotlib.rc('xtick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('ytick', labelsize=fs_m1)    # fontsize of the tick labels
matplotlib.rc('legend', fontsize=fs_m1)    # legend fontsize
matplotlib.rc('figure', titlesize=fs_p1)  # fontsize of the figure title


matplotlib.rc('font', **{'family':'serif', 'serif': ['Palatino']})
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
X_test, y_test = list(zip(*test_loader.dataset))

In [None]:
def categorical_probs(logits):
    assert_rank(logits, 2)
    assert_shape([logits], (None, 10))
    probs = nn.softmax(logits.mean(axis=0))
    assert_shape([probs], (10,))
    return probs

def categorical_entropy(logits):
    probs = categorical_probs(logits)
    cat = distrax.Categorical(probs=probs)
    return cat.entropy()

def categorical_nll(logits, y):
    probs = categorical_probs(logits).clip(min=1e-36)
    cat = distrax.Categorical(probs=probs)
    return -cat.log_prob(y)

def mse(x, y):
    assert_equal_shape([x, y])
    assert_shape(x, (10,))
    return ((x - y)**2).mean()

def categorical_brier(logits, y):
    probs = categorical_probs(logits)
    return mse(probs, jax.nn.one_hot(y, 10))

def categorical_err(logits, y):
    probs = categorical_probs(logits)
    return y != jnp.argmax(probs, axis=0)

In [None]:
def multiply_no_nan(x, y):
    """Equivalent of TF `multiply_no_nan`.
    Computes the element-wise product of `x` and `y` and return 0 if `y` is zero,
    even if `x` is NaN or infinite.
    Args:
        x: First input.
        y: Second input.
    Returns:
        The product of `x` and `y`.
    Raises:
        ValueError if the shapes of `x` and `y` do not match.
    """
    dtype = jnp.result_type(x, y)
    return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y)

def ovr_prod_probs(logits):
    assert_rank(logits, 2)
    assert_shape([logits], (None, 10))
    σ = nn.sigmoid(logits).round().prod(axis=0)#.clip(min=1e-36)
    assert_shape([σ], (10,))
    probs = σ/(σ.sum() + 1e-36)
    return probs

def ovr_entropy(logits):
    probs = ovr_prod_probs(logits)
    return -jnp.sum(multiply_no_nan(jnp.log(probs), probs), axis=-1)

def ovr_nll(logits, y):
    probs = ovr_prod_probs(logits)
    return -jnp.log(probs[y])

def ovr_brier(logits, y):
    probs = ovr_prod_probs(logits)
    return mse(probs, jax.nn.one_hot(y, 10))

def ovr_err(logits, y):
    probs = ovr_prod_probs(logits)
    return y != probs.argmax(axis=0)

In [None]:
results_df = pd.DataFrame(columns=['model_name', 'n_members', 'random_seed', 'H', 'err', 'brier', 'nll'])
results_df

In [None]:
# pred_fun = partial(
#     ens_model.apply,
#     {"params": restored_ens_models[i]['params'], **restored_ens_models[i]['model_state']},
#     train=False,
#     method=ens_model.ens_logits
# )
# logits = jax.vmap(
#     pred_fun, axis_name="batch"
# )(jnp.array(X_test))

In [None]:
# for idx, indices in enumerate(power_set):
#     print(idx)
#     logits_ = logits[:, indices, :]
    
#     for i in range(logits_.shape[0]):
#         l = logits_[i]
#         y = y_test[i]
#         nll = categorical_nll(l, y)
#         if jnp.isnan(nll) or jnp.isinf(nll):
#             print(nll, l, y)

In [None]:
s = set(range(config.model.size))
power_set = sum(map(lambda r: list(combinations(s, r)), range(1, len(s)+1)), [])

for (model, model_name) in [(prod_model, 'Prod'), (ens_model, 'Ens'), (stdprod_model, 'StdProd')]:
    for i in range(3): # 3
        if model_name == 'Prod':
            state = restored_prod_models[i]
        elif model_name == 'Ens':
            state = restored_ens_models[i]
        elif model_name == 'StdProd':
            state = restored_stdprod_models[i]
        
        pred_fun = partial(
            model.apply,
            {"params": state['params'], **state['model_state']},
            train=False,
            method=model.ens_logits
        )
        logits = jax.vmap(
            pred_fun, axis_name="batch"
        )(jnp.array(X_test))

        for indices in power_set:
            n_members = len(indices)
            logits_ = logits[:, indices, :]

            if model_name == 'Prod':
                entropies = jax.vmap(ovr_entropy)(logits_)
                nlls_ = jax.vmap(ovr_nll)(logits_, jnp.array(y_test))
                infs = jnp.isinf(nlls_)
                print(f"dropping {infs.sum()} infs for prod of {n_members}")
                nlls = nlls_[~infs]
                briers = jax.vmap(ovr_brier)(logits_, jnp.array(y_test))
                errs = jax.vmap(ovr_err)(logits_, jnp.array(y_test))
            elif model_name == 'Ens':
                entropies = jax.vmap(categorical_entropy)(logits_)
                nlls_ = jax.vmap(categorical_nll)(logits_, jnp.array(y_test))
                infs = jnp.isinf(nlls_)
                if infs.sum() > 0:
                    print(f"dropping {infs.sum()} infs for Ens of {n_members}")
                    print(logits_[infs])
                    print(jnp.array(y_test)[infs])
                nlls = nlls_[~infs]
                briers = jax.vmap(categorical_brier)(logits_, jnp.array(y_test))
                errs = jax.vmap(categorical_err)(logits_, jnp.array(y_test))
            elif model_name == 'StdProd':
                entropies = jax.vmap(ovr_entropy)(logits_)
                nlls_ = jax.vmap(ovr_nll)(logits_, jnp.array(y_test))
                infs = jnp.isinf(nlls_)
                print(f"dropping {infs.sum()} infs for prod of {n_members}")
                nlls = nlls_[~infs]
                briers = jax.vmap(ovr_brier)(logits_, jnp.array(y_test))
                errs = jax.vmap(ovr_err)(logits_, jnp.array(y_test))

            results_df = pd.concat([
                results_df,
                pd.DataFrame({
                    'model_name': [model_name],
                    'n_members': [n_members],
                    'random_seed': [i],
                    'H': [entropies.mean()],
                    'nll': [nlls.mean()],
                    'err': [errs.mean()],
                    'brier': [briers.mean()],
                })],
                ignore_index=True
            )

In [None]:
min_mse_df = results_df[results_df.n_members == config.model.size][['model_name', 'random_seed', 'err', 'nll', 'brier']].rename(
    columns={'err': 'final_err', 'nll': 'final_nll', 'brier': 'final_brier'}
)
min_mse_df

In [None]:
tmp_df = results_df.merge(min_mse_df, on=['model_name', 'random_seed'], how='left')
tmp_df['err_diff'] = tmp_df['err'] - tmp_df['final_err'] 
tmp_df['nll_diff'] = tmp_df['nll'] - tmp_df['final_nll'] 
tmp_df['brier_diff'] = tmp_df['brier'] - tmp_df['final_brier'] 
tmp_df

In [None]:
agg_df = tmp_df.groupby(by=['model_name', 'n_members']).agg({
    'H': ['mean', 'std', 'count'],
    'err_diff': ['mean', 'std', 'count'],
    'err': ['mean', 'std', 'count'],
    'nll_diff': ['mean', 'std', 'count'],
    'nll': ['mean', 'std', 'count'],
    'brier_diff': ['mean', 'std', 'count'],
    'brier': ['mean', 'std', 'count'],
})
agg_df

In [None]:
agg_df[('H', 'std_err')] = agg_df[('H', 'std')] / agg_df[('H', 'count')]
agg_df[('err_diff', 'std_err')] = agg_df[('err_diff', 'std')] / agg_df[('err_diff', 'count')]
agg_df[('err', 'std_err')] = agg_df[('err', 'std')] / agg_df[('err', 'count')]
agg_df[('nll_diff', 'std_err')] = agg_df[('nll_diff', 'std')] / agg_df[('nll_diff', 'count')]
agg_df[('nll', 'std_err')] = agg_df[('nll', 'std')] / agg_df[('nll', 'count')]
agg_df[('brier_diff', 'std_err')] = agg_df[('brier_diff', 'std')] / agg_df[('brier_diff', 'count')]
agg_df[('brier', 'std_err')] = agg_df[('brier', 'std')] / agg_df[('brier', 'count')]
agg_df

In [None]:
# matplotlib.style.use('default')
fig, axs = plt.subplots(1, 4, figsize=(text_width, text_width/4.), dpi=dpi, sharey=False, sharex=True, layout='tight')

models = ['Prod', 'Ens']
models += ['StdProd']

names = {
    'Prod': 'Ours',
    'Ens': 'DE',
    'StdProd': 'Ours (Std Loss)'
}

linestyles = {
    'Prod': '-',
    'Ens': '--',
    'StdProd': ':',
}

H_df = agg_df['H'].reset_index()
for i, model_name in enumerate(models):
    x = H_df[H_df.model_name == model_name]['n_members']
    y = H_df[H_df.model_name == model_name]['mean']
    y_err = H_df[H_df.model_name == model_name]['std_err']

    axs[0].plot(x, y, linestyles[model_name], c=f'C{i}', lw=1.25, alpha=0.5)
    axs[0].fill_between(x, y - y_err, y + y_err, color=f'C{i}', alpha=0.4, lw=0.1)

axs[0].grid(0.3)
axs[0].set_ylabel('$\mathbb{H}$')
axs[0].set_xlabel('$\#$ members')
axs[0].set_xticks(range(1, config.model.size + 1))

err_df = agg_df['brier'].reset_index()
for i, model_name in enumerate(models):
    x = err_df[err_df.model_name == model_name]['n_members']
    y = err_df[err_df.model_name == model_name]['mean']
    y_err = err_df[err_df.model_name == model_name]['std_err']

    axs[2].plot(x, y, linestyles[model_name], c=f'C{i}', lw=1.25, alpha=0.5, label=names[model_name])
    axs[2].fill_between(x, y - y_err, y + y_err, color=f'C{i}', alpha=0.4, lw=0.1)

axs[2].grid(0.3)
axs[2].set_ylabel('Brier Score')
axs[2].set_xlabel('$\#$ members')
axs[2].set_xticks(range(1, config.model.size + 1))

err_df = agg_df['err'].reset_index()
for i, model_name in enumerate(models):
    x = err_df[err_df.model_name == model_name]['n_members']
    y = err_df[err_df.model_name == model_name]['mean']
    y_err = err_df[err_df.model_name == model_name]['std_err']

    axs[1].plot(x, y, linestyles[model_name], c=f'C{i}', lw=1.25, alpha=0.5, label=names[model_name])
    axs[1].fill_between(x, y - y_err, y + y_err, color=f'C{i}', alpha=0.4, lw=0.1)

axs[1].grid(0.3)
axs[1].set_ylabel('Error')
axs[1].set_xlabel('$\#$ members')
axs[1].set_xticks(range(1, config.model.size + 1))

nll_df = agg_df[f'nll'].reset_index()
for i, model_name in enumerate(models):
    x = nll_df[nll_df.model_name == model_name]['n_members']
    y = nll_df[nll_df.model_name == model_name]['mean']
    y_err = nll_df[nll_df.model_name == model_name]['std_err']

    axs[3].plot(x, y, linestyles[model_name], c=f'C{i}', lw=1., alpha=0.5, label=names[model_name])
    axs[3].fill_between(x, y - y_err, y + y_err, color=f'C{i}', alpha=0.4, lw=0.1)

axs[3].grid(0.2)
axs[3].set_ylabel(f'NLL')
axs[3].set_xlabel('$\#$ members')
axs[3].set_xticks(range(1, config.model.size + 1))
axs[3].legend()

plt.savefig(f'mnist_entropy_err_nll_evolution.pdf', dpi=dpi, bbox_inches='tight')