<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/test_heat_kernel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The heat kernel for the sphere is expressible as theta function.
The heat kernels are implemented in * heat_kernels.py*, currently in jax-rb/tests/utils. This is for testing only, we want to avoid  having a dependency on mpmath - we clone the project then point to the module location explicitly

In [2]:
!git clone https://github.com/dnguyend/jax-rb
!pip install mpmath

Cloning into 'jax-rb'...
remote: Enumerating objects: 254, done.[K
remote: Counting objects: 100% (254/254), done.[K
remote: Compressing objects: 100% (172/172), done.[K
remote: Total 254 (delta 126), reused 167 (delta 76), pack-reused 0[K
Receiving objects: 100% (254/254), 2.76 MiB | 17.79 MiB/s, done.
Resolving deltas: 100% (126/126), done.


In [3]:
import sys
sys.path.append("/content/jax-rb/")
sys.path.append("/content/jax-rb/tests/utils")

# !wget https://github.com/dnguyend/jax-rb/blob/main/tests/utils/heat_kernels.py

In [None]:
# !curl https://github.com/dnguyend/jax-rb/blob/main/tests/utils/heat_kernels.py

In [4]:

import jax
import jax.numpy as jnp
import jax.scipy.integrate as jsi
from jax import random

import numpy as np



import heat_kernels as hkm
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.global_manifold_integrator as mi

from jax_rb.manifolds.sphere import Sphere

jax.config.update("jax_enable_x64", True)


$\newcommand{\sfT}{\mathsf{T}}$
The heat equation $\partial_t p = D_0\Delta_{S^n}p$ on the sphere $S^n$ has the kernel $p$ of the form
\begin{equation}p(x, y, t) = K_{D_0t}^d(\varphi)\; \text{ with }\varphi=\cos^{-1}(x^{\sfT}y)
\end{equation}
satisfying the recursion \cite{NSS}
\begin{equation}K^{d+2}_t(\varphi) = -\frac{e^{td}}{2\pi}(\sin\varphi)^{-1}\partial_{\varphi}K^d_t(\varphi), \quad\quad d\geq 1
\end{equation}
where $K^1_{\varphi} = \frac{1}{2\pi}\theta_3(\frac{1}{2}\varphi, e^{-t})$, where $\theta_3$ is one of the Jacobi theta functions, computed  using the package mpmath. Derivatives of $\theta_3$ in $\varphi$ are also available. $\theta_3$ satisfies a functional equation, we have
\begin{equation}
 \frac{1}{2\pi}\theta_3(\frac{1}{2}\varphi, e^{-D_0t}) =
 \frac{e^{-\frac{\varphi^2}{4tD_0}}}{(4D_0\pi t)^{\frac{1}{2}}}  \
   \theta_3(\frac{\sqrt{-1}}{2}\frac{\pi\varphi}{D_0t}, e^{-\frac{\pi^2}{D_0t}})
\end{equation}
It is desirable to use the form on the right-hand side in practical simulation, as for small $t$ it converges faster, while the left side oscillates, however, I have  yet to encounter the problem. Taking the derivative in $\phi$ gives us two expressions for $K^d_t$ with $d$ odd. For $d=2$, and then for even $d$, the left-hand side is a simpler sum, the right-hand side is expressed as an integral of theta functions and is rather complicated to compute, but approximations are  available.

$$K^1_{\phi} = \frac{1}{(2\pi)} \theta_3(\frac{1}{2}\phi, e^{-t}, 0)$$

 Define recursively the Legendre polynomials $p_0=1,p_{1} = \cos\phi$
For $i \geq 1$,    

$$ p_i = \frac{1}{i}((2i-1)\cos\phi p_{i-1} - (i-1)p_{i-2})\\
$$
$$    K^2_{\phi} = \frac{1}{4\pi}\sum_{i=0}^{\infty}e^{-i(i+1)t}(2i+1)p_{i}\\
K^d_t = c_d\int_{-1}^1K^{2d-1}_{t/4}(\arccos(v\cos\frac{\phi}{2})(1-v^2)^{\frac{d-3}{2}}dv
$$
For $d = 2, c_2 = \frac{2}{\pi}$, set $v = \sin u$, the heat kernel for the sphere with radius $r$ and diffusion coefficient $d_0$ is
$$K^2_{\frac{Td_0}{4r^2}}(u, \phi)  =  c_2\int_{-\pi/2}^{\pi/2}K^3_{\frac{Td_0}{4r^2}}(\arccos(\sin(u)\cos(\phi/2))) du
$$
And the expectation is
$$\int_0^{\pi} (\sin(\phi))^22\pi f(\phi)K^2_{\frac{Td_0}{4r^2}}(u, \phi)d\phi  =  c_2\int_0^{\pi}\int_{-\pi/2}^{\pi/2}K^3_{\frac{Td_0}{4r^2}}(\arccos(\sin(u)\cos(\phi/2)))(\sin(\phi))^22\pi du d_{\phi}
$$
where $K^3$ is computed from $K^1$.

Here are the tests for $1,2$ and $3$ dimensions:

In [5]:
def test_1d():
    key = random.PRNGKey(0)
    key, sk = random.split(key)

    d = 1

    r = 1.2
    t_final = 1.1

    d0 = .4
    n_path = 500
    n_div = 200
    sph = Sphere(d+1, r)

    r = 4.
    def fin(phi):
        # return phi**1.5 + phi**2.5
        # return phi**1.5 + phi**2.5 + phi**2.5
        return phi**2

    # new example
    sph = Sphere(d+1, r)
    sph_heat_kernel = np.trapz(
        np.array([fin(min(aa, 2*np.pi-aa))*hkm.thk1(0, min(aa, 2*np.pi-aa), t_final, d0/r**2)
                  for aa in np.arange(n_path+1)/n_path*np.pi]),
        dx=2*np.pi/n_path)

    # then change manifold range
    sph_sum = fin(jnp.arccos(jnp.cos(jnp.sum(random.normal(sk, (n_div, n_path)), axis=0)*jnp.sqrt(t_final/(n_div)*2*d0)/r)))

    # now random walk
    xtmp = random.normal(sk, (d, n_div, n_path))
    xw = xtmp/jnp.sqrt(jnp.sum(xtmp**2, axis=0))[None, :]*jnp.sqrt(t_final/(n_div)*2*d0)
    sph_walk = jnp.mean(fin(jnp.arccos(jnp.cos(jnp.sum(xw, axis=1)/r))))

    x_0 = jnp.zeros(d+1).at[0].set(sph.r)
    sph_sim_geo = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.geodesic_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(sph.dist(x_0, x)/sph.r),
        (sk, t_final, n_path, n_div, d0, d+1))

    sph_sim_ito = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_ito_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(sph.dist(x_0, x)/sph.r),
        (sk, t_final, n_path, n_div, d0, d+1))

    sph_sim_strato = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(sph.dist(x_0, x)/sph.r),
        (sk, t_final, n_path, n_div, d0, d+1))

    print(f"heat_kernels={sph_heat_kernel}")
    print(f"sum_of_moves={jnp.mean(sph_sum)}")
    print(f"random walk={sph_walk}")

    print(f"geodesic={jnp.mean(sph_sim_geo[0])}")
    print(f"ito={jnp.mean(sph_sim_ito[0])}")
    print(f"strato={jnp.mean(sph_sim_strato[0])}")

test_1d()

heat_kernels=0.055
sum_of_moves=0.05249187557005862
random walk=0.051803399999999986
geodesic=0.05300901418527616
ito=0.053023444144691165
strato=0.053006278779519224


In [6]:

def test_2d():
    from scipy.integrate import dblquad

    key = random.PRNGKey(0)
    key, sk = random.split(key)

    d = 2
    r = 3
    t_final = 2.

    d0 = .4
    n_path = 1000
    n_div = 1000
    sph =Sphere(d+1, r)

    def fin(phi):
        return phi**2.5
    # return phi**1.5 + phi**2.5
    # return phi**1.5 + phi**2.5 + phi**2.5

    sph_heat_kernel = jsi.trapezoid(
        np.array([hkm.k2(phi, t_final*d0/r**2)*(np.sin(phi))*2*np.pi*fin(phi)
                  for phi in np.arange(n_path+1)/n_path*np.pi]),
        dx=np.pi/n_path)

    # compute the 2d heat kernel by integrating the 3d

    ss = dblquad(lambda u, phi: hkm.k3(np.arccos(np.sin(u)*np.cos(phi/2)), t_final*d0/sph.r**2/4)*(np.sin(phi))**2*2*np.pi*fin(phi), 0., np.pi, -np.pi/2., np.pi/2)
    c2 = 2/np.pi
    sph_heat_kernel_alt = ss[0]*c2

    sph_sim_geo = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.geodesic_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, d+1))

    sph_sim_geo_norm = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.geodesic_move_normalized(sph, x, unit_move, scale*d),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, d+1))

    sph_sim_ito = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_ito_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, d+1))

    sph_sim_strato = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, d+1))

    print(f"heat_kernels={sph_heat_kernel}")
    print(f"heat_kernels_alt={sph_heat_kernel_alt}")

    print(f"geodesic={jnp.mean(sph_sim_geo[0])}")
    print(f"geodesic={jnp.mean(sph_sim_geo_norm[0])}")
    print(f"ito={jnp.mean(sph_sim_ito[0])}")
    print(f"strato={jnp.mean(sph_sim_strato[0])}")

test_2d()

heat_kernels=0.29940738111274223
heat_kernels_alt=0.2706534545899766
geodesic=0.2823254465952039
geodesic=0.2882006825595556
ito=0.28244562060757716
strato=0.28249267581460963


In [7]:

def test_3d():
    key = random.PRNGKey(0)
    key, sk = random.split(key)

    n = 4

    r = 3
    t_final = 2.

    d0 = .4
    n_path = 1000
    n_div = 1000
    sph = Sphere(n, r)

    def fin(phi):
        # return phi**2.5
        return phi**1.5 + phi**2.5
    # return phi**1.5 + phi**2.5 + phi**2.5

    sph_heat_kernel = jsi.trapezoid(
        np.array([hkm.k3(phi, t_final*d0/r**2)*(np.sin(phi))**2*4*np.pi*fin(phi)
                  for phi in np.arange(n_path+1)/n_path*np.pi]),
        dx=np.pi/n_path)

    sph_sim_geo = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.geodesic_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, n))

    sph_sim_geo_norm = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.geodesic_move_normalized(sph, x, unit_move, scale*(n-1)),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, n))


    sph_sim_ito = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_ito_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, n))

    sph_sim_strato = sim.simulate(
        sph.x0,
        lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(sph, x, unit_move, scale),
        None,
        lambda x: fin(jnp.arccos(x[0]/sph.r)),
        (sk, t_final, n_path, n_div, d0, n))

    print(f"heat_kernels={sph_heat_kernel}")

    print(f"geodesic={jnp.mean(sph_sim_geo[0])}")
    print(f"geodesic={jnp.mean(sph_sim_geo_norm[0])}")
    print(f"ito={jnp.mean(sph_sim_ito[0])}")
    print(f"strato={jnp.mean(sph_sim_strato[0])}")


test_3d()


heat_kernels=1.023917151541904
geodesic=1.0502866903715162
geodesic=1.0664487173174357
ito=1.0508031482726825
strato=1.0507484072041133
