In [18]:
# KAN PyTorch + régularisation de lissage des splines (et option sparsity)
# Idée: pénaliser la "courbure discrète" des coefficients spline:
#   smooth = Σ (c_{k+1} - 2 c_k + c_{k-1})^2
# Option: L1 sur les amplitudes (encourage sparsité des arêtes)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


# -----------------------------
# Utils: B-splines 1D (Cox–de Boor)
# -----------------------------
def make_open_uniform_knots(n_basis: int, degree: int, xmin: float, xmax: float, device):
    assert n_basis >= degree + 1
    n_knots = n_basis + degree + 1
    n_inner = n_knots - 2 * (degree + 1)
    if n_inner > 0:
        inner = torch.linspace(xmin, xmax, steps=n_inner + 2, device=device)[1:-1]
        knots = torch.cat([
            torch.full((degree + 1,), xmin, device=device),
            inner,
            torch.full((degree + 1,), xmax, device=device),
        ])
    else:
        knots = torch.cat([
            torch.full((degree + 1,), xmin, device=device),
            torch.full((degree + 1,), xmax, device=device),
        ])
    return knots


def bspline_basis(x: torch.Tensor, knots: torch.Tensor, degree: int, n_basis: int):
    x = x.reshape(-1)  # [B]
    B = x.shape[0]
    t = knots
    device = x.device

    # degree 0
    N = []
    for i in range(n_basis):
        left, right = t[i], t[i + 1]
        cond = (x >= left) & (x < right)
        N.append(cond.to(x.dtype))
    N = torch.stack(N, dim=1)  # [B, n_basis]

    xmax = t[-1]
    at_right = (x == xmax)
    if at_right.any():
        N[at_right, :] = 0
        N[at_right, -1] = 1

    # recursion
    for d in range(1, degree + 1):
        N_new = torch.zeros((B, n_basis), device=device, dtype=x.dtype)
        for i in range(n_basis):
            denom1 = t[i + d] - t[i]
            if denom1 != 0:
                a = (x - t[i]) / denom1
                left_term = a * N[:, i]
            else:
                left_term = 0.0

            denom2 = t[i + d + 1] - t[i + 1]
            if denom2 != 0 and i + 1 < n_basis:
                b = (t[i + d + 1] - x) / denom2
                right_term = b * N[:, i + 1]
            else:
                right_term = 0.0

            N_new[:, i] = left_term + right_term
        N = N_new
    return N


# -----------------------------
# KAN Layer
# -----------------------------
class KANSplineLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int,
                 n_basis: int = 16, degree: int = 3,
                 xmin: float = -1.0, xmax: float = 1.0):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_basis = n_basis
        self.degree = degree
        self.xmin = float(xmin)
        self.xmax = float(xmax)

        self.coeff = nn.Parameter(0.01 * torch.randn(out_dim, in_dim, n_basis))
        self.bias = nn.Parameter(torch.zeros(out_dim))

        knots = make_open_uniform_knots(n_basis, degree, self.xmin, self.xmax, device=torch.device("cpu"))
        self.register_buffer("knots", knots)

    def forward(self, x: torch.Tensor):
        B, D = x.shape
        assert D == self.in_dim
        x = torch.clamp(x, self.xmin, self.xmax)

        out = torch.zeros((B, self.out_dim), device=x.device, dtype=x.dtype)
        for j in range(self.in_dim):
            basis = bspline_basis(x[:, j], self.knots, self.degree, self.n_basis)  # [B, K]
            out += basis @ self.coeff[:, j, :].transpose(0, 1)  # [B, out]
        return out + self.bias


class KAN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 n_basis: int = 16, degree: int = 3,
                 xmin: float = -1.0, xmax: float = 1.0):
        super().__init__()
        self.l1 = KANSplineLayer(in_dim, hidden_dim, n_basis, degree, xmin, xmax)
        self.l2 = KANSplineLayer(hidden_dim, out_dim, n_basis, degree, xmin, xmax)

    def forward(self, x):
        h = torch.tanh(self.l1(x))
        return self.l2(h)

    def spline_smoothness_penalty(self):
        """
        Penalise discrete second derivative of spline coefficients.
        Returns scalar tensor.
        """
        penalty = 0.0
        for layer in (self.l1, self.l2):
            c = layer.coeff  # [out, in, K]
            # second difference along K: c[k+1] - 2c[k] + c[k-1]
            d2 = c[..., 2:] - 2.0 * c[..., 1:-1] + c[..., :-2]  # [out, in, K-2]
            penalty = penalty + (d2 ** 2).mean()
        return penalty

    def l1_amplitude_penalty(self):
        """
        Optional: encourage sparse/small functions on edges (like "edge pruning").
        """
        pen = 0.0
        for layer in (self.l1, self.l2):
            pen = pen + layer.coeff.abs().mean()
        return pen


# -----------------------------
# Dataset multivarié synthétique
# -----------------------------
def make_dataset(n=30000, in_dim=6, noise=0.05, seed=42, device="cpu"):
    torch.manual_seed(seed)

    X = (torch.rand((n, in_dim), device=device) * 2 - 1)  # [-1,1]

    y = (
        torch.sin(math.pi * X[:, 0]) +
        0.5 * X[:, 1] ** 2 -
        X[:, 2] * X[:, 3] +
        torch.exp(0.7 * X[:, 4]) -
        0.3 * torch.cos(3.0 * X[:, 5])
    ).unsqueeze(1)

    y = y + noise * torch.randn(y.shape, device=device)
    return X, y



# -----------------------------
# Train (avec lissage + clipping)
# -----------------------------
def train_kan(device="cuda" if torch.cuda.is_available() else "cpu"):
    in_dim = 6
    X, y = make_dataset(n=30000, in_dim=in_dim, noise=0.05, seed=42, device=device)

    n_train = int(0.85 * X.shape[0])
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=512, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=2048, shuffle=False)

    model = KAN(in_dim=in_dim, hidden_dim=32, out_dim=1, n_basis=24, degree=3, xmin=-1.0, xmax=1.0).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-4)

    # Hyperparams régularisation
    lambda_smooth = 1e-2   # lissage (augmente si ça "vibre", diminue si sous-fit)
    lambda_l1 = 0.0        # mets 1e-4 ou 1e-5 si tu veux sparsifier

    def eval_mse():
        model.eval()
        mse_sum = 0.0
        n = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                pred = model(xb)
                mse_sum += F.mse_loss(pred, yb, reduction="sum").item()
                n += yb.numel()
        return mse_sum / n
    
    best_val = float("inf")


    for epoch in range(1, 31):
        model.train()
        train_mse_sum = 0.0
        n_seen = 0

        for xb, yb in train_loader:
            opt.zero_grad(set_to_none=True)

            pred = model(xb)
            mse = F.mse_loss(pred, yb)

            smooth = model.spline_smoothness_penalty()
            l1 = model.l1_amplitude_penalty() if lambda_l1 > 0 else 0.0

            loss = mse + lambda_smooth * smooth + (lambda_l1 * l1 if lambda_l1 > 0 else 0.0)
            loss.backward()

            # Stabilise gradients (très utile sur KAN)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            opt.step()

            train_mse_sum += mse.item() * yb.shape[0]
            n_seen += yb.shape[0]

        train_mse = train_mse_sum / n_seen
        val_mse = eval_mse()

        if epoch % 5 == 0 or epoch == 1:
            s = float(model.spline_smoothness_penalty().detach().cpu())
            print(f"epoch {epoch:02d} | train MSE={train_mse:.6f} | val MSE={val_mse:.6f} | smooth={s:.6f}")
            
       
        
          # Sauvegarde si meilleur modèle
        if val_mse < best_val:
            best_val = val_mse
            torch.save(model.state_dict(), "kan_model.pt")

        print(f"epoch {epoch} | val MSE={val_mse:.6f} | best={best_val:.6f}")

    return model


In [67]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


# -----------------------------
# 1) Comptage de paramètres
# -----------------------------
def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# -----------------------------
# 2) MLP "flexible" (N hidden layers)
# -----------------------------
class FlexibleMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dims: list[int], out_dim: int, act: str = "tanh"):
        super().__init__()
        acts = {
            "tanh": nn.Tanh(),
            "relu": nn.ReLU(),
            "gelu": nn.GELU(),
            "silu": nn.SiLU(),
        }
        if act not in acts:
            raise ValueError(f"Unknown act={act}. Choose from {list(acts.keys())}")

        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(acts[act])
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# -----------------------------
# 3) Estimation du nb de params d'un MLP
# -----------------------------
def mlp_param_count(in_dim: int, hidden_dims: list[int], out_dim: int) -> int:
    # Chaque Linear: weights + bias = (prev*h) + h
    total = 0
    prev = in_dim
    for h in hidden_dims:
        total += prev * h + h
        prev = h
    total += prev * out_dim + out_dim
    return total


# -----------------------------
# 4) Trouver une architecture MLP proche d'un target_params
#    - mode "1 hidden layer" ou "2 hidden layers" (recommandé)
# -----------------------------
def find_mlp_dims_close_to_target(
    target_params: int,
    in_dim: int,
    out_dim: int,
    n_hidden_layers: int = 2,
    h_min: int = 4,
    h_max: int = 4096,
) -> list[int]:
    """
    Retourne hidden_dims (liste) qui minimise |params - target_params|.
    - n_hidden_layers=1 : résout quasi-analytique
    - n_hidden_layers=2 : petite recherche grille intelligente
    """
    if n_hidden_layers == 1:
        # params = (in*h + h) + (h*out + out) = h*(in+1+out) + out
        denom = in_dim + 1 + out_dim
        h = max(h_min, int(round((target_params - out_dim) / denom)))
        h = min(h_max, h)
        return [h]

    if n_hidden_layers == 2:
        # Recherche simple : h1 dans une plage, h2 déduit approximativement
        best = None
        best_dims = None

        # On balaye h1 sur une grille log-ish (plus large)
        # Tu peux densifier si tu veux.
        for h1 in [int(x) for x in torch.unique(torch.logspace(math.log10(h_min), math.log10(min(2048, h_max)), steps=60)).tolist()]:
            # params = (in*h1+h1) + (h1*h2+h2) + (h2*out+out)
            #       = h1*(in+1) + h2*(h1+1+out) + out
            a = h1 * (in_dim + 1) + out_dim
            denom = (h1 + 1 + out_dim)
            # h2 approx
            h2 = int(round((target_params - a) / max(1, denom)))
            if h2 < h_min or h2 > h_max:
                continue

            # On regarde autour de h2 pour ajuster finement
            for h2_try in [h2 - 2, h2 - 1, h2, h2 + 1, h2 + 2]:
                if h2_try < h_min or h2_try > h_max:
                    continue
                p = mlp_param_count(in_dim, [h1, h2_try], out_dim)
                err = abs(p - target_params)
                if (best is None) or (err < best):
                    best = err
                    best_dims = [h1, h2_try]

        # Fallback si la recherche n'a rien trouvé
        if best_dims is None:
            # fallback 1 couche
            return find_mlp_dims_close_to_target(target_params, in_dim, out_dim, n_hidden_layers=1, h_min=h_min, h_max=h_max)

        return best_dims

    raise ValueError("n_hidden_layers must be 1 or 2")


# -----------------------------
# 5) Entraînement MLP avec target param count
#    - Tu fournis ton KAN déjà créé (pour récupérer son nb de params)
# -----------------------------
def train_mlp_matched_to_kan(
    kan_model: nn.Module,
    device="cuda" if torch.cuda.is_available() else "cpu",
    act: str = "tanh",
    n_hidden_layers: int = 2,
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    epochs: int = 30,
):
    # Tu réutilises TON make_dataset existant
    in_dim = 6
    out_dim = 1

    target_params = count_params(kan_model)
    hidden_dims = find_mlp_dims_close_to_target(
        target_params=target_params,
        in_dim=in_dim,
        out_dim=out_dim,
        n_hidden_layers=n_hidden_layers,
        h_min=8,
        h_max=4096,
    )

    mlp = FlexibleMLP(in_dim=in_dim, hidden_dims=hidden_dims, out_dim=out_dim, act=act).to(device)
    mlp_params = count_params(mlp)

    print(f"[MATCH] KAN params={target_params} | MLP hidden_dims={hidden_dims} | MLP params={mlp_params} | diff={mlp_params-target_params}")

    X, y = make_dataset(n=30000, in_dim=in_dim, noise=0.05, seed=42, device=device)

    n_train = int(0.85 * X.shape[0])
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=512, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=2048, shuffle=False)

    opt = torch.optim.AdamW(mlp.parameters(), lr=lr, weight_decay=weight_decay)
    
    scheduler = torch.optim.lr_scheduler.StepLR(
        opt,
        step_size=60,   # baisse toutes les 60 epochs
        gamma=0.3       # LR = LR × 0.3
    )



    def eval_mse():
        mlp.eval()
        mse_sum, n = 0.0, 0
        with torch.no_grad():
            for xb, yb in val_loader:
                pred = mlp(xb)
                mse_sum += F.mse_loss(pred, yb, reduction="sum").item()
                n += yb.numel()
        return mse_sum / n
    
    best_val = float("inf")

    for epoch in range(1, epochs + 1):
        mlp.train()
        train_mse_sum, n_seen = 0.0, 0

        for xb, yb in train_loader:
            opt.zero_grad(set_to_none=True)
            pred = mlp(xb)
            loss = F.mse_loss(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(mlp.parameters(), 1.0)
            opt.step()

            train_mse_sum += loss.item() * yb.shape[0]
            n_seen += yb.shape[0]

        train_mse = train_mse_sum / n_seen
        val_mse = eval_mse()
        scheduler.step()

        if epoch % 5 == 0 or epoch == 1:
            print(f"[MLP-MATCH] epoch {epoch:02d} | train MSE={train_mse:.6f} | val MSE={val_mse:.6f}")
            
        # Sauvegarde si meilleur modèle
        if val_mse < best_val:
            best_val = val_mse
            torch.save(mlp.state_dict(), "best_model.pt")

        print(f"epoch {epoch} | val MSE={val_mse:.6f} | best={best_val:.6f}")

    return mlp


In [20]:
import time, os
import torch

def now_sync(device):
    if device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.synchronize()

def timed(fn, device, *args, **kwargs):
    now_sync(device)
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    now_sync(device)
    t1 = time.perf_counter()
    return out, (t1 - t0)

def save_best_state_dict(model, path):
    torch.save(model.state_dict(), path)

def load_state_dict(model, path, device):
    sd = torch.load(path, map_location=device)
    model.load_state_dict(sd)
    return model

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# KAN
kan, t_kan = timed(lambda: train_kan(device=device), device)

print(f"KAN train time: {t_kan:.2f}s | params: {count_params(kan)}")


epoch 01 | train MSE=1.158183 | val MSE=0.523067 | smooth=0.000695
epoch 1 | val MSE=0.523067 | best=0.523067
epoch 2 | val MSE=0.127265 | best=0.127265
epoch 3 | val MSE=0.114624 | best=0.114624
epoch 4 | val MSE=0.079720 | best=0.079720
epoch 05 | train MSE=0.062709 | val MSE=0.049747 | smooth=0.001243
epoch 5 | val MSE=0.049747 | best=0.049747
epoch 6 | val MSE=0.028451 | best=0.028451
epoch 7 | val MSE=0.014021 | best=0.014021
epoch 8 | val MSE=0.007070 | best=0.007070
epoch 9 | val MSE=0.004226 | best=0.004226
epoch 10 | train MSE=0.003570 | val MSE=0.003346 | smooth=0.001872
epoch 10 | val MSE=0.003346 | best=0.003346
epoch 11 | val MSE=0.003007 | best=0.003007
epoch 12 | val MSE=0.002910 | best=0.002910
epoch 13 | val MSE=0.002849 | best=0.002849
epoch 14 | val MSE=0.002852 | best=0.002849
epoch 15 | train MSE=0.002686 | val MSE=0.002781 | smooth=0.001859
epoch 15 | val MSE=0.002781 | best=0.002781
epoch 16 | val MSE=0.002802 | best=0.002781
epoch 17 | val MSE=0.002878 | best=0.

In [28]:
mlp, t_mlp = timed(lambda: train_mlp_matched_to_kan(
    kan_model=kan,
    device=device,
    n_hidden_layers=2,
    act="tanh",
    epochs=500,
    lr=1e-3,
    weight_decay=1e-4
), device)


print(f"MLP train time: {t_mlp:.2f}s | params: {count_params(mlp)}")


[MATCH] KAN params=5409 | MLP hidden_dims=[9, 486] | MLP params=5410 | diff=1




[MLP-MATCH] epoch 01 | train MSE=0.781585 | val MSE=0.439276
epoch 1 | val MSE=0.439276 | best=0.439276
epoch 2 | val MSE=0.362606 | best=0.362606
epoch 3 | val MSE=0.351370 | best=0.351370
epoch 4 | val MSE=0.335211 | best=0.335211
[MLP-MATCH] epoch 05 | train MSE=0.331982 | val MSE=0.316070
epoch 5 | val MSE=0.316070 | best=0.316070
epoch 6 | val MSE=0.298737 | best=0.298737
epoch 7 | val MSE=0.285698 | best=0.285698
epoch 8 | val MSE=0.279792 | best=0.279792
epoch 9 | val MSE=0.278115 | best=0.278115
[MLP-MATCH] epoch 10 | train MSE=0.282475 | val MSE=0.276568
epoch 10 | val MSE=0.276568 | best=0.276568
epoch 11 | val MSE=0.274124 | best=0.274124
epoch 12 | val MSE=0.274269 | best=0.274124
epoch 13 | val MSE=0.272034 | best=0.272034
epoch 14 | val MSE=0.269454 | best=0.269454
[MLP-MATCH] epoch 15 | train MSE=0.270758 | val MSE=0.267036
epoch 15 | val MSE=0.267036 | best=0.267036
epoch 16 | val MSE=0.266829 | best=0.266829
epoch 17 | val MSE=0.264764 | best=0.264764
epoch 18 | val MS

epoch 146 | val MSE=0.003997 | best=0.003997
epoch 147 | val MSE=0.004207 | best=0.003997
epoch 148 | val MSE=0.004118 | best=0.003997
epoch 149 | val MSE=0.004038 | best=0.003997
[MLP-MATCH] epoch 150 | train MSE=0.004305 | val MSE=0.004101
epoch 150 | val MSE=0.004101 | best=0.003997
epoch 151 | val MSE=0.004141 | best=0.003997
epoch 152 | val MSE=0.004751 | best=0.003997
epoch 153 | val MSE=0.004192 | best=0.003997
epoch 154 | val MSE=0.003841 | best=0.003841
[MLP-MATCH] epoch 155 | train MSE=0.004203 | val MSE=0.003877
epoch 155 | val MSE=0.003877 | best=0.003841
epoch 156 | val MSE=0.003790 | best=0.003790
epoch 157 | val MSE=0.004164 | best=0.003790
epoch 158 | val MSE=0.003715 | best=0.003715
epoch 159 | val MSE=0.003962 | best=0.003715
[MLP-MATCH] epoch 160 | train MSE=0.003860 | val MSE=0.003721
epoch 160 | val MSE=0.003721 | best=0.003715
epoch 161 | val MSE=0.004084 | best=0.003715
epoch 162 | val MSE=0.003668 | best=0.003668
epoch 163 | val MSE=0.003676 | best=0.003668
epoc

[MLP-MATCH] epoch 290 | train MSE=0.003152 | val MSE=0.003111
epoch 290 | val MSE=0.003111 | best=0.003091
epoch 291 | val MSE=0.003068 | best=0.003068
epoch 292 | val MSE=0.003186 | best=0.003068
epoch 293 | val MSE=0.003264 | best=0.003068
epoch 294 | val MSE=0.003197 | best=0.003068
[MLP-MATCH] epoch 295 | train MSE=0.003144 | val MSE=0.003100
epoch 295 | val MSE=0.003100 | best=0.003068
epoch 296 | val MSE=0.003214 | best=0.003068
epoch 297 | val MSE=0.003128 | best=0.003068
epoch 298 | val MSE=0.003077 | best=0.003068
epoch 299 | val MSE=0.003113 | best=0.003068
[MLP-MATCH] epoch 300 | train MSE=0.003165 | val MSE=0.003199
epoch 300 | val MSE=0.003199 | best=0.003068
epoch 301 | val MSE=0.003104 | best=0.003068
epoch 302 | val MSE=0.003130 | best=0.003068
epoch 303 | val MSE=0.003079 | best=0.003068
epoch 304 | val MSE=0.003115 | best=0.003068
[MLP-MATCH] epoch 305 | train MSE=0.003114 | val MSE=0.003077
epoch 305 | val MSE=0.003077 | best=0.003068
epoch 306 | val MSE=0.003130 | b

epoch 433 | val MSE=0.003031 | best=0.003030
epoch 434 | val MSE=0.003031 | best=0.003030
[MLP-MATCH] epoch 435 | train MSE=0.003047 | val MSE=0.003031
epoch 435 | val MSE=0.003031 | best=0.003030
epoch 436 | val MSE=0.003031 | best=0.003030
epoch 437 | val MSE=0.003031 | best=0.003030
epoch 438 | val MSE=0.003031 | best=0.003030
epoch 439 | val MSE=0.003031 | best=0.003030
[MLP-MATCH] epoch 440 | train MSE=0.003047 | val MSE=0.003031
epoch 440 | val MSE=0.003031 | best=0.003030
epoch 441 | val MSE=0.003031 | best=0.003030
epoch 442 | val MSE=0.003031 | best=0.003030
epoch 443 | val MSE=0.003031 | best=0.003030
epoch 444 | val MSE=0.003030 | best=0.003030
[MLP-MATCH] epoch 445 | train MSE=0.003047 | val MSE=0.003031
epoch 445 | val MSE=0.003031 | best=0.003030
epoch 446 | val MSE=0.003030 | best=0.003030
epoch 447 | val MSE=0.003031 | best=0.003030
epoch 448 | val MSE=0.003031 | best=0.003030
epoch 449 | val MSE=0.003031 | best=0.003030
[MLP-MATCH] epoch 450 | train MSE=0.003047 | val 

In [68]:
import torch

@torch.no_grad()
def find_span(x, knots, n_basis, degree):
    # span s.t. t[span] <= x < t[span+1]
    span = torch.bucketize(x, knots, right=True) - 1
    return span.clamp(min=degree, max=n_basis - 1)


def bspline_basis_funs_sparse(x, knots, n_basis: int, degree: int):
    """
    Vectorized computation of the (degree+1) non-zero B-spline basis functions at x.
    Returns:
      span: [B, D] spans
      N:    [B, D, degree+1] basis values for indices (span-degree ... span)
    x: [B, D]
    """
    device = x.device
    B, D = x.shape
    p = degree

    span = find_span(x, knots, n_basis, p)  # [B, D]

    # Algorithm A2.2 (The NURBS Book) vectorized
    N = torch.zeros((B, D, p + 1), device=device, dtype=x.dtype)
    left = torch.zeros((B, D, p + 1), device=device, dtype=x.dtype)
    right = torch.zeros((B, D, p + 1), device=device, dtype=x.dtype)

    N[..., 0] = 1.0

    # We need knot values t[span + j] and t[span + 1 - j] etc.
    # We'll gather knots by index.
    for j in range(1, p + 1):
        # left[j] = x - t[span + 1 - j]
        idx_left = (span + 1 - j).clamp(min=0, max=knots.numel() - 1)
        t_left = knots[idx_left]  # [B, D]
        left[..., j] = x - t_left

        # right[j] = t[span + j] - x
        idx_right = (span + j).clamp(min=0, max=knots.numel() - 1)
        t_right = knots[idx_right]  # [B, D]
        right[..., j] = t_right - x

        saved = torch.zeros((B, D), device=device, dtype=x.dtype)
        for r in range(0, j):
            denom = right[..., r + 1] + left[..., j - r]
            # safe divide
            denom = torch.where(denom == 0, torch.ones_like(denom), denom)
            temp = N[..., r] / denom
            N[..., r] = saved + right[..., r + 1] * temp
            saved = left[..., j - r] * temp
        N[..., j] = saved

    return span, N


In [69]:
import torch
import torch.nn as nn

class KANSplineLayerFast(nn.Module):
    def __init__(self, in_dim: int, out_dim: int,
                 n_basis: int = 24, degree: int = 3,
                 xmin: float = -1.0, xmax: float = 1.0):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_basis = n_basis
        self.degree = degree
        self.xmin = float(xmin)
        self.xmax = float(xmax)

        # [out, in, K]
        self.coeff = nn.Parameter(0.01 * torch.randn(out_dim, in_dim, n_basis))
        self.bias = nn.Parameter(torch.zeros(out_dim))
        self.lin_weight = nn.Parameter(torch.empty(out_dim, in_dim))
        nn.init.kaiming_uniform_(self.lin_weight, a=math.sqrt(5))  # init standard


        # Open-uniform knots (buffer)
        knots = self._make_open_uniform_knots(n_basis, degree, self.xmin, self.xmax)
        self.register_buffer("knots", knots)

    @staticmethod
    def _make_open_uniform_knots(n_basis, degree, xmin, xmax):
        device = torch.device("cpu")
        n_knots = n_basis + degree + 1
        n_inner = n_knots - 2 * (degree + 1)
        if n_inner > 0:
            inner = torch.linspace(xmin, xmax, steps=n_inner + 2, device=device)[1:-1]
            knots = torch.cat([
                torch.full((degree + 1,), xmin, device=device),
                inner,
                torch.full((degree + 1,), xmax, device=device),
            ])
        else:
            knots = torch.cat([
                torch.full((degree + 1,), xmin, device=device),
                torch.full((degree + 1,), xmax, device=device),
            ])
        return knots

    def forward(self, x: torch.Tensor):
        eps = 1e-7
        x = torch.clamp(x, self.xmin, self.xmax - eps)
        B, D = x.shape
        p = self.degree

        with torch.no_grad():
            span, N = bspline_basis_funs_sparse(x, self.knots, self.n_basis, p)
            k = torch.arange(p + 1, device=x.device).view(1, 1, p + 1)
            idx = (span.unsqueeze(-1) - p + k).clamp(0, self.n_basis - 1)

        coeff_exp = self.coeff.unsqueeze(0).repeat(B, 1, 1, 1).contiguous()  # [B,out,in,K]
        idx_exp = idx.unsqueeze(1).expand(B, self.out_dim, self.in_dim, p + 1)
        gathered = torch.gather(coeff_exp, dim=3, index=idx_exp)              # [B,out,in,p+1]

        out_spline = (gathered * N.unsqueeze(1)).sum(dim=(2, 3))              # [B,out]
        out_lin = torch.einsum("bi,oi->bo", x, self.lin_weight)               # [B,out]

        return out_lin + out_spline + self.bias


    
    
class KANFast(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_basis=24, degree=3, xmin=-1.0, xmax=1.0):
        super().__init__()
        self.l1 = KANSplineLayerFast(in_dim, hidden_dim, n_basis, degree, xmin, xmax)
        self.l2 = KANSplineLayerFast(hidden_dim, out_dim, n_basis, degree, xmin, xmax)

    def forward(self, x):
        h = self.l1(x)
        return self.l2(h)


    def spline_smoothness_penalty(self):
        penalty = 0.0
        for layer in (self.l1, self.l2):
            c = layer.coeff
            d2 = c[..., 2:] - 2.0 * c[..., 1:-1] + c[..., :-2]
            penalty = penalty + (d2 ** 2).mean()
        return penalty



In [81]:
import os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


def train_kan_fast(
    device="cuda" if torch.cuda.is_available() else "cpu",
    n=30000,
    in_dim=6,
    hidden_dim=32,
    out_dim=1,
    n_basis=24,
    degree=3,
    xmin=-1.0,
    xmax=1.0,
    noise=0.05,
    seed=42,
    epochs=30,
    batch_size=4096,
    lr=8e-4,
    weight_decay=1e-4,
    lambda_smooth=1e-2,
    grad_clip=1.0,
    use_amp=True,
    num_workers=4,
    pin_memory=True,
    
    scheduler_type="step",     # "step" ou "plateau" ou None
    step_size=80,
    gamma=0.5,
    plateau_factor=0.5,
    plateau_patience=10,
    plateau_min_lr=1e-6,
    best_path="best_kan_fast.pt",
):
    """
    Entraîne KANFast (doit être défini dans ton fichier) + régularisation de lissage.
    Suppose que make_dataset(...) existe déjà et renvoie (X, y) sur device.
    Sauvegarde le meilleur modèle (val MSE) dans best_path.
    Retourne le modèle (chargé avec le meilleur state_dict si best_path existe).
    """
    assert "KANFast" in globals(), "KANFast n'est pas défini dans ce notebook/fichier."
    assert "make_dataset" in globals(), "make_dataset n'est pas défini dans ce notebook/fichier."

    # Data
    X, y = make_dataset(n=n, in_dim=in_dim, noise=noise, seed=seed, device=device)
    n_train = int(0.85 * X.shape[0])
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    # Note: tes tensors sont déjà sur GPU (device), donc num_workers/pin_memory ne changent pas grand-chose.
    # Si tu veux profiter pleinement du DataLoader, génère le dataset sur CPU puis transfère dans la boucle.
    train_loader = DataLoader(
        TensorDataset(X_train, y_train),
        batch_size=batch_size,
        shuffle=True,
        num_workers=0 if device.startswith("cuda") else num_workers,
        pin_memory=False if device.startswith("cuda") else pin_memory,
        drop_last=False,
    )
    val_loader = DataLoader(
        TensorDataset(X_val, y_val),
        batch_size=max(batch_size, 4096),
        shuffle=False,
        num_workers=0 if device.startswith("cuda") else num_workers,
        pin_memory=False if device.startswith("cuda") else pin_memory,
        drop_last=False,
    )

    # Model
    model = KANFast(
        in_dim=in_dim,
        hidden_dim=hidden_dim,
        out_dim=out_dim,
        n_basis=n_basis,
        degree=degree,
        xmin=xmin,
        xmax=xmax,
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = None
    if scheduler_type == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
    elif scheduler_type == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, mode="min", factor=plateau_factor, patience=plateau_patience,
            min_lr=plateau_min_lr, verbose=True
        )
    elif scheduler_type is None or scheduler_type == "none":
        scheduler = None
    else:
        raise ValueError(f"Unknown scheduler_type={scheduler_type}")

    scaler = torch.amp.GradScaler("cuda", enabled=(use_amp and device.startswith("cuda")))


    @torch.no_grad()
    def eval_mse():
        model.eval()
        mse_sum, n_el = 0.0, 0
        for xb, yb in val_loader:
            pred = model(xb)
            mse_sum += F.mse_loss(pred, yb, reduction="sum").item()
            n_el += yb.numel()
        return mse_sum / n_el

    best_val = float("inf")

    for epoch in range(1, epochs + 1):
        model.train()
        mse_sum, n_seen = 0.0, 0

        for xb, yb in train_loader:
            opt.zero_grad(set_to_none=True)

            with torch.amp.autocast("cuda", enabled=(use_amp and device.startswith("cuda"))):
                pred = model(xb)
                mse = F.mse_loss(pred, yb)
                smooth = model.spline_smoothness_penalty()
                loss = mse + lambda_smooth * smooth

            scaler.scale(loss).backward()

            if grad_clip is not None and grad_clip > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            scaler.step(opt)
            scaler.update()

            mse_sum += mse.item() * yb.shape[0]
            n_seen += yb.shape[0]

        train_mse = mse_sum / n_seen
        val_mse = eval_mse()
        
        smooth_val = float(model.spline_smoothness_penalty().detach().cpu())
        if scheduler is not None:
            if scheduler_type == "plateau":
                scheduler.step(val_mse)
            else:  # "step"
                scheduler.step()

        
        if epoch % 5 == 0 or epoch == 1:
            current_lr = opt.param_groups[0]["lr"]
            print(f"... | lr={current_lr:.2e}")

        if val_mse < best_val:
            best_val = val_mse
            torch.save(model.state_dict(), best_path)

        if epoch % 5 == 0 or epoch == 1 or epoch == epochs:
            print(f"[KANFast] epoch {epoch:03d} | train MSE={train_mse:.6f} | val MSE={val_mse:.6f} | best={best_val:.6f} | smooth={smooth_val:.6f}")

    # reload best
    if os.path.exists(best_path):
        model.load_state_dict(torch.load(best_path, map_location=device))
    model.eval()
    return model


# ---------- LANCEMENT (avec timing) ----------
def run_kanfast_benchmark():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device.startswith("cuda"):
        torch.cuda.synchronize()
    t0 = time.perf_counter()

    kan_fast = train_kan_fast(
        device=device,
        epochs=30,
        batch_size=8192,     # tu peux tester 4096 / 8192 / 16384 selon VRAM
        lambda_smooth=1e-2,
        use_amp=True,
        best_path="best_kan_fast.pt",
    )

    if device.startswith("cuda"):
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    n_params = sum(p.numel() for p in kan_fast.parameters() if p.requires_grad)
    print(f"KANFast train time: {t1 - t0:.2f}s | params: {n_params}")
    return kan_fast


# Exemple d'appel:
# kan_fast = run_kanfast_benchmark()


In [82]:
kan_fast = train_kan_fast(
    hidden_dim=64,
    batch_size=8192,
    use_amp=True,
    lr=1e-3,
    lambda_smooth=1e-6,
    epochs=300,
    scheduler_type="step",
    step_size=80,
    gamma=0.5,
)


... | lr=1.00e-03
[KANFast] epoch 001 | train MSE=1.990232 | val MSE=1.394223 | best=1.394223 | smooth=0.001190
... | lr=1.00e-03
[KANFast] epoch 005 | train MSE=0.470598 | val MSE=0.424796 | best=0.424796 | smooth=0.001194
... | lr=1.00e-03
[KANFast] epoch 010 | train MSE=0.212119 | val MSE=0.191644 | best=0.191644 | smooth=0.001219
... | lr=1.00e-03
[KANFast] epoch 015 | train MSE=0.106663 | val MSE=0.095650 | best=0.095650 | smooth=0.001256
... | lr=1.00e-03
[KANFast] epoch 020 | train MSE=0.052422 | val MSE=0.047749 | best=0.047749 | smooth=0.001281
... | lr=1.00e-03
[KANFast] epoch 025 | train MSE=0.029109 | val MSE=0.027610 | best=0.027610 | smooth=0.001282
... | lr=1.00e-03
[KANFast] epoch 030 | train MSE=0.019510 | val MSE=0.018973 | best=0.018973 | smooth=0.001277
... | lr=1.00e-03
[KANFast] epoch 035 | train MSE=0.014049 | val MSE=0.013891 | best=0.013891 | smooth=0.001278
... | lr=1.00e-03
[KANFast] epoch 040 | train MSE=0.010790 | val MSE=0.010881 | best=0.010881 | smooth=0

In [79]:
kan_fast = train_kan_fast(
    batch_size=8192,
    use_amp=False,
    lr=1e-3,
    lambda_smooth=0.0,
    hidden_dim=64,
    epochs=120
)


... | lr=1.00e-03
[KANFast] epoch 001 | train MSE=1.990242 | val MSE=1.394222 | best=1.394222 | smooth=0.001190
... | lr=1.00e-03
[KANFast] epoch 005 | train MSE=0.470602 | val MSE=0.424794 | best=0.424794 | smooth=0.001194
... | lr=1.00e-03
[KANFast] epoch 010 | train MSE=0.212114 | val MSE=0.191641 | best=0.191641 | smooth=0.001220
... | lr=1.00e-03
[KANFast] epoch 015 | train MSE=0.106657 | val MSE=0.095646 | best=0.095646 | smooth=0.001257
... | lr=1.00e-03
[KANFast] epoch 020 | train MSE=0.052417 | val MSE=0.047745 | best=0.047745 | smooth=0.001283
... | lr=1.00e-03
[KANFast] epoch 025 | train MSE=0.029106 | val MSE=0.027608 | best=0.027608 | smooth=0.001284
... | lr=1.00e-03
[KANFast] epoch 030 | train MSE=0.019507 | val MSE=0.018971 | best=0.018971 | smooth=0.001280
... | lr=1.00e-03
[KANFast] epoch 035 | train MSE=0.014047 | val MSE=0.013890 | best=0.013890 | smooth=0.001282
... | lr=1.00e-03
[KANFast] epoch 040 | train MSE=0.010788 | val MSE=0.010880 | best=0.010880 | smooth=0

In [None]:
kan = KAN(...)
kan.load_state_dict(torch.load("kan_model.pt"))

mlp = FlexibleMLP(...)
mlp.load_state_dict(torch.load("mlp_model.pt"))
