In this notebook, we will implement a Stochastic ResNet architecture for image classification. Stochastic ResNets introduce randomness in the forward pass by propagating uncertainty alongside features, enabling better regularization and improved generalization through stochastic depth.

In [None]:
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

DATASET = "CIFAR100"  # or "CIFAR100"

# Hyperparameters
BATCH_SIZE = 128
TEST_BATCH_SIZE = 256
NUM_EPOCHS = 300

LR = 1e-3

# Beta (KL weight) schedule
# Overall loss: loss = CE + beta * KL
BETA_MAX = 1e-4              # max beta
BETA_SCHEDULE = "linear_warmup"  # options: "constant", "linear_warmup"
BETA_WARMUP_FRAC = 0.3       # fraction of epochs used to warm up (for linear_warmup)

# MC sampling for evaluation & uncertainty
T_MC_EVAL = 20          # MC passes for test accuracy
T_MC_VIZ = 20           # MC passes for uncertainty viz

# Checkpointing (save to Google Drive)
DRIVE_ROOT = Path("/content/drive/MyDrive")        # your Google Drive root
CHECKPOINT_DIR = DRIVE_ROOT / "stoch_resnet_ckpts" # folder inside Drive
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

SAVE_EVERY_EPOCHS = 25  # save every N epochs

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")


# Beta schedule helper
def beta_schedule(epoch, num_epochs, schedule_type="linear_warmup",
                  beta_max=1e-4, warmup_frac=0.3):
    """
    Returns beta for a given epoch index (0-based).
    schedule_type:
        - "constant"      : always beta_max
        - "linear_warmup" : linearly increase from 0 -> beta_max over warmup_frac * num_epochs
        - "cosine"        : cosine schedule between 0 and beta_max
        - "exp_decay"     : start at beta_max, decay exponentially
    """
    e = epoch
    T = num_epochs
    if schedule_type == "constant":
        return beta_max

    elif schedule_type == "linear_warmup":
        warmup_epochs = max(1, int(T * warmup_frac))
        if e < warmup_epochs:
            return beta_max * ( (e + 1) / warmup_epochs )
        else:
            return beta_max

    else:
        raise ValueError(f"Unknown schedule_type: {schedule_type}")

We now define a VIB (Variational Information Bottleneck) layer that will be used in our Stochastic ResNet. This layer will help us introduce uncertainty into the network while maintaining good feature representation. The VIB layer will output both deterministic features and stochastic uncertainty, enabling the network to learn robust representations through variational inference.

In [None]:
class VIBBlock(nn.Module):
    def __init__(self, channels, prior="unit", reduce="none", sample_at_eval=False):
        """
        prior:
            "unit"  -> KL(N(x, σ^2) || N(0, I))  (penalizes mean and variance)
            "match" -> KL(N(x, σ^2) || N(x, I))  (penalizes variance only)
        reduce:
            "batch_mean" -> average over batch (scalar)
            "none"       -> per-sample KL [B]
        sample_at_eval:
            if True, still sample noise at eval (for MC uncertainty estimation)
        """
        super().__init__()
        self.enc_logvar = nn.Conv2d(channels, channels, kernel_size=1)
        nn.init.constant_(self.enc_logvar.weight, 0.0)
        nn.init.constant_(self.enc_logvar.bias, -5.0)  # very low initial variance

        self.prior = prior
        self.reduce = reduce
        self.sample_at_eval = sample_at_eval

    def forward(self, x):
        # Predict log-variance
        logvar = self.enc_logvar(x)
        logvar = torch.clamp(logvar, min=-10.0, max=10.0)
        std = torch.exp(0.5 * logvar)

        # Sample or not
        if self.training or self.sample_at_eval:
            eps = torch.randn_like(std)
            z = x + eps * std
        else:
            z = x

        # KL computation
        if self.prior == "unit":
            # KL(N(μ=x, σ^2) || N(0, I))
            kl_element = -0.5 * (1 + logvar - x.pow(2) - logvar.exp())
        elif self.prior == "match":
            # KL(N(μ=x, σ^2) || N(μ=x, I)) = 0.5 * (σ^2 - 1 - log σ^2)
            kl_element = 0.5 * (logvar.exp() - 1.0 - logvar)
        else:
            raise ValueError(f"Unknown prior type: {self.prior}")

        # Sum over C, H, W → per-sample KL [B]
        kl = kl_element.sum(dim=[1, 2, 3])  # [batch]

        if self.reduce == "batch_mean":
            kl = kl.mean()  # scalar

        return z, kl

Now, we define the Stochastic ResNet architecture that will use the VIB blocks. This architecture will consist of a series of residual blocks where some blocks incorporate the VIB layers to introduce stochasticity and uncertainty into the feature representations. The network will use stochastic depth to further enhance regularization and generalization. We will implement a modified ResNet backbone with VIB layers inserted at strategic points to enable uncertainty quantification while maintaining strong feature learning capabilities.

In [None]:
class StochasticResNet(nn.Module):
    def __init__(self, num_classes=10, sample_at_eval=False, prior="match"):
        super().__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)

        # Adapt stem for CIFAR images (32x32)
        self.backbone.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.backbone.maxpool = nn.Identity()

        # inside StochasticResNet.__init__
        self.vib1 = VIBBlock(64,  prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib2 = VIBBlock(128, prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib3 = VIBBlock(256, prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib4 = VIBBlock(512, prior=prior, reduce="none", sample_at_eval=sample_at_eval)


        self.backbone.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x, kl1 = self.vib1(x)

        x = self.backbone.layer2(x)
        x, kl2 = self.vib2(x)

        x = self.backbone.layer3(x)
        x, kl3 = self.vib3(x)

        x = self.backbone.layer4(x)
        x, kl4 = self.vib4(x)

        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.backbone.fc(x)

        return logits, [kl1, kl2, kl3, kl4]

In [None]:
def train_epoch(model, loader, optimizer, beta, device):
    model.train()
    total_loss = 0.0
    total_kl = 0.0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits, kl_list = model(x)  # each kl in kl_list is [B]

        ce_loss = F.cross_entropy(logits, y)
        # mean over batch for each layer, then sum over layers
        kl_loss = sum(kl.mean() for kl in kl_list)

        loss = ce_loss + (beta * kl_loss)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_kl += kl_loss.item()

    return total_loss / len(loader), total_kl / len(loader)

def mc_predict(model, x, T=20):
    """
    x: [B, C, H, W] tensor on the correct device.
    Returns:
      mean_probs: [B, num_classes]
      probs_stacked: [T, B, num_classes]
    """
    model.eval()
    probs_list = []
    with torch.no_grad():
        for _ in range(T):
            logits, _ = model(x)
            probs = logits.softmax(dim=-1)
            probs_list.append(probs)

    probs_stacked = torch.stack(probs_list, dim=0)
    mean_probs = probs_stacked.mean(dim=0)
    return mean_probs, probs_stacked


def evaluate_mc(model, loader, device, T=20):
    """MC test accuracy using T stochastic forward passes."""
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            mean_probs, _ = mc_predict(model, x, T=T)
            preds = mean_probs.argmax(dim=-1)
            total += y.size(0)
            correct += (preds == y).sum().item()

    return correct / total

In [None]:
def collect_uncertainty(
    model, loader, device, T=20, max_batches=None
):
    """
    Collect per-image:
      - MC predictive entropy (output uncertainty)
      - per-layer KL vector [L]
      - total KL (sum over layers)
    Returns:
      images:     [N, 3, 32, 32]
      labels:     [N]
      entropies:  [N]
      preds:      [N]
      layer_kls:  [N, L]  (L = # of VIBBlocks)
      total_kls:  [N]
    """
    model.eval()
    all_images = []
    all_labels = []
    all_entropy = []
    all_preds = []
    all_layer_kls = []
    all_total_kls = []

    with torch.no_grad():
        for b, (x, y) in enumerate(loader):
            if (max_batches is not None) and (b >= max_batches):
                break

            x = x.to(device)
            y = y.to(device)

            # MC predictive entropy
            mean_probs, _ = mc_predict(model, x, T=T)
            probs = mean_probs.clamp(min=1e-8)
            entropy = -(probs * probs.log()).sum(dim=1)
            preds = probs.argmax(dim=1)

            # Single forward to get per-layer KLs [B] each
            logits_det, kl_list = model(x)
            layer_kls_batch = torch.stack(kl_list, dim=1)
            total_kls_batch = layer_kls_batch.sum(dim=1)

            # Collect
            all_images.append(x.cpu())
            all_labels.append(y.cpu())
            all_entropy.append(entropy.cpu())
            all_preds.append(preds.cpu())
            all_layer_kls.append(layer_kls_batch.cpu())
            all_total_kls.append(total_kls_batch.cpu())

    images = torch.cat(all_images, dim=0)
    labels = torch.cat(all_labels, dim=0)
    entropies = torch.cat(all_entropy, dim=0)
    preds = torch.cat(all_preds, dim=0)
    layer_kls = torch.cat(all_layer_kls, dim=0)
    total_kls = torch.cat(all_total_kls, dim=0)

    return images, labels, entropies, preds, layer_kls, total_kls


def compute_combined_uncertainty(entropies, total_kls, lambda_kl=0.5):
    """
    entropies: [N]
    total_kls: [N]
    lambda_kl: weight on KL relative to entropy
    Returns:
      combined_scores: [N]
    """
    # Normalize KL roughly to [0, 1] using min-max (robust-ish)
    kl_min = total_kls.min()
    kl_max = total_kls.max()
    kl_range = (kl_max - kl_min).clamp(min=1e-8)
    kl_norm = (total_kls - kl_min) / kl_range  # [0, 1]

    combined = entropies + lambda_kl * kl_norm
    return combined


def visualize_uncertainty_groups(
    images,
    labels,
    entropies,
    preds,
    layer_kls,
    total_kls,
    num_per_group=16,
    dataset_name="CIFAR",
    lambda_kl=0.5,
):
    """
    Make 3 grids:
      - lowest combined (most certain)
      - medium
      - highest combined (most uncertain)
    Each image title shows:
      pred, entropy, total KL, and layer-wise KL summary.
    """
    N, L = layer_kls.shape  # L = num layers (e.g., 4)
    combined = compute_combined_uncertainty(entropies, total_kls, lambda_kl=lambda_kl)

    # sort ascending (low to high uncertainty)
    sorted_indices = torch.argsort(combined)

    k = min(num_per_group, N // 3)
    idx_certain = sorted_indices[:k]
    idx_uncertain = sorted_indices[-k:]
    mid_start = (N - k) // 2
    idx_middle = sorted_indices[mid_start:mid_start + k]

    def make_titles(idxs):
        titles = []
        for i in idxs:
            ent = float(entropies[i].item())
            tot_kl = float(total_kls[i].item())
            pred = int(preds[i].item())
            kl_vec = layer_kls[i].tolist()
            # Option: show only coarse summary to keep titles short
            titles.append(
                f"pred={pred}, H={ent:.2f}, KLtot={tot_kl:.2f}, KL={['%.2f'%v for v in kl_vec]}"
            )
        return titles

    def show_image_grid(images, titles, n_rows=4, n_cols=4, fig_title=None):
        Ngrid = n_rows * n_cols
        images = images[:Ngrid]
        titles = titles[:Ngrid]

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 2.5*n_rows))
        axes = axes.flatten()

        for i, ax in enumerate(axes):
            if i >= len(images):
                ax.axis("off")
                continue
            img = images[i].permute(1, 2, 0).numpy()
            ax.imshow(img)
            ax.set_title(titles[i], fontsize=7)
            ax.axis("off")

        if fig_title is not None:
            fig.suptitle(fig_title, fontsize=14)
        plt.tight_layout()
        plt.show()

    # Most certain
    show_image_grid(
        images[idx_certain],
        make_titles(idx_certain),
        n_rows=int(np.sqrt(k)),
        n_cols=int(np.ceil(k / np.sqrt(k))),
        fig_title=f"{dataset_name}: Most certain (low entropy & KL)"
    )

    # Medium
    show_image_grid(
        images[idx_middle],
        make_titles(idx_middle),
        n_rows=int(np.sqrt(k)),
        n_cols=int(np.ceil(k / np.sqrt(k))),
        fig_title=f"{dataset_name}: Medium combined uncertainty"
    )

    # Most uncertain
    show_image_grid(
        images[idx_uncertain],
        make_titles(idx_uncertain),
        n_rows=int(np.sqrt(k)),
        n_cols=int(np.ceil(k / np.sqrt(k))),
        fig_title=f"{dataset_name}: Most uncertain (high entropy & KL)"
    )


In [None]:
# Transforms (same for CIFAR-10 and CIFAR-100)
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

if DATASET.upper() == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    num_classes = 10
elif DATASET.upper() == "CIFAR100":
    train_dataset = torchvision.datasets.CIFAR100(
        root="./data", train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR100(
        root="./data", train=False, download=True, transform=transform_test
    )
    num_classes = 100
else:
    raise ValueError("DATASET must be 'CIFAR10' or 'CIFAR100'")

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)

model = StochasticResNet(
    num_classes=num_classes,
    sample_at_eval=True,   # important for MC prediction
    prior="match",
).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print(f"Using dataset: {DATASET} with {num_classes} classes.")
print("Model and optimizer initialized.")