In [84]:
import jax

In [4]:
NC = 64**3
NI = 64**3
NB = 64**2
NC_TEST = 100
SEED = 444
LR = 1e-3
EPOCHS = 50000
N_LAYERS = 5
FEATURES = 128
LOG_ITER = 25000

In [5]:
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)

In [6]:
feat_sizes = tuple([FEATURES for _ in range(N_LAYERS - 1)] + [1])

In [7]:
feat_sizes

(128, 128, 128, 128, 1)

In [67]:
import jax.numpy as jnp
import optax
from flax import linen as nn
from typing import Sequence

In [59]:
class PINN(nn.Module):
    features: Sequence[int]

    @nn.compact 
    def __call__(self, x, y, z):
        X = jnp.concatenate([x, y, z], axis=1)
        init = nn.initializers.glorot_normal()
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = nn.activation.tanh(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        return X

In [64]:
model = PINN(feat_sizes)
params = model.init(subkey, jnp.ones((NC, 1)), jnp.ones((NC, 1)), jnp.ones((NC, 1)))

In [70]:
optim = optax.adam(LR)
state = optim.init(params)

In [83]:
jax.random.uniform(key, (5, 1), minval=0., maxval=10.)

Array([[2.4968624 ],
       [0.32575607],
       [6.916964  ],
       [7.491734  ],
       [4.8239136 ]], dtype=float32)

In [105]:
def _klein_gordon3d_exact_u(t, x, y):
    return (x + y) * jnp.cos(2*t) + (x * y) * jnp.sin(2*t)

def _klein_gordon3d_source_term(t, x, y):
    u = _klein_gordon3d_exact_u(t, x, y)
    return u**2 - 4*u

def pinn_train_generator_klein_gordon3d(nc, ni, nb, key):
    keys = jax.random.split(key, 13)

    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)
    uc = _klein_gordon3d_source_term(tc, xc, yc)

    # initial points
    ti = jnp.zeros((ni, 1))
    xi = jax.random.uniform(keys[3], (ni, 1), minval=-1., maxval=1.)
    yi = jax.random.uniform(keys[4], (ni, 1), minval=-1., maxval=1.)
    ui = _klein_gordon3d_exact_u(ti, xi, yi)

    # boundary points (hard-coded)
    tb = [
        jax.random.uniform(keys[5], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[6], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[7], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[8], (nb, 1), minval=0., maxval=10.)
    ]
    xb = [
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
        jax.random.uniform(keys[9], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[10], (nb, 1), minval=-1., maxval=1.)
    ]
    yb = [
        jax.random.uniform(keys[11], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb)
    ]
    ub = []
    for i in range(4):
        ub += [_klein_gordon3d_exact_u(tb[i], xb[i], yb[i])]

    tb = jnp.concatenate(tb)
    xb = jnp.concatenate(xb)
    yb = jnp.concatenate(yb)
    ub = jnp.concatenate(ub)

    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub

def pinn_test_generator_klein_gordon3d(nc_test):
    t = jnp.linspace(0, 10, nc_test)
    x = jnp.linspace(-1, 1, nc_test)
    y = jnp.linspace(-1, 1, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    u_gt = _klein_gordon3d_exact_u(tm, xm, ym)
    t = tm.reshape(-1, 1)
    x = xm.reshape(-1, 1)
    y = ym.reshape(-1, 1)
    u_gt = u_gt.reshape(-1, 1)
    return t, x, y, u_gt

In [106]:
train_data = pinn_train_generator_klein_gordon3d(NC, NI, NB, subkey)
t, x, y, u_gt = pinn_test_generator_klein_gordon3d(NC_TEST)

In [110]:
from jax import jvp, vjp, value_and_grad

In [111]:
def hvp_fwdrev(f, primals, tangents, return_primals=False):
    g = lambda primals: vjp(f, primals)[1](tangents[0])[0]
    primals_out, tangets_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangets_out
    else:
        return tangets_out

In [112]:
def pinn_loss_klein_gordon3d(apply_fn, *train_data):
    def residual_loss(params, t, x, y, source_term):
        u = apply_fn(params, t, x, y)
        v = jnp.ones(u.shape)
        utt = hvp_fwdrev(lambda t: apply_fn(params, t, x, y), (t,), (v,))
        uxx = hvp_fwdrev(lambda x: apply_fn(params, t, x, y), (x,), (v,))
        uyy = hvp_fwdrev(lambda y: apply_fn(params, t, x, y), (y,), (v,))
        return jnp.mean((utt - uxx + uyy + u**2 - source_term)**2)
    
    def initial_boundary_loss(params, t, x, y, u):
        return jnp.mean((apply_fn(params, t, x, y) - u)**2)
    
    tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub = train_data 
    fn = lambda params: residual_loss(params, tc, xc, yc, uc) + \
                        initial_boundary_loss(params, ti, xi, yi, ui) + \
                        initial_boundary_loss(params, tb, xb, yb, ub)
    
    return fn

In [114]:
from functools import partial

@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

In [115]:
apply_fn = jax.jit(model.apply)
loss_fn = pinn_loss_klein_gordon3d(apply_fn, *train_data)

In [117]:
@jax.jit 
def train_one_step(params, state):
    loss, gradient = value_and_grad(loss_fn)(params)
    params, state = update_model(optim, gradient, params, state)
    return loss, params, state

In [119]:
import time
from tqdm import trange

In [120]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

In [121]:
start = time.time()
for e in trange(1, EPOCHS+1):
    loss, params, state = train_one_step(params, state)
    if e % LOG_ITER == 0:
        error = relative_l2(apply_fn(params, t, x, y), u_gt)
        print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')

end = time.time()
print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')

  0%|          | 0/50000 [00:00<?, ?it/s]

In [1]:
import matplotlib.pyplot as plt 

def plot_klein_gordon3d(t, x, y, u):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(t, x, y, c=u, s=0.5, cmap='seismic')
    ax.set_title('U(t, x, y)', fontsize=20, pad=-5)
    ax.set_xlabel('t', fontsize=18, labelpad=10)
    ax.set_ylabel('x', fontsize=18, labelpad=10)
    ax.set_zlabel('y', fontsize=18, labelpad=10)
    plt.show()

In [None]:
u = apply_fn(params, t, x, y)

plot_klein_gordon3d(t, x, y, u)