In [1]:
import torch
import torch.nn.functional as F
import math

"""credit to opensource https://github.com/Blealtan/efficient-kan"""

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [2]:
# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 938/938 [00:06<00:00, 136.33it/s, accuracy=0.875, loss=0.314, lr=0.001] 


Epoch 1, Val Loss: 0.23603566991058506, Val Accuracy: 0.9290406050955414


100%|██████████| 938/938 [00:06<00:00, 136.44it/s, accuracy=1, loss=0.0354, lr=0.0008]    


Epoch 2, Val Loss: 0.15943998358194617, Val Accuracy: 0.9548168789808917


100%|██████████| 938/938 [00:06<00:00, 139.90it/s, accuracy=0.969, loss=0.0508, lr=0.00064]


Epoch 3, Val Loss: 0.14269541657179785, Val Accuracy: 0.9600915605095541


100%|██████████| 938/938 [00:07<00:00, 118.11it/s, accuracy=0.969, loss=0.134, lr=0.000512] 


Epoch 4, Val Loss: 0.11162164308735804, Val Accuracy: 0.9666600318471338


100%|██████████| 938/938 [00:09<00:00, 101.68it/s, accuracy=1, loss=0.00867, lr=0.00041]   


Epoch 5, Val Loss: 0.10873399607119429, Val Accuracy: 0.9674562101910829


100%|██████████| 938/938 [00:07<00:00, 125.10it/s, accuracy=1, loss=0.00895, lr=0.000328]   


Epoch 6, Val Loss: 0.09907284306709292, Val Accuracy: 0.9704418789808917


100%|██████████| 938/938 [00:07<00:00, 126.58it/s, accuracy=1, loss=0.0279, lr=0.000262]    


Epoch 7, Val Loss: 0.09754395741557073, Val Accuracy: 0.9692476114649682


100%|██████████| 938/938 [00:08<00:00, 109.79it/s, accuracy=0.969, loss=0.0769, lr=0.00021]


Epoch 8, Val Loss: 0.09209574455362715, Val Accuracy: 0.9721337579617835


100%|██████████| 938/938 [00:08<00:00, 105.37it/s, accuracy=1, loss=0.0155, lr=0.000168]    


Epoch 9, Val Loss: 0.0904715131185237, Val Accuracy: 0.9710390127388535


100%|██████████| 938/938 [00:07<00:00, 120.90it/s, accuracy=0.969, loss=0.078, lr=0.000134] 


Epoch 10, Val Loss: 0.08993846593294175, Val Accuracy: 0.9722332802547771


In [3]:
@torch.no_grad()
def eval_model(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss()
    tot_loss, tot_corr, tot_n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.view(x.size(0), -1).to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        tot_loss += loss.item() * x.size(0)
        tot_corr += (logits.argmax(1) == y).sum().item()
        tot_n += x.size(0)
    return tot_loss / tot_n, tot_corr / tot_n


@torch.no_grad()
def kan_node_scores(model: KAN):
    scores = []
    L = len(model.layers)  # number of KANLinear layers
    # per-layer |phi|_1 matrix (out, in)
    phi_l1 = []
    for layer in model.layers:
        # use scaled spline weights (includes spline_scaler if present)
        sw = layer.scaled_spline_weight  # (out, in, coeff)
        m = sw.abs().mean(dim=-1)        # (out, in)
        phi_l1.append(m)

    for l in range(1, L):
        prev = model.layers[l-1]             # maps n_{l-1} -> n_l
        prev_m = phi_l1[l-1]                 # shape (n_l, n_{l-1})
        next_m = phi_l1[l] if l < L else None

        I = prev_m.max(dim=1).values  # (n_l,)
        O = phi_l1[l].max(dim=0).values if l < L else torch.zeros_like(I)
        score = torch.maximum(I, O)
        scores.append(score.cpu())
    return scores  # length = #hidden layers


def EMP_kan(score_1d: torch.Tensor):
    s = score_1d.clone().float()
    s[s < 0] = 0
    if s.sum() <= 0:
        s = torch.ones_like(s)
    w = s / s.sum()
    w_sorted, idx_sorted = torch.sort(w, descending=True)
    neff = int(torch.floor(1.0 / torch.sum(w**2)).item())
    neff = max(1, min(neff, w.numel()))
    keep_idx = idx_sorted[:neff]
    keep_idx, _ = torch.sort(keep_idx)  # stable order
    return keep_idx.tolist(), neff


def _new_kan_like(old: KAN, layers_hidden):
    return KAN(
        layers_hidden=layers_hidden,
        grid_size=old.grid_size,
        spline_order=old.spline_order,
        scale_noise=old.layers[0].scale_noise,
        scale_base=old.layers[0].scale_base,
        scale_spline=old.layers[0].scale_spline,
        base_activation=type(old.layers[0].base_activation),
        grid_eps=old.layers[0].grid_eps,
        grid_range=[old.layers[0].grid[0,0].item(), old.layers[0].grid[0,-1].item()],
    )

@torch.no_grad()
def _slice_layer(old_layer: KANLinear, keep_in, keep_out):
    new_layer = KANLinear(
        in_features=len(keep_in),
        out_features=len(keep_out),
        grid_size=old_layer.grid_size,
        spline_order=old_layer.spline_order,
        scale_noise=old_layer.scale_noise,
        scale_base=old_layer.scale_base,
        scale_spline=old_layer.scale_spline,
        enable_standalone_scale_spline=old_layer.enable_standalone_scale_spline,
        base_activation=type(old_layer.base_activation),
        grid_eps=old_layer.grid_eps,
        grid_range=[old_layer.grid[0,0].item(), old_layer.grid[0,-1].item()],
    )
    device = old_layer.base_weight.device
    new_layer = new_layer.to(device)

    ki = torch.as_tensor(keep_in, dtype=torch.long, device=device)
    ko = torch.as_tensor(keep_out, dtype=torch.long, device=device)

    # base weights
    new_layer.base_weight.data.copy_(old_layer.base_weight.data.index_select(0, ko).index_select(1, ki))
    # spline weights
    new_layer.spline_weight.data.copy_(old_layer.spline_weight.data.index_select(0, ko).index_select(1, ki))
    # spline scaler (if exists)
    if old_layer.enable_standalone_scale_spline:
        new_layer.spline_scaler.data.copy_(old_layer.spline_scaler.data.index_select(0, ko).index_select(1, ki))
    # copy the selected input grids
    new_layer.grid.data.copy_(old_layer.grid.index_select(0, ki).contiguous())
    return new_layer

@torch.no_grad()
def prune_kan_by_neff(model: KAN, device):
    model.eval()
    # 1) node scores per hidden layer
    scores = kan_node_scores(model)  # list of tensors
    keep_lists = []
    for s in scores:
        keep, neff = EMP_kan(s)
        keep_lists.append(keep)

    # 2) new widths
    widths_old = [model.layers[0].in_features] + [ly.out_features for ly in model.layers]
    hidden_kept = [len(k) for k in keep_lists]
    widths_new = [widths_old[0]] + hidden_kept + [widths_old[-1]]

    # 3) construct new model and copy slices
    pruned = _new_kan_like(model, widths_new).to(device)
    # layer 0: keep_out = keep_lists[0], keep_in = all inputs
    all_in0 = list(range(widths_old[0]))
    pruned.layers[0] = _slice_layer(model.layers[0], keep_in=all_in0, keep_out=keep_lists[0])

    # middle layers (if any)
    for l in range(1, len(model.layers)-1):
        keep_in = keep_lists[l-1]
        keep_out = keep_lists[l]
        pruned.layers[l] = _slice_layer(model.layers[l], keep_in=keep_in, keep_out=keep_out)

    # last layer: keep_in = last hidden keep, keep_out = all outputs
    last = len(model.layers)-1
    keep_in_last = keep_lists[-1]
    all_out_last = list(range(model.layers[last].out_features))
    pruned.layers[last] = _slice_layer(model.layers[last], keep_in=keep_in_last, keep_out=all_out_last)

    return pruned, keep_lists

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
orig_val_loss, orig_val_acc = eval_model(model, valloader, device)
print(f"Original  - val_loss={orig_val_loss:.4f}, val_acc={orig_val_acc:.4f}")

pruned_model, kept = prune_kan_by_neff(model, device)
pruned_val_loss, pruned_val_acc = eval_model(pruned_model, valloader, device)
print("Kept node indices per hidden layer:", kept)
print(f"Pruned    - val_loss={pruned_val_loss:.4f}, val_acc={pruned_val_acc:.4f}")


Original  - val_loss=0.0903, val_acc=0.9721
Kept node indices per hidden layer: [[0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 16, 18, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 35, 36, 37, 38, 39, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62]]
Pruned    - val_loss=0.1350, val_acc=0.9581
