In [1]:
# Adapted from https://github.com/ZiyaoLi/fast-kan/tree/master

import torch
import jax
from flax import nnx
from efficient_kan import KANLinear
from fastkan import FastKANLayer
from kan_rbf import KANRBFLayer

import treescope

treescope.register_as_default()

eklayer = KANLinear(100, 100).cuda()
fklayer = FastKANLayer(100, 100).cuda()
rbflayer = KANRBFLayer(100, 100, use_layernorm=False, base_update_bias=True, rngs=nnx.Rngs(0))
rbflayer_jit = nnx.jit(rbflayer)

nnx.display(rbflayer)

2024-09-23 20:45:37.952177: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.5.82. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
x = torch.randn(8, 100).cuda()
x_jax = jax.numpy.array(x.cpu().numpy())

_ = rbflayer_jit(x_jax)

In [4]:
%timeit -r10 -n1000 y = eklayer(x)
%timeit -r10 -n1000 y = fklayer(x)
%timeit -r10 -n1000 y = rbflayer(x_jax)
%timeit -r10 -n1000 y = rbflayer_jit(x_jax)

1.33 ms ± 225 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
468 μs ± 113 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
2.08 ms ± 241 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
152 μs ± 11.1 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [5]:
%timeit -r10 -n1000 eklayer(x).sum().backward()
%timeit -r10 -n1000 fklayer(x).sum().backward()

2.15 ms ± 182 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
1.2 ms ± 140 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [6]:
from kan_rbf import KANRBF
from fastkan import FastKAN

fastkan = FastKAN([20, 20, 20]).cuda()
rbfkan = KANRBF([20, 20, 20], use_layernorm=True, base_update_bias=True, rngs=nnx.Rngs(0))
rbfkan_jit = nnx.jit(rbfkan)

In [7]:
x = torch.randn(32, 20).cuda()
x_jax = jax.numpy.array(x.cpu().numpy())

%timeit -r10 -n1000 y = fastkan(x)
%timeit -r10 -n1000 y = rbfkan(x_jax)
%timeit -r10 -n1000 y = rbfkan_jit(x_jax)

968 μs ± 284 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
10.2 ms ± 1.2 ms per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
The slowest run took 4.69 times longer than the fastest. This could mean that an intermediate result is being cached.
255 μs ± 170 μs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
