# Hello, KAN!

### Kolmogorov-Arnold representation theorem

Kolmogorov-Arnold representation theorem states that if $f$ is a multivariate continuous function
on a bounded domain, then it can be written as a finite composition of continuous functions of a
single variable and the binary operation of addition. More specifically, for a smooth $f : [0,1]^n \to \mathbb{R}$,


$$f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))$$

where $\phi_{q,p}:[0,1]\to\mathbb{R}$ and $\Phi_q:\mathbb{R}\to\mathbb{R}$. In a sense, they showed that the only true multivariate function is addition, since every other function can be written using univariate functions and sum. However, this 2-Layer width-$(2n+1)$ Kolmogorov-Arnold representation may not be smooth due to its limited expressive power. We augment its expressive power by generalizing it to arbitrary depths and widths.

### Kolmogorov-Arnold Network (KAN)

The Kolmogorov-Arnold representation can be written in matrix form

$$f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}$$

where 

$${\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n}(\cdot) \\ \vdots & & \vdots \\ \phi_{2n+1,1}(\cdot) & \cdots & \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) & \cdots & \Phi_{2n+1}(\cdot)\end{pmatrix}$$

We notice that both ${\bf \Phi}_{\rm in}$ and ${\bf \Phi}_{\rm out}$ are special cases of the following function matrix ${\bf \Phi}$ (with $n_{\rm in}$ inputs, and $n_{\rm out}$ outputs), we call a Kolmogorov-Arnold layer:

$${\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n_{\rm in}}(\cdot) \\ \vdots & & \vdots \\ \phi_{n_{\rm out},1}(\cdot) & \cdots & \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}$$

${\bf \Phi}_{\rm in}$ corresponds to $n_{\rm in}=n, n_{\rm out}=2n+1$, and ${\bf \Phi}_{\rm out}$ corresponds to $n_{\rm in}=2n+1, n_{\rm out}=1$.

After defining the layer, we can construct a Kolmogorov-Arnold network simply by stacking layers! Let's say we have $L$ layers, with the $l^{\rm th}$ layer ${\bf \Phi}_l$ have shape $(n_{l+1}, n_{l})$. Then the whole network is

$${\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}$$

In constrast, a Multi-Layer Perceptron is interleaved by linear layers ${\bf W}_l$ and nonlinearities $\sigma$:

$${\rm MLP}({\bf x})={\bf W}_{L-1}\circ\sigma\circ\cdots\circ {\bf W}_1\circ\sigma\circ {\bf W}_0\circ {\bf x}$$

A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers. (2) Each KAN layer can be visualized as a fully-connected layer, with a 1D function placed on each edge. Let's see an example below.

### Get started with KANs

Initialize KAN

In [None]:
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

Create dataset

In [None]:
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

Plot KAN at initialization

In [None]:
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)

In [None]:
# params save stem
dir_stem = 'params'
os.makedirs(dir_stem, exist_ok=True)

In [None]:
# Save the initial parameters
theta0 = model.state_dict().copy()
torch.save(theta0, os.path.join(dir_stem, 'theta0.pth'))
print('Saved initial parameters to', os.path.join(dir_stem, 'theta0.pth'))

Train KAN with sparsity regularization

In [None]:
# train the model
model.train_model(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);

Plot trained KAN

In [None]:
model.plot()

In [None]:
# Save the initial parameters
thetaN = model.state_dict().copy()
torch.save(thetaN, os.path.join(dir_stem, 'thetaN.pth'))
print('Saved final parameters to', os.path.join(dir_stem, 'thetaN.pth'))

Prune KAN and replot (keep the original shape)

In [None]:
model.prune()
model.plot(mask=True)

Prune KAN and replot (get a smaller shape)

In [None]:
model = model.prune()
model(dataset['train_input'])
model.plot()

Continue training and replot

In [None]:
model.train_model(dataset, opt="LBFGS", steps=50);

In [None]:
model.plot()

Automatically or manually set activation functions to be symbolic

In [None]:
mode = "auto" # "manual"

if mode == "manual":
    # manual mode
    model.fix_symbolic(0,0,0,'sin');
    model.fix_symbolic(0,1,0,'x^2');
    model.fix_symbolic(1,0,0,'exp');
elif mode == "auto":
    # automatic mode
    lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
    model.auto_symbolic(lib=lib)

Continue training to almost machine precision

In [None]:
model.train_model(dataset, opt="LBFGS", steps=50);

Obtain the symbolic formula

In [None]:
model.symbolic_formula()[0][0]

## Is KAN convex? 
No... :(

In [None]:
from kan import *

In [None]:
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

In [None]:
# Load initial and final parameters
dir_stem = 'params'
theta0 = torch.load(os.path.join(dir_stem, 'theta0.pth'))
thetaN = torch.load(os.path.join(dir_stem, 'thetaN.pth'))

# Linearly interpolate/extrapolate the parameters between theta0 and thetaN
alphas = np.linspace(-1.5, 2.0, 50)
thetas = [{name: (1-alpha)*theta0[name] + alpha*thetaN[name] for name in theta0} for alpha in alphas]

In [None]:
# Set up the model and define the loss
model = KAN(width=[2,5,1], grid=5, k=3)
criterion = nn.MSELoss()

In [None]:
n_datasets = 10

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
for i in range(n_datasets):
    # Create a new dataset
    dataset = create_dataset(f, n_var=2, seed=i)
    data = dataset['train_input']
    targets = dataset['train_label']

    # Compute the loss for each theta
    losses = []
    with torch.no_grad():  # No need to compute gradients for evaluation
        for theta in thetas:
            model.load_state_dict(theta)
            outputs = model(data)
            loss = criterion(outputs, targets)
            losses.append(loss.item())

    # Plot the loss as a function of alpha
    ax.plot(alphas, losses, label=f'Dataset {i}')

# add legend for each line
ax.legend()

plt.xlabel('alpha')
plt.ylabel('MSE loss')
plt.title('Interpolation between theta0 and thetaN')
plt.show()

### Spectrum Density

In [None]:
from kan import *
from pyhessian import hessian, get_esd_plot # Hessian computation
from matplotlib.backends.backend_pdf import PdfPages

In [None]:
# Set up the model and define the loss
model = KAN(width=[2,5,1], grid=5, k=3)
criterion = nn.MSELoss()

In [None]:
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)

In [None]:
# Dataset
inputs = dataset['train_input']
targets = dataset['train_label']

In [None]:
pdf = PdfPages('figures/spec_density/kan.pdf')
for i, theta in enumerate(thetas):
    model.load_state_dict(theta)
    hessian_comp = hessian(model, criterion, data=(inputs,targets), cuda=False) 
    density_eigen, density_weight = hessian_comp.density()
    fig, ax = get_esd_plot(density_eigen, density_weight)
    ax.set_title(f'Spectrum of Hessian for alpha={alphas[i]:.2f}')
    pdf.savefig(fig, bbox_inches = 'tight')
    plt.close(fig)
pdf.close()        

### Eigvalue

In [None]:
hessian_comp = hessian(model, criterion, data=(inputs,targets), cuda=False)

In [None]:
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
print("The top Hessian eigenvalue of this model is %.4f"%top_eigenvalues[-1])

In [None]:
# Now let's compute the top 2 eigenavlues and eigenvectors of the Hessian
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=5)
print("The top two eigenvalues of this model are: %.4f %.4f"% (top_eigenvalues[-1],top_eigenvalues[-2]))

In [None]:
trace = hessian_comp.trace()
print("The trace of this model is: %.4f"%(np.mean(trace)))

### MLP

In [1]:
from models.mlp import MLP
from utils import get_final_dataset
import torch
import numpy as np

In [2]:
data_folder = './data'
id = '41166'

In [3]:
train_dataset, test_dataset, input_size, y_train = get_final_dataset(data_folder, id)

n_classes = len(np.unique(y_train))

output_size = n_classes
hidden_layers = [512] * 1

model = MLP(input_size, output_size, hidden_layers)
loss_fn = torch.nn.CrossEntropyLoss()

print("Input size: ", input_size)
print("Output size: ", output_size)
print("Hidden layers: ", hidden_layers)

Input size:  147
Output size:  10
Hidden layers:  [512]


In [4]:
for param in model.named_parameters():
    print(param)

('hidden_layers.0.weight', Parameter containing:
tensor([[-0.0625, -0.0451, -0.0572,  ...,  0.0195, -0.0308,  0.0798],
        [ 0.0273,  0.0594, -0.0430,  ..., -0.0543, -0.0601, -0.0516],
        [ 0.0570, -0.0605, -0.0404,  ..., -0.0569,  0.0216, -0.0678],
        ...,
        [ 0.0643, -0.0566, -0.0621,  ..., -0.0296,  0.0460, -0.0806],
        [ 0.0040,  0.0394, -0.0167,  ...,  0.0114,  0.0195,  0.0399],
        [ 0.0396,  0.0744, -0.0017,  ...,  0.0508, -0.0338, -0.0091]],
       requires_grad=True))
('hidden_layers.0.bias', Parameter containing:
tensor([-5.6769e-02, -1.2258e-02, -7.9065e-02, -4.4657e-02,  1.6983e-02,
        -1.4582e-02, -1.2332e-02,  6.4829e-03,  5.6622e-02, -2.0920e-02,
        -1.7363e-02,  6.1459e-02,  4.1294e-02,  6.0514e-02,  5.2103e-02,
        -2.3026e-02, -7.2150e-02,  2.1650e-02, -7.5682e-02, -7.3647e-02,
        -4.6706e-02, -5.5821e-02,  5.0276e-02,  7.3004e-02, -6.4359e-02,
         5.6887e-02, -3.2981e-02,  1.6293e-02,  1.6006e-02, -9.9190e-03,
    

In [5]:
from SFN import SFN
opt = SFN(model.parameters(), lr=0.01)

In [6]:
i = 0
for group in opt.param_groups:
    print(type(group['params']))
    for p in group['params']:
        if i == 0:
            print(p)
            print('------------------')
            print(p.grad)
            print('------------------')
        i += 1


<class 'list'>
Parameter containing:
tensor([[-0.0625, -0.0451, -0.0572,  ...,  0.0195, -0.0308,  0.0798],
        [ 0.0273,  0.0594, -0.0430,  ..., -0.0543, -0.0601, -0.0516],
        [ 0.0570, -0.0605, -0.0404,  ..., -0.0569,  0.0216, -0.0678],
        ...,
        [ 0.0643, -0.0566, -0.0621,  ..., -0.0296,  0.0460, -0.0806],
        [ 0.0040,  0.0394, -0.0167,  ...,  0.0114,  0.0195,  0.0399],
        [ 0.0396,  0.0744, -0.0017,  ...,  0.0508, -0.0338, -0.0091]],
       requires_grad=True)
------------------
None
------------------


In [7]:
torch.cat([p.view(-1) for group in opt.param_groups for p in group['params']])

tensor([-0.0625, -0.0451, -0.0572,  ..., -0.0327, -0.0039, -0.0164],
       grad_fn=<CatBackward0>)

In [8]:
from torch.utils.data import DataLoader

In [9]:
n_train = len(train_dataset)
bh = int(n_train ** 0.5)
bsz = 128

train_dataloader = DataLoader(train_dataset, batch_size=bsz, shuffle=True)
train_dataloader_hess = DataLoader(train_dataset, batch_size=bh, shuffle=True) # Used for updating the preconditioner
train_dataloader2 = DataLoader(train_dataset, batch_size=4096, shuffle=False) # Used for computing metrics on the training set
test_dataloader = DataLoader(test_dataset, batch_size=4096, shuffle=False)

In [10]:
for x, y in train_dataloader:
    opt.zero_grad()
    y_hat = model(x)
    loss = loss_fn(y_hat, y)
    break

# Backward pass to compute gradients
loss.backward()
# grad_tuple = torch.autograd.grad(loss, model.parameters(), create_graph=True)

# Extract gradients
params = []
for group in opt.param_groups:
    for param in group['params']:
        params.append(param)



In [11]:
for x_h, y_h in train_dataloader_hess:
    y_h_hat = model(x_h)
    l_h = loss_fn(y_h_hat, y_h)
    break
grad_tuple = torch.autograd.grad(l_h, model.parameters(), create_graph=True)

In [12]:
grad_params = []
for gradient in grad_tuple:
    if gradient is not None:
        grad_params.append(gradient)
# grad_params = torch.cat([gradient.view(-1) for gradient in grad_tuple if gradient is not None])

In [13]:
v = [
    torch.randn_like(p)
    for p in grad_params
]
# augment the vector with a scalar
v.append(torch.randint(2, (1,)))

### _fvp

In [14]:
from pyhessian.utils import group_product, group_add, normalization, get_params_grad, orthnormal
from opt_utils import group_scalar

In [15]:
# Compute Hessian-vector product
Hv = opt._hvp(grad_params, params, v[:-1])

# Multiply the last element of v with grad_params
tg = group_scalar(grad_params, v[-1])
tg = [tgi.detach() for tgi in tg]

# Compute gTv
gTv = group_product(grad_params, v[:-1]).detach()

output = group_add(Hv, tg)
output.append(gTv - opt.delta * v[-1])

print(type(output))

<class 'list'>


### appx_min_eigvec

In [16]:
params = []
for group in opt.param_groups:
    for param in group['params']:
        params.append(param)

gradsH = []
for gradient in grad_tuple:
    if gradient is not None:
        gradsH.append(gradient)

device = params[0].device

In [17]:
a = torch.randint(2, (1,))
a.shape

torch.Size([1])

In [18]:
v = [
    torch.randint_like(p, high=2, device=device)
    for p in params
]
# generate Rademacher random variables
for v_i in v:
    v_i[v_i == 0] = -1
# augment the vector with a scalar
v.append(torch.randint(2, (1,), device=device))

In [19]:
w = [v_i.reshape(-1) for v_i in v]
w = torch.cat(w)
w = w/torch.norm(w)

In [20]:
v = normalization(v)

In [21]:
vp = group_scalar(v, torch.tensor([10.0]))
print(vp[2])    

tensor([[ 0.0352,  0.0352,  0.0352,  ...,  0.0352, -0.0352, -0.0352],
        [ 0.0352,  0.0352,  0.0352,  ...,  0.0352, -0.0352, -0.0352],
        [ 0.0352, -0.0352,  0.0352,  ...,  0.0352, -0.0352,  0.0352],
        ...,
        [-0.0352,  0.0352,  0.0352,  ..., -0.0352,  0.0352,  0.0352],
        [ 0.0352, -0.0352,  0.0352,  ...,  0.0352,  0.0352, -0.0352],
        [ 0.0352,  0.0352, -0.0352,  ..., -0.0352, -0.0352, -0.0352]])


In [22]:
w_prime = [torch.zeros(p.size()).to(device) for p in params]
w_prime.append(torch.zeros(1).to(device))   # add a scalar

In [23]:
w_prime = opt._fvp(gradsH, params, v)
print(w_prime[-1])

tensor([0.0137])


In [24]:
iter = 100

In [31]:
opt.verbose = True
min_val, vec = opt.appx_min_eigvec(gradsH, params, iter=100)
print("Min eigvec:", vec)
print("Min eigvec shape:", vec.shape)

Approximate minimum eigenvalue = -1.946365237236023
Min eigvec: tensor([-2.0281e-05,  3.2317e-06,  1.9437e-05,  ...,  1.2507e-03,
        -6.6984e-04, -6.7602e-02])
Min eigvec shape: torch.Size([80907])


In [29]:
# standard Lanczos algorithm initialization
v_list = [v]
w_list = []
alpha_list = []
beta_list = []
############### Lanczos
for i in range(iter):
    opt.zero_grad()
    Fv = [torch.zeros(p.size()).to(device) for p in params]
    Fv.append(torch.zeros(1).to(device))   # add a scalar
    if i == 0:
        Fv = opt._fvp(gradsH, params, v)
        alpha = group_product(Fv, v)
        alpha_list.append(alpha.cpu().item())
        w = group_add(Fv, v, alpha=-alpha)
        w_list.append(w)
    else:
        beta = torch.sqrt(group_product(w, w))
        beta_list.append(beta.cpu().item())
        if beta_list[-1] != 0.:
            # We should re-orth it
            v = orthnormal(w, v_list)
            v_list.append(v)
        else:
            # generate a new vector
            w = [torch.randn(p.size()).to(device) for p in params]
            w.append(torch.randn(1).to(device))
            v = orthnormal(w, v_list)
            v_list.append(v)
        Fv = opt._fvp(gradsH, params, v)
        alpha = group_product(Fv, v)
        alpha_list.append(alpha.cpu().item())
        w_tmp = group_add(Fv, v, alpha=-alpha)
        w = group_add(w_tmp, v_list[-2], alpha=-beta)

T = torch.zeros(iter, iter).to(device)
for i in range(len(alpha_list)):
    T[i, i] = alpha_list[i]
    if i < len(alpha_list) - 1:
        T[i + 1, i] = beta_list[i]
        T[i, i + 1] = beta_list[i]
eigvals, eigvecs_T = torch.linalg.eigh(T)
V = torch.stack([torch.cat([v_i.reshape(-1) for v_i in v]) for v in v_list])
min_eigvec = torch.mv(V.t(), eigvecs_T[:, 0])

In [30]:
print("Min eigval:", eigvals[0])
print("Min eigvec:",min_eigvec)
print("Min eigvec shape:", min_eigvec.shape)

Min eigval: tensor(-1.9462)
Min eigvec: tensor([-1.1366e-05,  9.1092e-06,  9.8557e-06,  ...,  1.3235e-03,
        -7.0513e-04, -7.2199e-02])
Min eigvec shape: torch.Size([80907])


In [53]:
from pyhessian import hessian, get_esd_plot # Hessian computation
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

  matplotlib.use('TkAgg')


In [36]:
for x, y in train_dataloader:
    hessian_comp = hessian(model, loss_fn, data=(x,y), cuda=False)
    break

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [38]:
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
print("The top Hessian eigenvalue of this model is %.4f"%top_eigenvalues[-1])

The top Hessian eigenvalue of this model is 4.9525


In [54]:
pdf = PdfPages('figures/spec_density/mlp.pdf')
density_eigen, density_weight = hessian_comp.density()
fig, ax = get_esd_plot(density_eigen, density_weight)
ax.set_title(f'Spectrum of Hessian')
pdf.savefig(fig, bbox_inches = 'tight')
plt.close(fig)
pdf.close()  

  density_output[i, j] = np.sum(tmp_result * weights[i, :])
