# Batch Ensemble Networks

This notebook functions the same as the MIMO Notebook but implements a Batch Ensemble MLP instead.
### Introduction
 

## 1. Quick Setup

In [1]:
import random
import time

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, log_loss
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)  # oqa E702
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


## 2. Toy dataset (classification variant)

Make use of a simple moon dataset from sklearn to have fast training time. Splitting the dataset in 80% training data and 20% evalulation data and making the data ready in batch sizes of 128 for training and evalulation

In [2]:
from sklearn.datasets import make_moons

X, y = make_moons(n_samples=30000, noise=0.2, random_state=seed)
X = X.astype("float32")
y = y.astype("int64")

split = int(len(X) * 0.8)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

print("Train size:", len(train_ds), "Val size:", len(val_ds))

Train size: 24000 Val size: 6000


## 3. Baseline model (standard MLP)

Creating a MLP to use as a base model for the ensemble to have a fast training time.
 

In [3]:
class MLP(nn.Module):
    def __init__(self, in_dim: int = 2, hidden: int = 128, out_dim: int = 2) -> None:
        """MLP with two hidden layers."""
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Simple forward."""
        return self.net(x)


m = MLP().to(device)
print(m)

MLP(
  (net): Sequential(
    (0): Linear(in_features=2, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)


## 4. Define Batch Ensemble Linear Layer

In [10]:
class BatchEnsembleLinear(nn.Module):
    """A fully connected layer with BatchEnsemble fast weights."""
    def __init__(self, in_dim, out_dim, M):
        super().__init__()
        self.M = M
        
        # Shared base weight & bias
        self.W = nn.Parameter(torch.randn(out_dim, in_dim) * 0.1)
        self.b = nn.Parameter(torch.zeros(out_dim))

        # Fast weights (rank-1 factors)
        self.r = nn.Parameter(torch.randn(M, in_dim))   # input direction
        self.s = nn.Parameter(torch.randn(M, out_dim))  # output direction

    def forward(self, x):
        # x: [B, in_features] or [B, M, in_features]
        if x.dim() == 2:
            # first layer: [B, in_features] -> [B, M, out_features]
            fast_in = torch.einsum("bi,mi->bm", x, self.r)
            fast = fast_in.unsqueeze(-1) * self.s.unsqueeze(0)
            shared = x @ self.W.t() + self.b
            shared = shared.unsqueeze(1)
            out = shared + fast
        elif x.dim() == 3:
            # later layers: [B, M, in_features]
            fast_in = (x * self.r.unsqueeze(0)).sum(-1)  # [B, M]
            fast = fast_in.unsqueeze(-1) * self.s.unsqueeze(0)  # [B, M, out_features]
            shared = x @ self.W.t() + self.b  # broadcasting works
            out = shared + fast
        else:
            raise ValueError(f"Unexpected input shape {x.shape}")
        return out


## 5. Batch Ensemble wrapper: make Batch Ensemble version of the MLP


In [15]:
class BatchEnsemble(nn.Module):
    def __init__(self, base_hidden=128, in_dim=2, out_dim=2, M=3):
        super().__init__()
        self.M = M
        
        self.input_layer = BatchEnsembleLinear(in_dim, base_hidden, M)

        self.body = nn.Sequential(
            nn.ReLU(),
        )
        
        self.fc = BatchEnsembleLinear(base_hidden, out_dim, M)

    def forward(self, x):
        """
        x: [batch, in_dim]
        Returns: [batch, M, out_dim]
        """
        h = self.input_layer(x)       # [B, M, H]
        h = self.body(h)              # apply activation
        out = self.fc(h)              # [B, M, O]
        return out


be_model = BatchEnsemble(M=3).to(device)
print(be_model)

BatchEnsemble(
  (input_layer): BatchEnsembleLinear()
  (body): Sequential(
    (0): ReLU()
  )
  (fc): BatchEnsembleLinear()
)


## 5. Comparison between Ensemble and MIMO-Idea

Training an Ensemble and a MIMO to compare acc, loss, ece, MI(disagreement), forward calls.
The number of subnetworks is four as recommended in [TRAINING INDEPENDENT SUBNETWORKS FOR ROBUST
PREDICTION](https://openreview.net/pdf?id=OGg9XnKxFAH)
 
Setup:

dataset with 24k datapoints to train and 6k datapoints for evalulation <br>
K members in ensemble = 3 <br>
K subnetworks in the MIMO = 3 <br> 
Epochs = 10 <br>
lr = 1e-3 <br>

In [16]:
# Small experiments: train MIMO and an ensemble with comparable capacity
k = 3
epochs = 10
lr = 1e-3


def softmax_np(logits: np.ndarray, axis: int = -1) -> np.ndarray:
    e = np.exp(logits - logits.max(axis=axis, keepdims=True))
    return e / e.sum(axis=axis, keepdims=True)


def entropy_np(probs: np.ndarray, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
    p = np.clip(probs, eps, 1.0)
    return -np.sum(p * np.log(p), axis=axis)


def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
    # probs: (N, C) predictive mean probs; labels: (N,)
    confs = probs.max(axis=1)
    preds = probs.argmax(axis=1)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    n = len(labels)
    for i in range(n_bins):
        mask = (confs >= bins[i]) & (confs < bins[i + 1])
        if mask.sum() == 0:
            continue
        acc = (preds[mask] == labels[mask]).mean()
        conf = confs[mask].mean()
        ece += (mask.sum() / n) * abs(conf - acc)
    return float(ece)


def reliability_diagram(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15, ax: None = None) -> plt.Axes:
    confs = probs.max(axis=1)
    preds = probs.argmax(axis=1)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    accs = []
    avg_confs = []
    counts = []
    for i in range(n_bins):
        mask = (confs >= bins[i]) & (confs < bins[i + 1])
        counts.append(mask.sum())
        if mask.sum() == 0:
            accs.append(np.nan)
            avg_confs.append(np.nan)
        else:
            accs.append((preds[mask] == labels[mask]).mean())
            avg_confs.append(confs[mask].mean())
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    ax.plot(bin_centers, accs, marker="o", label="accuracy per bin")
    ax.plot([0, 1], [0, 1], "--", color="gray")
    ax.set_xlabel("Confidence")
    ax.set_ylabel("Accuracy")
    ax.set_title("Reliability diagram")
    return ax


# Train a normal ensemble of K independently initialized MLPs
def train_ensemble(
    base_cls: nn.Module,
    k: int,
    train_loader: DataLoader,
    epochs: int = epochs,
    lr: float = 1e-3,
) -> list:
    models = []
    ensemble_forward_calls = 0
    for _ in range(k):
        print(f"\nTraining ensemble member {_ + 1}/{k}")
        m_k = base_cls().to(device)
        opt = optim.Adam(m_k.parameters(), lr=lr)
        lossfn = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            m_k.train()
            for xb, yb in train_loader:
                x = xb.to(device).float()
                y = yb.to(device).long()
                opt.zero_grad()
                ensemble_forward_calls += 1
                out = m_k(x)
                loss = lossfn(out, y)
                loss.backward()
                opt.step()
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}")
        models.append(m_k)

    print(f"Ensemble total forward calls: {ensemble_forward_calls}")
    return models


def train_batchensemble(
    be_model: nn.Module,
    train_loader: DataLoader,
    epochs: int = epochs,
    lr: float = lr,
) -> nn.Module:
    """
    Train a BatchEnsemble model.
    be_model: BatchEnsemble instance
    """
    opt = torch.optim.Adam(be_model.parameters(), lr=lr)
    lossfn = torch.nn.CrossEntropyLoss()
    forward_calls = 0

    for epoch in range(epochs):
        be_model.train()
        for xb, yb in train_loader:
            x = xb.to(device).float()
            y = yb.to(device).long()
            opt.zero_grad()

            forward_calls += 1
            out = be_model(x)  # [batch, M, out_dim]

            # Mean over ensemble members for loss
            out_mean = out.mean(dim=1)  # [batch, out_dim]
            loss = lossfn(out_mean, y)
            loss.backward()
            opt.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}")

    print(f"BatchEnsemble total forward calls: {forward_calls}")
    return be_model


# Evaluation helpers
def eval_ensemble_models(models: list, x_np: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    # returns mean_probs (N, C), member_probs (N, K, C)
    x = torch.from_numpy(x_np).to(device).float()
    member_probs = []
    for m in models:
        m.eval()
        with torch.no_grad():
            logits = m(x).cpu().numpy()
            member_probs.append(softmax_np(logits))
    member_probs = np.stack(member_probs, axis=1)  # (N, K, C)
    mean_probs = member_probs.mean(axis=1)
    return mean_probs, member_probs


def eval_batchensemble_model(
    be_model: nn.Module,
    x_np: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """
    Evaluate a trained BatchEnsemble model.
    Returns:
        mean_probs: [N, C]
        member_probs: [N, M, C]
    """
    x = torch.from_numpy(x_np).to(device).float()
    be_model.eval()
    with torch.no_grad():
        logits = be_model(x).cpu().numpy()  # [N, M, C]
    member_probs = softmax_np(logits, axis=-1)
    mean_probs = member_probs.mean(axis=1)
    return mean_probs, member_probs

### Training

Training and tracking time needed to train and displaying the results in differetn plots.


In [None]:
start = time.time()
ensemble_models = train_ensemble(MLP, k, train_loader, epochs=epochs, lr=lr)
ensemble_time = time.time() - start

# Train BatchEnsemble model for fairness
start = time.time()
train_batchensemble(be_model, train_loader, epochs=epochs, lr=lr)
be_time = time.time() - start
print(f"BatchEnsemble training time: {be_time:.2f}s")

# Eval on validation set
X_val_np = X_val.astype("float32")
y_val_np = y_val.astype("int64")

be_mean, be_members = eval_batchensemble_model(be_model, X_val_np)  # [N, C], [N, M, C]
ens_mean, ens_members = eval_ensemble_models(ensemble_models, X_val_np)


Training ensemble member 1/3
Epoch 1/10, Loss: 0.2620
Epoch 2/10, Loss: 0.0527
Epoch 3/10, Loss: 0.1919
Epoch 4/10, Loss: 0.0583
Epoch 5/10, Loss: 0.0346
Epoch 6/10, Loss: 0.0740
Epoch 7/10, Loss: 0.0430
Epoch 8/10, Loss: 0.0722
Epoch 9/10, Loss: 0.0916
Epoch 10/10, Loss: 0.0698

Training ensemble member 2/3
Epoch 1/10, Loss: 0.1035
Epoch 2/10, Loss: 0.1127
Epoch 3/10, Loss: 0.1962
Epoch 4/10, Loss: 0.0441
Epoch 5/10, Loss: 0.0823
Epoch 6/10, Loss: 0.0956
Epoch 7/10, Loss: 0.0649
Epoch 8/10, Loss: 0.1649
Epoch 9/10, Loss: 0.0854
Epoch 10/10, Loss: 0.1114

Training ensemble member 3/3
Epoch 1/10, Loss: 0.0909
Epoch 2/10, Loss: 0.0889
Epoch 3/10, Loss: 0.1157
Epoch 4/10, Loss: 0.1122
Epoch 5/10, Loss: 0.1417
Epoch 6/10, Loss: 0.0681
Epoch 7/10, Loss: 0.0268
Epoch 8/10, Loss: 0.0825
Epoch 9/10, Loss: 0.0891
Epoch 10/10, Loss: 0.0462
Ensemble total forward calls: 5640
Epoch 1/10, Loss: 0.2637
Epoch 2/10, Loss: 0.1423
Epoch 3/10, Loss: 0.2917
Epoch 4/10, Loss: 0.1985
Epoch 5/10, Loss: 0.17