In [1]:
# Adapted from https://github.com/ZiyaoLi/fast-kan/tree/master

import torch
import jax
from flax import nnx
from efficient_kan import KANLinear
from fastkan import FastKANLayer
from kan_rbf import KANRBFLayer

import treescope

treescope.register_as_default()

eklayer = KANLinear(100, 100).cuda()
fklayer = FastKANLayer(100, 100).cuda()
rbflayer = KANRBFLayer(100, 100, use_layernorm=False, base_update_bias=True, rngs=nnx.Rngs(0))
rbflayer_jit = nnx.jit(rbflayer)

nnx.display(rbflayer)

2024-09-23 16:27:21.752719: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.5.82. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
x = torch.randn(8, 100).cuda()
x_jax = jax.numpy.array(x.cpu().numpy())

_ = rbflayer_jit(x_jax)

In [3]:
%timeit -r10 -n1000 y = eklayer(x)
%timeit -r10 -n1000 y = fklayer(x, use_layernorm=False)
%timeit -r10 -n1000 y = rbflayer(x_jax)
%timeit -r10 -n1000 y = rbflayer_jit(x_jax)

The slowest run took 5.97 times longer than the fastest. This could mean that an intermediate result is being cached.
1.97 ms ± 1.64 ms per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [6]:
%timeit -r10 -n1000 eklayer(x).sum().backward()
%timeit -r10 -n1000 fklayer(x).sum().backward()

2.49 ms ± 944 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
1.26 ms ± 83.8 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [7]:
from kan_rbf import KANRBF
from fastkan import FastKAN

fastkan = FastKAN([20, 20, 20]).cuda()
rbfkan = KANRBF([20, 20, 20], use_layernorm=True, base_update_bias=True, rngs=nnx.Rngs(0))
rbfkan_jit = nnx.jit(rbfkan)

In [8]:
x = torch.randn(32, 20).cuda()
x_jax = jax.numpy.array(x.cpu().numpy())

%timeit -r10 -n1000 y = fastkan(x)
%timeit -r10 -n1000 y = rbfkan(x_jax)
%timeit -r10 -n1000 y = rbfkan_jit(x_jax)

796 μs ± 74.5 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
8.73 ms ± 527 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
187 μs ± 102 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
