# NGD‑T with layerwise K‑FAC (CIFAR‑10) — short demo

This notebook implements:
- patch‑based K‑FAC factor accumulation for `Conv2d` and `Linear` layers,
- per‑layer pseudoinverse via eigendecomposition with caching (recompute every `K` steps),
- NGD‑T thermodynamic regulator using global geometric norm \(\Delta_F\),
- hybrid updates (natural‑space + small nullspace Euclidean fallback),
- diagnostics logging and visualization: training loss, `eta_T`, predicted dissipation `Q_pred`.

Run cells in order. Adjust hyperparameters in the config cell.

In [None]:
# Setup and imports
import time, math
from collections import OrderedDict, defaultdict
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
# Config and model
BATCH_SIZE = 256
EPOCHS = 3
LR_BASE = 1.0
Q_BUDGET = 1e-3
ETA_MIN = 1e-6
ETA_MAX = 1.0
EPS = 1e-8
EMA_DECAY = 0.95
DAMPING = 1e-3
USE_DAMPING = False
TOL = 1e-8
KFAC_EIG_UPDATE = 20
KFAC_FACTOR_UPDATE = 1
ETA_NULL_RATIO = 0.01
SEED = 0

torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed_all(SEED)

class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 16 * 16, 256, bias=True)
        self.fc2 = nn.Linear(256, num_classes, bias=True)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Data loaders (torchvision)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
# K-FAC layer state and eig helpers
class KFACLayerState:
    def __init__(self, module, layer_type, ema_decay=0.95, device="cpu"):
        self.module = module
        self.layer_type = layer_type
        self.ema_decay = ema_decay
        self.device = device
        self.A = None
        self.G = None
        self.A_plus = None
        self.G_plus = None
        self.A_eigvals = None
        self.A_eigvecs = None
        self.G_eigvals = None
        self.G_eigvecs = None
        self.A_mask = None
        self.G_mask = None
        self._acts_raw = None
        self._backprops_raw = None
        self.handle_forward = None
        self.handle_backward = None

    def register_hooks(self):
        module = self.module
        def forward_hook(mod, inp, out):
            x = inp[0].detach()
            self._acts_raw = x.to(self.device)
        def backward_hook(mod, grad_in, grad_out):
            gy = grad_out[0].detach()
            self._backprops_raw = gy.to(self.device)
        self.handle_forward = module.register_forward_hook(forward_hook)
        self.handle_backward = module.register_full_backward_hook(backward_hook)

    def remove_hooks(self):
        if self.handle_forward is not None:
            self.handle_forward.remove()
            self.handle_forward = None
        if self.handle_backward is not None:
            self.handle_backward.remove()
            self.handle_backward = None

    def update_factors(self, damping=0.0):
        if self._acts_raw is None or self._backprops_raw is None:
            return
        if self.layer_type == "conv":
            x = self._acts_raw
            gy = self._backprops_raw
            B, C_in, H, W = x.shape
            _, C_out, H_out, W_out = gy.shape
            kh, kw = self.module.kernel_size
            stride = self.module.stride
            padding = self.module.padding
            x_patches = torch.nn.functional.unfold(x, kernel_size=(kh, kw), padding=padding, stride=stride)
            Bp, K, L = x_patches.shape
            x_patches = x_patches.permute(0, 2, 1).contiguous().view(Bp * L, K)
            gy_patches = gy.permute(0, 2, 3, 1).contiguous().view(Bp * L, C_out)
            A_batch = (x_patches.t() @ x_patches) / float(x_patches.shape[0])
            G_batch = (gy_patches.t() @ gy_patches) / float(gy_patches.shape[0])
            if self.A is None:
                self.A = A_batch.detach().clone()
            else:
                self.A = self.ema_decay * self.A + (1.0 - self.ema_decay) * A_batch.detach()
            if self.G is None:
                self.G = G_batch.detach().clone()
            else:
                self.G = self.ema_decay * self.G + (1.0 - self.ema_decay) * G_batch.detach()
            if damping > 0.0:
                self.A = self.A + damping * torch.eye(self.A.shape[0], device=self.A.device)
                self.G = self.G + damping * torch.eye(self.G.shape[0], device=self.G.device)
            self._acts_raw = None
            self._backprops_raw = None
        elif self.layer_type == "linear":
            x = self._acts_raw
            gy = self._backprops_raw
            A_batch = (x.t() @ x) / float(x.shape[0])
            G_batch = (gy.t() @ gy) / float(gy.shape[0])
            if self.A is None:
                self.A = A_batch.detach().clone()
            else:
                self.A = self.ema_decay * self.A + (1.0 - self.ema_decay) * A_batch.detach()
            if self.G is None:
                self.G = G_batch.detach().clone()
            else:
                self.G = self.ema_decay * self.G + (1.0 - self.ema_decay) * G_batch.detach()
            if damping > 0.0:
                self.A = self.A + damping * torch.eye(self.A.shape[0], device=self.A.device)
                self.G = self.G + damping * torch.eye(self.G.shape[0], device=self.G.device)
            self._acts_raw = None
            self._backprops_raw = None

def build_kfac_state(model, device, ema_decay=0.95):
    kfac_state = OrderedDict()
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            st = KFACLayerState(module, "conv", ema_decay=ema_decay, device=device)
            st.register_hooks()
            kfac_state[name] = st
        elif isinstance(module, nn.Linear):
            st = KFACLayerState(module, "linear", ema_decay=ema_decay, device=device)
            st.register_hooks()
            kfac_state[name] = st
    return kfac_state

def symmetric_eig_pseudoinverse_torch(mat, tol=1e-8, damping=None):
    mat = 0.5 * (mat + mat.t())
    eigvals, eigvecs = torch.linalg.eigh(mat)
    eigvals = eigvals.flip(0)
    eigvecs = eigvecs.flip(1)
    sigma_max = max(float(eigvals.max().item()), 1e-12)
    tau = max(sigma_max * tol, 1e-12)
    if damping is not None and damping > 0.0:
        inv_diag = 1.0 / (eigvals + damping)
        retained_mask = eigvals > 0.0
    else:
        inv_diag = torch.zeros_like(eigvals)
        retained_mask = eigvals > tau
        inv_diag[retained_mask] = 1.0 / eigvals[retained_mask]
    mat_plus = (eigvecs * inv_diag.unsqueeze(0)) @ eigvecs.t()
    return mat_plus, eigvecs, eigvals, retained_mask


In [None]:
# Precondition, apply updates, and diagnostics
def precondition_and_apply_updates(model, kfac_state, q_budget, eta0, eta_min, eta_max, eps,
                                   tol, damping, use_damping, eta_null_ratio):
    layer_entries = []
    total_delta_F = 0.0
    for name, st in kfac_state.items():
        module = st.module
        params = []
        grads = []
        if hasattr(module, "weight") and module.weight is not None:
            params.append(module.weight)
            grads.append(module.weight.grad if module.weight.grad is not None else torch.zeros_like(module.weight))
        if hasattr(module, "bias") and module.bias is not None:
            params.append(module.bias)
            grads.append(module.bias.grad if module.bias.grad is not None else torch.zeros_like(module.bias))
        if len(params) == 0:
            continue
        g_flat = torch.cat([g.contiguous().view(-1) for g in grads]).detach()
        if st.A is None or st.G is None:
            g_nat_flat = g_flat.clone()
            retained_mask_A = None
            retained_mask_G = None
        else:
            if st.A_plus is None or st.G_plus is None:
                if use_damping:
                    st.A_plus, st.A_eigvecs, st.A_eigvals, st.A_mask = symmetric_eig_pseudoinverse_torch(st.A, tol=tol, damping=damping)
                    st.G_plus, st.G_eigvecs, st.G_eigvals, st.G_mask = symmetric_eig_pseudoinverse_torch(st.G, tol=tol, damping=damping)
                else:
                    st.A_plus, st.A_eigvecs, st.A_eigvals, st.A_mask = symmetric_eig_pseudoinverse_torch(st.A, tol=tol, damping=None)
                    st.G_plus, st.G_eigvecs, st.G_eigvals, st.G_mask = symmetric_eig_pseudoinverse_torch(st.G, tol=tol, damping=None)
            if st.layer_type == "linear":
                out, inp = module.weight.shape
                grad_mat = g_flat[:out*inp].view(out, inp)
                precond_mat = st.G_plus @ grad_mat @ st.A_plus
                g_nat_weight = precond_mat.contiguous().view(-1)
                if module.bias is not None:
                    bias_grad = g_flat[out*inp:]
                    if bias_grad.numel() == st.G_plus.shape[0]:
                        bias_nat = (st.G_plus @ bias_grad.view(-1,1)).view(-1)
                    else:
                        bias_nat = bias_grad
                    g_nat_flat = torch.cat([g_nat_weight, bias_nat])
                else:
                    g_nat_flat = g_nat_weight
            elif st.layer_type == "conv":
                out, inp, kh, kw = module.weight.shape
                K = inp * kh * kw
                grad_weight = g_flat[:out * K].view(out, K)
                precond_mat = st.G_plus @ grad_weight @ st.A_plus
                g_nat_weight = precond_mat.contiguous().view(-1)
                if module.bias is not None:
                    bias_grad = g_flat[out * K:]
                    if bias_grad.numel() == st.G_plus.shape[0]:
                        bias_nat = (st.G_plus @ bias_grad.view(-1,1)).view(-1)
                    else:
                        bias_nat = bias_grad
                    g_nat_flat = torch.cat([g_nat_weight, bias_nat])
                else:
                    g_nat_flat = g_nat_weight
            else:
                g_nat_flat = g_flat.clone()
            retained_mask_A = st.A_mask
            retained_mask_G = st.G_mask
        delta_F_layer = float((g_flat @ g_nat_flat).item())
        total_delta_F += delta_F_layer
        layer_entries.append((name, st, params, grads, g_flat, g_nat_flat, retained_mask_A, retained_mask_G))
    total_delta_F = max(total_delta_F, eps)
    eta_T = eta0 * (q_budget / (total_delta_F + eps))
    eta_T = max(min(eta_T, eta_max), eta_min)
    for (name, st, params, grads, g_flat, g_nat_flat, maskA, maskG) in layer_entries:
        r = g_flat - g_nat_flat
        eta_null = eta_null_ratio * eta_T
        delta_flat = -eta_T * g_nat_flat - eta_null * r
        idx = 0
        for p in params:
            n = p.numel()
            d = delta_flat[idx:idx+n].view_as(p)
            p.data.add_(d)
            idx += n
    diagnostics = {"total_delta_F": total_delta_F, "eta_T": eta_T, "layers": len(layer_entries)}
    return diagnostics


In [None]:
# Training loop with logging
def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss(reduction="sum")
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss_sum += float(loss.item())
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100.0 * correct / total
    avg_loss = loss_sum / total
    return acc, avg_loss

model = SmallCNN(num_classes=10).to(device)
model.train()
kfac_state = build_kfac_state(model, device=device, ema_decay=EMA_DECAY)
criterion = nn.CrossEntropyLoss()
logs = defaultdict(list)
step = 0
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_loss = 0.0
    t0 = time.time()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        model.zero_grad()
        loss.backward()
        epoch_loss += float(loss.item()) * inputs.size(0)
        if step % KFAC_FACTOR_UPDATE == 0:
            for st in kfac_state.values():
                st.update_factors(damping=DAMPING if USE_DAMPING else 0.0)
        if step % KFAC_EIG_UPDATE == 0:
            for st in kfac_state.values():
                if st.A is not None and st.G is not None:
                    if USE_DAMPING:
                        st.A_plus, st.A_eigvecs, st.A_eigvals, st.A_mask = symmetric_eig_pseudoinverse_torch(st.A, tol=TOL, damping=DAMPING)
                        st.G_plus, st.G_eigvecs, st.G_eigvals, st.G_mask = symmetric_eig_pseudoinverse_torch(st.G, tol=TOL, damping=DAMPING)
                    else:
                        st.A_plus, st.A_eigvecs, st.A_eigvals, st.A_mask = symmetric_eig_pseudoinverse_torch(st.A, tol=TOL, damping=None)
                        st.G_plus, st.G_eigvecs, st.G_eigvals, st.G_mask = symmetric_eig_pseudoinverse_torch(st.G, tol=TOL, damping=None)
        diagnostics = precondition_and_apply_updates(
            model, kfac_state,
            q_budget=Q_BUDGET,
            eta0=LR_BASE,
            eta_min=ETA_MIN,
            eta_max=ETA_MAX,
            eps=EPS,
            tol=TOL,
            damping=DAMPING,
            use_damping=USE_DAMPING,
            eta_null_ratio=ETA_NULL_RATIO
        )
        logs["loss"].append(float(loss.item()))
        logs["eta_T"].append(diagnostics["eta_T"])
        logs["predicted_Q"].append(0.5 * (diagnostics["eta_T"]**2) * diagnostics["total_delta_F"])
        logs["total_delta_F"].append(diagnostics["total_delta_F"])
        if step % 50 == 0:
            print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f} eta_T {diagnostics['eta_T']:.6f} total_delta_F {diagnostics['total_delta_F']:.6e}")
        step += 1
    t1 = time.time()
    avg_loss = epoch_loss / len(trainset)
    print(f"Epoch {epoch} completed in {t1-t0:.1f}s, avg loss {avg_loss:.4f}")
    acc, test_loss = test(model, testloader)
    print(f"Test accuracy after epoch {epoch}: {acc:.2f}%  test_loss {test_loss:.4f}")

total_time = time.time() - start_time
print(f"Training finished in {total_time/60.0:.2f} minutes")


In [None]:
# Plot diagnostics
import matplotlib.pyplot as plt
plt.style.use("seaborn-darkgrid")
steps = np.arange(len(logs["loss"]))
fig, axs = plt.subplots(3, 1, figsize=(10, 10), sharex=True)
axs[0].plot(steps, logs["loss"], label="train loss")
axs[0].set_ylabel("Loss")
axs[0].legend()
axs[1].plot(steps, logs["eta_T"], label="eta_T", color="C1")
axs[1].set_ylabel("eta_T")
axs[1].legend()
axs[2].plot(steps, logs["predicted_Q"], label="predicted_Q", color="C2")
axs[2].set_ylabel("predicted_Q")
axs[2].set_xlabel("training step")
axs[2].legend()
plt.tight_layout()
plt.show()


In [None]:
# Optional: save model checkpoint
torch.save(model.state_dict(), "ngd_t_kfac_cifar_model.pth")
print("Model saved to ngd_t_kfac_cifar_model.pth")
