In [None]:
import sys
sys.path.append("..")

import jax
import jax.numpy as jnp

from src.kernel_functions import *

In [None]:
arr = jnp.array([[x, 0] for x in jnp.linspace(-1.25, 1.25, 1000)])
test_point = jnp.array([[0, 0] for _ in range(1000)])
h = 1.0
kernel_function_poly6_jit = jax.jit(kernel_function_poly6)
kernel_function_gradient_spiky_jit = jax.jit(kernel_function_gradient_spiky)
kernel_function_viscosity_laplacian_jit = jax.jit(kernel_function_viscosity_laplacian)

kernel_values = kernel_function_poly6_jit(arr, test_point, h)
print(kernel_values.shape)
kernel_gradient_values = kernel_function_gradient_spiky_jit(arr, test_point, h)
print(kernel_gradient_values.shape)
kernel_viscosity_laplacian_values = kernel_function_viscosity_laplacian_jit(arr, test_point, h)
print(kernel_viscosity_laplacian_values.shape)

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].set_title("Kernel Density Function")
axs[0].plot(arr[:, 0], kernel_values, label="Kernel Function", color="blue")
axs[1].set_title("Kernel Gradient Function")
axs[1].plot(arr[:, 0], kernel_gradient_values[:, 0], label="Kernel Gradient", color="orange")
axs[2].set_title("Kernel Viscosity Laplacian Function")
axs[2].plot(arr[:, 0], kernel_viscosity_laplacian_values, label="Kernel Viscosity Laplacian", color="green")