In this notebook, we will implement a Stochastic ResNet architecture for image classification. Stochastic ResNets introduce randomness in the forward pass by propagating uncertainty alongside features, enabling better regularization and improved generalization through stochastic depth.

In [None]:
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

DATASET = "CIFAR100"  # or "CIFAR100"

# Hyperparameters
BATCH_SIZE = 128
TEST_BATCH_SIZE = 256
NUM_EPOCHS = 300

LR = 1e-3

# Beta (KL weight) schedule
# Overall loss: loss = CE + beta * KL
BETA_MAX = 1e-4              # max beta
BETA_SCHEDULE = "linear_warmup"  # options: "constant", "linear_warmup"
BETA_WARMUP_FRAC = 0.3       # fraction of epochs used to warm up (for linear_warmup)

# MC sampling for evaluation & uncertainty
T_MC_EVAL = 20          # MC passes for test accuracy
T_MC_VIZ = 20           # MC passes for uncertainty viz

# Checkpointing (save to Google Drive)
DRIVE_ROOT = Path("/content/drive/MyDrive")        # your Google Drive root
CHECKPOINT_DIR = DRIVE_ROOT / "stoch_resnet_ckpts" # folder inside Drive
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

SAVE_EVERY_EPOCHS = 25  # save every N epochs

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")


# Beta schedule helper
def beta_schedule(epoch, num_epochs, schedule_type="linear_warmup",
                  beta_max=1e-4, warmup_frac=0.3):
    """
    Returns beta for a given epoch index (0-based).
    schedule_type:
        - "constant"      : always beta_max
        - "linear_warmup" : linearly increase from 0 -> beta_max over warmup_frac * num_epochs
        - "cosine"        : cosine schedule between 0 and beta_max
        - "exp_decay"     : start at beta_max, decay exponentially
    """
    e = epoch
    T = num_epochs
    if schedule_type == "constant":
        return beta_max

    elif schedule_type == "linear_warmup":
        warmup_epochs = max(1, int(T * warmup_frac))
        if e < warmup_epochs:
            return beta_max * ( (e + 1) / warmup_epochs )
        else:
            return beta_max

    else:
        raise ValueError(f"Unknown schedule_type: {schedule_type}")