# XAI CODE DEMO

## Explainable AI Specialization on Coursera

If you experience high latency while running this notebook, you can open it in Google Colab:

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/explainable-machine-learning/interpretable-ml/blob/main/kan_interpretability.ipynb)

# Kolmogorov-Arnold Network (KAN)

* Paper: [Liu,et.al., 2024](https://arxiv.org/pdf/2404.19756)
* KANs have no linear weights at all – every weight parameter is replaced by a univariate function parametrized as a spline
* KANs can be intuitively visualized and can easily interact with human users

#### Training a KAN:
1. Randomly initialize parameters of B-splines for each function in each layer
2. Forward pass X through the network
3. Calculate loss wrt ground truth
4. Backpropagation
5. Update B-spline parameters
6. Repeat

#### Code Demo
The original paper authors released [pykan](https://kindxiaoming.github.io/pykan/index.html), a python library for KANs. This demo is based off of the documentation provided in the library, specifically ["Getting Started with KANs"](https://kindxiaoming.github.io/pykan/intro.html#get-started-with-kans)



---






In [None]:
from kan import *
import torch

#### Initialize a KAN
* width = [2, 3, 1] - 2D input, 1D output, 5 hidden neurons
* k=3 - use cubic splines
* grid=5 - use 5 grid intervals


In [None]:
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

#### We are going to create a dataset for this code demonstration.

The output values (labels) are computed based on the function **f(x,y)**, providing a target for the KAN model to learn during training.

**f(x,y) = exp(sin(pi*x)+y^2)**

In [None]:
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

#### Plot KAN at initialization.
This is what our KAN looks like before training. Note the 2D input (bottom), our 5 hidden neurons (middle), and our 1D output (top)

In [None]:
model(dataset['train_input']);
model.plot(beta=100)

#### Train KAN

Uses the LBFGS optimizer for 20 steps with sparsity regularization


In [None]:
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);

In [None]:
model.plot()

#### Prune KAN and replot (keep original shape)

In [None]:
model.prune()
model.plot()

#### Prune KAN and replot (get a smaller shape)

In [None]:
model = model.prune()
model(dataset['train_input'])
model.plot()

#### Continue training
Remember that KAN allows us to do continual training!

*Question: How have the splines changed after further training?*

In [None]:
model.fit(dataset, opt="LBFGS", steps=50);

In [None]:
model.plot()

#### Set activation functions to be symbolic

We can either do this manually or automatically using the pykan library

*Try it yourself: try manual mode. How do the outputs differ from using the automated tooling?*

In [None]:
mode = "auto" # "manual"

if mode == "manual":
    # manual mode
    model.fix_symbolic(0,0,0,'sin');
    model.fix_symbolic(0,1,0,'x^2');
    model.fix_symbolic(1,0,0,'exp');
elif mode == "auto":
    # automatic mode
    lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
    model.auto_symbolic(lib=lib)

#### Keep training!

In [None]:
model.fit(dataset, opt="LBFGS", steps=50);

#### Get the symbolic formula

In [None]:
model.symbolic_formula()[0][0]