In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, VBox, interactive

In [2]:
def bspline_basis(t, k, i, x):
    if k == 0:
        return 1.0 if t[i] <= x < t[i+1] else 0.0
    else:
        coeff1 = (x - t[i]) / (t[i+k] - t[i]) if t[i+k] != t[i] else 0
        coeff2 = (t[i+k+1] - x) / (t[i+k+1] - t[i+1]) if t[i+k+1] != t[i+1] else 0
        return coeff1 * bspline_basis(t, k-1, i, x) + coeff2 * bspline_basis(t, k-1, i+1, x)

In [3]:
def plot_bspline_basis(n, k):
    t = np.linspace(0, 1, n + k + 1)  # Generate knot vector
    x = np.linspace(0, 1, 1000)       # Points to evaluate the basis functions
    plt.figure(figsize=(10, 6))
    for i in range(n):
        y = [bspline_basis(t, k, i, xi) for xi in x]
        plt.plot(x, y, label=f'B_{i},{k}(x)')
    plt.title(f'B-spline Basis Functions of order {k}')
    plt.xlabel('x')
    plt.ylabel('B(x)')
    plt.legend()
    plt.grid(True)
    plt.show()

interact(plot_bspline_basis, n=IntSlider(min=1, max=10, step=1, value=5), k=IntSlider(min=0, max=5, step=1, value=2))

interactive(children=(IntSlider(value=5, description='n', max=10, min=1), IntSlider(value=2, description='k', …

<function __main__.plot_bspline_basis(n, k)>

In [4]:
def plot_bspline_basis_extended(n, k):
    t = np.concatenate(([0]*(k+1), np.linspace(0, 1, n - k + 1), [1]*(k+1)))  # Extended knot vector
    x = np.linspace(0, 1, 1000)       # Points to evaluate the basis functions
    plt.figure(figsize=(10, 6))
    for i in range(n):
        y = [bspline_basis(t, k, i, xi) for xi in x]
        plt.plot(x, y, label=f'B_{i},{k}(x)')
    plt.title(f'Extended B-spline Basis Functions of order {k}')
    plt.xlabel('x')
    plt.ylabel('B(x)')
    plt.legend()
    plt.grid(True)
    plt.show()

interact(plot_bspline_basis_extended, n=IntSlider(min=1, max=10, step=1, value=5), k=IntSlider(min=0, max=5, step=1, value=2))

interactive(children=(IntSlider(value=5, description='n', max=10, min=1), IntSlider(value=2, description='k', …

<function __main__.plot_bspline_basis_extended(n, k)>

In [5]:
from kan.spline import B_batch
import torch

def plot_bspline(G, k):
    grid = torch.linspace(-1, 1, steps=G + 1)[None, :]
    x = torch.linspace(-1, 1, steps=1001)[None, :]
    basis = B_batch(x, grid, k=k)
    
    plt.figure(figsize=(10, 6))
    for i in range(G + k):
        plt.plot(x[0].detach().numpy(), basis[0, i, :].detach().numpy())
        
    plt.legend(['B_{}(x)'.format(i) for i in np.arange(G + k)])
    plt.xlabel('x')
    plt.xticks(grid[0].detach().numpy())
    plt.ylabel('B_i(x)')
    plt.title(f'B-spline basis functions with G={G}, k={k}')
    plt.grid(True)
    plt.show()

interact(plot_bspline, G=IntSlider(min=1, max=10, step=1, value=3, description='G'), 
         k=IntSlider(min=0, max=5, step=1, value=2, description='k'));

interactive(children=(IntSlider(value=3, description='G', max=10, min=1), IntSlider(value=2, description='k', …

# Combined B-Spline

In [6]:
def plot_combined_spline(G, k, **coeffs):
    grid = torch.linspace(-1, 1, steps=G + 1)[None, :]
    x = torch.linspace(-1, 1, steps=1001)[None, :]
    basis = B_batch(x, grid, k=k)

    coeff_values = [coeffs[f'coeff_{i}'] for i in range(G + k)]
    scaled_basis = list(coeff * basis[0, i, :] for i, coeff in enumerate(coeff_values))
    combined_spline = sum(scaled_basis)

    plt.figure(figsize=(10, 6))
    for i in range(G + k):
        plt.plot(x[0].detach().numpy(), scaled_basis[i].detach().numpy(), alpha=0.5)
        
    plt.plot(x[0].detach().numpy(), combined_spline.detach().numpy(), 'k', linewidth=2, label='Combined Spline')
    plt.legend(['B_{}(x)'.format(i) for i in np.arange(G + k)] + ['Combined Spline'])
    plt.xlabel('x')
    plt.xticks(grid[0].detach().numpy())
    plt.ylabel('B_i(x)')
    plt.title(f'B-spline basis functions and combined spline with G={G}, k={k}')
    plt.grid(True)
    plt.show()

def update_coeffs_sliders(change):
    display_interactive_plot(G_slider, k_slider)

def display_interactive_plot(G_slider, k_slider):
    G = G_slider.value
    k = k_slider.value
    num_basis = G + k

    coeffs = [FloatSlider(min=-2.0, max=2.0, step=0.1, value=1.0, description=f'Coeff {i}') for i in range(num_basis)]    
    coeff_dict = {f'coeff_{i}': coeff for i, coeff in enumerate(coeffs)}

    out = interactive(plot_combined_spline, G=G_slider, k=k_slider, **coeff_dict)    
    container.children = [out]

# Create the container widget
container = VBox()

G_slider = IntSlider(min=1, max=10, step=1, value=3, description='G')
k_slider = IntSlider(min=0, max=5, step=1, value=2, description='k')

G_slider.observe(update_coeffs_sliders, names='value')
k_slider.observe(update_coeffs_sliders, names='value')

# Initialize the display
display(container)
display_interactive_plot(G_slider, k_slider)

VBox()