In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

from rich.pretty import pprint

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [2]:
%cd ~/avid

/home/nmiklaucic/avid


In [3]:
dataset_splits = (1,2,3,4,5,6,7)
batch_size = 512

df = pd.read_feather('data/mpc_full_feats_scaled_split.feather')
df = df[df['dataset_split'].isin(dataset_splits)]
is_valid = df['Xshift_umap']
print(np.sum(is_valid), np.sum(~is_valid))
df = df.select_dtypes('number').drop(columns=['TSNE_x', 'TSNE_y', 'umap_x', 'umap_y', 'dataset_split'])
df

5637 35800


Unnamed: 0,0-norm,2-norm,3-norm,5-norm,7-norm,10-norm,minimum Number,maximum Number,range Number,mean Number,...,avg s valence electrons,avg p valence electrons,avg d valence electrons,avg f valence electrons,compound possible,max ionic char,avg ionic char,magmom_pa,bandgap,delta_e
0,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,2.864847,0.010989,-1.760193,1.580818,...,0.737291,-0.899290,2.895815,-0.492369,1.125155,-1.720739,-1.587831,-0.353321,0.0000,0.003319
2,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,-0.670320,-2.027018,-1.760193,-1.302375,...,0.737291,-1.778907,-0.937909,-0.492369,1.125155,-1.720739,-1.587831,-0.353340,0.0000,0.108143
3,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,4.671710,1.052637,-1.760193,3.054449,...,0.737291,-1.778907,-0.171164,5.574266,1.125155,-1.720739,-1.587831,-0.353355,0.0000,0.071216
4,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,0.193832,-1.528839,-1.760193,-0.597594,...,0.737291,0.859944,-0.937909,-0.492369,1.125155,-1.720739,-1.587831,1.805382,2.0113,3.509988
10,-2.544139,4.304959,4.099305,3.846319,3.738150,3.663704,0.743747,-1.211815,-1.760193,-0.149098,...,0.737291,-1.778907,-0.171164,-0.492369,1.125155,-1.720739,-1.587831,-0.353329,0.1069,0.114489
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84182,-0.302404,1.073151,1.155693,1.189103,1.189594,1.184777,-0.277523,-0.441901,-0.302464,-0.524370,...,0.416977,1.362582,-0.855758,-0.492369,1.125155,1.435748,1.399402,-0.353365,7.1220,-4.235145
84184,-0.302404,-0.229039,-0.384105,-0.537700,-0.599189,-0.629952,-0.748879,1.143215,1.689766,-0.751365,...,-1.056469,-0.019673,-0.784560,0.114294,1.125155,0.961428,1.501256,-0.353302,3.6820,-2.416089
84186,-0.302404,-0.032556,0.033257,0.130844,0.170592,0.193219,-0.356083,1.052637,1.349629,0.158443,...,0.737291,0.332174,-0.784560,0.720958,1.125155,0.989041,1.397058,-0.353266,3.8500,-3.617353
84187,-0.302404,-0.032556,0.033257,0.130844,0.170592,0.193219,-0.356083,1.097926,1.398220,0.158443,...,-0.159589,0.332174,-0.707885,0.720958,1.125155,1.100757,1.370412,-0.080313,0.0000,1.118394


In [4]:
from typing import Callable, Sequence
import jax
import jax.numpy as jnp
import jax.random as jr
import flax.linen as nn
from flax import struct
from jaxtyping import Float, Array
from eins import EinsOp
import functools as ft

from avid.layers import Identity
from avid.utils import debug_structure, flax_summary

jax.config.update('jax_debug_nans', False)

eps = 1e-12

@ft.partial(jax.jit, static_argnames=('k', 'extend'))
def eval_spline(x: Float[Array, "splines"], grid: Float[Array, "splines grid"], k: int = 0, extend: bool = True) -> Float[Array, "splines coefs=grid+k-1"]:
    """Evaluate x on B-spline bases."""
    if extend:
        h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
        pad_start = jnp.tile(grid[:, [0]], (1, k)) - h
        pad_end = jnp.tile(grid[:, [-1]], (1, k)) + h
        grid = jnp.concat([pad_start, grid, pad_end], axis=1)    

    if x.ndim == 1:
        x = x[..., None]
        
    # debug_structure(x=x, grid=grid)

    if k == 0:
        value = ((x >= grid[:, :-1]) * (x < grid[:, 1:])).astype(x.dtype)
    else:        
        B_km1 = eval_spline(x, grid=grid, k=k - 1, extend=False)        
        value = (x - grid[:, :-(k + 1)])        
        value = value / (grid[:, k:-1] - grid[:, :-(k + 1)] + 1e-12)        
        value = value * B_km1[:, :-1]
        value = value + (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)] + eps) * B_km1[:, 1:]
    return value

@ft.partial(jax.jit, static_argnames='k')
def coef2curve(x_eval: Float[Array, "splines"], grid: Float[Array, "splines grid"], coef: Float[Array, "splines coefs"], k: int) -> Float[Array, "splines"]:
    """converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing
    up B_batch results over B-spline basis)."""    
    return jnp.einsum('sc,sc->s', coef, eval_spline(x_eval, grid, k))    


@ft.partial(jax.jit, static_argnames='k')
def curve2coef(x_eval: Float[Array, "samples splines"], y_eval: Float[Array, "samples splines"], grid: Float[Array, "splines grid"], k: int) -> Float[Array, "splines coefs"]:
    '''
    converting B-spline curves to B-spline coefficients using least squares.
    '''
    # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
    # debug_structure(b=eval_spline(x_eval, grid, k), x=x_eval, grid=grid)
    mat = jnp.permute_dims(jax.vmap(eval_spline, in_axes=(0, None, None))(x_eval, grid, k), (1, 0, 2))
    y = jnp.expand_dims(y_eval.T, -1)
    # debug_structure(mat=mat, x=x_eval, y=y)
    # m n, m k -> n k
    coef, _resid, _rank, _s = jax.vmap(jnp.linalg.lstsq, in_axes=0)(mat, y_eval.T)
    return coef


# x = np.random.randn(16, 5)
# y = np.random.randn(16, 5)
# grid = jnp.tile(jnp.linspace(-1, 1, 6)[None, :], (x.shape[1], 1))
# # y_eval = eval_spline(x, grid, k=3, extend=True)
# print(x.shape, y.shape, grid.shape)
# curve2coef(x, y, grid, k=3).shape


@struct.dataclass
class KANLayerOutput:
    y: Float[Array, "out_dim"]
    postacts: Float[Array, "out_dim in_dim"]
    postspline: Float[Array, "out_dim in_dim"]

# @struct.dataclass
class KANLayer(nn.Module):
    in_dim: int
    out_dim: int
    n_grid: int = 5
    order: int = 3
    kernel_init: Callable = nn.initializers.normal(stddev=0.1)
    resid_scale_trainable: bool = False
    resid_scale_init: Callable = nn.initializers.ones
    spline_scale_trainable: bool = False
    spline_scale_init: Callable = nn.initializers.ones
    base_act: Callable = nn.gelu
    grid_range: tuple[float, float] = (-1, 1)

    def setup(self):
        self.size = self.in_dim * self.out_dim
        self.grid = jnp.einsum('i,j->ij', jnp.ones(self.size), jnp.linspace(*self.grid_range, self.n_grid + 1))
        def spline_init(*args, **kwargs):
            noise = self.kernel_init(*args, **kwargs)          
            # debug_structure(grid=self.grid, noise=noise)  
            return curve2coef(self.grid.T, noise.T, self.grid, self.order)
        
        self.coef = self.param('coef', spline_init, (self.size, self.n_grid + 1))
        if self.resid_scale_trainable:
            self.resid_scale = self.param('resid_scale', self.resid_scale_init, (self.size,))
        else:
            self.resid_scale = 1

        if self.spline_scale_trainable:
            self.spline_scale = self.param('spline_scale', self.spline_scale_init, (self.size,))
        else:
            self.spline_scale = 1

    def full_output(self, x: Float[Array, "in_dim"]) -> KANLayerOutput:
        # splines: (out_dim in_dim)
        x = jnp.tile(x[None, ...], (self.out_dim, 1)).reshape(-1)
        y = coef2curve(x, self.grid, self.coef, self.order)
        postspline = y.reshape(self.out_dim, self.in_dim)

        y = self.resid_scale * self.base_act(x) + self.spline_scale * y
        postacts = y.reshape(self.out_dim, self.in_dim)
        
        y = EinsOp('(out in) -> out', reduce='mean', symbol_values={'out': self.out_dim})(y)

        return KANLayerOutput(y=y, postacts=postacts, postspline=postspline)
    
    def __call__(self, x: Float[Array, "batch in_dim"]) -> Float[Array, "batch out_dim"]:        
        out = lambda x: self.full_output(x).y
        return jax.vmap(out)(x)

    

# @struct.dataclass
class KAN(nn.Module):
    in_dim: int
    out_dim: int
    inner_dims: Sequence[int]

    use_layernorm: bool = True
    layer_templ: KANLayer = KANLayer(in_dim=1, out_dim=1)
    final_act: Callable = Identity()

    def setup(self):        
        norms = []
        layers = []
        in_dim = self.in_dim
        out_dims = tuple(self.inner_dims) + (self.out_dim,)        
        for out_dim in out_dims:
            layers.append(self.layer_templ.copy(in_dim=in_dim, out_dim=out_dim))
            norms.append(nn.LayerNorm() if self.use_layernorm else Identity())
            in_dim = out_dim

        self.norms = norms
        self.layers = layers
        self.network = nn.Sequential(self.layers)

    def full_outputs(self, x: Float[Array, "in_dim"]) -> tuple[Float[Array, "out_dim"], Sequence[KANLayerOutput]]:
        outputs = []
        curr_x = x
        for layer, norm in zip(self.layers, self.norms):
            curr_x = norm(curr_x)
            outputs.append(layer.full_output(curr_x))
            curr_x = outputs[-1].y
        y = self.final_act(curr_x)
        return y, outputs

    def __call__(self, x: Float[Array, "in_dim"], training: bool = False) -> Float[Array, "out_dim"]:
        out = lambda b: self.full_outputs(b)[0]
        return jax.vmap(out)(x)


in_dim = 4
rng = jr.key(0)
xtest = jr.normal(rng, (14, in_dim))

kan = KAN(in_dim=in_dim, out_dim=6, inner_dims=[5])
print(kan.tabulate(rng, xtest))


[3m                                  KAN Summary                                   [0m
┏━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath     [0m[1m [0m┃[1m [0m[1mmodule        [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams        [0m[1m [0m┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│           │ KAN            │ [2mfloat32[0m[14,4] │ [2mfloat32[0m[14,6]  │                │
├───────────┼────────────────┼───────────────┼────────────────┼────────────────┤
│ norms_0   │ LayerNorm      │ [2mfloat32[0m[4]    │ [2mfloat32[0m[4]     │ bias:          │
│           │                │               │                │ [2mfloat32[0m[4]     │
│           │                │               │                │ scale:         │
│           │                │               │                │ [2mfloat32[0m[4]     │
│           │

In [10]:
from jaxtyping import Bool

dtype = jnp.float32

@struct.dataclass
class TrainBatch:
    X: Float[Array, "batch in_dim"]
    delta_e: Float[Array, "batch"]
    mask: Bool[Array, "batch"]


datasets = []
for sub in df[is_valid], df[~is_valid]:
    Xy = jnp.array(sub.values, dtype=dtype)
    num_pad = -Xy.shape[0] % batch_size
    mask = jnp.concat([jnp.ones(Xy.shape[0]), jnp.zeros(num_pad)]).astype(jnp.bool)
    Xy = jnp.concat([Xy, Xy[:num_pad]])
    datasets.append((Xy, mask))

datasets = {'train': datasets[1], 'valid': datasets[0]}

steps_in_epoch = Xy.shape[0] // batch_size

def data_loader(split='train', infinite=False):
    data, mas = datasets[split]
    inds = np.arange(data.shape[0])

    perm = np.random.permutation(inds)

    data = jnp.array(data[perm])
    mas = jnp.array(mas[perm])
        
    first_time = True
    while first_time or infinite:
        first_time = False
        
        for i in range(0, data.shape[0], batch_size):
            yield TrainBatch(X=data[i:i+batch_size, :-1], delta_e=data[i:i+batch_size, -1] / 4, mask=mas[i:i+batch_size])

sample_batch = next(data_loader())
# debug_structure(sample_batch)


from avid.layers import LazyInMLP

sample_batch = next(data_loader())
kan = KAN(in_dim=sample_batch.X.shape[-1], out_dim=1, inner_dims=[64, 32], use_layernorm=True)
mlp = LazyInMLP(out_dim=1, inner_dims=[1024, 1024])

kwargs = {'console_kwargs': {'width': 200}, 'compute_flops': True, 'compute_vjp_flops': True}
flax_summary(kan, x=sample_batch.X, training=False, **kwargs)
flax_summary(mlp, x=sample_batch.X, training=False, **kwargs)


[3m                                                                                              KAN Summary                                                                                               [0m
                                                                                                                                                                                                        
 [1m [0m[1m          path[0m[1m [0m [1m [0m[1m                           module[0m[1m [0m [1m [0m[1m                       inputs[0m[1m [0m [1m [0m[1m                                 outputs[0m[1m [0m [1m [0m[1m           flops[0m[1m [0m [1m [0m[1m       vjp_flops[0m[1m [0m [1m [0m[1m                        params[0m[1m [0m 
 ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── 
                   

In [11]:
from tqdm import tqdm


import optax
from ml_collections import ConfigDict
from flax.training import train_state

start_frac = 0.1
end_frac = 0.3
base_lr = 5e-3
warmup = 10
n_epochs = 250
warmup_steps = steps_in_epoch * min(warmup, n_epochs // 4)
sched = optax.warmup_cosine_decay_schedule(
    init_value=start_frac * base_lr,
    peak_value=base_lr,
    warmup_steps=warmup_steps,
    decay_steps=steps_in_epoch * n_epochs,
    end_value=end_frac*base_lr
)

def create_train_state(model, rng):
    params = model.init(rng, sample_batch.X, training=True)['params']
    tx = optax.adamw(sched)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

steps_per_log = steps_in_epoch

@ft.partial(jax.jit, static_argnames='training')
def apply_model(state, batch: TrainBatch, training: bool):
    def loss_fn(params):
        yhat = state.apply_fn({'params': params}, batch.X, training=training)
        err = jnp.abs(jnp.squeeze(yhat) - batch.delta_e) * batch.mask
        return jnp.mean(err)
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    return grad, loss


def train_model(model: nn.Module):
    state = create_train_state(model, jr.key(np.random.randint(0, 1000)))
    print(model.__class__.__name__)
    epochs = []

    with tqdm(np.arange(n_epochs)) as bar:
        for epoch_i in bar:
            losses = []
            for batch in data_loader():
                grad, loss = apply_model(state, batch, training=True) 
                losses.append(loss)
                state = state.apply_gradients(grads=grad)

            train_loss = np.mean(losses)

            losses = []
            for batch in data_loader(split='valid'):
                grad, loss = apply_model(state, batch, training=False)            
                losses.append(loss)
                state = state.apply_gradients(grads=grad)

            valid_loss = np.mean(losses)
            epochs.append({'train': train_loss, 'valid': valid_loss})
            bar.set_description(f'Train: {train_loss:.03f}\tValid: {valid_loss:.03f}')

    return state, epochs

mlp_state, mlp_epochs = train_model(mlp)
kan_state, kan_epochs = train_model(kan)

LazyInMLP


  0%|          | 0/250 [00:00<?, ?it/s]

Train: 0.009	Valid: 0.012: 100%|██████████| 250/250 [04:24<00:00,  1.06s/it]


KAN


Train: 0.007	Valid: 0.010:  80%|████████  | 200/250 [04:11<01:01,  1.23s/it]

In [None]:
hist = []
for name, epochs in zip(('kan', 'mlp'), (kan_epochs, mlp_epochs)):
    epochs = pd.DataFrame(epochs)
    epochs['model'] = name    
    hist.append(epochs.reset_index().rename(columns={'index': 'epoch'}))

hist = pd.concat(hist).reset_index(drop=True)
hist = hist.melt(var_name='dataset', value_name='loss', id_vars=['epoch', 'model'])
hist

fig, ax = plt.subplots(figsize=(12, 5))
sns.lineplot(hist, x='epoch', y='loss', hue=hist[['model', 'dataset']].apply(' '.join, axis=1))
plt.ylim(0, hist.query('epoch > 50')['loss'].max())
rp.line_labels()

In [None]:
from avid.utils import debug_stat

coefs = kan_state.params['layers_0']['coef'].reshape(147, -1, 10)
plt.plot(coefs[-1, :6, :].T);

Some things to note from my first efforts training these KANs:
- I get better results adding LayerNorms in between the layers like you would in an MLP.
- I don't train the residual or spline scale factors after doing so, which reduces parameter count
  significantly and seems to have negligible impact on accuracy.
- I haven't implemented the learnable grid that the pykan repo has: it doesn't seem super important
  after normalization.
- The splines have a lot of hyperparameters to worry about, and I have to really sit down and write
  an accelerated version of B-splines so this runs faster. There are lots of other ways to
  parameterize flexible, smooth functions, maybe one of those would be better.
  
I'm quite impressed given that there's a lot of low-hanging fruit and it's already competitive with
MLPs, which have tons of work behind them.