Tools

In [11]:
import sys  
sys.path.append('/home/josorior/kinetick/source')

In [12]:
# Standard tool packages
import jax
from jax import grad, jacfwd, hessian
import jax.numpy as jnp
import matplotlib.pyplot as plt
# My tools
from models import *
from mykernels import get_gaussianRBF
# Settings
#plt.style.use('seaborn-v0_8')
plt.rcParams["figure.figsize"] = (5, 5)

## PDE

Let $M(x) = \exp\left(-0.5(x_1^2 + x_2^2)\right)$. The Fokker-Planck equation in 2d is given by
$$
\begin{align*}

\partial_t \rho (t,x) &= \text{div}\left(M(x)\nabla\left(\frac{\rho(t,x)}{M(x)}\right)\right) & \text{for } t >0, x \in \Omega = (-3,3)^2 \\
\rho(0,x) &= \rho^0(x), &\text{for } x \in \Omega \\
\nabla\left(\frac{\rho(t,x)}{M(x)} \cdot \mathbf{n}\right) &= 0, & \text{for } t >0, x \in \partial\Omega = (-3,3)^2
\end{align*}
$$
whose analytical solution is known to be
$$
\rho(t, x)=\frac{1}{2 \pi\left(1-e^{-2 t}\right)} e^{-\frac{x_1^2+x_2^2}{2\left(1-e^{-2 t}\right)}}.
$$

We will use the kernel method in closed form since the PDE operator is linear.


## Kernel solution

Choose a kernel

In [13]:
k = get_gaussianRBF(0.2)

Note: This is 2d time dependent problem meaning that the kernel actually is of the form
$$
K(s,w)
$$
where $s = (t,x_1,x_2)$ and $w = (\tau,y_1,y_2)$

First we choose collocation pointsets $\Omega^*\subset\Omega$ and $\Omega^*\subset\partial\Omega$

In [14]:
# \Omega
M_int = int(jnp.sqrt(500))
x_1 = jnp.linspace(-3,3,M_int)[1:]
x_2 = x_1
X_1, X_2 = jnp.meshgrid(x_1, x_2)
pairs_int = jnp.vstack([X_1.ravel(),X_2.ravel()]).T
M_int = pairs_int.shape[0]
# \partial\Omega
M_bdry = int(25*4)
bottom = jnp.vstack((jnp.linspace(-3,3,M_bdry//4), jnp.tile(-3,M_bdry//4))).T
top = jnp.vstack((jnp.linspace(-3,3,M_bdry//4), jnp.tile(3,M_bdry//4))).T
left = jnp.vstack((jnp.tile(-3,M_bdry//4),jnp.linspace(-3,3,M_bdry//4))).T[1:-1]
right = jnp.vstack((jnp.tile(3,M_bdry//4),jnp.linspace(-3,3,M_bdry//4))).T[1:-1]
pairs_bdry = jnp.vstack((left,bottom,right,top))
M_bdry = pairs_bdry.shape[0]

Build the kernel matrix $K(\phi,\phi)$ where $\phi = [\phi_\Omega, \phi_{\partial\Omega}]$ and
$$
\phi_\Omega \left(\square\right) = \delta_{x}\left(\square\right) - h\left(\delta_{x} \circ \text{div}\left(M(x)\nabla\left(\frac{\square}{M(x)}\right)\right)\right) \qquad \text{for } x\in \Omega
$$
$$
\phi_{\partial\Omega} \left(\square\right) = \delta_{x} \circ \nabla\left(\frac{\square}{M(x)}\right) \cdot \mathbf{n} = \qquad \text{for } x\in \partial \Omega
$$
Thus, the matrix has the form
$$
K(\phi,\phi) = 
\begin{pmatrix}
\phi_{\Omega}\left(\phi_{\Omega}\left(K(x,y)\right)\right) & \phi_{\partial\Omega}\left(\phi_{\Omega}\left(K(x,y)\right)\right) \\
\phi_{\Omega}\left(\phi_{\partial\Omega}\left(K(x,y)\right)\right) & \phi_{\partial\Omega}\left(\phi_{\partial\Omega}\left(K(x,y)\right)\right) \\
\end{pmatrix}
$$
where for instance the block $\phi_{\Omega}\left(\phi_{\partial\Omega}\left(K(x,y)\right)\right)$ is evaluated as:

1. Apply $\phi_{\partial\Omega}$ to the function $K(x,y)$ as a function of $x$ only. 

2. Apply $\phi_\Omega$ to $\phi_{\partial\Omega}\left(K(x,y)\right)$ as a function of $y$ only.

Let's compute the blocks.

$$\phi_{\Omega}\left(\phi_{\Omega}\left(K(x,y)\right)\right)$$

In [15]:
def M(x,y):
    del y
    return jnp.exp(-(jnp.sum(x**2))/ 2.)

In [16]:
def vectorize_kfunc(k):
    return jax.vmap(jax.vmap(k, in_axes=(None,0)), in_axes=(0,None))

def op_k_apply(k,L_op,R_op):
    return R_op(L_op(k,0),1)

def make_block(k,L_op,R_op):
    return vectorize_kfunc(op_k_apply(k,L_op,R_op))

In [17]:
def get_div(f,g):
    def f_op_g(x,y):
        return f(x,y) / g(x,y)
    return f_op_g

# Multiply two scalar functions
def get_mult(f,g):
    def f_op_g(x,y):
        return f(x,y) * g(x,y)
    return f_op_g

def get_sum(f,g):
    def f_op_g(x,y):
        return f(x,y) + g(x,y)
    return f_op_g

def get_subs(f,g):
    def f_op_g(x,y):
        return f(x,y) - g(x,y)
    return f_op_g

In [18]:
def f(x,y):
    return jnp.sum(x**2) - 3*jnp.sum(y**3)

def V(x,y):
    return jnp.array([jnp.sum(x**2) , 3*jnp.sum(y**3)])

In [19]:
r = jax.jacfwd(V,0)
a, b = jnp.array([1.,2.]), jnp.array([8.,9.])
r(a,b)

Array([[2., 4.],
       [0., 0.]], dtype=float32)

In [20]:
g = get_mult(f,V)

In [21]:
a, b = jnp.array([1.,2.]), jnp.array([8.,9.])
r = jnp.trace(jax.jacobian(g)(a,b))

In [25]:
def laplacian_k(k,index):
    def lapk(*args):
        return jnp.trace(jax.hessian(k,index)(*args))
    return lapk

def get_selected_grad(k,index,selected_index):
    gradf = grad(k,index)
    def selgrad(*args):
        return gradf(*args)[selected_index]
    return selgrad

def dx_k(k,index):
    return get_selected_grad(k,index,0)

# Option 1
def divergence_k(k,index):
    def divk(*args):
        return jnp.trace(jax.jacfwd(k,index)(*args))
    return divk
# Option 2
def get_selected_divergence(k,index,selected_index):
    divergencef = jax.jit(lambda x: jnp.trace(jax.jacfwd(f,index)(x)))
    def seldiv(*args):
        return divergencef(*args)[selected_index]
    return seldiv

def divergence(k,index):
    return get_selected_divergence(k,index,0)

def eval_k(k,index):
    return k

def phi_omega(k,index):
    return divergence(get_mult(M,dx_k(get_div(k,M),index)),index)

In [31]:
q = get_div(k,M)

In [30]:
q(a,a)

Array(12.182494, dtype=float32)

In [26]:
f = get_mult(M,dx_k(get_div(k,M),0))

In [27]:
b11 = make_block(k,phi_omega,phi_omega)(pairs_int,pairs_int)

TypeError: get_selected_divergence.<locals>.<lambda>() takes 1 positional argument but 2 were given