<a href="https://colab.research.google.com/github/burner129/mbfkan/blob/main/mbfkan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title KAN install and import for google colab


import sys, os, subprocess, textwrap

_ = subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "kan", "pykan"], stdout=subprocess.PIPE)


print("Installing pykan from GitHub...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "git+https://github.com/KindXiaoming/pykan.git"])

try:
    import kan
    print("Imported 'kan' from pip:", getattr(kan, "__file__", "<no file>"))
except Exception as e:
    print("Pip install import failed:", e)
    print("Falling back to source checkout + sys.path import...")

    repo_dir = "/content/pykan"
    if not os.path.isdir(repo_dir):
        subprocess.check_call(["git", "clone", "https://github.com/KindXiaoming/pykan.git", repo_dir])

    if repo_dir not in sys.path:
        sys.path.append(repo_dir)

    import kan
    print("Imported 'kan' from source path:", getattr(kan, "__file__", "<no file>"))

from kan.spline import B_batch
import torch
x = torch.rand(10, 2)
grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11)
y = B_batch(x, grid, k=3)
print("B_batch OK, shape:", tuple(y.shape))


Installing pykan from GitHub...
Imported 'kan' from pip: /usr/local/lib/python3.12/dist-packages/kan/__init__.py
B_batch OK, shape: (10, 2, 7)


In [None]:
!pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Downloading torchviz-0.0.3-py3-none-any.whl (5.7 kB)
Installing collected packages: torchviz
Successfully installed torchviz-0.0.3


In [None]:
#@title Import statements
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from typing import *
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
import kan
from kan.spline import *
from kan.utils import sparse_mask
from kan.Symbolic_KANLayer import Symbolic_KANLayer
import random
import torch.nn.functional as F



B-spline KAN layer implementation taken from https://github.com/Blealtan/efficient-kan/tree/master


In [None]:
#@title B-spline KAN layer
class SPLKANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(SPLKANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


In [None]:
#@title SPLKAN class
class SPLKAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(SPLKAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                SPLKANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

SineKAN Layer taken from

In [None]:
#@title SineKANLayer
def forward_step(i_n, grid_size, A, K, C):
    ratio = A * grid_size**(-K) + C
    i_n1 = ratio * i_n
    return i_n1

class SineKANLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, device='cuda', grid_size=5, is_first=False, add_bias=True, norm_freq=True):
        super(SineKANLayer,self).__init__()
        self.grid_size = grid_size
        self.device = device
        self.is_first = is_first
        self.add_bias = add_bias
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.A, self.K, self.C = 0.9724108095811765, 0.9884401790754128, 0.999449553483052

        self.grid_norm_factor = (torch.arange(grid_size) + 1)
        self.grid_norm_factor = self.grid_norm_factor.reshape(1, 1, grid_size)

        if is_first:
            self.amplitudes = torch.nn.Parameter(torch.empty(output_dim, input_dim, 1).normal_(0, .4) / output_dim  / self.grid_norm_factor)
        else:
            self.amplitudes = torch.nn.Parameter(torch.empty(output_dim, input_dim, 1).uniform_(-1, 1) / output_dim  / self.grid_norm_factor)

        grid_phase = torch.arange(1, grid_size + 1).reshape(1, 1, 1, grid_size) / (grid_size + 1)
        self.input_phase = torch.linspace(0, math.pi, input_dim).reshape(1, 1, input_dim, 1).to(device)
        phase = grid_phase.to(device) + self.input_phase

        if norm_freq:
            self.freq = torch.nn.Parameter(torch.arange(1, grid_size + 1).float().reshape(1, 1, 1, grid_size) / (grid_size + 1)**(1 - is_first))
        else:
            self.freq = torch.nn.Parameter(torch.arange(1, grid_size + 1).float().reshape(1, 1, 1, grid_size))

        for i in range(1, self.grid_size):
            phase = forward_step(phase, i, self.A, self.K, self.C)
        # self.phase = torch.nn.Parameter(phase)
        self.register_buffer('phase', phase)

        if self.add_bias:
            self.bias  = torch.nn.Parameter(torch.ones(1, output_dim) / output_dim)

    def forward(self, x):
        x_shape = x.shape
        output_shape = x_shape[0:-1] + (self.output_dim,)
        x = torch.reshape(x, (-1, self.input_dim))
        x_reshaped = torch.reshape(x, (x.shape[0], 1, x.shape[1], 1))
        s = torch.sin(x_reshaped * self.freq + self.phase)
        y = torch.einsum('ijkl,jkl->ij', s, self.amplitudes)
        if self.add_bias:
            y += self.bias
        y = torch.reshape(y, output_shape)
        return y

WavKAN layer from

In [None]:
#@title WavKANLayer
class WavKANLinear(nn.Module):
    def __init__(self, in_features, out_features, wavelet_type='dog'):
        super(WavKANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.wavelet_type = wavelet_type

        self.scale = nn.Parameter(torch.ones(out_features, in_features))
        self.translation = nn.Parameter(torch.zeros(out_features, in_features))


        self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))
        self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))

        nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))

        # Base activation function #not used for this experiment
        self.base_activation = nn.SiLU()

        # Batch normalization
        #self.bn = nn.BatchNorm1d(out_features)

    def wavelet_transform(self, x):
        if x.dim() == 2:
            x_expanded = x.unsqueeze(1)
        else:
            x_expanded = x

        translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
        scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
        x_scaled = (x_expanded - translation_expanded) / scale_expanded

        if self.wavelet_type == 'mexican_hat':
            term1 = ((x_scaled ** 2)-1)
            term2 = torch.exp(-0.5 * x_scaled ** 2)
            wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        elif self.wavelet_type == 'morlet':
            omega0 = 5.0  # Central frequency
            real = torch.cos(omega0 * x_scaled)
            envelope = torch.exp(-0.5 * x_scaled ** 2)
            wavelet = envelope * real
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)

        elif self.wavelet_type == 'dog':
            dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)
            wavelet = dog
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        elif self.wavelet_type == 'meyer':
            v = torch.abs(x_scaled)
            pi = math.pi

            def meyer_aux(v):
                return torch.where(v <= 1/2,torch.ones_like(v),torch.where(v >= 1,torch.zeros_like(v),torch.cos(pi / 2 * nu(2 * v - 1))))

            def nu(t):
                return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)
            wavelet = torch.sin(pi * v) * meyer_aux(v)
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        elif self.wavelet_type == 'shannon':
            pi = math.pi
            sinc = torch.sinc(x_scaled / pi)  # sinc(x) = sin(pi*x) / (pi*x)

            window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)
            wavelet = sinc * window
            wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
            wavelet_output = wavelet_weighted.sum(dim=2)
        else:
            raise ValueError("Unsupported wavelet type")

        return wavelet_output

    def forward(self, x):
        wavelet_output = self.wavelet_transform(x)

        #wav_output = F.linear(wavelet_output, self.weight)
        #base_output = F.linear(self.base_activation(x), self.weight1)

        base_output = F.linear(x, self.weight1)
        combined_output =  wavelet_output #+ base_output

        # Apply batch normalization
        return combined_output


In [None]:
#@title SineKAN model class
class SineKAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden: List[int],
        grid_size: int = 8,
        device: str = 'cuda',
    ) -> None:
        super().__init__()

        self.layers = torch.nn.ModuleList([
            SineKANLayer(
                in_dim, out_dim, device, grid_size=grid_size, is_first=True
            ) if i == 0 else SineKANLayer(
                in_dim, out_dim, device, grid_size=grid_size,
            ) for i, (in_dim, out_dim) in enumerate(zip(layers_hidden[:-1], layers_hidden[1:]))
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
#@title WavKAN model class
class WavKAN(nn.Module):
    def __init__(self, layers_hidden, wavelet_type='dog'):
        super(WavKAN, self).__init__()
        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden[:-1], layers_hidden[1:]):
            self.layers.append(WavKANLinear(in_features, out_features, wavelet_type))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:

#@title WavKAN model class with LayerNorm added
class WavKAN(nn.Module):
    def __init__(self, layers_hidden, wavelet_type='dog'):
        super(WavKAN, self).__init__()
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        for in_features, out_features in zip(layers_hidden[:-1], layers_hidden[1:]):
            self.layers.append(WavKANLinear(in_features, out_features, wavelet_type))
            self.norms.append(nn.LayerNorm(out_features))  # Add LayerNorm after each linear layer

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x)
            x = norm(x)
        return x


In [None]:
#@title ChebyKanLayer

class ChebyKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(ChebyKANLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.degree = degree

        self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
        self.register_buffer("arange", torch.arange(0, degree + 1, 1))

    def forward(self, x):

        x = torch.tanh(x)
        x = x.view((-1, self.inputdim, 1)).expand(
            -1, -1, self.degree + 1
        )
        x = x.acos()
        x *= self.arange
        x = x.cos()

        y = torch.einsum(
            "bid,iod->bo", x, self.cheby_coeffs
        )
        y = y.view(-1, self.outdim)
        return y

In [None]:
#@title ChebyKAN model class
class MNISTChebyKAN(nn.Module):
    def __init__(self, dims, degree=4):
        """
        dims: list of layer sizes, e.g. [784, 32, 32, 32, 10]
        degree: polynomial degree for ChebyKAN layers
        """
        super(MNISTChebyKAN, self).__init__()
        layers = []

        for i in range(len(dims) - 1):
            layers.append(ChebyKANLayer(dims[i], dims[i+1], degree))
            if i < len(dims) - 2:
                layers.append(nn.LayerNorm(dims[i+1]))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.net(x)




MNISTChebyKAN(
  (net): Sequential(
    (0): ChebyKANLayer()
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): ChebyKANLayer()
    (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (4): ChebyKANLayer()
    (5): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (6): ChebyKANLayer()
  )
)


In [None]:
#@title HybridKAN
class HybridKAN(nn.Module):
    def __init__(self, input_dim, layer_specs, cheby_degree=4, wavelet_type="morlet"):
        """
        input_dim: int, number of input features (e.g. 28*28 for MNIST)
        layer_specs: list of (width, type) tuples, e.g. [(32, 'wav'), (32, 'cheby'), (10, 'linear')]
        cheby_degree: degree of Chebyshev polynomials (default=4)
        wavelet_type: which wavelet to use for WavKANLinear
        """
        super(HybridKAN, self).__init__()
        layers = []
        prev_dim = input_dim

        for width, ltype in layer_specs:
            if ltype == "cheby":
                layers.append(ChebyKANLayer(prev_dim, width, degree=cheby_degree))
            elif ltype == "sine":
                layers.append(SineKANLayer(prev_dim, width, grid_size=100))
            elif ltype == "wav":
                layers.append(WavKANLinear(prev_dim, width, wavelet_type=wavelet_type))
            elif ltype == "linear":
                layers.append(nn.Linear(prev_dim, width))
            else:
                raise ValueError(f"Unknown layer type: {ltype}")

            if ltype in ["cheby", "sine", "wav"]:
                layers.append(nn.LayerNorm(width))

            prev_dim = width

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)


In [None]:
#@title Vertical MBFKANLayer
class MBFKANLayer(nn.Module):
    def __init__(self, in_dim, out_dim, basis_cycle=("cheby", "sine", "wav"), degree=4):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.degree = degree


        edge_types = []
        funcs = list(basis_cycle)
        for j in range(out_dim):
            for i in range(in_dim):
                edge_types.append(funcs[(i + j) % len(funcs)])
        edge_types = torch.tensor([{"cheby":0,"sine":1,"wav":2}[t] for t in edge_types]).reshape(out_dim, in_dim)
        self.register_buffer("edge_types", edge_types)

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))

    def forward(self, x):
        """
        x: [batch, in_dim]
        returns: [batch, out_dim]
        """
        x_exp = x[:, None, :].expand(-1, self.out_dim, -1)

        mask_cheby = (self.edge_types == 0).float()
        mask_sine  = (self.edge_types == 1).float()
        mask_wav   = (self.edge_types == 2).float()

        if self.degree > 1:
            T0 = torch.ones_like(x_exp)
            T1 = x_exp
            cheby_terms = [T0, T1]
            for k in range(2, self.degree + 1):
                Tk = 2 * x_exp * cheby_terms[-1] - cheby_terms[-2]
                cheby_terms.append(Tk)
            cheby_basis = torch.stack(cheby_terms, dim=-1).sum(dim=-1)
        else:
            cheby_basis = x_exp

        sine_basis = torch.sin(x_exp)
        wav_basis = torch.sin(5 * x_exp) * torch.exp(-x_exp**2)

        out = (
            cheby_basis * mask_cheby[None, :, :] +
            sine_basis  * mask_sine[None, :, :] +
            wav_basis   * mask_wav[None, :, :]
        )


        out = out * self.weight[None, :, :]
        out = out.sum(dim=-1)
        return out



In [None]:
#@title Vertical MBFKAN class
class MBFKAN(nn.Module):
    def __init__(self, dims, degree=4):
        """
        dims: list like [784, 32, 32, 10]
        """
        super().__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(MBFKANLayer(dims[i], dims[i+1], degree=degree))
            layers.append(nn.LayerNorm(dims[i+1]))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


In [None]:
models = []

In [None]:
#@title Train HybridKAN function (train_HybridKAN)
def train_HybridKAN(dimensions, lr):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = HybridKAN(input_dim=28*28, layer_specs=dimensions).to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params}")

    criterion = nn.CrossEntropyLoss()


    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


    def train_epoch(model, loader, optimizer, criterion, device):
        model.train()

        running_loss, running_acc = 0, 0
        for x, y in tqdm(loader, leave=False):
            x, y = x.view(-1, 28*28).to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (out.argmax(dim=1) == y).float().mean().item()
        return running_loss/len(loader), running_acc/len(loader)


    def test_epoch(model, loader, criterion, device):
        model.eval()
        loss, acc = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.view(-1, 28*28).to(device), y.to(device)
                out = model(x)
                loss += criterion(out, y).item()
                acc += (out.argmax(dim=1) == y).float().mean().item()
        return loss/len(loader), acc/len(loader)


    epochs = 30
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = test_epoch(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.4f} | Val Loss {val_loss:.4f}, Acc {val_acc:.4f}")

    print("✅ Finished training ChebyKAN on MNIST")
    models.append(model)
    x = torch.linspace(-3, 3, 500).unsqueeze(-1)  # shape (500, 1)

In [None]:
#@title train Vertical MBFKAN function (train_MBFKAN)
def train_MBFKAN(dimensions, lr):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MBFKAN(dims=dimensions).to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params}")

    criterion = nn.CrossEntropyLoss()


    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


    def train_epoch(model, loader, optimizer, criterion, device):
        model.train()

        running_loss, running_acc = 0, 0
        for x, y in tqdm(loader, leave=False):
            x, y = x.view(-1, 28*28).to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (out.argmax(dim=1) == y).float().mean().item()
        return running_loss/len(loader), running_acc/len(loader)


    def test_epoch(model, loader, criterion, device):
        model.eval()
        loss, acc = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.view(-1, 28*28).to(device), y.to(device)
                out = model(x)
                loss += criterion(out, y).item()
                acc += (out.argmax(dim=1) == y).float().mean().item()
        return loss/len(loader), acc/len(loader)


    epochs = 150
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = test_epoch(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.4f} | Val Loss {val_loss:.4f}, Acc {val_acc:.4f}")

    print("✅ Finished training ChebyKAN on MNIST")

In [None]:
#@title train_ChebyKAN
def train_ChebyKAN(dimensions, lr):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MNISTChebyKAN(dimensions).to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params}")

    criterion = nn.CrossEntropyLoss()


    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


    def train_epoch(model, loader, optimizer, criterion, device):
        model.train()

        running_loss, running_acc = 0, 0
        for x, y in tqdm(loader, leave=False):
            x, y = x.view(-1, 28*28).to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (out.argmax(dim=1) == y).float().mean().item()
        return running_loss/len(loader), running_acc/len(loader)


    def test_epoch(model, loader, criterion, device):
        model.eval()
        loss, acc = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.view(-1, 28*28).to(device), y.to(device)
                out = model(x)
                loss += criterion(out, y).item()
                acc += (out.argmax(dim=1) == y).float().mean().item()
        return loss/len(loader), acc/len(loader)


    epochs = 30
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = test_epoch(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.4f} | Val Loss {val_loss:.4f}, Acc {val_acc:.4f}")

    print("✅ Finished training ChebyKAN on MNIST")

Train and Benchmark SineKAN

In [None]:
#@title train_SineKAN
def train_SineKAN(dimensions, lr):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    input_dim = 28 * 28
    hidden_dim = 32
    output_dim = 10

    model = SineKAN(dimensions, grid_size=8).to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(total_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


    def train_epoch(model, loader, optimizer, criterion, device):
        model.train()

        running_loss, running_acc = 0, 0
        for x, y in tqdm(loader, leave=False):
            x, y = x.view(-1, 28*28).to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (out.argmax(dim=1) == y).float().mean().item()
        return running_loss/len(loader), running_acc/len(loader)


    def test_epoch(model, loader, criterion, device):
        model.eval()
        loss, acc = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.view(-1, 28*28).to(device), y.to(device)
                out = model(x)
                loss += criterion(out, y).item()
                acc += (out.argmax(dim=1) == y).float().mean().item()
        return loss/len(loader), acc/len(loader)


    epochs = 30
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = test_epoch(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.4f} | Val Loss {val_loss:.4f}, Acc {val_acc:.4f}")

    print("✅ Finished training SineKAN on MNIST")

Train and benchmark Wav-KAN

In [None]:
#@title train_WavKAN
def train_WavKAN(dimensions, lr):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = WavKAN(dimensions, wavelet_type='dog').to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(total_params)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    def train_epoch(model, loader, optimizer, criterion, device):
        model.train()

        running_loss, running_acc = 0, 0
        for x, y in tqdm(loader, leave=False):
            x, y = x.view(-1, 28*28).to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_acc += (out.argmax(dim=1) == y).float().mean().item()
        return running_loss/len(loader), running_acc/len(loader)


    def test_epoch(model, loader, criterion, device):
        model.eval()
        loss, acc = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.view(-1, 28*28).to(device), y.to(device)
                out = model(x)
                loss += criterion(out, y).item()
                acc += (out.argmax(dim=1) == y).float().mean().item()
        return loss/len(loader), acc/len(loader)

    epochs = 30
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = test_epoch(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f}, Acc {train_acc:.4f} | Val Loss {val_loss:.4f}, Acc {val_acc:.4f}")

    print("✅ Finished training WavKAN on MNIST")

Train and Benchmark SPLKAN

In [None]:
hidden_layer_width = 8

In [None]:
train_HybridKAN([(hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
     (hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
    (10, "linear")], 4e-4)

Total trainable parameters: 46318




Epoch 1/30 | Train Loss 2.0511, Acc 0.1902 | Val Loss 1.7513, Acc 0.3176




Epoch 2/30 | Train Loss 1.3897, Acc 0.4970 | Val Loss 1.0601, Acc 0.6096




Epoch 3/30 | Train Loss 0.9617, Acc 0.6648 | Val Loss 0.8682, Acc 0.7250




Epoch 4/30 | Train Loss 0.7958, Acc 0.7392 | Val Loss 0.7392, Acc 0.7758




Epoch 5/30 | Train Loss 0.6760, Acc 0.7871 | Val Loss 0.6250, Acc 0.8027




Epoch 6/30 | Train Loss 0.5943, Acc 0.8183 | Val Loss 0.5888, Acc 0.8193




Epoch 7/30 | Train Loss 0.5298, Acc 0.8434 | Val Loss 0.5015, Acc 0.8513




Epoch 8/30 | Train Loss 0.4810, Acc 0.8609 | Val Loss 0.4582, Acc 0.8663




Epoch 9/30 | Train Loss 0.4415, Acc 0.8735 | Val Loss 0.4428, Acc 0.8728




Epoch 10/30 | Train Loss 0.4139, Acc 0.8816 | Val Loss 0.4275, Acc 0.8741




Epoch 11/30 | Train Loss 0.3941, Acc 0.8867 | Val Loss 0.4531, Acc 0.8683




Epoch 12/30 | Train Loss 0.3816, Acc 0.8918 | Val Loss 0.4108, Acc 0.8835




Epoch 13/30 | Train Loss 0.3664, Acc 0.8956 | Val Loss 0.4157, Acc 0.8796




Epoch 14/30 | Train Loss 0.3546, Acc 0.8995 | Val Loss 0.4159, Acc 0.8764




Epoch 15/30 | Train Loss 0.3478, Acc 0.9004 | Val Loss 0.3870, Acc 0.8874




Epoch 16/30 | Train Loss 0.3398, Acc 0.9031 | Val Loss 0.4085, Acc 0.8827




Epoch 17/30 | Train Loss 0.3329, Acc 0.9060 | Val Loss 0.3944, Acc 0.8847




Epoch 18/30 | Train Loss 0.3260, Acc 0.9074 | Val Loss 0.3866, Acc 0.8871




Epoch 19/30 | Train Loss 0.3217, Acc 0.9085 | Val Loss 0.3866, Acc 0.8881




Epoch 20/30 | Train Loss 0.3164, Acc 0.9102 | Val Loss 0.3838, Acc 0.8893




Epoch 21/30 | Train Loss 0.3119, Acc 0.9123 | Val Loss 0.3897, Acc 0.8908




Epoch 22/30 | Train Loss 0.3081, Acc 0.9132 | Val Loss 0.3909, Acc 0.8889




Epoch 23/30 | Train Loss 0.3052, Acc 0.9131 | Val Loss 0.3784, Acc 0.8906




Epoch 24/30 | Train Loss 0.3017, Acc 0.9141 | Val Loss 0.3784, Acc 0.8932




Epoch 25/30 | Train Loss 0.2992, Acc 0.9153 | Val Loss 0.3787, Acc 0.8922




Epoch 26/30 | Train Loss 0.2970, Acc 0.9166 | Val Loss 0.3772, Acc 0.8927




Epoch 27/30 | Train Loss 0.2944, Acc 0.9168 | Val Loss 0.3728, Acc 0.8943




Epoch 28/30 | Train Loss 0.2918, Acc 0.9176 | Val Loss 0.3772, Acc 0.8928




Epoch 29/30 | Train Loss 0.2906, Acc 0.9177 | Val Loss 0.3765, Acc 0.8936




Epoch 30/30 | Train Loss 0.2888, Acc 0.9189 | Val Loss 0.3799, Acc 0.8919
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "wav"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "cheby"),
    (10, "linear")], 4e-4)

Total trainable parameters: 46318




Epoch 1/30 | Train Loss 2.1899, Acc 0.1600 | Val Loss 1.7451, Acc 0.3505




Epoch 2/30 | Train Loss 1.7254, Acc 0.3721 | Val Loss 1.3665, Acc 0.5386




Epoch 3/30 | Train Loss 1.2376, Acc 0.5765 | Val Loss 1.1927, Acc 0.5990




Epoch 4/30 | Train Loss 1.0651, Acc 0.6408 | Val Loss 0.9547, Acc 0.6710




Epoch 5/30 | Train Loss 0.9326, Acc 0.6918 | Val Loss 0.8570, Acc 0.7085




Epoch 6/30 | Train Loss 0.8180, Acc 0.7343 | Val Loss 0.7585, Acc 0.7582




Epoch 7/30 | Train Loss 0.7596, Acc 0.7562 | Val Loss 0.7547, Acc 0.7557




Epoch 8/30 | Train Loss 0.7087, Acc 0.7722 | Val Loss 0.6704, Acc 0.7861




Epoch 9/30 | Train Loss 0.6732, Acc 0.7832 | Val Loss 0.6474, Acc 0.7914




Epoch 10/30 | Train Loss 0.6458, Acc 0.7963 | Val Loss 0.6487, Acc 0.7955




Epoch 11/30 | Train Loss 0.6229, Acc 0.8058 | Val Loss 0.6279, Acc 0.8081




Epoch 12/30 | Train Loss 0.5991, Acc 0.8149 | Val Loss 0.6330, Acc 0.8012




Epoch 13/30 | Train Loss 0.5868, Acc 0.8205 | Val Loss 0.6158, Acc 0.8124




Epoch 14/30 | Train Loss 0.5690, Acc 0.8270 | Val Loss 0.6256, Acc 0.8135




Epoch 15/30 | Train Loss 0.5549, Acc 0.8320 | Val Loss 0.5616, Acc 0.8318




Epoch 16/30 | Train Loss 0.5405, Acc 0.8381 | Val Loss 0.5462, Acc 0.8367




Epoch 17/30 | Train Loss 0.5312, Acc 0.8403 | Val Loss 0.5427, Acc 0.8363




Epoch 18/30 | Train Loss 0.5206, Acc 0.8448 | Val Loss 0.5329, Acc 0.8410




Epoch 19/30 | Train Loss 0.5076, Acc 0.8499 | Val Loss 0.5663, Acc 0.8307




Epoch 20/30 | Train Loss 0.4999, Acc 0.8507 | Val Loss 0.5146, Acc 0.8481




Epoch 21/30 | Train Loss 0.4901, Acc 0.8552 | Val Loss 0.5355, Acc 0.8408




Epoch 22/30 | Train Loss 0.4828, Acc 0.8573 | Val Loss 0.5134, Acc 0.8473




Epoch 23/30 | Train Loss 0.4782, Acc 0.8586 | Val Loss 0.5087, Acc 0.8487




Epoch 24/30 | Train Loss 0.4714, Acc 0.8610 | Val Loss 0.5028, Acc 0.8540




Epoch 25/30 | Train Loss 0.4669, Acc 0.8626 | Val Loss 0.4988, Acc 0.8534




Epoch 26/30 | Train Loss 0.4616, Acc 0.8651 | Val Loss 0.4989, Acc 0.8549




Epoch 27/30 | Train Loss 0.4588, Acc 0.8656 | Val Loss 0.4986, Acc 0.8550




Epoch 28/30 | Train Loss 0.4551, Acc 0.8657 | Val Loss 0.4929, Acc 0.8588




Epoch 29/30 | Train Loss 0.4508, Acc 0.8682 | Val Loss 0.4926, Acc 0.8563




Epoch 30/30 | Train Loss 0.4487, Acc 0.8681 | Val Loss 0.4967, Acc 0.8557
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "cheby"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "wav"),
    (10, "linear")], 4e-4)

Total trainable parameters: 52526




Epoch 1/30 | Train Loss 2.3334, Acc 0.1035 | Val Loss 2.3039, Acc 0.1136




Epoch 2/30 | Train Loss 2.2629, Acc 0.1287 | Val Loss 1.9167, Acc 0.2908




Epoch 3/30 | Train Loss 2.1894, Acc 0.1651 | Val Loss 2.3014, Acc 0.1136




Epoch 4/30 | Train Loss 2.3022, Acc 0.1106 | Val Loss 2.3010, Acc 0.1136




Epoch 5/30 | Train Loss 1.9691, Acc 0.2640 | Val Loss 1.7092, Acc 0.3870




Epoch 6/30 | Train Loss 1.3598, Acc 0.5040 | Val Loss 1.2635, Acc 0.5030




Epoch 7/30 | Train Loss 1.1787, Acc 0.5602 | Val Loss 1.1281, Acc 0.5702




Epoch 8/30 | Train Loss 1.1391, Acc 0.5976 | Val Loss 1.0663, Acc 0.6753




Epoch 9/30 | Train Loss 0.9242, Acc 0.7165 | Val Loss 0.8329, Acc 0.7635




Epoch 10/30 | Train Loss 0.7329, Acc 0.8005 | Val Loss 0.6451, Acc 0.8277




Epoch 11/30 | Train Loss 0.6528, Acc 0.8205 | Val Loss 0.5916, Acc 0.8413




Epoch 12/30 | Train Loss 0.5992, Acc 0.8342 | Val Loss 0.6012, Acc 0.8329




Epoch 13/30 | Train Loss 0.5631, Acc 0.8457 | Val Loss 0.6052, Acc 0.8227




Epoch 14/30 | Train Loss 0.5318, Acc 0.8547 | Val Loss 0.5820, Acc 0.8415




Epoch 15/30 | Train Loss 0.5038, Acc 0.8634 | Val Loss 0.4970, Acc 0.8688




Epoch 16/30 | Train Loss 0.4880, Acc 0.8676 | Val Loss 0.4993, Acc 0.8709




Epoch 17/30 | Train Loss 0.4740, Acc 0.8708 | Val Loss 0.4688, Acc 0.8748




Epoch 18/30 | Train Loss 0.4493, Acc 0.8795 | Val Loss 0.4763, Acc 0.8737




Epoch 19/30 | Train Loss 0.4433, Acc 0.8813 | Val Loss 0.4687, Acc 0.8735




Epoch 20/30 | Train Loss 0.4397, Acc 0.8810 | Val Loss 0.4402, Acc 0.8813




Epoch 21/30 | Train Loss 0.4254, Acc 0.8853 | Val Loss 0.4603, Acc 0.8729




Epoch 22/30 | Train Loss 0.4364, Acc 0.8819 | Val Loss 0.4323, Acc 0.8827




Epoch 23/30 | Train Loss 0.4128, Acc 0.8896 | Val Loss 0.4212, Acc 0.8875




Epoch 24/30 | Train Loss 0.3991, Acc 0.8930 | Val Loss 0.4286, Acc 0.8867




Epoch 25/30 | Train Loss 0.3960, Acc 0.8936 | Val Loss 0.4320, Acc 0.8833




Epoch 26/30 | Train Loss 0.3901, Acc 0.8952 | Val Loss 0.4258, Acc 0.8862




Epoch 27/30 | Train Loss 0.3990, Acc 0.8921 | Val Loss 0.4415, Acc 0.8841




Epoch 28/30 | Train Loss 0.3934, Acc 0.8941 | Val Loss 0.4136, Acc 0.8892




Epoch 29/30 | Train Loss 0.3993, Acc 0.8920 | Val Loss 0.3991, Acc 0.8939




Epoch 30/30 | Train Loss 0.3799, Acc 0.8981 | Val Loss 0.4150, Acc 0.8886
✅ Finished training ChebyKAN on MNIST


In [None]:
train_SineKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

53490




Epoch 1/30 | Train Loss 2.3024, Acc 0.1101 | Val Loss 2.3021, Acc 0.1136




Epoch 2/30 | Train Loss 2.3021, Acc 0.1111 | Val Loss 2.3022, Acc 0.1136




Epoch 3/30 | Train Loss 2.3020, Acc 0.1116 | Val Loss 2.3013, Acc 0.1136




Epoch 4/30 | Train Loss 2.3020, Acc 0.1098 | Val Loss 2.3022, Acc 0.1136




Epoch 5/30 | Train Loss 2.3019, Acc 0.1102 | Val Loss 2.3019, Acc 0.1136




Epoch 6/30 | Train Loss 2.3017, Acc 0.1101 | Val Loss 2.3019, Acc 0.1024




Epoch 7/30 | Train Loss 2.3018, Acc 0.1122 | Val Loss 2.3022, Acc 0.1136




Epoch 8/30 | Train Loss 2.3017, Acc 0.1123 | Val Loss 2.3012, Acc 0.1136




Epoch 9/30 | Train Loss 2.3017, Acc 0.1123 | Val Loss 2.3012, Acc 0.1136




Epoch 10/30 | Train Loss 2.3015, Acc 0.1119 | Val Loss 2.3013, Acc 0.1136




Epoch 11/30 | Train Loss 2.3016, Acc 0.1124 | Val Loss 2.3014, Acc 0.1136




Epoch 12/30 | Train Loss 2.3016, Acc 0.1124 | Val Loss 2.3013, Acc 0.1136




Epoch 13/30 | Train Loss 2.3015, Acc 0.1124 | Val Loss 2.3012, Acc 0.1136




Epoch 14/30 | Train Loss 2.3015, Acc 0.1124 | Val Loss 2.3014, Acc 0.1136




Epoch 15/30 | Train Loss 2.3015, Acc 0.1124 | Val Loss 2.3013, Acc 0.1136




Epoch 16/30 | Train Loss 2.3014, Acc 0.1123 | Val Loss 2.3012, Acc 0.1136




Epoch 17/30 | Train Loss 2.3014, Acc 0.1123 | Val Loss 2.3012, Acc 0.1136




Epoch 18/30 | Train Loss 2.3014, Acc 0.1124 | Val Loss 2.3013, Acc 0.1136




Epoch 19/30 | Train Loss 2.3014, Acc 0.1124 | Val Loss 2.3012, Acc 0.1136




Epoch 20/30 | Train Loss 2.3014, Acc 0.1124 | Val Loss 2.3012, Acc 0.1136




Epoch 21/30 | Train Loss 2.3014, Acc 0.1123 | Val Loss 2.3011, Acc 0.1136




Epoch 22/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 23/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3010, Acc 0.1136




Epoch 24/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 25/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 26/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 27/30 | Train Loss 2.3013, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 28/30 | Train Loss 2.3012, Acc 0.1124 | Val Loss 2.3010, Acc 0.1136




Epoch 29/30 | Train Loss 2.3012, Acc 0.1124 | Val Loss 2.3011, Acc 0.1136




Epoch 30/30 | Train Loss 2.3012, Acc 0.1124 | Val Loss 2.3010, Acc 0.1136
✅ Finished training SineKAN on MNIST


In [None]:
train_ChebyKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width,hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

Total trainable parameters: 34464




Epoch 1/30 | Train Loss 1.9191, Acc 0.2675 | Val Loss 1.4983, Acc 0.4161




Epoch 2/30 | Train Loss 1.2723, Acc 0.5388 | Val Loss 1.0916, Acc 0.5807




Epoch 3/30 | Train Loss 1.0260, Acc 0.6224 | Val Loss 0.9501, Acc 0.6701




Epoch 4/30 | Train Loss 0.9951, Acc 0.6538 | Val Loss 1.2512, Acc 0.5663




Epoch 5/30 | Train Loss 0.8999, Acc 0.7022 | Val Loss 0.7119, Acc 0.7704




Epoch 6/30 | Train Loss 0.6732, Acc 0.7817 | Val Loss 0.5939, Acc 0.8022




Epoch 7/30 | Train Loss 0.5926, Acc 0.8062 | Val Loss 0.5645, Acc 0.8126




Epoch 8/30 | Train Loss 0.5433, Acc 0.8230 | Val Loss 0.5335, Acc 0.8287




Epoch 9/30 | Train Loss 0.4963, Acc 0.8451 | Val Loss 0.4589, Acc 0.8660




Epoch 10/30 | Train Loss 0.4534, Acc 0.8664 | Val Loss 0.5009, Acc 0.8624




Epoch 11/30 | Train Loss 0.4213, Acc 0.8802 | Val Loss 0.4324, Acc 0.8780




Epoch 12/30 | Train Loss 0.3861, Acc 0.8920 | Val Loss 0.3730, Acc 0.9011




Epoch 13/30 | Train Loss 0.3503, Acc 0.9047 | Val Loss 0.4407, Acc 0.8794




Epoch 14/30 | Train Loss 0.3342, Acc 0.9084 | Val Loss 0.3514, Acc 0.9044




Epoch 15/30 | Train Loss 0.3160, Acc 0.9143 | Val Loss 0.3470, Acc 0.9045




Epoch 16/30 | Train Loss 0.3065, Acc 0.9163 | Val Loss 0.3262, Acc 0.9128




Epoch 17/30 | Train Loss 0.2889, Acc 0.9218 | Val Loss 0.3361, Acc 0.9122




Epoch 18/30 | Train Loss 0.2829, Acc 0.9227 | Val Loss 0.3349, Acc 0.9083




Epoch 19/30 | Train Loss 0.2720, Acc 0.9256 | Val Loss 0.3132, Acc 0.9171




Epoch 20/30 | Train Loss 0.2639, Acc 0.9291 | Val Loss 0.3293, Acc 0.9108




Epoch 21/30 | Train Loss 0.2551, Acc 0.9306 | Val Loss 0.3103, Acc 0.9156




Epoch 22/30 | Train Loss 0.2478, Acc 0.9326 | Val Loss 0.3306, Acc 0.9100




Epoch 23/30 | Train Loss 0.2442, Acc 0.9330 | Val Loss 0.2997, Acc 0.9209




Epoch 24/30 | Train Loss 0.2395, Acc 0.9347 | Val Loss 0.3186, Acc 0.9161




Epoch 25/30 | Train Loss 0.2344, Acc 0.9371 | Val Loss 0.3063, Acc 0.9182




Epoch 26/30 | Train Loss 0.2292, Acc 0.9385 | Val Loss 0.3179, Acc 0.9157




Epoch 27/30 | Train Loss 0.2274, Acc 0.9389 | Val Loss 0.3118, Acc 0.9171




Epoch 28/30 | Train Loss 0.2236, Acc 0.9389 | Val Loss 0.2972, Acc 0.9199




Epoch 29/30 | Train Loss 0.2198, Acc 0.9410 | Val Loss 0.2989, Acc 0.9201




Epoch 30/30 | Train Loss 0.2177, Acc 0.9414 | Val Loss 0.3004, Acc 0.9195
✅ Finished training ChebyKAN on MNIST


In [None]:
train_WavKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width, hidden_layer_width,hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

27620




Epoch 1/30 | Train Loss 1.8806, Acc 0.3140 | Val Loss 1.3708, Acc 0.5620




Epoch 2/30 | Train Loss 1.1129, Acc 0.6560 | Val Loss 0.8749, Acc 0.7524




Epoch 3/30 | Train Loss 0.8188, Acc 0.7637 | Val Loss 0.7267, Acc 0.7872




Epoch 4/30 | Train Loss 0.7104, Acc 0.7953 | Val Loss 0.6753, Acc 0.8045




Epoch 5/30 | Train Loss 0.6600, Acc 0.8108 | Val Loss 0.6771, Acc 0.8041




Epoch 6/30 | Train Loss 0.6234, Acc 0.8222 | Val Loss 0.6239, Acc 0.8242




Epoch 7/30 | Train Loss 0.5862, Acc 0.8353 | Val Loss 0.5975, Acc 0.8304




Epoch 8/30 | Train Loss 0.5587, Acc 0.8445 | Val Loss 0.5622, Acc 0.8437




Epoch 9/30 | Train Loss 0.5330, Acc 0.8528 | Val Loss 0.5326, Acc 0.8572




Epoch 10/30 | Train Loss 0.5181, Acc 0.8573 | Val Loss 0.5189, Acc 0.8596




Epoch 11/30 | Train Loss 0.5019, Acc 0.8618 | Val Loss 0.5074, Acc 0.8613




Epoch 12/30 | Train Loss 0.4884, Acc 0.8651 | Val Loss 0.4889, Acc 0.8676




Epoch 13/30 | Train Loss 0.4720, Acc 0.8710 | Val Loss 0.4864, Acc 0.8698




Epoch 14/30 | Train Loss 0.4618, Acc 0.8742 | Val Loss 0.4825, Acc 0.8679




Epoch 15/30 | Train Loss 0.4513, Acc 0.8773 | Val Loss 0.4723, Acc 0.8733




Epoch 16/30 | Train Loss 0.4455, Acc 0.8790 | Val Loss 0.4671, Acc 0.8716




Epoch 17/30 | Train Loss 0.4375, Acc 0.8821 | Val Loss 0.4658, Acc 0.8739




Epoch 18/30 | Train Loss 0.4300, Acc 0.8836 | Val Loss 0.4557, Acc 0.8759




Epoch 19/30 | Train Loss 0.4249, Acc 0.8845 | Val Loss 0.4571, Acc 0.8733




Epoch 20/30 | Train Loss 0.4193, Acc 0.8872 | Val Loss 0.4547, Acc 0.8786




Epoch 21/30 | Train Loss 0.4157, Acc 0.8878 | Val Loss 0.4461, Acc 0.8771




Epoch 22/30 | Train Loss 0.4098, Acc 0.8891 | Val Loss 0.4436, Acc 0.8787




Epoch 23/30 | Train Loss 0.4063, Acc 0.8902 | Val Loss 0.4376, Acc 0.8836




Epoch 24/30 | Train Loss 0.4024, Acc 0.8917 | Val Loss 0.4327, Acc 0.8847




Epoch 25/30 | Train Loss 0.4002, Acc 0.8923 | Val Loss 0.4363, Acc 0.8828




Epoch 26/30 | Train Loss 0.3972, Acc 0.8928 | Val Loss 0.4332, Acc 0.8824




Epoch 27/30 | Train Loss 0.3952, Acc 0.8934 | Val Loss 0.4320, Acc 0.8831




Epoch 28/30 | Train Loss 0.3929, Acc 0.8946 | Val Loss 0.4304, Acc 0.8841




Epoch 29/30 | Train Loss 0.3915, Acc 0.8949 | Val Loss 0.4291, Acc 0.8847




Epoch 30/30 | Train Loss 0.3887, Acc 0.8958 | Val Loss 0.4283, Acc 0.8867
✅ Finished training WavKAN on MNIST


In [None]:
train_SineKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

In [None]:
train_ChebyKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

100%|██████████| 9.91M/9.91M [00:01<00:00, 6.81MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 161kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.52MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.58MB/s]


Total trainable parameters: 66176




Epoch 1/30 | Train Loss 0.9371, Acc 0.7402 | Val Loss 0.4324, Acc 0.8856




Epoch 2/30 | Train Loss 0.3274, Acc 0.9099 | Val Loss 0.2726, Acc 0.9203




Epoch 3/30 | Train Loss 0.2318, Acc 0.9340 | Val Loss 0.2061, Acc 0.9414




Epoch 4/30 | Train Loss 0.1929, Acc 0.9433 | Val Loss 0.2075, Acc 0.9394




Epoch 5/30 | Train Loss 0.1628, Acc 0.9522 | Val Loss 0.1849, Acc 0.9444




Epoch 6/30 | Train Loss 0.1480, Acc 0.9560 | Val Loss 0.1523, Acc 0.9559




Epoch 7/30 | Train Loss 0.1329, Acc 0.9609 | Val Loss 0.1560, Acc 0.9549




Epoch 8/30 | Train Loss 0.1203, Acc 0.9641 | Val Loss 0.1449, Acc 0.9575




Epoch 9/30 | Train Loss 0.1085, Acc 0.9683 | Val Loss 0.1522, Acc 0.9521




KeyboardInterrupt: 

In [None]:
train_WavKAN([784, hidden_layer_width, hidden_layer_width, hidden_layer_width, 10], 4e-4)

110036




Epoch 1/30 | Train Loss 0.6898, Acc 0.8530 | Val Loss 0.4309, Acc 0.9243




Epoch 2/30 | Train Loss 0.3663, Acc 0.9349 | Val Loss 0.3200, Acc 0.9406




Epoch 3/30 | Train Loss 0.2775, Acc 0.9491 | Val Loss 0.2535, Acc 0.9509




Epoch 4/30 | Train Loss 0.2269, Acc 0.9562 | Val Loss 0.2208, Acc 0.9562




Epoch 5/30 | Train Loss 0.1912, Acc 0.9627 | Val Loss 0.1881, Acc 0.9596




Epoch 6/30 | Train Loss 0.1659, Acc 0.9673 | Val Loss 0.1797, Acc 0.9604




Epoch 7/30 | Train Loss 0.1476, Acc 0.9704 | Val Loss 0.1743, Acc 0.9598




Epoch 8/30 | Train Loss 0.1313, Acc 0.9735 | Val Loss 0.1485, Acc 0.9663




Epoch 9/30 | Train Loss 0.1198, Acc 0.9753 | Val Loss 0.1568, Acc 0.9630




Epoch 10/30 | Train Loss 0.1091, Acc 0.9779 | Val Loss 0.1424, Acc 0.9656




Epoch 11/30 | Train Loss 0.1023, Acc 0.9786 | Val Loss 0.1317, Acc 0.9688




Epoch 12/30 | Train Loss 0.0930, Acc 0.9809 | Val Loss 0.1291, Acc 0.9669




Epoch 13/30 | Train Loss 0.0872, Acc 0.9826 | Val Loss 0.1304, Acc 0.9662




Epoch 14/30 | Train Loss 0.0816, Acc 0.9833 | Val Loss 0.1256, Acc 0.9679




Epoch 15/30 | Train Loss 0.0765, Acc 0.9846 | Val Loss 0.1216, Acc 0.9694




Epoch 16/30 | Train Loss 0.0716, Acc 0.9859 | Val Loss 0.1250, Acc 0.9686




Epoch 17/30 | Train Loss 0.0682, Acc 0.9869 | Val Loss 0.1210, Acc 0.9693




Epoch 18/30 | Train Loss 0.0655, Acc 0.9873 | Val Loss 0.1185, Acc 0.9703




Epoch 19/30 | Train Loss 0.0623, Acc 0.9883 | Val Loss 0.1183, Acc 0.9695




Epoch 20/30 | Train Loss 0.0594, Acc 0.9893 | Val Loss 0.1215, Acc 0.9688




Epoch 21/30 | Train Loss 0.0570, Acc 0.9898 | Val Loss 0.1173, Acc 0.9697




Epoch 22/30 | Train Loss 0.0553, Acc 0.9903 | Val Loss 0.1169, Acc 0.9701




Epoch 23/30 | Train Loss 0.0534, Acc 0.9908 | Val Loss 0.1164, Acc 0.9694




Epoch 24/30 | Train Loss 0.0519, Acc 0.9910 | Val Loss 0.1180, Acc 0.9695




Epoch 25/30 | Train Loss 0.0503, Acc 0.9915 | Val Loss 0.1156, Acc 0.9702




Epoch 26/30 | Train Loss 0.0490, Acc 0.9920 | Val Loss 0.1161, Acc 0.9710




Epoch 27/30 | Train Loss 0.0478, Acc 0.9924 | Val Loss 0.1153, Acc 0.9703




Epoch 28/30 | Train Loss 0.0469, Acc 0.9923 | Val Loss 0.1145, Acc 0.9704




Epoch 29/30 | Train Loss 0.0460, Acc 0.9927 | Val Loss 0.1160, Acc 0.9697




Epoch 30/30 | Train Loss 0.0451, Acc 0.9931 | Val Loss 0.1150, Acc 0.9708
✅ Finished training WavKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "wav"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "sine"),
    (10, "linear")], 4e-4)


Total trainable parameters: 208526




Epoch 1/30 | Train Loss 0.6324, Acc 0.8119 | Val Loss 0.3165, Acc 0.9090




Epoch 2/30 | Train Loss 0.2823, Acc 0.9145 | Val Loss 0.2346, Acc 0.9292




Epoch 3/30 | Train Loss 0.2137, Acc 0.9352 | Val Loss 0.2417, Acc 0.9280




Epoch 4/30 | Train Loss 0.1717, Acc 0.9474 | Val Loss 0.1955, Acc 0.9391




Epoch 5/30 | Train Loss 0.1461, Acc 0.9550 | Val Loss 0.1721, Acc 0.9484




Epoch 6/30 | Train Loss 0.1266, Acc 0.9605 | Val Loss 0.1950, Acc 0.9420




Epoch 7/30 | Train Loss 0.1067, Acc 0.9670 | Val Loss 0.1817, Acc 0.9456




Epoch 8/30 | Train Loss 0.0939, Acc 0.9711 | Val Loss 0.1643, Acc 0.9508




Epoch 9/30 | Train Loss 0.0818, Acc 0.9747 | Val Loss 0.1876, Acc 0.9477




Epoch 10/30 | Train Loss 0.0701, Acc 0.9782 | Val Loss 0.1751, Acc 0.9513




Epoch 11/30 | Train Loss 0.0614, Acc 0.9815 | Val Loss 0.1831, Acc 0.9513




Epoch 12/30 | Train Loss 0.0544, Acc 0.9838 | Val Loss 0.1831, Acc 0.9503




Epoch 13/30 | Train Loss 0.0459, Acc 0.9865 | Val Loss 0.1881, Acc 0.9502




Epoch 14/30 | Train Loss 0.0404, Acc 0.9880 | Val Loss 0.1935, Acc 0.9509




Epoch 15/30 | Train Loss 0.0351, Acc 0.9897 | Val Loss 0.1911, Acc 0.9520




Epoch 16/30 | Train Loss 0.0299, Acc 0.9920 | Val Loss 0.2031, Acc 0.9489




Epoch 17/30 | Train Loss 0.0261, Acc 0.9934 | Val Loss 0.2000, Acc 0.9524




Epoch 18/30 | Train Loss 0.0227, Acc 0.9948 | Val Loss 0.2125, Acc 0.9485




Epoch 19/30 | Train Loss 0.0199, Acc 0.9955 | Val Loss 0.2109, Acc 0.9511




Epoch 20/30 | Train Loss 0.0169, Acc 0.9966 | Val Loss 0.2191, Acc 0.9517




Epoch 21/30 | Train Loss 0.0149, Acc 0.9974 | Val Loss 0.2233, Acc 0.9496




Epoch 22/30 | Train Loss 0.0128, Acc 0.9980 | Val Loss 0.2290, Acc 0.9490




Epoch 23/30 | Train Loss 0.0114, Acc 0.9986 | Val Loss 0.2304, Acc 0.9499




Epoch 24/30 | Train Loss 0.0107, Acc 0.9985 | Val Loss 0.2328, Acc 0.9490




Epoch 25/30 | Train Loss 0.0093, Acc 0.9989 | Val Loss 0.2399, Acc 0.9501




Epoch 26/30 | Train Loss 0.0084, Acc 0.9992 | Val Loss 0.2430, Acc 0.9494




Epoch 27/30 | Train Loss 0.0077, Acc 0.9993 | Val Loss 0.2458, Acc 0.9501




Epoch 28/30 | Train Loss 0.0068, Acc 0.9993 | Val Loss 0.2508, Acc 0.9491




Epoch 29/30 | Train Loss 0.0062, Acc 0.9994 | Val Loss 0.2529, Acc 0.9486




Epoch 30/30 | Train Loss 0.0059, Acc 0.9995 | Val Loss 0.2564, Acc 0.9500
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
    (hidden_layer_width, "wav"),
    (10, "linear")], 4e-4)

Total trainable parameters: 2518670




Epoch 1/30 | Train Loss 0.6877, Acc 0.8020 | Val Loss 0.3403, Acc 0.8997




Epoch 2/30 | Train Loss 0.2460, Acc 0.9282 | Val Loss 0.1940, Acc 0.9475




Epoch 3/30 | Train Loss 0.1784, Acc 0.9465 | Val Loss 0.1762, Acc 0.9462




Epoch 4/30 | Train Loss 0.1418, Acc 0.9576 | Val Loss 0.1657, Acc 0.9497




Epoch 5/30 | Train Loss 0.1160, Acc 0.9653 | Val Loss 0.1403, Acc 0.9566




Epoch 6/30 | Train Loss 0.1010, Acc 0.9694 | Val Loss 0.1188, Acc 0.9627




Epoch 7/30 | Train Loss 0.0866, Acc 0.9730 | Val Loss 0.1141, Acc 0.9634




Epoch 8/30 | Train Loss 0.0769, Acc 0.9756 | Val Loss 0.1324, Acc 0.9596




Epoch 9/30 | Train Loss 0.0659, Acc 0.9795 | Val Loss 0.0970, Acc 0.9716




Epoch 10/30 | Train Loss 0.0582, Acc 0.9821 | Val Loss 0.0946, Acc 0.9711




Epoch 11/30 | Train Loss 0.0495, Acc 0.9847 | Val Loss 0.1074, Acc 0.9680




Epoch 12/30 | Train Loss 0.0410, Acc 0.9871 | Val Loss 0.1044, Acc 0.9699




Epoch 13/30 | Train Loss 0.0367, Acc 0.9884 | Val Loss 0.1014, Acc 0.9699




Epoch 14/30 | Train Loss 0.0306, Acc 0.9908 | Val Loss 0.1114, Acc 0.9644




Epoch 15/30 | Train Loss 0.0258, Acc 0.9926 | Val Loss 0.0944, Acc 0.9731




Epoch 16/30 | Train Loss 0.0213, Acc 0.9939 | Val Loss 0.0983, Acc 0.9740




Epoch 17/30 | Train Loss 0.0196, Acc 0.9942 | Val Loss 0.0958, Acc 0.9747




Epoch 18/30 | Train Loss 0.0158, Acc 0.9960 | Val Loss 0.0930, Acc 0.9743




Epoch 19/30 | Train Loss 0.0130, Acc 0.9969 | Val Loss 0.0994, Acc 0.9740




Epoch 20/30 | Train Loss 0.0117, Acc 0.9972 | Val Loss 0.0953, Acc 0.9756




Epoch 21/30 | Train Loss 0.0093, Acc 0.9980 | Val Loss 0.1004, Acc 0.9743




Epoch 22/30 | Train Loss 0.0082, Acc 0.9984 | Val Loss 0.0960, Acc 0.9748




Epoch 23/30 | Train Loss 0.0062, Acc 0.9992 | Val Loss 0.1001, Acc 0.9754




Epoch 24/30 | Train Loss 0.0055, Acc 0.9993 | Val Loss 0.1001, Acc 0.9746




Epoch 25/30 | Train Loss 0.0047, Acc 0.9997 | Val Loss 0.0996, Acc 0.9755




Epoch 26/30 | Train Loss 0.0039, Acc 0.9997 | Val Loss 0.1044, Acc 0.9745




Epoch 27/30 | Train Loss 0.0034, Acc 0.9998 | Val Loss 0.1029, Acc 0.9760




Epoch 28/30 | Train Loss 0.0030, Acc 0.9998 | Val Loss 0.1028, Acc 0.9766




Epoch 29/30 | Train Loss 0.0025, Acc 0.9999 | Val Loss 0.1046, Acc 0.9756




Epoch 30/30 | Train Loss 0.0023, Acc 0.9999 | Val Loss 0.1110, Acc 0.9741
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "sine"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "cheby"),
    (10, "linear")], 4e-4)

Total trainable parameters: 2518670




Epoch 1/30 | Train Loss 0.4656, Acc 0.8745 | Val Loss 0.2648, Acc 0.9244




Epoch 2/30 | Train Loss 0.2225, Acc 0.9352 | Val Loss 0.2131, Acc 0.9341




Epoch 3/30 | Train Loss 0.1729, Acc 0.9490 | Val Loss 0.1655, Acc 0.9528




Epoch 4/30 | Train Loss 0.1420, Acc 0.9578 | Val Loss 0.1705, Acc 0.9509




Epoch 5/30 | Train Loss 0.1237, Acc 0.9625 | Val Loss 0.1438, Acc 0.9558




Epoch 6/30 | Train Loss 0.1076, Acc 0.9673 | Val Loss 0.1227, Acc 0.9638




Epoch 7/30 | Train Loss 0.0970, Acc 0.9704 | Val Loss 0.1305, Acc 0.9613




Epoch 8/30 | Train Loss 0.0859, Acc 0.9740 | Val Loss 0.1303, Acc 0.9621




Epoch 9/30 | Train Loss 0.0761, Acc 0.9771 | Val Loss 0.1152, Acc 0.9678




Epoch 10/30 | Train Loss 0.0671, Acc 0.9790 | Val Loss 0.1150, Acc 0.9667




Epoch 11/30 | Train Loss 0.0565, Acc 0.9824 | Val Loss 0.1188, Acc 0.9666




Epoch 12/30 | Train Loss 0.0516, Acc 0.9841 | Val Loss 0.1188, Acc 0.9670




Epoch 13/30 | Train Loss 0.0451, Acc 0.9859 | Val Loss 0.1067, Acc 0.9713




Epoch 14/30 | Train Loss 0.0379, Acc 0.9886 | Val Loss 0.1123, Acc 0.9700




Epoch 15/30 | Train Loss 0.0348, Acc 0.9897 | Val Loss 0.1111, Acc 0.9715




Epoch 16/30 | Train Loss 0.0290, Acc 0.9912 | Val Loss 0.1086, Acc 0.9713




Epoch 17/30 | Train Loss 0.0243, Acc 0.9938 | Val Loss 0.1086, Acc 0.9712




Epoch 18/30 | Train Loss 0.0215, Acc 0.9941 | Val Loss 0.1064, Acc 0.9735




Epoch 19/30 | Train Loss 0.0198, Acc 0.9950 | Val Loss 0.1216, Acc 0.9708




Epoch 20/30 | Train Loss 0.0157, Acc 0.9964 | Val Loss 0.1115, Acc 0.9731




Epoch 21/30 | Train Loss 0.0138, Acc 0.9971 | Val Loss 0.1164, Acc 0.9714




Epoch 22/30 | Train Loss 0.0129, Acc 0.9972 | Val Loss 0.1129, Acc 0.9724




Epoch 23/30 | Train Loss 0.0107, Acc 0.9981 | Val Loss 0.1190, Acc 0.9707




Epoch 24/30 | Train Loss 0.0096, Acc 0.9984 | Val Loss 0.1193, Acc 0.9715




Epoch 25/30 | Train Loss 0.0081, Acc 0.9991 | Val Loss 0.1197, Acc 0.9724




Epoch 26/30 | Train Loss 0.0073, Acc 0.9992 | Val Loss 0.1166, Acc 0.9725




Epoch 27/30 | Train Loss 0.0066, Acc 0.9992 | Val Loss 0.1201, Acc 0.9732




Epoch 28/30 | Train Loss 0.0057, Acc 0.9994 | Val Loss 0.1183, Acc 0.9731




Epoch 29/30 | Train Loss 0.0052, Acc 0.9995 | Val Loss 0.1264, Acc 0.9702




Epoch 30/30 | Train Loss 0.0047, Acc 0.9996 | Val Loss 0.1240, Acc 0.9725
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "cheby"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "wav"),
    (10, "linear")], 4e-4)

Total trainable parameters: 232590




Epoch 1/30 | Train Loss 0.6363, Acc 0.8062 | Val Loss 0.2461, Acc 0.9315




Epoch 2/30 | Train Loss 0.1906, Acc 0.9447 | Val Loss 0.1474, Acc 0.9572




Epoch 3/30 | Train Loss 0.1384, Acc 0.9586 | Val Loss 0.1355, Acc 0.9598




Epoch 4/30 | Train Loss 0.1067, Acc 0.9681 | Val Loss 0.1180, Acc 0.9643




Epoch 5/30 | Train Loss 0.0899, Acc 0.9725 | Val Loss 0.1171, Acc 0.9650




Epoch 6/30 | Train Loss 0.0746, Acc 0.9776 | Val Loss 0.1194, Acc 0.9640




Epoch 7/30 | Train Loss 0.0614, Acc 0.9811 | Val Loss 0.1273, Acc 0.9607




Epoch 8/30 | Train Loss 0.0528, Acc 0.9829 | Val Loss 0.1019, Acc 0.9715




Epoch 9/30 | Train Loss 0.0435, Acc 0.9865 | Val Loss 0.1055, Acc 0.9699




Epoch 10/30 | Train Loss 0.0369, Acc 0.9888 | Val Loss 0.0919, Acc 0.9737




Epoch 11/30 | Train Loss 0.0303, Acc 0.9910 | Val Loss 0.1006, Acc 0.9710




Epoch 12/30 | Train Loss 0.0265, Acc 0.9915 | Val Loss 0.0974, Acc 0.9719




Epoch 13/30 | Train Loss 0.0214, Acc 0.9935 | Val Loss 0.0968, Acc 0.9716




Epoch 14/30 | Train Loss 0.0159, Acc 0.9957 | Val Loss 0.1053, Acc 0.9726




Epoch 15/30 | Train Loss 0.0138, Acc 0.9964 | Val Loss 0.0971, Acc 0.9746




Epoch 16/30 | Train Loss 0.0112, Acc 0.9971 | Val Loss 0.1005, Acc 0.9737




Epoch 17/30 | Train Loss 0.0087, Acc 0.9983 | Val Loss 0.0947, Acc 0.9756




Epoch 18/30 | Train Loss 0.0073, Acc 0.9987 | Val Loss 0.1080, Acc 0.9720




Epoch 19/30 | Train Loss 0.0062, Acc 0.9989 | Val Loss 0.1025, Acc 0.9749




Epoch 20/30 | Train Loss 0.0045, Acc 0.9995 | Val Loss 0.1041, Acc 0.9750




Epoch 21/30 | Train Loss 0.0032, Acc 0.9998 | Val Loss 0.1129, Acc 0.9728




Epoch 22/30 | Train Loss 0.0036, Acc 0.9995 | Val Loss 0.1006, Acc 0.9747




Epoch 23/30 | Train Loss 0.0023, Acc 0.9999 | Val Loss 0.1151, Acc 0.9737




Epoch 24/30 | Train Loss 0.0022, Acc 0.9998 | Val Loss 0.1038, Acc 0.9764




Epoch 25/30 | Train Loss 0.0018, Acc 0.9999 | Val Loss 0.1065, Acc 0.9764




Epoch 26/30 | Train Loss 0.0013, Acc 1.0000 | Val Loss 0.1100, Acc 0.9759




Epoch 27/30 | Train Loss 0.0016, Acc 0.9998 | Val Loss 0.1103, Acc 0.9759




Epoch 28/30 | Train Loss 0.0009, Acc 1.0000 | Val Loss 0.1092, Acc 0.9768




Epoch 29/30 | Train Loss 0.0009, Acc 1.0000 | Val Loss 0.1139, Acc 0.9761




Epoch 30/30 | Train Loss 0.0008, Acc 1.0000 | Val Loss 0.1119, Acc 0.9766
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "cheby"),
    (hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (10, "linear")], 4e-4)

Total trainable parameters: 232590




Epoch 1/30 | Train Loss 0.4819, Acc 0.8712 | Val Loss 0.2303, Acc 0.9356




Epoch 2/30 | Train Loss 0.1842, Acc 0.9469 | Val Loss 0.1579, Acc 0.9562




Epoch 3/30 | Train Loss 0.1346, Acc 0.9615 | Val Loss 0.1590, Acc 0.9537




Epoch 4/30 | Train Loss 0.1088, Acc 0.9678 | Val Loss 0.1367, Acc 0.9575




Epoch 5/30 | Train Loss 0.0911, Acc 0.9727 | Val Loss 0.1232, Acc 0.9643




Epoch 6/30 | Train Loss 0.0738, Acc 0.9776 | Val Loss 0.1270, Acc 0.9647




Epoch 7/30 | Train Loss 0.0636, Acc 0.9806 | Val Loss 0.1173, Acc 0.9655




Epoch 8/30 | Train Loss 0.0544, Acc 0.9838 | Val Loss 0.1188, Acc 0.9654




Epoch 9/30 | Train Loss 0.0436, Acc 0.9865 | Val Loss 0.1291, Acc 0.9636




Epoch 10/30 | Train Loss 0.0377, Acc 0.9889 | Val Loss 0.1039, Acc 0.9711




Epoch 11/30 | Train Loss 0.0310, Acc 0.9910 | Val Loss 0.1208, Acc 0.9675




Epoch 12/30 | Train Loss 0.0247, Acc 0.9928 | Val Loss 0.1152, Acc 0.9689




Epoch 13/30 | Train Loss 0.0222, Acc 0.9938 | Val Loss 0.1151, Acc 0.9688




Epoch 14/30 | Train Loss 0.0179, Acc 0.9951 | Val Loss 0.1159, Acc 0.9699




Epoch 15/30 | Train Loss 0.0140, Acc 0.9964 | Val Loss 0.1135, Acc 0.9720




Epoch 16/30 | Train Loss 0.0109, Acc 0.9977 | Val Loss 0.1108, Acc 0.9732




Epoch 17/30 | Train Loss 0.0106, Acc 0.9974 | Val Loss 0.1154, Acc 0.9721




Epoch 18/30 | Train Loss 0.0077, Acc 0.9985 | Val Loss 0.1431, Acc 0.9671




Epoch 19/30 | Train Loss 0.0061, Acc 0.9990 | Val Loss 0.1226, Acc 0.9706




Epoch 20/30 | Train Loss 0.0056, Acc 0.9991 | Val Loss 0.1505, Acc 0.9664




Epoch 21/30 | Train Loss 0.0049, Acc 0.9992 | Val Loss 0.1227, Acc 0.9726




Epoch 22/30 | Train Loss 0.0035, Acc 0.9996 | Val Loss 0.1221, Acc 0.9718




Epoch 23/30 | Train Loss 0.0030, Acc 0.9997 | Val Loss 0.1217, Acc 0.9732




Epoch 24/30 | Train Loss 0.0030, Acc 0.9996 | Val Loss 0.1358, Acc 0.9709




Epoch 25/30 | Train Loss 0.0021, Acc 0.9999 | Val Loss 0.1235, Acc 0.9734




Epoch 26/30 | Train Loss 0.0019, Acc 0.9999 | Val Loss 0.1335, Acc 0.9727




Epoch 27/30 | Train Loss 0.0017, Acc 0.9999 | Val Loss 0.1319, Acc 0.9715




Epoch 28/30 | Train Loss 0.0018, Acc 0.9999 | Val Loss 0.1329, Acc 0.9729




Epoch 29/30 | Train Loss 0.0012, Acc 1.0000 | Val Loss 0.1351, Acc 0.9719




Epoch 30/30 | Train Loss 0.0011, Acc 1.0000 | Val Loss 0.1355, Acc 0.9725
✅ Finished training ChebyKAN on MNIST


In [None]:
train_HybridKAN([(hidden_layer_width, "wav"),
    (hidden_layer_width, "sine"),
    (hidden_layer_width, "cheby"),
    (10, "linear")], 4e-4)

Total trainable parameters: 208526




Epoch 1/30 | Train Loss 0.6770, Acc 0.7868 | Val Loss 0.3351, Acc 0.8971




Epoch 2/30 | Train Loss 0.2813, Acc 0.9147 | Val Loss 0.2520, Acc 0.9254




Epoch 3/30 | Train Loss 0.2087, Acc 0.9359 | Val Loss 0.1939, Acc 0.9428




Epoch 4/30 | Train Loss 0.1716, Acc 0.9473 | Val Loss 0.2159, Acc 0.9335




Epoch 5/30 | Train Loss 0.1423, Acc 0.9557 | Val Loss 0.1738, Acc 0.9446




Epoch 6/30 | Train Loss 0.1236, Acc 0.9622 | Val Loss 0.2033, Acc 0.9387




Epoch 7/30 | Train Loss 0.1063, Acc 0.9673 | Val Loss 0.1826, Acc 0.9463




Epoch 8/30 | Train Loss 0.0920, Acc 0.9719 | Val Loss 0.1919, Acc 0.9433




Epoch 9/30 | Train Loss 0.0798, Acc 0.9760 | Val Loss 0.1622, Acc 0.9556




Epoch 10/30 | Train Loss 0.0688, Acc 0.9798 | Val Loss 0.1562, Acc 0.9553




Epoch 11/30 | Train Loss 0.0601, Acc 0.9821 | Val Loss 0.1618, Acc 0.9577




Epoch 12/30 | Train Loss 0.0517, Acc 0.9851 | Val Loss 0.1595, Acc 0.9562




Epoch 13/30 | Train Loss 0.0463, Acc 0.9863 | Val Loss 0.1618, Acc 0.9562




Epoch 14/30 | Train Loss 0.0391, Acc 0.9892 | Val Loss 0.1729, Acc 0.9541




Epoch 15/30 | Train Loss 0.0338, Acc 0.9906 | Val Loss 0.1645, Acc 0.9577




Epoch 16/30 | Train Loss 0.0281, Acc 0.9928 | Val Loss 0.1699, Acc 0.9568




Epoch 17/30 | Train Loss 0.0244, Acc 0.9940 | Val Loss 0.1772, Acc 0.9558




Epoch 18/30 | Train Loss 0.0204, Acc 0.9953 | Val Loss 0.1869, Acc 0.9557




Epoch 19/30 | Train Loss 0.0179, Acc 0.9961 | Val Loss 0.1900, Acc 0.9539




Epoch 20/30 | Train Loss 0.0155, Acc 0.9968 | Val Loss 0.1824, Acc 0.9572




Epoch 21/30 | Train Loss 0.0133, Acc 0.9976 | Val Loss 0.1931, Acc 0.9563




Epoch 22/30 | Train Loss 0.0112, Acc 0.9980 | Val Loss 0.1973, Acc 0.9555




Epoch 23/30 | Train Loss 0.0096, Acc 0.9987 | Val Loss 0.2048, Acc 0.9546




Epoch 24/30 | Train Loss 0.0084, Acc 0.9991 | Val Loss 0.2031, Acc 0.9559




Epoch 25/30 | Train Loss 0.0073, Acc 0.9992 | Val Loss 0.2121, Acc 0.9550




Epoch 26/30 | Train Loss 0.0064, Acc 0.9994 | Val Loss 0.2117, Acc 0.9560




Epoch 27/30 | Train Loss 0.0056, Acc 0.9996 | Val Loss 0.2148, Acc 0.9545




Epoch 28/30 | Train Loss 0.0050, Acc 0.9997 | Val Loss 0.2175, Acc 0.9545




Epoch 29/30 | Train Loss 0.0044, Acc 0.9998 | Val Loss 0.2199, Acc 0.9547




Epoch 30/30 | Train Loss 0.0042, Acc 0.9998 | Val Loss 0.2246, Acc 0.9541
✅ Finished training ChebyKAN on MNIST
