# Kolmogorov Arnold Networks

### Kolmogorov-Arnold Networks (KANs): Overview  
- **Theoretical Foundation**: Based on the Kolmogorov-Arnold representation theorem, which states that any multivariate continuous function can be expressed as a sum of univariate functions, allowing for a systematic decomposition of complex functions.  
- **Architecture**: Decomposes high-dimensional mappings into a series of simpler one-dimensional functions, enabling efficient and accurate approximation of intricate dependencies in high-dimensional problems.  
- **Applications**: Particularly useful in fields requiring complex function approximations, such as machine learning, physics, and computational mathematics.

### KANs for Noise Removal in Partial Differential Equations (PDEs)  
- **Denoising Mechanism**: KANs excel at separating structured signals from stochastic noise by learning the underlying deterministic relationships within the data governed by PDEs.  
- **Training Process**: Trains to approximate noiseless PDE solutions by leveraging the network's bias towards representing smooth, well-structured functions.  
- **Advantage**: Provides robust and efficient recovery of clean PDE solutions from noisy datasets, combining theoretical rigor with practical performance.  
- **Practical Benefit**: Enhances the accuracy and reliability of numerical solutions for PDEs in scenarios where data is corrupted by noise.

## Implementation

This implementation of KAN has been taken from https://github.com/Blealtan/efficient-kan. 

This code defines a **Kolmogorov-Arnold Network (KAN) Linear Layer** in PyTorch. It is designed to efficiently model complex relationships by combining linear transformations and B-spline interpolation, as informed by the Kolmogorov-Arnold representation theorem.

###  Components:
1. **Initialization and Grid Setup**:
   - The grid represents the domain over which splines are interpolated. It is extended by the spline order to ensure smooth boundary conditions.
   - Parameters for the layer include base weights (`base_weight`) for linear transformations and spline weights (`spline_weight`) for B-spline coefficients.

2. **B-Spline Basis Calculation**:
   - The `b_splines` function computes B-spline bases for given inputs, which are then used to interpolate values in the input space.

3. **Spline Coefficient Computation**:
   - The `curve2coeff` function fits spline coefficients to the data by solving a least-squares problem, ensuring that the spline interpolation matches the input-output relationship.

4. **Forward Pass**:
   - Combines the linear transformation of the input using `base_weight` and the spline interpolation using `spline_weight` to produce the output.

5. **Grid Adaptation**:
   - The `update_grid` method adjusts the spline grid based on input distribution, allowing the model to adapt dynamically to new data.

6. **Regularization**:
   - A custom loss (`regularization_loss`) penalizes the spline weights to control overfitting and ensure smooth approximations. This includes terms for weight sparsity and entropy regularization.

This implementation is designed to be flexible and efficient, enabling it to approximate and learn high-dimensional functions, especially useful in tasks like function approximation, noise filtering, or PDE solutions.



In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub

shusrith_burgers_noisy_path = kagglehub.dataset_download("shusrith/burgers-noisy")

print("Data source import complete.")

## Layers

In [None]:
import torch
import torch.nn.functional as F
import math
import torch.nn as nn


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(
            self.base_weight, a=math.sqrt(5) * self.scale_base
        )
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(
                    self.spline_scaler, a=math.sqrt(5) * self.scale_spline
                )

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

## Complete Model

- Defines a complete Kolmogorov-Arnold Network (KAN) as a sequence of KANLinear layers for multivariate function approximation.  
- Accepts hyperparameters like grid size, spline order, scaling factors, activation functions, and grid range to configure the architecture.  
- Uses a `ModuleList` to stack multiple KANLinear layers based on the input hidden layer configuration (`layers_hidden`).  
- Implements a forward pass that processes input through each KANLinear layer, optionally updating the grid dynamically for adaptability.  
- Provides a method to compute regularization loss by aggregating the regularization terms from all layers to promote sparsity and smoothness.

In [2]:
class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

### Regularisation

- Extends the KAN architecture by integrating dropout layers to improve regularization and prevent overfitting.  
- Uses a `ModuleList` to stack layers, where each layer includes a KAN module, batch normalization, SiLU activation, and dropout.  
- Configurable dropout probability allows control over the level of regularization applied during training.  
- Implements a forward pass that processes input through each stacked layer sequentially, applying all operations in the layer pipeline.

In [None]:
class KANWithDropout(torch.nn.Module):
    def __init__(self, layers_hidden, dropout_prob=0.3):
        super(KANWithDropout, self).__init__()
        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                nn.Sequential(
                    KAN(in_features, out_features),
                    nn.BatchNorm1d(out_features),
                    nn.SiLU(),
                    nn.Dropout(dropout_prob),
                )
            )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## Custom Loss Function

- A custom loss function combining Mean Squared Error (MSE) and L1 loss for balanced optimization between precision and robustness.  
- The weighting factor `alpha` controls the contribution of each loss term, allowing flexibility based on the problem's requirements.  
- Uses the `forward` method to compute the weighted sum of MSE and L1 losses between the model's output and the target values.  
- Provides a smooth and robust loss function useful for tasks where both small errors and outlier handling are important.

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()

    def forward(self, output, target):
        return self.alpha * self.mse(output, target) + (1 - self.alpha) * self.l1(
            output, target
        )

# It has 1000 samples, push it as much as you can please till it runs out of VRAM

# modify for loop

## Data Loading

The clean and noisy samples of the solution are loaded for training

In [4]:
import h5py
import numpy as np
with h5py.File("/kaggle/input/burgers-noisy/simulation_data.h5", "r") as f:
    a = list(f.keys())
    clean = []
    noisy = []
    for i in a[:100]:
        clean.append(f[i]["clean"][:])
        noisy.append(f[i]["noisy"][:])

clean = np.array(clean)
noisy = np.array(noisy)

## Data Preparation

This script processes noisy and clean data for training a model, preparing it for PyTorch's DataLoader.

- **Data Reshaping**:
  - The noisy and clean datasets are flattened from `(num_samples, height, width)` into `(num_samples, height × width)` for compatibility with PyTorch models.

- **Tensor Conversion**:
  - The flattened arrays are converted into PyTorch tensors (`X_tensor` for noisy data and `Y_tensor` for clean data) of type `float32`.

- **Dataset Creation**:
  - A `TensorDataset` pairs the noisy data (`X_tensor`) with the clean target data (`Y_tensor`) for supervised learning.

- **Train/Test Split**:
  - The dataset is split into training and testing subsets, with 80% for training and 20% for testing, using PyTorch's `random_split`.

- **DataLoaders**:
  - Training and testing datasets are wrapped into DataLoaders with a batch size of 2, enabling shuffled training and sequential testing.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import h5py
import numpy as np


print("Noisy data shape:", noisy.shape)

num_samples, height, width = noisy.shape
noisy_flattened = noisy.reshape(num_samples, -1)
clean_flattened = clean.reshape(num_samples, -1)

X_tensor = torch.tensor(noisy_flattened, dtype=torch.float32)
Y_tensor = torch.tensor(clean_flattened, dtype=torch.float32)

dataset = TensorDataset(X_tensor, Y_tensor)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, test_size]
)

trainloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

print("Trainloader size:", len(trainloader.dataset))
print("Testloader size:", len(testloader.dataset))

## Model Definition

  - A `KANWithDropout` model is initialized with the input dimension (`input_dim`), two hidden layers of sizes 256 and 64, and a final output layer of size `input_dim` to match the input/output shape.
  - The model is moved to the appropriate device (`GPU` if available, otherwise `CPU`).

- **Optimizer**:
  - Uses the `AdamW` optimizer, which combines the benefits of Adam with weight decay regularization to reduce overfitting.  
  - Learning rate is set to `1e-3`, and weight decay to `1e-4`.

- **Learning Rate Scheduler**:
  - A `ReduceLROnPlateau` scheduler reduces the learning rate by a factor of 0.5 when the monitored metric (e.g., loss) stops improving for 4 epochs (`patience=4`).
  - Minimum learning rate (`min_lr`) is set to `1e-6`, and verbose mode is enabled for logging changes.

- **Loss Function**:
  - Combines MSE and L1 losses using the `CombinedLoss` class, balancing precision and robustness during optimization.

In [None]:
input_dim = noisy_flattened.shape[1]  
model = KANWithDropout([input_dim, 256, 64, 256, input_dim])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=4,
    min_lr=1e-6,
    verbose=True,
)

criterion = CombinedLoss()

## Training Loop

This code trains the `KANWithDropout` model for 200 epochs, evaluates its validation loss, adjusts the learning rate using a scheduler, and saves the trained model.

- **Training Loop**:
  - Iterates over 200 epochs.
  - For each batch:
    - Moves data (`inputs` and `targets`) to the appropriate device.
    - Clears previous gradients using `optimizer.zero_grad()`.
    - Computes model outputs, loss, and gradients, and updates weights using `optimizer.step()`.
    - Performs validation


- **Learning Rate Adjustment**:
  - The learning rate scheduler (`ReduceLROnPlateau`) adjusts the learning rate based on the validation loss, reducing it when the loss plateaus.

- **Model Saving**:
  - Saves the trained model's state dictionary (`model.state_dict()`) to a file named `"kan_model.pth"`, allowing for later use or deployment.


In [None]:
for epoch in range(200):
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            output = model(inputs)
            val_loss += criterion(output, targets).item()
    val_loss /= len(testloader)

    # Update learning rate
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Val Loss: {val_loss}")

# Save the model
torch.save(model.state_dict(), "kan_model.pth")

## Prediction

Randomly selected sample is reshaped and prepared to be fed to the model.

In [29]:
with  h5py.File("/kaggle/input/burgers-noisy/simulation_data.h5", "r") as f:
    a = f["10"]["noisy"][:]
    b = f["10"]["clean"][:]
    x = f["coords"]["x-coordinates"][:]
    f.close()
a = torch.Tensor(a).to(device)
a = a.view(1, -1)
a.shape

torch.Size([1, 205824])

### Run Predictions

Predictions are made and the output is reshaped for visualisation.

In [None]:
with torch.no_grad():
    u = model(a.to(device)).to("cpu")

u = u.cpu().numpy().reshape((201, 1024))

## Visualisation

The `visualize_burgers` function generates an animated GIF of the Burgers equation's solution over time. It takes spatial coordinates (`xcrd`), the simulation data (`data`), and an identifier (`i`) to name the output GIF. The function iterates through the time steps of the solution, plots each one, and stores the frames for the animation. It then creates an animation using `matplotlib.animation.ArtistAnimation`, saving it as a `.gif` file with a specified frame rate.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from tqdm import tqdm


def visualize_burgers(xcrd, data):
    """
    This function animates the Burgers equation

    Args:
    path : path to the desired file
    param: PDE parameter of the data shard to be visualized
    """

    fig, ax = plt.subplots()

    ims = []

    for i in tqdm(range(data.shape[0])):
        if i == 0:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        else:
            im = ax.plot(
                xcrd, data[i].squeeze(), animated=True, color="blue"
            )  # show an initial one first
        ims.append([im[0]])

    # Animate the plot
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)

    writer = animation.PillowWriter(fps=15, bitrate=1800)
    ani.save("burgerCombo.gif", writer=writer)
    plt.close(fig)

visualize_burgers(x[:], u)