In [2]:
import matplotlib.pyplot as plt
import torch
import numpy as np

import numpy as np
import torch
import torch.nn as nn


class ReLUKANLayer(nn.Module):
    def __init__(self, input_size: int, g: int, k: int, output_size: int, train_ab: bool = True):
        super().__init__()
        self.g, self.k, self.r = g, k, 4*g*g / ((k+1)*(k+1))
        self.input_size, self.output_size = input_size, output_size
        phase_low = np.arange(-k, g) / g
        phase_height = phase_low + (k+1) / g
        self.phase_low = nn.Parameter(torch.Tensor(np.array([phase_low for i in range(input_size)])),
                                      requires_grad=train_ab)
        self.phase_height = nn.Parameter(torch.Tensor(np.array([phase_height for i in range(input_size)])),
                                         requires_grad=train_ab)
        self.equal_size_conv = nn.Conv2d(1, output_size, (g+k, input_size))
    def forward(self, x):
        print(x.shape, self.phase_low.shape)
        x1 = torch.relu(x - self.phase_low)
        x2 = torch.relu(self.phase_height - x)
        x = x1 * x2 * self.r
        x = x * x
        x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
        x = self.equal_size_conv(x)
        x = x.reshape((len(x), self.output_size, 1))
        return x


class ReLUKAN(nn.Module):
    def __init__(self, width, grid, k):
        super().__init__()
        self.width = width
        self.grid = grid
        self.k = k
        self.rk_layers = []
        for i in range(len(width) - 1):
            self.rk_layers.append(ReLUKANLayer(width[i], grid, k, width[i+1]))
            # if len(width) - i > 2:
            #     self.rk_layers.append()
        self.rk_layers = nn.ModuleList(self.rk_layers)

    def forward(self, x):
        for rk_layer in self.rk_layers:
            x = rk_layer(x)
        # x = x.reshape((len(x), self.width[-1]))
        return x

    
x = torch.Tensor([np.arange(0, 1024) / 1024]).repeat(8,1).T
shape = x.shape[:-1]
x.view(*shape, -1)
y = torch.sin(5*torch.pi*x)
print(x.shape)
rk = ReLUKANLayer(1, 5, 3, 2)
y = rk(x)
print(y.shape)

torch.Size([1024, 8])
torch.Size([1024, 8]) torch.Size([1, 8])
torch.Size([1024, 2, 1])


In [6]:
    def rbf_kernel(X, Y, gamma=-1, ad=1):
        # X and Y should be tensors with shape (batch_size, num_channels, height, width)
        # gamma is a hyperparameter controlling the width of the RBF kernel

        # Reshape X and Y to have shape (batch_size, num_channels*height*width)
        X_flat = X.view(X.size(0), -1)
        Y_flat = Y.view(Y.size(0), -1)

        # Compute the pairwise squared Euclidean distances between the samples
        with torch.cuda.amp.autocast():
            dists = torch.cdist(X_flat, Y_flat, p=2)**2
        print(dists.shape)
        if gamma <0: # use median trick
            gamma = torch.median(dists)
            gamma = torch.sqrt(0.5 * gamma / np.log(dists.size(0) + 1))
            gamma = 1 / (2 * gamma**2)
            # print(gamma)

        gamma = gamma * ad 
        # gamma = torch.max(gamma, torch.tensor(1e-3))
        # Compute the RBF kernel using the squared distances and gamma
        K = torch.exp(-gamma * dists)
        dK = -2 * gamma * K.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * (X.unsqueeze(1) - Y.unsqueeze(0))
        dK_dX = torch.sum(dK, dim=1)

        return K, dK_dX
    
a = torch.ones(100,2)
b = torch.ones(100,2)
K , dk = rbf_kernel(a,b)
print(K.shape, dk.shape)

torch.Size([100, 100]) torch.Size([100, 100, 100, 2])
