In [1]:
import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(
            self.base_weight, a=math.sqrt(5) * self.scale_base
        )
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(
                    self.spline_scaler, a=math.sqrt(5) * self.scale_spline
                )

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from tqdm import tqdm


def visualize_burgers(xcrd, data, path):
    """
    This function animates the Burgers equation

    Args:
    path : path to the desired file
    param: PDE parameter of the data shard to be visualized
    """
    fig, ax = plt.subplots()
    ims = []

    for i in tqdm(range(data.shape[0])):
        if i == 0:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        else:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        ims.append([im[0]])

    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)

    writer = animation.PillowWriter(fps=15, bitrate=1800)
    ani.save(path, writer=writer)
    plt.close(fig)

In [3]:
import torch
import h5py
import numpy as np
from tqdm import tqdm

# Adjust the input dimensions to match the data
model = KAN([201, 512, 512, 1024, 512, 512, 201]).to("cuda")

with h5py.File("/home/pes1ug22am100/Documents/Research and Experimentation/NoisyICML/pinns-inverse/WaveEquation/wave_solutions_new.h5", "r") as f:
    l = list(f.keys())
    d = []
    for i in l:
        if i != "coords":
            d.append([f[i]["clean"][:], f[i]["noisy"][:]])
    d = np.array(d)
    f.close()

clean = torch.Tensor(d[:, 0, :, :])
train = torch.Tensor(d[:, 1, :, :])

clean = clean.squeeze(1)  # Shape will be [1000, 256, 500]
train = train.squeeze(1)  # Shape will be [1000, 256, 500]

# clean, train = clean.permute(0, 2, 1), train.permute(0, 2, 1)
clean = clean.reshape(-1, 201).to("cuda")
train = train.reshape(-1, 201).to("cuda")

loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dataset = torch.utils.data.TensorDataset(train, clean)
loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=10
)
lr = optimizer.param_groups[0]["lr"]
for i in range(60):
    for j in tqdm(loader):
        inp, out = j
        optimizer.zero_grad()
        a = model(inp)
        l = loss(a, out)
        l.backward()
        optimizer.step()
    scheduler.step(l)
    if lr != optimizer.param_groups[0]["lr"]:
        lr = optimizer.param_groups[0]["lr"]
        print("Learning rate changed to", lr)
    print(i, l.item())
torch.save(model.state_dict(), "/home/pes1ug22am100/Documents/Research and Experimentation/NoisyICML/kan/mightDelete/model.pth")
a = train[:1024].view(-1, 201)
print(a.shape)
a = model(a)

print(torch.mean((clean[:1024] - a) ** 2))
print(torch.mean((train[:1024] - clean[:1024]) ** 2))
visualize_burgers([i for i in range(1024)], a.cpu().detach().T, "test.gif")

100%|██████████| 1924/1924 [04:37<00:00,  6.94it/s]


0 0.014138316735625267


100%|██████████| 1924/1924 [04:44<00:00,  6.76it/s]


1 0.014184702187776566


100%|██████████| 1924/1924 [02:24<00:00, 13.27it/s]


2 0.013600184582173824


100%|██████████| 1924/1924 [02:25<00:00, 13.25it/s]


3 0.014096741564571857


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


4 0.013004212640225887


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


5 0.011590304784476757


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


6 0.011693017557263374


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


7 0.010753949172794819


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


8 0.011074596084654331


100%|██████████| 1924/1924 [02:25<00:00, 13.22it/s]


9 0.00962010957300663


100%|██████████| 1924/1924 [02:25<00:00, 13.22it/s]


10 0.008078177459537983


100%|██████████| 1924/1924 [02:25<00:00, 13.23it/s]


11 0.00831447634845972


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


12 0.007196575868874788


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


13 0.006040601991117001


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


14 0.00482280133292079


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


15 0.004131890367716551


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


16 0.0034927872475236654


100%|██████████| 1924/1924 [02:25<00:00, 13.24it/s]


17 0.005280336365103722


100%|██████████| 1924/1924 [02:38<00:00, 12.16it/s]


18 0.003863360034301877


100%|██████████| 1924/1924 [03:26<00:00,  9.32it/s]


19 0.004547491203993559


100%|██████████| 1924/1924 [02:36<00:00, 12.32it/s]


20 0.0029113583732396364


100%|██████████| 1924/1924 [02:24<00:00, 13.29it/s]


21 0.0032580639235675335


100%|██████████| 1924/1924 [02:24<00:00, 13.27it/s]


22 0.0030350841116160154


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


23 0.003528860630467534


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


24 0.0025681606493890285


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


25 0.003951083403080702


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


26 0.003114726860076189


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


27 0.013377025723457336


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


28 0.012607906013727188


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


29 0.011678586713969707


100%|██████████| 1924/1924 [02:27<00:00, 13.09it/s]


30 0.010905196890234947


100%|██████████| 1924/1924 [02:59<00:00, 10.71it/s]


31 0.010032049380242825


100%|██████████| 1924/1924 [02:44<00:00, 11.66it/s]


32 0.008314906619489193


100%|██████████| 1924/1924 [02:24<00:00, 13.29it/s]


33 0.007064716424793005


100%|██████████| 1924/1924 [02:24<00:00, 13.27it/s]


34 0.0077056847512722015


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


Learning rate changed to 0.0001
35 0.006596013437956572


100%|██████████| 1924/1924 [02:25<00:00, 13.27it/s]


36 0.0046366057358682156


100%|██████████| 1924/1924 [02:25<00:00, 13.27it/s]


37 0.0037952547427266836


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


38 0.0033758075442165136


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


39 0.0034683283884078264


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


40 0.003132096491754055


100%|██████████| 1924/1924 [02:26<00:00, 13.16it/s]


41 0.003151949029415846


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


42 0.0025077257305383682


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


43 0.0029688102658838034


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


44 0.0025208669248968363


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


45 0.0026845294050872326


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


46 0.0024569907691329718


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


47 0.0023416313342750072


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


48 0.002220829948782921


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


49 0.0018498350400477648


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


50 0.001486238557845354


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


51 0.0015586229274049401


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


52 0.0020558235701173544


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


53 0.0017984716687351465


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


54 0.0015311883762478828


100%|██████████| 1924/1924 [02:25<00:00, 13.25it/s]


55 0.0014333759900182486


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


56 0.0021670758724212646


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


57 0.0014184147585183382


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


58 0.0011167527409270406


100%|██████████| 1924/1924 [02:25<00:00, 13.26it/s]


59 0.0012214204762130976
torch.Size([1024, 201])
tensor(0.0025, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.4166, device='cuda:0')


100%|██████████| 201/201 [00:01<00:00, 190.84it/s]


In [4]:
torch.cuda.empty_cache()