In this notebook, we will implement a Stochastic ResNet architecture for image classification. Stochastic ResNets introduce randomness in the forward pass by propagating uncertainty alongside features, enabling better regularization and improved generalization through stochastic depth.

In [None]:
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

DATASET = "CIFAR100"  # or "CIFAR100"

# Hyperparameters
BATCH_SIZE = 128
TEST_BATCH_SIZE = 256
NUM_EPOCHS = 300

LR = 1e-3

# Beta (KL weight) schedule
# Overall loss: loss = CE + beta * KL
BETA_MAX = 1e-4              # max beta
BETA_SCHEDULE = "linear_warmup"  # options: "constant", "linear_warmup"
BETA_WARMUP_FRAC = 0.3       # fraction of epochs used to warm up (for linear_warmup)

# MC sampling for evaluation & uncertainty
T_MC_EVAL = 20          # MC passes for test accuracy
T_MC_VIZ = 20           # MC passes for uncertainty viz

# Checkpointing (save to Google Drive)
DRIVE_ROOT = Path("/content/drive/MyDrive")        # your Google Drive root
CHECKPOINT_DIR = DRIVE_ROOT / "stoch_resnet_ckpts" # folder inside Drive
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

SAVE_EVERY_EPOCHS = 25  # save every N epochs

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")


# Beta schedule helper
def beta_schedule(epoch, num_epochs, schedule_type="linear_warmup",
                  beta_max=1e-4, warmup_frac=0.3):
    """
    Returns beta for a given epoch index (0-based).
    schedule_type:
        - "constant"      : always beta_max
        - "linear_warmup" : linearly increase from 0 -> beta_max over warmup_frac * num_epochs
        - "cosine"        : cosine schedule between 0 and beta_max
        - "exp_decay"     : start at beta_max, decay exponentially
    """
    e = epoch
    T = num_epochs
    if schedule_type == "constant":
        return beta_max

    elif schedule_type == "linear_warmup":
        warmup_epochs = max(1, int(T * warmup_frac))
        if e < warmup_epochs:
            return beta_max * ( (e + 1) / warmup_epochs )
        else:
            return beta_max

    else:
        raise ValueError(f"Unknown schedule_type: {schedule_type}")

We now define a VIB (Variational Information Bottleneck) layer that will be used in our Stochastic ResNet. This layer will help us introduce uncertainty into the network while maintaining good feature representation. The VIB layer will output both deterministic features and stochastic uncertainty, enabling the network to learn robust representations through variational inference.

In [None]:
class VIBBlock(nn.Module):
    def __init__(self, channels, prior="unit", reduce="none", sample_at_eval=False):
        """
        prior:
            "unit"  -> KL(N(x, σ^2) || N(0, I))  (penalizes mean and variance)
            "match" -> KL(N(x, σ^2) || N(x, I))  (penalizes variance only)
        reduce:
            "batch_mean" -> average over batch (scalar)
            "none"       -> per-sample KL [B]
        sample_at_eval:
            if True, still sample noise at eval (for MC uncertainty estimation)
        """
        super().__init__()
        self.enc_logvar = nn.Conv2d(channels, channels, kernel_size=1)
        nn.init.constant_(self.enc_logvar.weight, 0.0)
        nn.init.constant_(self.enc_logvar.bias, -5.0)  # very low initial variance

        self.prior = prior
        self.reduce = reduce
        self.sample_at_eval = sample_at_eval

    def forward(self, x):
        # Predict log-variance
        logvar = self.enc_logvar(x)
        logvar = torch.clamp(logvar, min=-10.0, max=10.0)
        std = torch.exp(0.5 * logvar)

        # Sample or not
        if self.training or self.sample_at_eval:
            eps = torch.randn_like(std)
            z = x + eps * std
        else:
            z = x

        # KL computation
        if self.prior == "unit":
            # KL(N(μ=x, σ^2) || N(0, I))
            kl_element = -0.5 * (1 + logvar - x.pow(2) - logvar.exp())
        elif self.prior == "match":
            # KL(N(μ=x, σ^2) || N(μ=x, I)) = 0.5 * (σ^2 - 1 - log σ^2)
            kl_element = 0.5 * (logvar.exp() - 1.0 - logvar)
        else:
            raise ValueError(f"Unknown prior type: {self.prior}")

        # Sum over C, H, W → per-sample KL [B]
        kl = kl_element.sum(dim=[1, 2, 3])  # [batch]

        if self.reduce == "batch_mean":
            kl = kl.mean()  # scalar

        return z, kl

Now, we define the Stochastic ResNet architecture that will use the VIB blocks. This architecture will consist of a series of residual blocks where some blocks incorporate the VIB layers to introduce stochasticity and uncertainty into the feature representations. The network will use stochastic depth to further enhance regularization and generalization. We will implement a modified ResNet backbone with VIB layers inserted at strategic points to enable uncertainty quantification while maintaining strong feature learning capabilities.

In [None]:
class StochasticResNet(nn.Module):
    def __init__(self, num_classes=10, sample_at_eval=False, prior="match"):
        super().__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)

        # Adapt stem for CIFAR images (32x32)
        self.backbone.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.backbone.maxpool = nn.Identity()

        # inside StochasticResNet.__init__
        self.vib1 = VIBBlock(64,  prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib2 = VIBBlock(128, prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib3 = VIBBlock(256, prior=prior, reduce="none", sample_at_eval=sample_at_eval)
        self.vib4 = VIBBlock(512, prior=prior, reduce="none", sample_at_eval=sample_at_eval)


        self.backbone.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x, kl1 = self.vib1(x)

        x = self.backbone.layer2(x)
        x, kl2 = self.vib2(x)

        x = self.backbone.layer3(x)
        x, kl3 = self.vib3(x)

        x = self.backbone.layer4(x)
        x, kl4 = self.vib4(x)

        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.backbone.fc(x)

        return logits, [kl1, kl2, kl3, kl4]