# Discrete-KAN (Kolmogorov-Arnold Networks)
_Jan Pfeifer_, July 2024

The Kolmogorov–Arnold Networks paper [1] presented an alternative to classical neural networks by replacing a generic multi-variate function approximation (linear transformations followed by non-linearities in NN) by a sum of univariate functions, as depicted in the figure:

<div style="text-align: center;">
<img width="1163" alt="mlp_kan_compare" style="width:80%" src="https://github.com/KindXiaoming/pykan/assets/23551623/695adc2d-0d0b-4e4b-bcff-db2c8070f841"/>
</div>


The original work in [1] used a [bsplines (wikipedia)](https://en.wikipedia.org/wiki/B-spline) as generic univariate functions. But immediately after, there has been many proposals and work on using different univariate functions [2].

This notebook presents results with yet another class of univariate functions: **Piecewise Constant Functions** (**PCF** for short), that can also be described as "staircase functions". This work describes also the input perturbation trick (from [3]) used to create a differentiable `SoftPCF` to train these functions -- since their gradient with respect to the input is 0 or undefined, and plain backpropagation won't work.

- [1] [KAN: Kolmogorov-Arnold Networks (arxiv)](https://arxiv.org/abs/2404.19756) - _Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljačić, Thomas Y. Hou, Max Tegmark_, 2024
  - See also [github.com/KindXiaoming/pykan](https://github.com/KindXiaoming/pykan) with the paper's code, and links to many tutorials.
- [2] ["Awesome KAN" - github/mintisan/awesome-kan](https://github.com/mintisan/awesome-kan) links various derived work.
- [3] [Learning Representations for Axis-Aligned Decision Forests through Input Perturbation (arxiv)](https://arxiv.org/abs/2007.14761) -
  _Sebastian Bruch, Jan Pfeifer, Mathieu Guillame-Bert_, 2021


## Motivation

PCFs, once trained, **requires no multiplications for inference**. To calculate a **PCF** one only needs to to find a value in a table and reduce-sum those values -- which can be done using _ints_ (no floating point calculations needed).

This would replace multiplications by small table lookups.

Probably not a winning trade-off in modern CPUs, but, still interesting question.

## Source Code

The code used for this notebook is in [github.com/gomlx/gomlx/pkg/ml/layers/kan](https://github.com/gomlx/gomlx/tree/pjrt/ml/layers/kan) 

It is written [GoMLX](https://github.com/gomlx/gomlx), a machine-learning framework for [Go](https://go.dev/) using [OpenXLA](openxla.org), and the notebook uses [`gonb`](https://github.com/janpfeifer/gonb), a Jupyter notebook kernel for [Go](https://go.dev/).

For development, to allow the notebook to use a local copy of the library (as opposed to an official release from github):

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb" "${HOME}/Projects/gopjrt"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".


## Piecewise Constant Functions (PCF)

A piecewise constant function (PCF) is defined by a set of $N$ control points, $O = {o_0,...,o_{n-1}}$,
and a set of $N−1$ ordered split points, $S = {s_0, ..., s_{n-2}}$, such that $s_0 < s_1 < ... < s_{N-2}$.

The control points determine the output values of the function, and the split points determine where the function changes its output.

Notice that the first and last values of the PCF extend to infinity. The full definition of $PCF(x)$ can be given by:

$$
PCF(x) = 
\begin{cases}
o_0, & \text{if } x < s_0 \\\\
o_{i+1}, & \text{if } s_i \le x < s_{i+1} \text{ for } 0 \le i \le N-3 \\\\
o_{N-1}, & \text{if } x \ge s_{N-2} 
\end{cases}
$$

Consider the following examples:

#### `PCF1`: trivial 4 control points 

In [2]:
import (
    . "github.com/gomlx/gomlx/pkg/core/graph"
    "github.com/gomlx/gomlx/pkg/ml/layers/kan"
    dkan "github.com/gomlx/gomlx/examples/discretekan"
)

func PCF1(x *Node) *Node {
    g := x.Graph()
    controlPoints := Const(g, []float64{0.5, 1, 0, 0.2, 0, 1, 0.5})
    splitPoints := Const(g, []float64{0, 0.2, 0.4, 0.6, 0.8, 1})
    return kan.PiecewiseConstantFunction(x, controlPoints, splitPoints)
}

%%
dkan.Plot("PCF([0.5, 1, 0, 0.2, 0, 1, 0.5])", PCF1)

#### `PCFSin`: A discretized sine curve

Notice `dkan.Univariate` is defined as:

```type Univariate func(x *Node) *Node```

In [3]:
// PCFFromFunc convert the univariate function fn to a PCF discretized univariate version.
func PCFFromFunc(fn dkan.Univariate, numControlPoints int) dkan.Univariate {
    return func (x *Node) *Node {
        g := x.Graph()
        splitPoints := Iota(g, shapes.Make(dtypes.Float64, numControlPoints-1), 0)
        splitPoints = MulScalar(splitPoints, 1/float64(numControlPoints-2)) // = [0, 1/(N-2), 2/(N-2), ..., 1.0]
        // Take control points at the middle of the split-points.
        left := GrowLeft(splitPoints, 0, 1, 0.0) // Preprend a 0.0 to splitPoints.
        right := GrowRight(splitPoints, 0, 1, 1.0) // Append a 1.0 to splitPoints.
        controlX := MulScalar(Add(right, left), 0.5)  // Take the mean of left and right.
        controlPoints := fn(controlX)
        return kan.PiecewiseConstantFunction(x, controlPoints, splitPoints)
    }
}

func Sin2πx(x *Node) *Node {
    return Sin(MulScalar(x, 2.0*math.Pi))
}

var PCFSin = PCFFromFunc(Sin2πx, 10)

%%
dkan.Plot("PCF of Sin(2πx);PCFSin(x);Sin(2πx)", PCFSin, Sin2πx)


## Soft Piecewise Constant Functions by Input Perturbation

While we want to use PCF in our model so we can do inference without multiplications (see [Motivation](#motivation) section), they are not differentiable with respect to the inputs.

So we use a technique used in [3] (Section 4.2) with a controlled softening of the PCF, making differentiable and we are able to train PCFs with gradient descent. We can schedule the amount of smoothing during training to make it gradually converge back to a PCF proper.

The core idea is replace $x$ with a perturbed distribution $\mathcal{D}(x), \int_{-\inf}^{+\inf}{\mathcal{D}(x)} = 1$. [3] suggests a gaussian distribution, but we propose a triangular distribution since in practice it works as well (we experimented with it in [3]), and it is simpler to code and faster to run.

For an aritrary $\mathcal{D}(x)$, given a $PCF(x)$ we define:

$$
\begin{aligned}
\text{SoftPCF}(x)_{\mathcal{D}} &= \mathbb{E}_{x \sim \mathcal{D}(x)}[\text{PCF}(x)] \\
&= \int_{-\infty}^{\infty} \text{PCF}(x) \cdot \mathcal{D}(x) \, dx \\
&= \int_{-\infty}^{s_0} \text{PCF}(x) \cdot \mathcal{D}(x) \, dx + \sum_{i=0}^{N-2} \int_{s_i}^{s_{i+1}} \text{PCF}(x) \cdot \mathcal{D}(x) \, dx + \int_{s_{N-2}}^{\infty} \text{PCF}(x) \cdot \mathcal{D}(x) \, dx \\
&= o_0 \int_{-\infty}^{s_0}  \mathcal{D}(x) \, dx + \sum_{i=0}^{N-2} o_{i+1} \int_{s_i}^{s_{i+1}} \mathcal{D}(x) \, dx + o_{N-1} \int_{s_{N-2}}^{\infty} \mathcal{D}(x) \, dx \\
&= o_0 \cdot p_{-1} + \sum_{i=0}^{N-2} o_{i+1} \cdot p_i + o_{N-1} \cdot p_{N-1}
\end{aligned}
$$  

Now let's expand this assuming $\mathcal{D}(x)$ is a triangular distribution with a base of length $L$ (and height $2/L$). We can define the probability density function (PDF) centered in 0 as:

$$
f(x) = 
\begin{cases}
0, & x < -\frac{L}{2} \\
\frac{2}{L}\left(1 + \frac{2x}{L}\right), & -\frac{L}{2} \le x < 0 \\
\frac{2}{L}\left(1 - \frac{2x}{L}\right), & 0 \le x \le \frac{L}{2} \\
0, & x > \frac{L}{2}
\end{cases}
$$

And the cumulative distribution function (CDF):

$$
F(x) = 
\begin{cases}
0, & x < -\frac{L}{2} \\
\frac{1}{2}\left(1 + \frac{2x}{L}\right)^2, & -\frac{L}{2} \le x < 0 \\
1 - \frac{1}{2}\left(1 - \frac{2x}{L}\right)^2, & 0 \le x \le \frac{L}{2} \\
1, & x > \frac{L}{2}
\end{cases}
$$

With that we can derive further our $SoftPCF$ function:

$$
\begin{aligned}
\text{SoftPCF}(x)_{\mathcal{D}} = & o_0 \cdot [F(s_0) - F(-\infty)]  \\
&+ \sum_{i: i < N-2} o_{i+1} \cdot [F(s_{i+1}) - F(s_i)] \\
&+ o_{N-1} \cdot [F(\infty) - F(s_{N-2})] 
\end{aligned}
$$

As an extra benefit from this formulation, $\text{SoftPCF}(x)$ is also differentiable with respect to the split points $S = {s_0, ..., s_{n-2}}$, meaning they can be let loose and trained along with the model (with some proper regularization to keep them monotonic). 

Some examples follow.

- [3] [Learning Representations for Axis-Aligned Decision Forests through Input Perturbation (arxiv)](https://arxiv.org/abs/2007.14761) -
  _Sebastian Bruch, Jan Pfeifer, Mathieu Guillame-Bert_, 2021


#### `SoftPCF1`

Same example function as `PCF1` but smoothed at different levels:

In [4]:
func BuildSoftPCF1(softness float64) dkan.Univariate {
    return func(x *Node) *Node {
        g := x.Graph()
        controlPoints := Const(g, []float64{0.5, 1, 0, 0.2, 0, 1, 0.5})
        splitPoints := Const(g, []float64{0, 0.2, 0.4, 0.6, 0.8, 1})
        s := Const(g, softness)
        return kan.PiecewiseConstantFunctionWithInputPerturbation(x, controlPoints, splitPoints, kan.PerturbationNormal, s)
    }
}

%%
dkan.Plot("SoftPCF1 with various softness-values;PCF1;SoftPCF1(0.05);SoftPCF1(0.10);SoftPCF1(0.20)", 
          PCF1, BuildSoftPCF1(0.05), BuildSoftPCF1(0.10), BuildSoftPCF1(0.20))


#### `SoftPCFSin`

In this example we soften the original PCFSin curve.

In [5]:
func SoftPCFFromFunc(fn dkan.Univariate, numControlPoints int, softness float64) dkan.Univariate {
    return func (x *Node) *Node {
        g := x.Graph()
        splitPoints := Iota(g, shapes.Make(dtypes.Float64, numControlPoints-1), 0)
        splitPoints = MulScalar(splitPoints, 1/float64(numControlPoints-2)) // = [0, 1/(N-2), 2/(N-2), ..., 1.0]
        // Take control points at the middle of the split-points.
        left := GrowLeft(splitPoints, 0, 1, 0.0) // Preprend a 0.0 to splitPoints.
        right := GrowRight(splitPoints, 0, 1, 1.0) // Append a 1.0 to splitPoints.
        controlX := MulScalar(Add(right, left), 0.5)  // Take the mean of left and right.
        controlPoints := fn(controlX)
        softnessConst := Const(g, softness)
        return kan.PiecewiseConstantFunctionWithInputPerturbation(x, controlPoints, splitPoints, kan.PerturbationNormal, softnessConst)
    }
}

%%
dkan.Plot("SoftPCF of Sin(2πx);PCFSin(x);SoftPCFSin(2πx, 0.03);SoftPCFSin(2πx, 0.1);SoftPCFSin(2πx, 0.2)", 
          PCFSin, 
          SoftPCFFromFunc(Sin2πx, 10, 0.03), 
          SoftPCFFromFunc(Sin2πx, 10, 0.1),
          SoftPCFFromFunc(Sin2πx, 10, 0.2),
)

## Experimental Results

### Training UCI-Adult with DiscretedKAN

It works slower than a normal FNN or KAN, but it reaches approximately the same results. See code in [GoMLX's examples/adult/demo/main.go](https://github.com/gomlx/gomlx/blob/pjrt/examples/adult/demo/main.go):

```shell
$ go run . -set 'kan=true;kan_discrete=true;kan_num_points=10;kan_num_hidden_layers=3;kan_num_hidden_nodes=16;l2_regularization=1e-3;l1_regularization=0;activation=relu;plots=false;train_steps=20000;kan_discrete_softness=0.1'
Training (20000 steps):  100% [========================================] (521 steps/s)
        ╭────────────────────────────────────┬───────────────╮
        │                        Global Step │ 19999 / 20000 │
        │          Batch Loss+Regularization │ 0.215         │
        │ Moving Average Loss+Regularization │ 0.298         │
        │                Moving Average Loss │ 0.281         │
        │            Moving Average Accuracy │ 86.99%        │
        ╰────────────────────────────────────┴───────────────╯

        [Step 20000] median train step: 1698 microseconds

Results on batched train:
        Mean Loss+Regularization (#loss+): 0.294
        Mean Loss (#loss): 0.278
        Mean Accuracy (#acc): 87.12%
Results on test:
        Mean Loss+Regularization (#loss+): 0.314
        Mean Loss (#loss): 0.298
        Mean Accuracy (#acc): 86.95%
```
