In [5]:
from tinygrad import Device
print (Device.DEFAULT)
from typing import *
from tinygrad import TinyJit

METAL


In [71]:
from tinygrad import Tensor, nn
import tinygrad.function as F
import numpy as np

class SplineLinearFunction:
    def __init__(
        self,
        in_features: int, out_features: int, init_scale: float = 0.1
    ):
        self.init_scale = init_scale
        self.linear_function = nn.Linear(in_features, out_features, bias=False)
    
    def __call__(self, x:Tensor) -> Tensor:
        return self.linear_function(x)

class RadialBasisFunction:
    def __init__(
        self,
        grid_min = -2.,
        grid_max = 2.,
        num_grids = 8,
        denominator = None
    ):
        self.grid_min = grid_min
        self.grid_max = grid_max
        self.num_grids = num_grids
        # You don't need a special Parameter initialization here.
        # You can initialize a Tensor later with this
        self.grid = Tensor(np.linspace(grid_min, grid_max, num_grids, dtype=np.float32), requires_grad=True)
        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

    def __call__(self, x:Tensor) -> Tensor:
        return (-(((x[..., None] - self.grid) / self.denominator).pow(2))).exp()

In [72]:
rbf = RadialBasisFunction()
print(rbf(Tensor([0., 1., 3., 10.])))

slf = SplineLinearFunction(1, 3, 1)
slf(Tensor([10]))

<Tensor <LB METAL (4, 8) float (<UnaryOps.EXP2: 1>, None)> on METAL with grad None>


<Tensor <LB METAL (3,) float ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),))> on METAL with grad None>

In [73]:
class FastKANLayer:
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        grid_min: float = -2.,
        grid_max: float = 2.,
        num_grids: int = 8,
        use_base_update: bool = True,
        use_layernorm: bool = True,
        base_activation = Tensor.silu,
        spline_weight_init_scale: float = 0.1
    ) -> None:
        self.input_dim = input_dim
        self.output_dim = output_dim
        # normally you'd init layernorm here.
        # but because layernorm *isn't* a layer in tinygrad,
        # it's a function, I'm gonna hold off until the call
        self.layernorm = None
        if use_layernorm:
            assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
            self.layernorm = nn.LayerNorm(input_dim)
        self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
        self.spline_linear = SplineLinearFunction(input_dim * num_grids, output_dim, spline_weight_init_scale)
        self.use_base_update = use_base_update
        if use_base_update:
            self.base_activation = base_activation
            self.base_linear = nn.Linear(input_dim, output_dim)

    def __call__(
        self, x: Tensor, use_layernorm=True
    ) -> Tensor:
        if self.layernorm is not None and use_layernorm:
            spline_basis = self.rbf(self.layernorm(x))
        else:
            spline_basis = self.rbf(x)
        spline_basis_view = spline_basis.view(*spline_basis.shape[:-2], -1)
        ret = self.spline_linear(spline_basis_view)
        if self.use_base_update:
            base = self.base_linear(self.base_activation(x))
            ret = ret + base
        return ret

    def plot_curve(
        self,
        input_index: int,
        output_index: int,
        num_pts: int = 1000,
        num_extrapolate_bins: int = 2
    ):
        '''this function returns the learned curves in a FastKANLayer.
        input_index: the selected index of the input, in [0, input_dim) .
        output_index: the selected index of the output, in [0, output_dim) .
        num_pts: num of points sampled for the curve.
        num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve 
            will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
        '''
        ng = self.rbf.num_grids
        h = self.rbf.denominator
        assert input_index < self.input_dim
        assert output_index < self.output_dim
        w = self.spline_linear.linear_function.weight[
            output_index, input_index * ng : (input_index + 1) * ng
        ]   # num_grids,
        x = Tensor(np.linspace(
            self.rbf.grid_min - num_extrapolate_bins * h,
            self.rbf.grid_max + num_extrapolate_bins * h,
            num_pts
        ))   # num_pts, num_grids
        Tensor.no_grad = True
        y = (w * self.rbf(x)).sum(-1)
        Tensor.no_grad = False
        return x, y

In [74]:
fastKANLayer = FastKANLayer(2, 2)
fastKANLayer(Tensor([[1, 2], [1, 2]]))

<Tensor <LB METAL (2, 2) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>

In [78]:
from tinygrad import dtypes
d_in = 2
d_out = 3

layer = FastKANLayer(
    d_in, d_out,
    use_base_update=False,
    use_layernorm=False
)

x, y = layer.plot_curve(0, 1, num_pts=1000, num_extrapolate_bins=3)
x.shape, y.shape

import matplotlib.pyplot as plt

for i in range(d_in):
    for j in range(d_out):
        x, y = layer.plot_curve(i, j, 200, num_extrapolate_bins=3)
        plt.plot(x.numpy(), y.numpy(), label=r"$\phi_{" + f"{i},{j}" + r"}$")
plt.xlabel("$x$")
plt.ylabel("$\phi_{p,q}(x)$")
plt.legend(loc="upper right")

  plt.ylabel("$\phi_{p,q}(x)$")


CompileError: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:3:28: error: 'double' is not supported in Metal
kernel void E_2_4n2(device double* data0, const device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
                           ^
program_source:7:22: error: 'double' is not supported in Metal
  *(data0+alu0+1) = (double)(((0.5f*val0.y)+(-0.25f)));
                     ^
program_source:8:22: error: 'double' is not supported in Metal
  *(data0+alu0+2) = (double)(((0.5f*val0.z)+(-0.25f)));
                     ^
program_source:9:22: error: 'double' is not supported in Metal
  *(data0+alu0+3) = (double)(((0.5f*val0.w)+(-0.25f)));
                     ^
program_source:10:20: error: 'double' is not supported in Metal
  *(data0+alu0) = (double)(((0.5f*val0.x)+(-0.25f)));
                   ^
" UserInfo={NSLocalizedDescription=program_source:3:28: error: 'double' is not supported in Metal
kernel void E_2_4n2(device double* data0, const device float* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
                           ^
program_source:7:22: error: 'double' is not supported in Metal
  *(data0+alu0+1) = (double)(((0.5f*val0.y)+(-0.25f)));
                     ^
program_source:8:22: error: 'double' is not supported in Metal
  *(data0+alu0+2) = (double)(((0.5f*val0.z)+(-0.25f)));
                     ^
program_source:9:22: error: 'double' is not supported in Metal
  *(data0+alu0+3) = (double)(((0.5f*val0.w)+(-0.25f)));
                     ^
program_source:10:20: error: 'double' is not supported in Metal
  *(data0+alu0) = (double)(((0.5f*val0.x)+(-0.25f)));
                   ^
}

In [79]:
class FastKAN:
    def __init__(
        self,
        layers_hidden: List[int],
        grid_min: float = -2.,
        grid_max: float = 2.,
        num_grids: int = 8,
        use_base_update: bool = True,
        base_activation = Tensor.silu,
        spline_weight_init_scale: float = 0.1,
    ) -> None:
        self.layers = [
            FastKANLayer(
                in_dim, out_dim,
                grid_min=grid_min,
                grid_max=grid_max,
                num_grids=num_grids,
                use_base_update=use_base_update,
                base_activation=base_activation,
                spline_weight_init_scale=spline_weight_init_scale,
            ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
        ]

    def __call__(self, x:Tensor) -> Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

In [80]:
fastKAN = FastKAN([2, 3, 1])
fastKAN(Tensor([[2, 1], [1, 2], [5, 9]]))

<Tensor <LB METAL (3, 1) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>

In [85]:
class AttentionWithFastKANTransform:
    def __init__(
        self,
        q_dim: int,
        k_dim: int,
        v_dim: int,
        head_dim: int,
        num_heads: int,
        gating: bool = True,
    ):
        self.num_heads = num_heads
        total_dim = head_dim * self.num_heads
        self.gating = gating
        self.linear_q = FastKANLayer(q_dim, total_dim)
        self.linear_k = FastKANLayer(k_dim, total_dim)
        self.linear_v = FastKANLayer(v_dim, total_dim)
        self.linear_o = FastKANLayer(total_dim, q_dim)
        self.linear_g = None
        if self.gating:
            self.linear_g = FastKANLayer(q_dim, total_dim)
        # precompute the 1/sqrt(head_dim)
        self.norm = head_dim**-0.5

    def __call__(
        self,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        bias: Tensor = None,      # additive attention bias
    ) -> Tensor:

        wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm     # *q1hc
        wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1)    # *1khc
        att = (wq * wk).sum(-1).softmax(-2)     # *qkh
        del wq, wk
        if bias is not None:
            att = att + bias[..., None]

        wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1)     # *1khc
        o = (att[..., None] * wv).sum(-3)        # *qhc
        del att, wv

        o = o.view(*o.shape[:-2], -1)           # *q(hc)

        if self.linear_g is not None:
            # gating, use raw query input
            g = self.linear_g(q)
            o = Tensor.sigmoid(g) * o

        # merge heads
        o = self.linear_o(o)
        return o

In [89]:
batch_shape = (1,)
num_q = 12
num_kv = 24
q_dim = k_dim = v_dim = 32
head_dim = 8
num_heads = 2
q = Tensor.randn(*batch_shape, num_q, q_dim)
k = Tensor.randn(*batch_shape, num_kv, k_dim)
v = Tensor.randn(*batch_shape, num_kv, v_dim)

fast_kan_att = AttentionWithFastKANTransform(q_dim, k_dim, v_dim, head_dim, num_heads, gating=True)
out = fast_kan_att(q, k, v, bias=None)
assert out.shape == q.shape, out.shape

bias = Tensor.rand(*batch_shape, num_q, num_kv)
out = fast_kan_att(q, k, v, bias=bias)
assert out.shape == q.shape, out.shape

print("test attention: attention with fast kan transform got correct shapes.")

test attention: attention with fast kan transform got correct shapes.


In [29]:
#Okay, we've ascended the hill!
# Time to train MNIST

from tinygrad.nn.datasets import mnist
X_train, Y_train, X_test, Y_test = mnist()
print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)
print(X_test.shape, X_test.dtype, Y_test.shape, Y_test.dtype)
# (60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar

(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar
(10000, 1, 28, 28) dtypes.uchar (10000,) dtypes.uchar


In [30]:
model = FastKAN([28 * 28, 64, 10])

In [31]:
print(model(X_test.view(-1, 28 * 28)).argmax(axis=1))
print(Y_test.shape)

acc = (model(X_test.view(-1, 28 * 28)).argmax(axis=1) == Y_test).mean()
# NOTE: tinygrad is lazy, and hasn't actually run anything by this point
print(acc.item())  # ~10% accuracy, as expected from a random model

<Tensor <LB METAL (10000,) int (<BinaryOps.ADD: 1>, None)> on METAL with grad None>
(10000,)
0.08649999648332596


In [34]:
print(nn.state.get_parameters(model))
optim = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-3, weight_decay=1e-4)
batch_size = 128
def step():
  Tensor.training = True  # makes dropout work
  samples = Tensor.randint(batch_size, high=X_train.shape[0])
  X, Y = X_train[samples], Y_train[samples]
  optim.zero_grad()
  loss = (model(X.view(-1, 28 * 28)) + 1e-8).sparse_categorical_crossentropy(Y).backward()
  optim.step()
  return loss

[<Tensor <LB METAL (784,) float ShapeTracker(views=(View(shape=(784,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (784,) float ShapeTracker(views=(View(shape=(784,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (8,) float (<MetaOps.COPY: 3>, <buf real:True device:METAL size:8 dtype:dtypes.float offset:0>)> on METAL with grad <LB METAL (8,) float (<BinaryOps.ADD: 1>, None)>>, <Tensor <LB METAL (64, 6272) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:401408 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64, 784) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:50176 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64,) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:64 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64,) float ShapeTracker(views=(View(shape

In [35]:
import timeit
timeit.repeat(step, repeat=5, number=1)

[0.6120352919679135,
 0.22639587498269975,
 0.026706916047260165,
 0.03579529095441103,
 0.02406612504273653]

In [36]:
from tinygrad import TinyJit
jit_step = TinyJit(step)

In [37]:
import timeit
timeit.repeat(jit_step, repeat=5, number=1)

[0.08452891698107123,
 0.028836792102083564,
 0.025331582874059677,
 0.026197542203590274,
 0.02121345791965723]

In [70]:
from tinygrad import Context
with Context(BEAM=2):
  for step in range(7000):
    loss = jit_step()
    if step%100 == 0:
      Tensor.training = False
      acc = (model(X_test.view(-1, 28 * 28)).argmax(axis=1) == Y_test).mean().item()
      print(f"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")


step    0, loss 0.08, acc 96.13%
step  100, loss 0.05, acc 96.18%
step  200, loss 0.08, acc 96.10%
step  300, loss 0.05, acc 96.36%
step  400, loss 0.04, acc 96.64%
step  500, loss 0.10, acc 96.41%
step  600, loss 0.11, acc 96.29%
step  700, loss 0.10, acc 96.42%
step  800, loss 0.03, acc 96.57%
step  900, loss 0.04, acc 96.51%
step 1000, loss 0.03, acc 96.54%
step 1100, loss 0.09, acc 96.51%
step 1200, loss 0.02, acc 96.45%
step 1300, loss 0.04, acc 96.68%
step 1400, loss 0.11, acc 96.94%
step 1500, loss 0.06, acc 96.70%
step 1600, loss 0.03, acc 96.51%
step 1700, loss 0.08, acc 96.71%
step 1800, loss 0.02, acc 96.37%
step 1900, loss 0.04, acc 96.81%


KeyboardInterrupt: 

In [69]:
from tinygrad import TinyJit
with Context(BEAM=2):
    fklayer = FastKANLayer(100, 100)
    x = Tensor.randn(8, 100)
    rbf = RadialBasisFunction()
    def fklayer_step():
        fklayer(x, use_layernorm=False)
    def fklayer_sum_backwards_test():
        fklayer(x).sum().backward()
    def rbf_test():
        rbf(x)
    %timeit -r10 -n1000 TinyJit(fklayer_step)
    %timeit -r10 -n1000 TinyJit(fklayer_sum_backwards_test)
    %timeit -r10 -n1000 TinyJit(rbf_test)

345 ns ± 157 ns per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
304 ns ± 24.4 ns per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
334 ns ± 43 ns per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
