In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, VBox, interactive

# B-Spline Basis Functions

The first visualization shows the B-spline basis functions of a given order $k$. These basis functions form the building blocks of B-spline curves. By adjusting the sliders, you can change the number of basis functions $n$ and their order $k$. The knot vector is automatically generated based on the number of basis functions and their order.

In [2]:
def bspline_basis(t, k, i, x):
    if k == 0:
        return 1.0 if t[i] <= x < t[i+1] else 0.0
    else:
        coeff1 = (x - t[i]) / (t[i+k] - t[i]) if t[i+k] != t[i] else 0
        coeff2 = (t[i+k+1] - x) / (t[i+k+1] - t[i+1]) if t[i+k+1] != t[i+1] else 0
        return coeff1 * bspline_basis(t, k-1, i, x) + coeff2 * bspline_basis(t, k-1, i+1, x)

In [3]:
def plot_bspline_basis(n, k):
    t = np.linspace(0, 1, n + k + 1)  # Generate knot vector
    x = np.linspace(0, 1, 1000)       # Points to evaluate the basis functions
    plt.figure(figsize=(10, 6))
    for i in range(n):
        y = [bspline_basis(t, k, i, xi) for xi in x]
        plt.plot(x, y, label=f'B_{i},{k}(x)')
    plt.title(f'B-spline Basis Functions of order {k}')
    plt.xlabel('x')
    plt.ylabel('B(x)')
    plt.legend()
    plt.grid(True)
    plt.show()

interact(plot_bspline_basis, n=IntSlider(min=1, max=10, step=1, value=5), k=IntSlider(min=0, max=5, step=1, value=2))

interactive(children=(IntSlider(value=5, description='n', max=10, min=1), IntSlider(value=2, description='k', …

<function __main__.plot_bspline_basis(n, k)>

# B-Spline Basis Functions used in KANs

The second visualization displays the B-spline basis functions with specified grid intervals $G$ and spline order $k$, by utilizing the `B_batch` function provided by the [pykan](https://github.com/KindXiaoming/pykan) package.
This visualization allows you to see how the basis functions are structured in a KAN grid, spanning the whole domain.

In [4]:
from kan.spline import B_batch
import torch

def plot_bspline(G, k):
    grid = torch.linspace(-1, 1, steps=G + 1)[None, :]
    x = torch.linspace(-1, 1, steps=1001)[None, :]
    basis = B_batch(x, grid, k=k)
    
    plt.figure(figsize=(10, 6))
    for i in range(G + k):
        plt.plot(x[0].detach().numpy(), basis[0, i, :].detach().numpy())
        
    plt.legend(['B_{}(x)'.format(i) for i in np.arange(G + k)])
    plt.xlabel('x')
    plt.xticks(grid[0].detach().numpy())
    plt.ylabel('B_i(x)')
    plt.title(f'B-spline basis functions with G={G}, k={k}')
    plt.grid(True)
    plt.show()

interact(plot_bspline, G=IntSlider(min=1, max=10, step=1, value=3, description='G'), 
         k=IntSlider(min=0, max=5, step=1, value=2, description='k'));

interactive(children=(IntSlider(value=3, description='G', max=10, min=1), IntSlider(value=2, description='k', …

# Combined B-Spline

The third visualization shows the combined B-spline formed by the weighted sum of the basis functions. You can interactively adjust the coefficients of each basis function to see how they influence the combined B-spline. The basis functions are plotted along with the combined B-spline, scaled by their respective coefficients.

In [8]:
def plot_combined_spline(G, k, **coeffs):
    grid = torch.linspace(-1, 1, steps=G + 1)[None, :]
    x = torch.linspace(-1, 1, steps=1001)[None, :]
    basis = B_batch(x, grid, k=k)

    coeff_values = [coeffs[f'Coeff {i}'] for i in range(G + k)]
    scaled_basis = [coeff * basis[0, i, :] for i, coeff in enumerate(coeff_values)]
    combined_spline = sum(scaled_basis)

    plt.figure(figsize=(10, 6))
    for i in range(G + k):
        plt.plot(x[0].detach().numpy(), scaled_basis[i].detach().numpy(), alpha=0.3)
        
    plt.plot(x[0].detach().numpy(), combined_spline.detach().numpy(), 'k', linewidth=3, label='Combined Spline')
    plt.legend([f'{coeff_values[i]:.1f} * B_{i}(x)' for i in np.arange(G + k)] + ['Combined Spline'])
    plt.xlabel('x')
    plt.xticks(grid[0].detach().numpy())
    plt.title(f'B-spline basis functions and combined spline with G={G}, k={k}')
    plt.grid(True)
    plt.show()

def update_coeffs_sliders(change):
    display_interactive_plot()

def display_interactive_plot():
    G = G_slider.value
    k = k_slider.value
    num_basis = G + k

    coeffs = {f'Coeff {i}': FloatSlider(min=-2.0, max=2.0, step=0.1, value=1.0) for i in range(num_basis)}

    out = interactive(plot_combined_spline, G=G_slider, k=k_slider, **coeffs)    
    container.children = [out]

container = VBox()

G_slider = IntSlider(min=1, max=10, step=1, value=3, description='G')
k_slider = IntSlider(min=0, max=5, step=1, value=2, description='k')

G_slider.observe(update_coeffs_sliders, names='value')
k_slider.observe(update_coeffs_sliders, names='value')

display(container)
display_interactive_plot()

VBox()