In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import copy

from kan import KAN as pyKan
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [1]:
# kan_neff_prune.py
import math
import numpy as np
import torch
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# ---- KAN imports (official library) ----
#   Docs: https://kindxiaoming.github.io/pykan/
from kan import KAN
from kan.utils import create_dataset_from_data

# -------------------------------
# Utility: Neff from a score vector
# -------------------------------
def neff_topk_indices(scores: torch.Tensor, beta: float = 1.0, min_keep: int = 1):
    """
    scores: 1D tensor of nonnegative node importances for a layer
    returns: indices of top floor(beta * Neff) elements
    """
    s = scores.clone()
    
    if torch.all(s == 0):
        # fallback: keep min_keep
        k = min(min_keep, s.numel())
        return torch.topk(s, k, largest=True).indices
    
    abs_s = torch.abs(s)
    s_norm = abs_s / (abs_s.sum() + 1e-12)
    neff = 1.0 / (torch.sum(s_norm ** 2) + 1e-12)
    k = int(torch.floor(beta * neff).item())
    k = max(min_keep, min(k, s.numel()))
    return torch.topk(s, k, largest=True).indices

# -----------------------------------------
# Compute incoming/outgoing node scores
# -----------------------------------------
@torch.no_grad()
def node_scores_in_out(model: KAN, device: torch.device):
    """
    Returns a list (for each hidden layer) of tuples (scores, width)
    where scores[i] = min(incoming[i], outgoing[i]) for neuron i in that layer.
    """
    # Make sure attribution scores are computed
    model.attribute(plot=False)  # populates model.edge_scores / node_scores internally

    # edge_scores[l] is a 2D tensor of shape [out_dim_l, in_dim_l] for layer l
    # Hidden layers are 1..L-2 (exclude input layer 0 and last output layer L-1)
    L = len(model.width)
    scores_per_layer = []

    # We must have edge scores for layers 0..L-2 (mappings 0->1, 1->2, ..., L-2->L-1)
    edge_scores = [torch.as_tensor(es, device=device, dtype=torch.float32)
                   for es in model.edge_scores]  # list length L-1

    for l in range(1, L-1):  # hidden layers only
        # Incoming to layer l: from l-1 -> l are the *rows* of edge_scores[l-1] indexed by neurons in layer l
        # edge_scores[l-1] has shape [out=l_width, in=(l-1)_width]
        E_prev = edge_scores[l-1]              # shape [width[l], width[l-1]]
        incoming = E_prev.abs().max(dim=1).values  # max over incoming edges (across input dim) -> size [width[l]]

        # Outgoing from layer l: from l -> l+1 are the *columns* of edge_scores[l] indexed by neurons in layer l
        # edge_scores[l] has shape [width[l+1], width[l]]
        E_this = edge_scores[l]                # shape [width[l+1], width[l]]
        outgoing = E_this.abs().max(dim=0).values   # max over outgoing edges (across output dim) -> size [width[l]]

        score = torch.minimum(incoming, outgoing)    # keep if both large (min proxy)
        scores_per_layer.append((score, model.width[l]))
    return scores_per_layer

# -----------------------------------------------------------
# Manual Neff-pruning of nodes using KAN's prune_node(..., mode="manual")
# -----------------------------------------------------------
@torch.no_grad()
def neff_prune_nodes(model: KAN, x_for_attr: torch.Tensor, beta: float = 1.0, device: str = "cpu"):
    """
    Uses incoming/outgoing node scores to select top floor(beta * Neff) neurons per hidden layer,
    then calls model.prune_node(mode="manual", active_neurons_id=...).
    Returns a NEW pruned model.
    """
    device = torch.device(device)
    # Ensure activations/scores are up-to-date
    model(x_for_attr)
    scores_per_layer = node_scores_in_out(model, device=device)

    # Build active id list per hidden layer. active_neurons_id is a list of per-layer id lists.
    active_neurons_id = []
    for (scores, width_l) in scores_per_layer:
        keep_idx = neff_topk_indices(scores, beta=beta, min_keep=1).tolist()
        active_neurons_id.append(keep_idx)

    # Call manual pruning
    pruned = model.prune_node(mode="manual", active_neurons_id=active_neurons_id)
    return pruned

# -------------------------------
# Simple evaluation (MSE)
# -------------------------------
@torch.no_grad()
def mse(model: KAN, x: torch.Tensor, y: torch.Tensor) -> float:
    pred = model(x)
    return torch.mean((pred - y) ** 2).item()

# -------------------------------
# Train + prune + report
# -------------------------------
def run_diabetes(beta=1.0, seed=0, device="cuda", steps=300, grid=5, k=3, hidden=16, lamb=1e-3):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # ---- Data ----
    X, y = load_diabetes(return_X_y=True)  # (442, 10), regression
    y = y.reshape(-1, 1)

    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=seed)
    x_scaler = StandardScaler().fit(Xtr)
    y_scaler = StandardScaler().fit(ytr)

    Xtr = x_scaler.transform(Xtr).astype(np.float32)
    Xte = x_scaler.transform(Xte).astype(np.float32)
    ytr = y_scaler.transform(ytr).astype(np.float32)
    yte = y_scaler.transform(yte).astype(np.float32)

    xtr_t = torch.from_numpy(Xtr).to(device)
    xte_t = torch.from_numpy(Xte).to(device)
    ytr_t = torch.from_numpy(ytr).to(device)
    yte_t = torch.from_numpy(yte).to(device)

    dataset = {
        "train_input": xtr_t, "train_label": ytr_t,
        "test_input":  xte_t, "test_label":  yte_t
    }


    # ---- Model ----
    in_dim = Xtr.shape[1]
    width = [in_dim, hidden, 1]  # one hidden layer for demo
    model = KAN(width=width, grid=grid, k=k, seed=seed, device=device)
    # Tip from authors: if you won't use symbolic branch and want custom loops, call model.speed()
    # We use model.fit() here (built-in). :contentReference[oaicite:4]{index=4}

    # ---- Train (with mild sparsity) ----
    model.fit(dataset, opt="LBFGS", steps=steps, lamb=lamb)  # adds L1+entropy regularization under the hood :contentReference[oaicite:5]{index=5}

    # ---- Evaluate before pruning ----
    pre_train_mse = mse(model, xtr_t, ytr_t)
    pre_test_mse  = mse(model, xte_t, yte_t)
    pre_widths = model.width.copy()

    # ---- Neff pruning on nodes (incoming/outgoing) ----
    # Use training inputs for attribution; could also use full dataset
    pruned = neff_prune_nodes(model, x_for_attr=xtr_t, beta=beta, device=device)

    # ---- Evaluate after pruning ----
    post_train_mse = mse(pruned, xtr_t, ytr_t)
    post_test_mse  = mse(pruned, xte_t, yte_t)
    post_widths = pruned.width.copy()

    # ---- Report ----
    report = {
        "beta": beta,
        "width_before": pre_widths,
        "width_after": post_widths,
        "train_mse_before": pre_train_mse,
        "train_mse_after": post_train_mse,
        "test_mse_before": pre_test_mse,
        "test_mse_after": post_test_mse,
    }
    return report

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")
    # Single run
    rep = run_diabetes(beta=1.0, device=device)
    print("=== Single-run report (beta=1.0) ===")
    for k, v in rep.items():
        print(f"{k}: {v}")

    # Optional: sweep beta
    betas = [0.5, 2/3, 0.75, 0.8, 1.0, 5/4, 4/3, 1.5, 2.0]
    print("\n=== Beta sweep ===")
    for b in betas:
        r = run_diabetes(beta=b, device=device)
        print(f"beta={b:.3f} | pre_width={r['width_before']} -> post_width={r['width_after']} | "
              f"test MSE {r['test_mse_before']:.4f} -> {r['test_mse_after']:.4f}")


Device: cuda
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:57<00:00,  5.20


saving model version 0.1
saving model version 0.2
=== Single-run report (beta=1.0) ===
beta: 1.0
width_before: [[10, 0], [16, 0], [1, 0]]
width_after: [[10, 0], [10, 0], [1, 0]]
train_mse_before: 9.751252946443856e-05
train_mse_after: 0.107933409512043
test_mse_before: 1.139122724533081
test_mse_after: 1.2132433652877808

=== Beta sweep ===
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.33


saving model version 0.1
saving model version 0.2
beta=0.500 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [5, 0], [1, 0]] | test MSE 1.1391 -> 1.3921
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.31


saving model version 0.1
saving model version 0.2
beta=0.667 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [6, 0], [1, 0]] | test MSE 1.1391 -> 1.3682
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.34


saving model version 0.1
saving model version 0.2
beta=0.750 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [7, 0], [1, 0]] | test MSE 1.1391 -> 1.8089
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.35


saving model version 0.1
saving model version 0.2
beta=0.800 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [8, 0], [1, 0]] | test MSE 1.1391 -> 1.5436
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.33


saving model version 0.1
saving model version 0.2
beta=1.000 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [10, 0], [1, 0]] | test MSE 1.1391 -> 1.2132
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 9.87e-03 | test_loss: 1.07e+00 | reg: 2.60e+01 | : 100%|█| 300/300 [00:56<00:00,  5.31


saving model version 0.1
saving model version 0.2
beta=1.250 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [12, 0], [1, 0]] | test MSE 1.1391 -> 1.1409
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.54e-03 | test_loss: 1.15e+00 | reg: 2.56e+01 | : 100%|█| 300/300 [00:58<00:00,  5.13


saving model version 0.1
saving model version 0.2
beta=1.333 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [12, 0], [1, 0]] | test MSE 1.3215 -> 1.3377
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.54e-03 | test_loss: 1.15e+00 | reg: 2.56e+01 | : 100%|█| 300/300 [00:57<00:00,  5.21


saving model version 0.1
saving model version 0.2
beta=1.500 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [13, 0], [1, 0]] | test MSE 1.3215 -> 1.3224
checkpoint directory created: ./model
saving model version 0.0


| train_loss: 8.54e-03 | test_loss: 1.15e+00 | reg: 2.56e+01 | : 100%|█| 300/300 [00:55<00:00,  5.42

saving model version 0.1
saving model version 0.2
beta=2.000 | pre_width=[[10, 0], [16, 0], [1, 0]] -> post_width=[[10, 0], [16, 0], [1, 0]] | test MSE 1.3215 -> 1.3215



