In [None]:
import torch
from scipy.special import spherical_jn

In [None]:
def _poly_cutoff(x: torch.Tensor, r_max: float, p: float = 6.0) -> torch.Tensor:
    factor = 1.0 / float(r_max)
    x = x * factor

    out = 1.0
    out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
    out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
    out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))

    return out * (x < 1.0)

def _bessel(x: torch.Tensor, r_max: float, num_basis: int = 8):
    prefactor = 2.0 / r_max

    bessel_weights = torch.linspace(start=1.0, end=num_basis, steps=num_basis) * torch.pi

    numerator = torch.sin(bessel_weights * x.unsqueeze(-1) / r_max)

    return prefactor * (numerator / x.unsqueeze(-1))

def _bessel_new(x: torch.Tensor, num_basis: int = 8):
    bessels = []
    for n in range(num_basis):
        bessels.append(spherical_jn(n, x))

    return torch.stack(bessels, dim=-1)

In [None]:
import torch

from torch import nn

class BesselBasis(nn.Module):

    def __init__(self, r_max, num_basis=8, accuracy=0.1, trainable=True):
        super(BesselBasis, self).__init__()

        self.trainable = trainable
        self.num_basis = num_basis
        self.num_points = int(r_max / accuracy)
        
        self.r_values = torch.linspace(0., r_max, self.num_points)
        Jn_values = []
        for n in range(num_basis):
            Jn_values.append(spherical_jn(n, self.r_values))

        bessel_values = torch.stack(Jn_values, dim=0).float().T
        self.register_buffer("bessel_values", bessel_values)

        bessel_weights = (
            torch.ones(self.num_basis, dtype=torch.get_default_dtype())
        )
        if self.trainable:
            self.bessel_weights = nn.Parameter(bessel_weights)
        else:
            self.register_buffer("bessel_weights", bessel_weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Evaluate Bessel Basis for input x.

        Parameters
        ----------
        x : torch.Tensor
            Input
        """
        idcs = torch.searchsorted(self.r_values, x)
        return torch.einsum("i,ji->ji", self.bessel_weights, self.bessel_values[idcs])

b = BesselBasis(r_max=6.)
x = torch.linspace(2., 3., 5)
b(x).shape

In [None]:
x = torch.linspace(0., 10., 100)
_bessel_new(x, num_basis=5)

In [None]:
import numpy as np
from matplotlib import pyplot as plt

r_max = 6.

x = np.linspace(0., 10., 100)
plt.plot(x, _poly_cutoff(torch.from_numpy(x), r_max, p=12).numpy())
#plt.plot(x, _bessel(torch.from_numpy(x), r_max, num_basis=8).numpy())
plt.plot(x, _bessel_new(torch.from_numpy(x), num_basis=5).numpy())

In [None]:
import torch
from scipy.special import spherical_jn
import matplotlib.pyplot as plt

# Define function to generate spherical Bessel functions
def spherical_bessel_function(n, kr):
    return spherical_jn(n, kr)

# Generate sample values for r
r_values = torch.linspace(0, 10, 100)

# Define parameters for multiple basis functions
k_values = [1.0, 2.0, 3.0]  # You can add more values as needed
n_max = 2  # Maximum order of spherical Bessel functions

# Plot multiple spherical Bessel functions
plt.figure(figsize=(8, 6))
for k in k_values:
    for n in range(n_max + 1):
        kr_values = r_values
        basis_values = spherical_bessel_function(n, kr_values)
        plt.plot(r_values.numpy(), basis_values, label=f'n={n}, k={k}')

plt.title('Multiple Spherical Bessel Functions')
plt.xlabel('r')
plt.ylabel('Basis Function Value')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import torch
from scipy.special import spherical_jn
import numpy as np

# Define the function to be decomposed
def f(r, theta, phi):
    return r**2 * torch.sin(theta) * torch.cos(phi)

# Generate sample values for r, theta, phi
r_values = torch.linspace(0, 10, 50)
theta_values = torch.linspace(0, np.pi, 50)
phi_values = torch.linspace(0, 2*np.pi, 50)

# Compute the spherical Bessel functions up to a certain order
def compute_basis_functions(r, theta, phi, n_max, k_max):
    basis_functions = []
    for n in range(n_max + 1):
        for k in range(1, k_max + 1):
            kr = k * r
            basis_functions.append(spherical_jn(n, kr))
    return basis_functions

# Compute the coefficients using inner products
def compute_coefficients(f, basis_functions, r_values, theta_values, phi_values):
    coefficients = torch.zeros(len(basis_functions))
    for i, basis_func in enumerate(basis_functions):
        integrand = f(r_values[:, None, None], theta_values[None, :, None], phi_values[None, None, :]) * basis_func
        integral = torch.trapz(torch.trapz(torch.trapz(integrand, phi_values, axis=2), theta_values, axis=1), r_values, axis=0)
        coefficients[i] = integral
    return coefficients

# Define parameters
n_max = 5  # Maximum order of spherical Bessel functions
k_max = 5  # Maximum value of k for scaling factor

# Compute basis functions
basis_functions = compute_basis_functions(r_values, theta_values, phi_values, n_max, k_max)

# Compute coefficients
coefficients = compute_coefficients(f, basis_functions, r_values, theta_values, phi_values)

print("Coefficients:", coefficients)

In [None]:
import matplotlib.pyplot as plt

# Evaluate the approximated function using the computed coefficients
def approximated_function(coefficients, basis_functions, theta, phi):
    approximated = torch.zeros_like(basis_functions[0])
    for coeff, basis_func in zip(coefficients, basis_functions):
        approximated += coeff * basis_func * torch.sin(theta) * torch.cos(phi)
    return approximated

# Compute the approximated function
approximated = approximated_function(coefficients, basis_functions, theta_values, phi_values)

# Plot the original and approximated functions
plt.figure(figsize=(10, 6))
plt.plot(phi_values.numpy(), f(r_values, theta_values, phi_values).numpy(), label='Original Function')
plt.plot(phi_values.numpy(), approximated.numpy(), label='Approximated Function')
# for b in basis_functions:
#     plt.plot(phi_values.numpy(), b.numpy(), label='Basis Function')
plt.title('Original and Approximated Functions')
plt.xlabel('Phi')
plt.ylabel('Function Value')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import matplotlib.pyplot as plt
x = np.arange(0.0, 10.0, 0.01)
fig, ax = plt.subplots()
ax.set_ylim(-0.5, 1.5)
ax.set_title(r'Spherical Bessel functions $j_n$')
for n in np.arange(0, 4):
    ax.plot(x, spherical_jn(n, x), label=rf'$j_{n}$')
plt.legend(loc='best')
plt.show()