# Practice: Poisson equation

"L" shaped region $(-1, 1)^2 \setminus (0, 1)^2$ 위에서 정의된 Poisson equation
$$
-\Delta u = 1,
$$
을 고려하겠습니다.
Boundary condition은 zero Dirichlet 입니다.

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
# double precision
jax.config.update("jax_enable_x64", True)

Multi-layer perceptron을 정의합니다.

In [None]:
def MLP(layers: list[int] = [1, 64, 1], activation: callable = jnp.tanh):
    def init_params(key):
        def _init(key, d_in, d_out):
            w = jr.normal(key, shape=(d_in, d_out)) * jnp.sqrt(2 / (d_in + d_out))
            b = jnp.zeros((d_out,))
            return [w, b]

        keys = jr.split(key, len(layers) - 1)
        params = list(map(_init, keys, layers[:-1], layers[1:]))
        return params

    def apply(params, inputs):
        for W, b in params[:-1]:
            outputs = inputs @ W + b
            inputs = activation(outputs)
        W, b = params[-1]
        outputs = inputs @ W + b
        return outputs
    
    return init_params, apply

PINN loss function을 계산하기 위해 domain으로부터 collocation points를 sampling 합니다.
균일하게 sampling하기 위해서 Sobol sequence를 사용합니다.

In [None]:
from scipy.stats.qmc import Sobol


def sampling_interior(m: int = 9):
    sobol = Sobol(d=2)
    # sampling 3 * 2^m points
    xy1 = sobol.random_base2(m)
    sobol.reset()
    xy2 = sobol.random_base2(m)
    sobol.reset()
    xy3 = sobol.random_base2(m)
    sobol.reset()
    xy1[:, 0] = xy1[:, 0] - 1.0
    xy2 = xy2 - 1
    xy3[:, 1] = xy3[:, 1] - 1.0
    xy_interior = jnp.concatenate([xy1, xy2, xy3])
    return xy_interior

def sampling_boundary(m: int = 8):
    sobol = Sobol(d=1)
    # sampling 6 * 2^m points
    N = 2**m
    # x
    x1 = jnp.stack([sobol.random_base2(m).squeeze() * 2 - 1, -1 * jnp.ones((N,))], 1)
    sobol.reset()
    x2 = jnp.stack([sobol.random_base2(m).squeeze() - 1, jnp.ones((N,))], 1)
    sobol.reset()
    x3 = jnp.stack([sobol.random_base2(m).squeeze(), jnp.zeros((N,))], 1)
    sobol.reset()
    # y
    y1 = jnp.stack(
        [
            -1 * jnp.ones((N,)),
            sobol.random_base2(m).squeeze() * 2 - 1,
        ],
        1,
    )
    sobol.reset()
    y2 = jnp.stack([jnp.zeros((N,)), sobol.random_base2(m).squeeze()], 1)
    sobol.reset()
    y3 = jnp.stack([jnp.ones((N,)), sobol.random_base2(m).squeeze() - 1], 1)
    sobol.reset()
    xy_boundary = jnp.concatenate([x1, x2, x3, y1, y2, y3])
    return xy_boundary

xy_in = sampling_interior()
xy_b = sampling_boundary()

PINN loss function을 계산하기 위해 neural network의 spatial gradients를 계산합니다.

In [None]:
init, apply = MLP([2, 50, 50, 50, 50, 1], jnp.tanh)

def pinn(params, x, y):
    inputs = jnp.stack([x, y])
    pinn = apply(params, inputs).squeeze()  # scalar
    return pinn

def pinn_x(params, x, y):
    return jax.jacfwd(pinn, 1)(params, x, y)

def pinn_xx(params, x, y):
    return jax.jacfwd(pinn_x, 1)(params, x, y)

def pinn_y(params, x, y):
    return jax.jacfwd(pinn, 2)(params, x, y)

def pinn_yy(params, x, y):
    return jax.jacfwd(pinn_y, 2)(params, x, y)


PINN loss의 component를 계산하기 위한 함수, 그리고 PINN loss를 정의합니다.

In [None]:
def pde(params, xy_in):
    x, y = xy_in
    u_xx = pinn_xx(params, x, y)
    u_yy = pinn_yy(params, x, y)
    return u_xx + u_yy + 1


def bc(params, xy_b):
    x, y = xy_b
    u = pinn(params, x, y)
    return u


def loss(params, xy_in, xy_b):
    pde_res = jax.vmap(pde, in_axes=(None, 0))(params, xy_in)
    bc_res = jax.vmap(bc, in_axes=(None, 0))(params, xy_b)
    pde_loss = (pde_res**2).mean()
    bc_loss = (bc_res**2).mean()
    return pde_loss + 1e2 * bc_loss, (pde_loss, bc_loss)

Optimization을 준비하는 코드입니다.

In [None]:
import jaxopt

nIter = 10000

# lbfgs
opt = jaxopt.LBFGS(loss, has_aux=True)

# initialize
params = init(jr.PRNGKey(0))
state = opt.init_state(params, xy_in, xy_b)


@jax.jit
def step(params, state, xy_in=xy_in, xy_b=xy_b):
    params, state = opt.update(params, state, xy_in, xy_b)
    return params, state

Optimization을 실행하는 코드입니다.

In [6]:
import time


loss_total, loss_pde, loss_bc = [], [], []
print("Solving...")
tic = time.time()
for it in range(1, 1 + nIter):
    params, state = step(params, state)
    if it % 100 == 0:
        total_loss = state.value
        pde_loss, bc_loss = state.aux
        loss_total.append(total_loss)
        loss_pde.append(pde_loss)
        loss_bc.append(bc_loss)
        print(f"it: {it}, loss: {total_loss:.3e}")
toc = time.time()
print(f"Done! Elapsed time: {toc - tic:.2f}")

결과를 시각화하는 코드입니다.

In [None]:
import matplotlib.pyplot as plt


_, (ax0, ax1) = plt.subplots(ncols=2, figsize=(8, 4))
ax0.semilogy(loss_total, label=r"$\mathcal{L}_\mathrm{PINN}$")
ax0.semilogy(loss_pde, "--", label=r"$\mathcal{L}_\mathrm{pde}$")
ax0.semilogy(loss_bc, ":", label=r"$\mathcal{L}_\mathrm{bc}$")
ax0.legend()
ax0.set_title("PINN")

x, y = xy_in[:, 0], xy_in[:, 1]
u_pred = jax.vmap(pinn, (None, 0, 0))(params, x, y)
ax1.tricontourf(x, y, u_pred, cmap="jet")
ax1.set_title(r"$u_\theta$")
ax1.set_xlabel(r"$x$")
ax1.set_ylabel(r"$y$")

plt.tight_layout()
plt.savefig("figures/poisson2d", dpi=300)