In [None]:
# kan_neff_pruning.py
import torch
import torch.nn as nn
import copy

@torch.no_grad()
def kan_mask_per_edge_basis(module: nn.Module,
                            basis_scale: torch.Tensor,
                            beta: float = 1.0,
                            coeff_attr: str = "coeff",
                            renormalize: bool = False):
    """
    Neff on basis coefficients per edge (i,j).

    Args:
      module: KAN-like layer with parameter `coeff` of shape [out, in, K].
      basis_scale: tensor [in, K] with ||phi_{j,k}(x_j)||_2 over calibration data.
      beta: multiplicative factor for Neff.
      coeff_attr: attribute name of coefficients.
      renormalize: if True, preserve L1 per (i,j) over basis after masking.

    Returns:
      mask: bool tensor [out, in, K] with True=keep.
      neff_ij: floor(Neff) per (i,j) as float tensor [out, in].
    """
    C = getattr(module, coeff_attr).data      # [out, in, K]
    out, din, K = C.shape
    S = C.abs() * basis_scale.to(C.device).view(1, din, K)   # [out,in,K]

    # Normalize across basis k per (i,j)
    Ssum = S.sum(dim=2, keepdim=True) + 1e-12
    P = S / Ssum
    neff = 1.0 / (P.pow(2).sum(dim=2))       # [out, in]
    r = torch.floor(beta * neff).clamp_(min=1, max=K).to(torch.long)

    # sort within each (i,j) across K
    P2 = P.reshape(-1, K)                    # [(out*in), K]
    _, idx = torch.sort(P2, dim=1, descending=True)
    range_k = torch.arange(K, device=C.device).view(1, -1).expand(P2.size(0), -1)
    keep = range_k < r.view(-1, 1)
    mask_flat = torch.zeros_like(P2, dtype=torch.bool)
    mask_flat.scatter_(1, idx, keep)
    mask = mask_flat.view(out, din, K)

    # apply
    C_masked = C * mask
    if renormalize:
        pre = C.abs().sum(dim=2, keepdim=True)
        post = C_masked.abs().sum(dim=2, keepdim=True)
        C_masked = C_masked * (pre / (post + 1e-12))

    getattr(module, coeff_attr).data.copy_(C_masked)
    return mask, torch.floor(neff)


@torch.no_grad()
def kan_mask_per_output_inputs(module: nn.Module,
                               basis_scale: torch.Tensor,
                               beta: float = 1.0,
                               coeff_attr: str = "coeff",
                               renormalize: bool = False):
    """
    Neff on input edges j per output i (collapse basis first).

    Args:
      module: KAN-like layer with parameter `coeff` of shape [out, in, K].
      basis_scale: tensor [in, K] with ||phi_{j,k}(x_j)||_2.
      beta: multiplicative factor for Neff.
      renormalize: if True, preserves L1 per row (i) across inputs j (over all K).

    Returns:
      mask_ijk: bool mask [out, in, K] (True=keep).
      neff_i: floor(Neff) per output i, tensor [out].
    """
    C = getattr(module, coeff_attr).data      # [out, in, K]
    out, din, K = C.shape
    S = C.abs() * basis_scale.to(C.device).view(1, din, K)   # [out,in,K]
    E = S.sum(dim=2)                           # collapse basis => [out, in]

    # Normalize per row i across inputs j
    Esum = E.sum(dim=1, keepdim=True) + 1e-12
    P = E / Esum                               # [out, in]
    neff = 1.0 / (P.pow(2).sum(dim=1))         # [out]
    r = torch.floor(beta * neff).clamp_(min=1, max=din).to(torch.long)

    # top-r per row
    _, idx = torch.sort(P, dim=1, descending=True)
    range_j = torch.arange(din, device=C.device).view(1, -1).expand(out, -1)
    keep_inputs = range_j < r.unsqueeze(1)     # [out, in]

    # expand to (i,j,k)
    mask = keep_inputs.unsqueeze(-1).expand(out, din, K)

    C_masked = C * mask
    if renormalize:
        pre = C.abs().sum(dim=(1,2), keepdim=True)    # L1 per row across j,k
        post = C_masked.abs().sum(dim=(1,2), keepdim=True)
        C_masked = C_masked * (pre / (post + 1e-12))

    getattr(module, coeff_attr).data.copy_(C_masked)
    return mask, torch.floor(neff)


def model_kan_prune(model: nn.Module,
                    basis_scales_by_name: dict,
                    beta: float = 1.0,
                    mode: str = "per_edge_basis",
                    coeff_attr: str = "coeff",
                    layer_matcher=lambda m: hasattr(m, "coeff"),
                    renormalize: bool = False):
    """
    Apply KAN Neff pruning to all matched layers.

    basis_scales_by_name[name] must be [in, K].
    mode in {"per_edge_basis", "per_output_inputs"}.
    """
    pruned = copy.deepcopy(model)
    for name, m in pruned.named_modules():
        if layer_matcher(m):
            basis_scale = basis_scales_by_name[name]
            if mode == "per_edge_basis":
                kan_mask_per_edge_basis(m, basis_scale, beta, coeff_attr, renormalize)
            elif mode == "per_output_inputs":
                kan_mask_per_output_inputs(m, basis_scale, beta, coeff_attr, renormalize)
            else:
                raise ValueError("Unknown mode")
    return pruned


In [1]:
# kan_neff_demo.py
# ------------------------------------------------------------
# KAN training + Neff pruning + evaluation + comparison
# ------------------------------------------------------------
import math
import copy
import random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


# --------------------------
# Utilities
# --------------------------
def set_seed(seed: int = 1234):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def device_of(x):
    return x.device if isinstance(x, torch.Tensor) else torch.device("cpu")


def count_nonzero(t: torch.Tensor):
    return (t != 0).sum().item()


# --------------------------
# Synthetic regression data
# --------------------------
def make_regression(n_train=20000, n_test=2000, in_dim=4, noise=0.05, seed=0):
    gen = torch.Generator().manual_seed(seed)
    Xtr = 2 * torch.rand(n_train, in_dim, generator=gen) - 1.0
    Xte = 2 * torch.rand(n_test, in_dim, generator=gen) - 1.0

    def f(X):
        x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3]
        y = torch.sin(2 * math.pi * x1) + 0.5 * x2**2 - x3 * x4 + 0.3 * x1 * x2
        return y.unsqueeze(1)

    ytr = f(Xtr) + noise * torch.randn(n_train, 1, generator=gen)
    yte = f(Xte) + noise * torch.randn(n_test, 1, generator=gen)
    return (Xtr, ytr), (Xte, yte)


# --------------------------
# KAN layer (fixed RBF basis)
# y = sum_{j=1}^in sum_{k=1}^K C[i,j,k] * phi_k(x_j) + b_i
# --------------------------
class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, K=16, sigma=0.25):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.K = K
        # Fixed RBF centers in [-1,1]
        centers = torch.linspace(-1.0, 1.0, K)
        self.register_buffer("centers", centers)      # [K]
        self.register_buffer("sigma", torch.tensor(float(sigma)))
        # Coefficients and bias
        self.coeff = nn.Parameter(torch.empty(out_features, in_features, K))
        self.bias = nn.Parameter(torch.zeros(out_features))
        # Optional pruning mask (set later)
        self.register_buffer("mask", None, persistent=False)
        self.reset_parameters()

    def reset_parameters(self):
        # Small init helps stability
        nn.init.normal_(self.coeff, mean=0.0, std=0.1)
        nn.init.zeros_(self.bias)

    def phi(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, in_features]
        returns: phi [B, in_features, K]
        """
        # Broadcast over K basis
        # (x - c)^2 / (2 sigma^2)
        diff = x.unsqueeze(-1) - self.centers.view(1, 1, -1)
        phi = torch.exp(-0.5 * (diff / self.sigma).pow(2))
        return phi

    def forward(self, x):
        # phi: [B, in, K], coeff: [out, in, K]
        phi = self.phi(x)
        coeff = self.coeff if self.mask is None else self.coeff * self.mask
        y = torch.einsum("bik,oik->bo", phi, coeff) + self.bias
        return y


class KANRegressor(nn.Module):
    """
    Simple KAN model: one KAN layer -> tanh -> linear head
    """
    def __init__(self, in_dim=4, hidden=64, K=16, sigma=0.25):
        super().__init__()
        self.kan = KANLayer(in_dim, hidden, K=K, sigma=sigma)
        self.act = nn.Tanh()
        self.head = nn.Linear(hidden, 1)

    def forward(self, x):
        z = self.act(self.kan(x))
        y = self.head(z)
        return y


# --------------------------
# Neff pruning for KAN (exactly your rule)
# --------------------------
@torch.no_grad()
def neff_prune_kan_per_edge_basis(layer: KANLayer, beta: float = 1.0, renormalize: bool = False):
    """
    For each (i,j), consider the series w_k = |C[i,j,k]|.
    p_k = w_k / sum_k w_k
    Neff_ij = floor(1 / sum_k p_k^2)
    Keep the top floor(beta * Neff_ij) basis coefficients by magnitude and zero the rest.

    This is EXACTLY your Neff rule (no activation scaling, no extra factors).
    """
    C = layer.coeff.data  # [out, in, K]
    out, din, K = C.shape

    w = C.abs()
    wsum = w.sum(dim=2, keepdim=True)                  # [out, in, 1]
    # Normalize to simplex; if wsum==0 => p is 0; the clamp below ensures we keep at least 1
    p = w / (wsum + 1e-12)
    neff = 1.0 / (p.pow(2).sum(dim=2))                 # [out, in]
    r = torch.floor(beta * neff).clamp_(min=1, max=K).long()

    # Sort by |C| along k, keep top-r per (i,j)
    w_flat = w.view(-1, K)                              # [(out*in), K]
    _, idx = torch.sort(w_flat, dim=1, descending=True)
    range_k = torch.arange(K, device=C.device).view(1, -1).expand(w_flat.size(0), -1)
    keep = range_k < r.view(-1, 1)

    mask_flat = torch.zeros_like(w_flat, dtype=torch.bool)
    mask_flat.scatter_(1, idx, keep)
    mask = mask_flat.view(out, din, K)                 # bool

    # Apply mask
    C_masked = C * mask
    if renormalize:
        pre = C.abs().sum(dim=2, keepdim=True)         # L1 per (i,j)
        post = C_masked.abs().sum(dim=2, keepdim=True)
        C_masked = C_masked * (pre / (post + 1e-12))

    layer.coeff.data.copy_(C_masked)
    # store float mask to zero grads during finetune
    layer.mask = mask.to(C.dtype)

    # return useful stats
    kept = mask.sum().item()
    total = mask.numel()
    sparsity = 1.0 - kept / total
    return {"neff": neff, "mask": mask, "kept": kept, "total": total, "sparsity": sparsity}


def attach_grad_mask(model: nn.Module):
    """
    Make sure pruned parameters stay zero during finetuning.
    """
    for m in model.modules():
        if isinstance(m, KANLayer) and m.mask is not None:
            mask = m.mask  # float tensor [out,in,K]
            def _hook_factory(msk):
                return lambda g: g * msk
            m._coeff_mask_hook = m.coeff.register_hook(_hook_factory(mask))


def remove_grad_mask(model: nn.Module):
    for m in model.modules():
        if isinstance(m, KANLayer) and hasattr(m, "_coeff_mask_hook"):
            m._coeff_mask_hook.remove()
            delattr(m, "_coeff_mask_hook")


# --------------------------
# Training / evaluation
# --------------------------
@dataclass
class TrainConfig:
    epochs: int = 15
    batch_size: int = 512
    lr: float = 3e-3
    weight_decay: float = 0.0


def train(model, loader, cfg: TrainConfig, device):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    for epoch in range(cfg.epochs):
        total = 0.0
        n = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            pred = model(xb)
            loss = F.mse_loss(pred, yb)
            loss.backward()
            # ensure zeros remain zeros if masks exist
            for m in model.modules():
                if isinstance(m, KANLayer) and m.mask is not None:
                    if m.coeff.grad is not None:
                        m.coeff.grad.mul_(m.mask)
            opt.step()
            # hard-enforce zeros
            with torch.no_grad():
                for m in model.modules():
                    if isinstance(m, KANLayer) and m.mask is not None:
                        m.coeff.mul_(m.mask)
            total += loss.item() * xb.size(0)
            n += xb.size(0)
        print(f"  epoch {epoch+1:02d}/{cfg.epochs} | train MSE: {total/n:.6f}")
    return model


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total = 0.0
    n = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss = F.mse_loss(pred, yb)
        total += loss.item() * xb.size(0)
        n += xb.size(0)
    return total / n


# --------------------------
# Main experiment
# --------------------------
def main():
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # Data
    (Xtr, ytr), (Xte, yte) = make_regression(in_dim=4, n_train=20000, n_test=2000, noise=0.05, seed=2024)
    train_loader = DataLoader(TensorDataset(Xtr, ytr), batch_size=512, shuffle=True, drop_last=False)
    test_loader = DataLoader(TensorDataset(Xte, yte), batch_size=2048, shuffle=False)

    # Model
    model_orig = KANRegressor(in_dim=4, hidden=64, K=16, sigma=0.25).to(device)
    print("\n=== Train ORIGINAL KAN ===")
    cfg = TrainConfig(epochs=15, batch_size=512, lr=3e-3, weight_decay=0.0)
    model_orig = train(model_orig, train_loader, cfg, device)

    print("\nEvaluate ORIGINAL:")
    mse_orig = evaluate(model_orig, test_loader, device)
    print(f"Test MSE (original): {mse_orig:.6f}")

    # Clone and prune with Neff
    print("\n=== Neff PRUNE trained KAN (per-edge basis) ===")
    model_pruned = copy.deepcopy(model_orig).to(device)
    pr_stats = neff_prune_kan_per_edge_basis(model_pruned.kan, beta=1.0, renormalize=False)
    print(f"Sparsity on KAN.coeff: {pr_stats['sparsity']*100:.2f}% "
          f"({pr_stats['kept']}/{pr_stats['total']} nonzeros)")

    print("\nEvaluate PRUNED (one-shot, no finetune):")
    mse_pruned_oneshot = evaluate(model_pruned, test_loader, device)
    print(f"Test MSE (pruned, one-shot): {mse_pruned_oneshot:.6f}")

    # Finetune pruned (zeros stay zeros)
    print("\n=== Finetune PRUNED KAN (masked) ===")
    attach_grad_mask(model_pruned)
    cfg_ft = TrainConfig(epochs=10, batch_size=512, lr=2e-3, weight_decay=0.0)
    model_pruned = train(model_pruned, train_loader, cfg_ft, device)
    remove_grad_mask(model_pruned)

    print("\nEvaluate PRUNED after finetune:")
    mse_pruned_ft = evaluate(model_pruned, test_loader, device)
    print(f"Test MSE (pruned, finetune): {mse_pruned_ft:.6f}")

    # Summary
    total_coeff = model_orig.kan.coeff.numel()
    kept_coeff = count_nonzero(model_pruned.kan.coeff.data)
    sparsity_pct = 100.0 * (1.0 - kept_coeff / total_coeff)
    print("\n=== SUMMARY ===")
    print(f"Original params in KAN.coeff: {total_coeff}")
    print(f"Kept after Neff pruning:      {kept_coeff}  (sparsity = {sparsity_pct:.2f}%)")
    print(f"Test MSE original:            {mse_orig:.6f}")
    print(f"Test MSE pruned (one-shot):   {mse_pruned_oneshot:.6f}")
    print(f"Test MSE pruned (finetuned):  {mse_pruned_ft:.6f}")


if __name__ == "__main__":
    main()


Device: cuda

=== Train ORIGINAL KAN ===
  epoch 01/15 | train MSE: 0.296288
  epoch 02/15 | train MSE: 0.118026
  epoch 03/15 | train MSE: 0.038095
  epoch 04/15 | train MSE: 0.022506
  epoch 05/15 | train MSE: 0.016441
  epoch 06/15 | train MSE: 0.013015
  epoch 07/15 | train MSE: 0.010353
  epoch 08/15 | train MSE: 0.008936
  epoch 09/15 | train MSE: 0.007627
  epoch 10/15 | train MSE: 0.006760
  epoch 11/15 | train MSE: 0.006020
  epoch 12/15 | train MSE: 0.005554
  epoch 13/15 | train MSE: 0.005169
  epoch 14/15 | train MSE: 0.005272
  epoch 15/15 | train MSE: 0.004794

Evaluate ORIGINAL:
Test MSE (original): 0.004648

=== Neff PRUNE trained KAN (per-edge basis) ===
Sparsity on KAN.coeff: 38.82% (2506/4096 nonzeros)

Evaluate PRUNED (one-shot, no finetune):
Test MSE (pruned, one-shot): 0.010057

=== Finetune PRUNED KAN (masked) ===
  epoch 01/10 | train MSE: 0.005856
  epoch 02/10 | train MSE: 0.004529
  epoch 03/10 | train MSE: 0.004515
  epoch 04/10 | train MSE: 0.004306
  epoch