In [28]:
import sys
import torch
sys.path.append('../')

from model.kan.activations import calculate_B_spline_basis_functions

# Compare with pykan


In [29]:
num_splines = 5
num_samples = 10
num_knots = 7
degree = 3

x = torch.randn(num_splines, num_samples)
grid = torch.arange(1, num_knots + 1).repeat(num_splines, 1)
coef = torch.randn(num_splines, num_knots + degree - 1)
k = degree


### Calculate B_splines basis functions

In [30]:
def calculate_B_spline_basis_functions(x, grid, k):
    """
    Args:
        x: torch.Tensor (num splines, num samples)
        grid: torch.Tensor (num splines, num knots)
        k: int (degree of the spline)
    Returns:
        B: torch.Tensor (num splines, num knots + k - 1, num samples)
    """
    # Extend k grid points to the left and right
    num_knots = grid.shape[1]
    distance = (grid[:, -1] - grid[:, 0]) / (num_knots-1) # (num splines, )
    for _ in range(1, k + 1):
        left_extension = grid[:, 0] - distance
        right_extension = grid[:, -1] + distance
        grid = torch.cat((left_extension.unsqueeze(1), grid, right_extension.unsqueeze(1)), dim=1)
    # Calculate B-spline basis functions
    x = x.unsqueeze(dim=1) # (num splines, 1, num samples)
    grid = grid.unsqueeze(dim=2) # (num splines, num knots + 2k, 1)

    b = (grid[:, :-1] <= x) * (x < grid[:, 1:])
    for p in range(1, k+1):
        b = (x - grid[:, :-(p + 1)]) / (grid[:, p:-1] - grid[:, :-(p + 1)]) * b[:, :-1] + (grid[:, (p + 1):] - x) / (grid[:, (p + 1):] - grid[:, 1:-p]) * b[:, 1:] # (num splines, num knots + 2k - p - 1, num samples)

    return b # (num splines, num knots + k - 1, num samples)

def B_batch(x, grid, k=0, extend=True, device='cpu'):
    '''
    evaludate x on B-spline bases
    
    Args:
    -----
        x : 2D torch.tensor
            inputs, shape (number of splines, number of samples)
        grid : 2D torch.tensor
            grids, shape (number of splines, number of grid points)
        k : int
            the piecewise polynomial order of splines.
        extend : bool
            If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
        device : str
            devicde
    
    Returns:
    --------
        spline values : 3D torch.tensor
            shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1.
      
    Example
    -------
    >>> num_spline = 5
    >>> num_sample = 100
    >>> num_grid_interval = 10
    >>> k = 3
    >>> x = torch.normal(0,1,size=(num_spline, num_sample))
    >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
    >>> B_batch(x, grids, k=k).shape
    torch.Size([5, 13, 100])
    '''

    # x shape: (size, x); grid shape: (size, grid)
    def extend_grid(grid, k_extend=0):
        # pad k to left and right
        # grid shape: (batch, grid)
        h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
        for i in range(k_extend):
            grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
            grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
        grid = grid.to(device)
        return grid

    if extend == True:
        grid = extend_grid(grid, k_extend=k)

    grid = grid.unsqueeze(dim=2).to(device)
    x = x.unsqueeze(dim=1).to(device)

    if k == 0:
        value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
        # print(value)
    else:
        B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
        value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
                    grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
    return value

In [31]:
a = B_batch(x, grid, k, extend=True, device='cpu')
b = calculate_B_spline_basis_functions(x, grid, k)

In [32]:
a

tensor([[[6.6547e-01, 2.1843e-01, 4.4455e-01, 5.5538e-01, 6.4197e-01,
          5.6479e-01, 4.8754e-01, 5.6936e-01, 2.8722e-01, 6.2921e-01],
         [1.8470e-01, 6.5802e-01, 1.4753e-02, 4.1783e-02, 2.5993e-01,
          3.8253e-01, 2.2563e-02, 4.7336e-02, 6.2822e-01, 2.8539e-01],
         [7.0824e-06, 1.2341e-01, 0.0000e+00, 0.0000e+00, 7.3563e-04,
          7.2429e-03, 0.0000e+00, 0.0000e+00, 8.3083e-02, 1.4203e-03],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0

In [33]:
b

tensor([[[6.6547e-01, 2.1843e-01, 4.4455e-01, 5.5538e-01, 6.4197e-01,
          5.6479e-01, 4.8754e-01, 5.6936e-01, 2.8722e-01, 6.2921e-01],
         [1.8470e-01, 6.5802e-01, 1.4753e-02, 4.1783e-02, 2.5993e-01,
          3.8253e-01, 2.2563e-02, 4.7336e-02, 6.2822e-01, 2.8539e-01],
         [7.0824e-06, 1.2341e-01, 0.0000e+00, 0.0000e+00, 7.3563e-04,
          7.2429e-03, 0.0000e+00, 0.0000e+00, 8.3083e-02, 1.4203e-03],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0

In [34]:
a - b

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0