# SPINN Structure
> New!

In [None]:
import os
import pdb
import scipy.io

import jax
import jax.numpy as jnp
from jax import jvp, vjp
from flax import linen as nn
import optax

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from typing import Sequence
from functools import partial

In [None]:
# forward over forward
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# reverse over reverse
def hvp_revrev(f, primals, tangents, return_primals=False):
    g = lambda primals: vjp(f, primals)[1](tangents)
    primals_out, vjp_fn = vjp(g, primals)
    tangents_out = vjp_fn((tangents,))[0]
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# forward over reverse
def hvp_fwdrev(f, primals, tangents, return_primals=False):
    g = lambda primals: vjp(f, primals)[1](tangents[0])[0]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# reverse over forward
def hvp_revfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, primals, tangents)[1]
    primals_out, vjp_fn = vjp(g, primals)
    tangents_out = vjp_fn(tangents[0])[0][0]
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out

In [None]:
# 3d time-independent helmholtz exact u
#@partial(jax.jit, static_argnums=(0, 1, 2,))
def helmholtz3d_exact_u(a1, a2, a3, x, y, z):
    return jnp.sin(a1*jnp.pi*x) * jnp.sin(a2*jnp.pi*y) * jnp.sin(a3*jnp.pi*z)


# 3d time-independent helmholtz source term
#@partial(jax.jit, static_argnums=(0, 1, 2,))
def helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.):
    u_gt = helmholtz3d_exact_u(a1, a2, a3, x, y, z)
    uxx = -(a1*jnp.pi)**2 * u_gt
    uyy = -(a2*jnp.pi)**2 * u_gt
    uzz = -(a3*jnp.pi)**2 * u_gt
    return uxx + uyy + uzz + lda*u_gt


# 2d time-dependent klein-gordon exact u
def klein_gordon3d_exact_u(t, x, y, k):
    return (x + y) * jnp.cos(k * t) + (x * y) * jnp.sin(k * t)


# 2d time-dependent klein-gordon source term
def klein_gordon3d_source_term(t, x, y, k):
    u = klein_gordon3d_exact_u(t, x, y, k)
    return u**2 - (k**2)*u


# 3d time-dependent klein-gordon exact u
def klein_gordon4d_exact_u(t, x, y, z, k):
    return (x + y + z) * jnp.cos(k*t) + (x * y * z) * jnp.sin(k*t)


# 3d time-dependent klein-gordon source term
def klein_gordon4d_source_term(t, x, y, z, k):
    u = klein_gordon4d_exact_u(t, x, y, z, k)
    return u**2 - (k**2)*u


# 3d time-dependent navier-stokes forcing term
def navier_stokes4d_forcing_term(t, x, y, z, nu):
    # forcing terms in the PDE
    # f_x = -24*jnp.exp(-18*nu*t)*jnp.sin(2*y)*jnp.cos(2*y)*jnp.sin(z)*jnp.cos(z)
    f_x = -6*jnp.exp(-18*nu*t)*jnp.sin(4*y)*jnp.sin(2*z)
    # f_y = -24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(z)*jnp.cos(z)
    f_y = -6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(2*z)
    # f_z = 24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(2*y)
    f_z = 6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(4*y)
    return f_x, f_y, f_z


# 3d time-dependent navier-stokes exact vorticity
def navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vorticity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z


# 3d time-dependent navier-stokes exact velocity
def navier_stokes4d_exact_u(t, x, y, z, nu):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z

In [None]:



#========================== diffusion equation 3-d =========================#
#---------------------------------- PINN -----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _pinn_train_generator_diffusion3d(nc, key):
    keys = jax.random.split(key, 13)
    ni, nb = nc**2, nc**2

    # colocation points
    tc = jax.random.uniform(keys[0], (nc**3, 1), minval=0., maxval=1.)
    xc = jax.random.uniform(keys[1], (nc**3, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc**3, 1), minval=-1., maxval=1.)
    # initial points
    ti = jnp.zeros((ni, 1))
    xi = jax.random.uniform(keys[3], (ni, 1), minval=-1., maxval=1.)
    yi = jax.random.uniform(keys[4], (ni, 1), minval=-1., maxval=1.)
    ui = 0.25 * jnp.exp(-((xi - 0.3)**2 + (yi - 0.2)**2) / 0.1) + \
         0.4 * jnp.exp(-((xi + 0.5)**2 + (yi + 0.1)**2) * 15) + \
         0.3 * jnp.exp(-(xi**2 + (yi + 0.5)**2) * 20)
    # boundary points (hard-coded)
    tb = [
        jax.random.uniform(keys[5], (nb, 1), minval=0., maxval=1.),
        jax.random.uniform(keys[6], (nb, 1), minval=0., maxval=1.),
        jax.random.uniform(keys[7], (nb, 1), minval=0., maxval=1.),
        jax.random.uniform(keys[8], (nb, 1), minval=0., maxval=1.)
    ]
    xb = [
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
        jax.random.uniform(keys[9], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[10], (nb, 1), minval=-1., maxval=1.)
    ]
    yb = [
        jax.random.uniform(keys[11], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb)
    ]
    tb = jnp.concatenate(tb)
    xb = jnp.concatenate(xb)
    yb = jnp.concatenate(yb)
    return tc, xc, yc, ti, xi, yi, ui, tb, xb, yb


#---------------------------------- SPINN ----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_diffusion3d(nc, key):
    keys = jax.random.split(key, 3)
    # colocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=1.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)
    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    xi_mesh, yi_mesh = jnp.meshgrid(xi.ravel(), yi.ravel(), indexing='ij')
    ui = 0.25 * jnp.exp(-((xi_mesh - 0.3)**2 + (yi_mesh - 0.2)**2) / 0.1) + \
         0.4 * jnp.exp(-((xi_mesh + 0.5)**2 + (yi_mesh + 0.1)**2) * 15) + \
         0.3 * jnp.exp(-(xi_mesh**2 + (yi_mesh + 0.5)**2) * 20)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]])]
    return tc, xc, yc, ti, xi, yi, ui, tb, xb, yb


#========================== Helmholtz equation 3-d =========================#
#---------------------------------- PINN -----------------------------------#
#@partial(jax.jit, static_argnums=(0, 1, 2, 3,))
def _pinn_train_generator_helmholtz3d(a1, a2, a3, nc, key):
    keys = jax.random.split(key, 15)
    nb = nc**2
    # collocation points
    xc = jax.random.uniform(keys[0], (nc**3, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[1], (nc**3, 1), minval=-1., maxval=1.)
    zc = jax.random.uniform(keys[2], (nc**3, 1), minval=-1., maxval=1.)
    uc = helmholtz3d_source_term(a1, a2, a3, xc, yc, zc)
    # boundary points (hard-coded)
    xb = [
        jnp.array([[1.]]*nb),
        jnp.array([[-1.]]*nb),
        jax.random.uniform(keys[3], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[4], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[5], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[6], (nb, 1), minval=-1., maxval=1.)
    ]
    yb = [
        jax.random.uniform(keys[7], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[8], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[1.]]*nb),
        jnp.array([[-1.]]*nb),
        jax.random.uniform(keys[9], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[10], (nb, 1), minval=-1., maxval=1.),
    ]
    zb = [
        jax.random.uniform(keys[11], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[13], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[14], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[1.]]*nb),
        jnp.array([[-1.]]*nb)
    ]
    xb = jnp.concatenate(xb)
    yb = jnp.concatenate(yb)
    zb = jnp.concatenate(zb)
    return xc, yc, zc, uc, xb, yb, zb



#---------------------------------- SPINN ----------------------------------#
#@partial(jax.jit, static_argnums=(0, 1, 2, 3,))
def _spinn_train_generator_helmholtz3d(a1, a2, a3, nc, key):
    keys = jax.random.split(key, 3)
    # collocation points
    xc = jax.random.uniform(keys[0], (nc,), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[1], (nc,), minval=-1., maxval=1.)
    zc = jax.random.uniform(keys[2], (nc,), minval=-1., maxval=1.)
    # source term
    xcm, ycm, zcm = jnp.meshgrid(xc, yc, zc, indexing='ij')
    uc = helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm)
    xc, yc, zc = xc.reshape(-1, 1), yc.reshape(-1, 1), zc.reshape(-1, 1)
    # boundary (hard-coded)
    xb = [jnp.array([[1.]]), jnp.array([[-1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[1.]]), jnp.array([[-1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[1.]]), jnp.array([[-1.]])]
    return xc, yc, zc, uc, xb, yb, zb


#======================== Klein-Gordon equation 3-d ========================#
#---------------------------------- PINN -----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _pinn_train_generator_klein_gordon3d(nc, k, key):
    ni, nb = nc**2, nc**2
    keys = jax.random.split(key, 13)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc**3, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc**3, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc**3, 1), minval=-1., maxval=1.)
    uc = klein_gordon3d_source_term(tc, xc, yc, k)
    # initial points
    ti = jnp.zeros((ni, 1))
    xi = jax.random.uniform(keys[3], (ni, 1), minval=-1., maxval=1.)
    yi = jax.random.uniform(keys[4], (ni, 1), minval=-1., maxval=1.)
    ui = klein_gordon3d_exact_u(ti, xi, yi, k)
    # boundary points (hard-coded)
    tb = [
        jax.random.uniform(keys[5], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[6], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[7], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[8], (nb, 1), minval=0., maxval=10.)
    ]
    xb = [
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
        jax.random.uniform(keys[9], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[10], (nb, 1), minval=-1., maxval=1.)
    ]
    yb = [
        jax.random.uniform(keys[11], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb)
    ]
    ub = []
    for i in range(4):
        ub += [klein_gordon3d_exact_u(tb[i], xb[i], yb[i], k)]
    tb = jnp.concatenate(tb)
    xb = jnp.concatenate(xb)
    yb = jnp.concatenate(yb)
    ub = jnp.concatenate(ub)
    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub


#---------------------------------- SPINN ----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_klein_gordon3d(nc, k, key):
    keys = jax.random.split(key, 3)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)
    tc_mesh, xc_mesh, yc_mesh = jnp.meshgrid(tc.ravel(), xc.ravel(), yc.ravel(), indexing='ij')
    uc = klein_gordon3d_source_term(tc_mesh, xc_mesh, yc_mesh, k)
    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    ti_mesh, xi_mesh, yi_mesh = jnp.meshgrid(ti.ravel(), xi.ravel(), yi.ravel(), indexing='ij')
    ui = klein_gordon3d_exact_u(ti_mesh, xi_mesh, yi_mesh, k)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]])]
    ub = []
    for i in range(4):
        tb_mesh, xb_mesh, yb_mesh = jnp.meshgrid(tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), indexing='ij')
        ub += [klein_gordon3d_exact_u(tb_mesh, xb_mesh, yb_mesh, k)]
    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub


#======================== Klein-Gordon equation 4-d ========================#
#---------------------------------- PINN -----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _pinn_train_generator_klein_gordon4d(nc, k, key):
    ni, nb = nc**3, nc**3
    keys = jax.random.split(key, 24)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc**4, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc**4, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc**4, 1), minval=-1., maxval=1.)
    zc = jax.random.uniform(keys[3], (nc**4, 1), minval=-1., maxval=1.)
    uc = klein_gordon4d_source_term(tc, xc, yc, zc, k)
    # initial points
    ti = jnp.zeros((ni, 1))
    xi = jax.random.uniform(keys[4], (ni, 1), minval=-1., maxval=1.)
    yi = jax.random.uniform(keys[5], (ni, 1), minval=-1., maxval=1.)
    zi = jax.random.uniform(keys[6], (ni, 1), minval=-1., maxval=1.)
    ui = klein_gordon4d_exact_u(ti, xi, yi, zi, k)
    # boundary points (hard-coded)
    tb = [
        jax.random.uniform(keys[6], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[7], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[8], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[9], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[10], (nb, 1), minval=0., maxval=10.),
        jax.random.uniform(keys[11], (nb, 1), minval=0., maxval=10.)
    ]
    xb = [
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[13], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[14], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[15], (nb, 1), minval=-1., maxval=1.)
    ]
    yb = [
        jax.random.uniform(keys[16], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[17], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
        jax.random.uniform(keys[18], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[19], (nb, 1), minval=-1., maxval=1.)
    ]
    zb = [
        jax.random.uniform(keys[20], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[21], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[22], (nb, 1), minval=-1., maxval=1.),
        jax.random.uniform(keys[23], (nb, 1), minval=-1., maxval=1.),
        jnp.array([[-1.]]*nb),
        jnp.array([[1.]]*nb),
    ]
    ub = []
    for i in range(6):
        ub += [klein_gordon4d_exact_u(tb[i], xb[i], yb[i], zb[i], k)]
    tb = jnp.concatenate(tb)
    xb = jnp.concatenate(xb)
    yb = jnp.concatenate(yb)
    zb = jnp.concatenate(zb)
    ub = jnp.concatenate(ub)
    return tc, xc, yc, zc, uc, ti, xi, yi, zi, ui, tb, xb, yb, zb, ub


#---------------------------------- SPINN ----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_klein_gordon4d(nc, k, key):
    keys = jax.random.split(key, 4)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)
    zc = jax.random.uniform(keys[3], (nc, 1), minval=-1., maxval=1.)
    tcm, xcm, ycm, zcm = jnp.meshgrid(
        tc.ravel(), xc.ravel(), yc.ravel(), zc.ravel(), indexing='ij'
    )
    uc = klein_gordon4d_source_term(tcm, xcm, ycm, zcm, k)
    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    zi = zc
    tim, xim, yim, zim = jnp.meshgrid(
        ti.ravel(), xi.ravel(), yi.ravel(), zi.ravel(), indexing='ij'
    )
    ui = klein_gordon4d_exact_u(tim, xim, yim, zim, k)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[-1.]]), jnp.array([[1.]])]
    ub = []
    for i in range(6):
        tbm, xbm, ybm, zbm = jnp.meshgrid(
            tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), zb[i].ravel(), indexing='ij'
        )
        ub += [klein_gordon4d_exact_u(tbm, xbm, ybm, zbm, k)]
    return tc, xc, yc, zc, uc, ti, xi, yi, zi, ui, tb, xb, yb, zb, ub


#======================== Navier-Stokes equation 3-d ========================#
#---------------------------------- SPINN -----------------------------------#
def _spinn_train_generator_navier_stokes3d(nt, nxy, data_dir, result_dir, marching_steps, step_idx, offset_num, key):
    keys = jax.random.split(key, 2)
    gt_data = scipy.io.loadmat(os.path.join(data_dir, 'w_data.mat'))
    t = gt_data['t']

    # initial points
    ti = jnp.zeros((1, 1))
    xi = gt_data['x']
    yi = gt_data['y']
    if step_idx == 0:
        # get data from ground truth
        w0 = gt_data
    else:
        # get data from previous time window prediction
        w0 = scipy.io.loadmat(os.path.join(result_dir, '..', f'IC_pred/w0_{step_idx}.mat'))
        ti = w0['t']

    # collocation points
    tc = jnp.expand_dims(jnp.linspace(start=0., stop=t[0][-1], num=nt, endpoint=False), axis=1)
    xc = jnp.expand_dims(jnp.linspace(start=0., stop=2.*jnp.pi, num=nxy, endpoint=False), axis=1)
    yc = jnp.expand_dims(jnp.linspace(start=0., stop=2.*jnp.pi, num=nxy, endpoint=False), axis=1)

    if marching_steps != 0:
        # when using time marching
        Dt = t[0][-1] / marching_steps  # interval of a single time window
        # generate temporal coordinates within current time window
        if step_idx == 0:
            tc = jnp.expand_dims(jnp.linspace(start=0., stop=Dt*(step_idx+1), num=nt, endpoint=False), axis=1)
        else:
            tc = jnp.expand_dims(jnp.linspace(start=w0['t'][0][0], stop=Dt*(step_idx+1), num=nt, endpoint=False), axis=1)

    # for stacking multi-input grid
    tc_mult = jnp.expand_dims(tc, axis=0)
    xc_mult = jnp.expand_dims(xc, axis=0)
    yc_mult = jnp.expand_dims(yc, axis=0)

    # maximum value of offsets
    dt = tc[1][0] - tc[0][0]
    dxy = xc[1][0] - xc[0][0]

    # create offset values (zero is included by default)
    offset_t = jax.random.uniform(keys[0], (offset_num-1,), minval=0., maxval=dt)
    offset_xy = jax.random.uniform(keys[1], (offset_num-1,), minval=0., maxval=dxy)

    # make multi-grid
    for i in range(offset_num-1):
        tc_mult = jnp.concatenate((tc_mult, jnp.expand_dims(tc + offset_t[i], axis= 0)), axis=0)
        xc_mult = jnp.concatenate((xc_mult, jnp.expand_dims(xc + offset_xy[i], axis=0)), axis=0)
        yc_mult = jnp.concatenate((yc_mult, jnp.expand_dims(yc + offset_xy[i], axis=0)), axis=0)

    return tc_mult, xc_mult, yc_mult, ti, xi, yi, w0['w0'], w0['u0'], w0['v0']


#======================== Navier-Stokes equation 4-d ========================#
#---------------------------------- SPINN -----------------------------------#
#@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_navier_stokes4d(nc, nu, key):
    keys = jax.random.split(key, 4)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=5.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=0., maxval=2.*jnp.pi)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=0., maxval=2.*jnp.pi)
    zc = jax.random.uniform(keys[3], (nc, 1), minval=0., maxval=2.*jnp.pi)

    tcm, xcm, ycm, zcm = jnp.meshgrid(
        tc.ravel(), xc.ravel(), yc.ravel(), zc.ravel(), indexing='ij'
    )
    fc = navier_stokes4d_forcing_term(tcm, xcm, ycm, zcm, nu)

    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    zi = zc
    tim, xim, yim, zim = jnp.meshgrid(
        ti.ravel(), xi.ravel(), yi.ravel(), zi.ravel(), indexing='ij'
    )
    wi = navier_stokes4d_exact_w(tim, xim, yim, zim, nu)
    ui = navier_stokes4d_exact_u(tim, xim, yim, zim, nu)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[-1.]]), jnp.array([[1.]])]
    wb = []
    for i in range(6):
        tbm, xbm, ybm, zbm = jnp.meshgrid(
            tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), zb[i].ravel(), indexing='ij'
        )
        wb += [navier_stokes4d_exact_w(tbm, xbm, ybm, zbm, nu)]
    return tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb


def generate_train_data(args, key, result_dir=None):
    eqn = args.equation
    if args.model == 'pinn':
        if eqn == 'diffusion3d':
            data = _pinn_train_generator_diffusion3d(
                args.nc, key
            )
        elif eqn == 'helmholtz3d':
            data = _pinn_train_generator_helmholtz3d(
                args.a1, args.a2, args.a3, args.nc, key
            )
        elif eqn == 'klein_gordon3d':
            data = _pinn_train_generator_klein_gordon3d(
                args.nc, args.k, key
            )
        elif eqn == 'klein_gordon4d':
            data = _pinn_train_generator_klein_gordon4d(
                args.nc, args.k, key
            )
        else:
            raise NotImplementedError
    elif args.model == 'spinn':
        if eqn == 'diffusion3d':
            data = _spinn_train_generator_diffusion3d(
                args.nc, key
            )
        elif eqn == 'helmholtz3d':
            data = _spinn_train_generator_helmholtz3d(
                args.a1, args.a2, args.a3, args.nc, key
            )
        elif eqn == 'klein_gordon3d':
            data = _spinn_train_generator_klein_gordon3d(
                args.nc, args.k, key
            )
        elif eqn == 'klein_gordon4d':
            data = _spinn_train_generator_klein_gordon4d(
                args.nc, args.k, key
            )
        elif eqn == 'navier_stokes3d':
            data = _spinn_train_generator_navier_stokes3d(
                args.nt, args.nxy, args.data_dir, result_dir, args.marching_steps, args.step_idx, args.offset_num, key
            )
        elif eqn == 'navier_stokes4d':
            data = _spinn_train_generator_navier_stokes4d(
                args.nc, args.nu, key
            )
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
    return data


#============================== test dataset ===============================#
#------------------------- diffusion equation 3-d --------------------------#
#@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_diffusion3d(model, data_dir):
    u_gt, tt = [], 0.
    for _ in range(101):
        u_gt += [jnp.load(os.path.join(data_dir, f'heat_gaussian_{tt:.2f}.npy'))]
        tt += 0.01
    u_gt = jnp.stack(u_gt)
    t = jnp.linspace(0., 1., u_gt.shape[0])
    x = jnp.linspace(-1., 1., u_gt.shape[1])
    y = jnp.linspace(-1., 1., u_gt.shape[2])
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    if model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        u_gt = u_gt.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
    return t, x, y, u_gt


#------------------------- Helmholtz equation 3-d --------------------------#
#@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4,))
def _test_generator_helmholtz3d(model, a1, a2, a3, nc_test):
    x = jnp.linspace(-1., 1., nc_test)
    y = jnp.linspace(-1., 1., nc_test)
    z = jnp.linspace(-1., 1., nc_test)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')
    u_gt = helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm)
    if model == 'pinn':
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        z = zm.reshape(-1, 1)
        u_gt = u_gt.reshape(-1, 1)
    else:
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
        z = z.reshape(-1, 1)
    return x, y, z, u_gt


#----------------------- Klein-Gordon equation 3-d -------------------------#
#@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_klein_gordon3d(model, nc_test, k):
    t = jnp.linspace(0, 10, nc_test)
    x = jnp.linspace(-1, 1, nc_test)
    y = jnp.linspace(-1, 1, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    u_gt = klein_gordon3d_exact_u(tm, xm, ym, k)
    if model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        u_gt = u_gt.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
    return t, x, y, u_gt


#----------------------- Klein-Gordon equation 4-d -------------------------#
#@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_klein_gordon4d(model, nc_test, k):
    t = jnp.linspace(0, 10, nc_test)
    x = jnp.linspace(-1, 1, nc_test)
    y = jnp.linspace(-1, 1, nc_test)
    z = jnp.linspace(-1, 1, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    tm, xm, ym, zm = jnp.meshgrid(
        t, x, y, z, indexing='ij'
    )
    u_gt = klein_gordon4d_exact_u(tm, xm, ym, zm, k)
    if model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        z = zm.reshape(-1, 1)
        u_gt = u_gt.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
        z = z.reshape(-1, 1)
    return t, x, y, z, u_gt


#----------------------- Navier-Stokes equation 3-d -------------------------#
def _test_generator_navier_stokes3d(model, data_dir, result_dir, marching_steps, step_idx):
    ns_data = scipy.io.loadmat(os.path.join(data_dir, 'w_data.mat'))
    t = ns_data['t'].reshape(-1, 1)
    x = ns_data['x'].reshape(-1, 1)
    y = ns_data['y'].reshape(-1, 1)
    t = jnp.insert(t, 0, jnp.array([0.]), axis=0)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)

    gt = ns_data['w']   # without t=0
    gt = jnp.insert(gt, 0, ns_data['w0'], axis=0)

    # get data within current time window
    if marching_steps != 0:
        Dt = t[-1][0] / marching_steps  # interval of time window
        i = 0
        while Dt*(step_idx+1) > t[i][0]:
            i+=1
        t = t[:i]
        gt = gt[:i]

    # get data within current time window
    if step_idx > 0:
        w0_pred = scipy.io.loadmat(os.path.join(result_dir, '..', f'IC_pred/w0_{step_idx}.mat'))
        i = 0
        while t[i] != w0_pred['t'][0][0]:
            i+=1
        t = t[i:]
        gt = gt[i:]

    if model == 'pinn':
        tm, xm, ym = jnp.meshgrid(t.ravel(), x.ravel(), y.ravel(), indexing='ij')
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        gt = gt.reshape(-1, 1)
    
    return t, x, y, gt


#----------------------- Navier-Stokes equation 4-d -------------------------#
#@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_navier_stokes4d(model, nc_test, nu):
    t = jnp.linspace(0, 5, nc_test)
    x = jnp.linspace(0, 2*jnp.pi, nc_test)
    y = jnp.linspace(0, 2*jnp.pi, nc_test)
    z = jnp.linspace(0, 2*jnp.pi, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    tm, xm, ym, zm = jnp.meshgrid(
        t, x, y, z, indexing='ij'
    )
    w_gt = navier_stokes4d_exact_w(tm, xm, ym, zm, nu)
    if model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        z = zm.reshape(-1, 1)
        w_gt = w_gt.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
        z = z.reshape(-1, 1)
    return t, x, y, z, w_gt


def generate_test_data(args, result_dir):
    eqn = args.equation
    if eqn == 'diffusion3d':
        data = _test_generator_diffusion3d(
            args.model, args.data_dir
        )
    elif eqn == 'helmholtz3d':
        data = _test_generator_helmholtz3d(
            args.model, args.a1, args.a2, args.a3, args.nc_test
        )
    elif eqn == 'klein_gordon3d':
        data = _test_generator_klein_gordon3d(
            args.model, args.nc_test, args.k
        )
    elif eqn == 'klein_gordon4d':
        data = _test_generator_klein_gordon4d(
            args.model, args.nc_test, args.k
        )
    elif eqn == 'navier_stokes3d':
        data = _test_generator_navier_stokes3d(
            args.model, args.data_dir, result_dir, args.marching_steps, args.step_idx
        )
    elif eqn == 'navier_stokes4d':
        data = _test_generator_navier_stokes4d(
            args.model, args.nc_test, args.nu
        )
    else:
        raise NotImplementedError
    return data

In [None]:



def velocity_to_vorticity_fwd(apply_fn, params, t, x, y):
    # t, x, y, _ = train_data
    vec = jnp.ones(x.shape)
    # w = v_x - u_y
    v_x = jvp(lambda x: apply_fn(params, t, x, y)[1], (x,), (vec,))[1]
    u_y = jvp(lambda y: apply_fn(params, t, x, y)[0], (y,), (vec,))[1]
    
    return v_x - u_y


def velocity_to_vorticity_rev(apply_fn, params, t, x, y):    
    # w = v_x - u_y
    v, vjp_fn = vjp(lambda x: apply_fn(params, t, x, y)[..., 1], x)
    v_x = vjp_fn(jnp.ones(v.shape))[0]
    u, vjp_fn = vjp(lambda x: apply_fn(params, t, x, y)[..., 0], y)
    u_y = vjp_fn(jnp.ones(u.shape))[0]
    
    return v_x - u_y


def vorx(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_x = uz_y - uy_z
    vec_z = jnp.ones(z.shape)
    vec_y = jnp.ones(y.shape)
    uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec_z,))[1]
    uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec_y,))[1]
    wx = uz_y - uy_z
    return wx


def vory(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_y = ux_z - uz_x
    vec_z = jnp.ones(z.shape)
    vec_x = jnp.ones(x.shape)
    ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec_z,))[1]
    uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec_x,))[1]
    wy = ux_z - uz_x
    return wy


def vorz(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_z = uy_x - ux_y
    vec_y = jnp.ones(y.shape)
    vec_x = jnp.ones(x.shape)
    ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec_y,))[1]
    uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec_x,))[1]
    wz = uy_x - ux_y
    return wz

In [None]:



def _diffusion3d(args, apply_fn, params, test_data, result_dir, e, resol):
    print("visualizing solution...")

    nt = 11 # number of time steps to visualize
    t = jnp.linspace(0., 1., nt)
    x = jnp.linspace(-1., 1., resol)
    y = jnp.linspace(-1., 1., resol)
    xd, yd = jnp.meshgrid(x, y, indexing='ij')  # for 3-d surface plot
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    if args.model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)

    u_ref = test_data[-1]
    ref_idx = 0

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)
    u = apply_fn(params, t, x, y)
    if args.model == 'pinn':
        u = u.reshape(nt, resol, resol)
        u_ref = u_ref.reshape(-1, resol, resol)

    for tt in range(nt):
        fig = plt.figure(figsize=(12, 6))

        # reference solution (hard-coded; must be modified if nt changes)
        ax1 = fig.add_subplot(121, projection='3d')
        im = ax1.plot_surface(xd, yd, u_ref[ref_idx], cmap='jet', linewidth=0, antialiased=False)
        ref_idx += 10
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        ax1.set_zlabel('u')
        ax1.set_title(f'Reference $u(x, y)$ at $t={tt*(1/(nt-1)):.1f}$', fontsize=15)
        ax1.set_zlim(jnp.min(u_ref), jnp.max(u_ref))

        # predicted solution
        ax2 = fig.add_subplot(122, projection='3d')
        im = ax2.plot_surface(xd, yd, u[tt], cmap='jet', linewidth=0, antialiased=False)
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        ax2.set_zlabel('u')
        ax2.set_title(f'Predicted $u(x, y)$ at $t={tt*(1/(nt-1)):.1f}$', fontsize=15)
        ax2.set_zlim(jnp.min(u_ref), jnp.max(u_ref))

        plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred_{tt*(1/(nt-1)):.1f}.png'))
        plt.close()


def _helmholtz3d(args, apply_fn, params, result_dir, e, resol):
    print("visualizing solution...")

    x = jnp.linspace(-1., 1., resol)
    y = jnp.linspace(-1., 1., resol)
    z = jnp.linspace(-1., 1., resol)
    xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')
    if args.model == 'pinn':
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        z = zm.reshape(-1, 1)
    else:
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
        z = z.reshape(-1, 1)

    u_ref = helmholtz3d_exact_u(args.a1, args.a2, args.a3, xm, ym, zm)

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)
    u_pred = apply_fn(params, x, y, z)
    if args.model == 'pinn':
        u_pred = u_pred.reshape(resol, resol, resol)
        u_ref = u_ref.reshape(resol, resol, resol)

    fig = plt.figure(figsize=(14, 5))

    # reference solution
    ax1 = fig.add_subplot(131, projection='3d')
    im = ax1.scatter(xm, ym, zm, c=u_ref, cmap = 'seismic', s=0.5)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.set_title(f'Reference $u(x, y, z)$', fontsize=15)

    # predicted solution
    ax2 = fig.add_subplot(132, projection='3d')
    im = ax2.scatter(xm, ym, zm, c=u_pred, cmap = 'seismic', s=0.5, vmin=jnp.min(u_ref), vmax=jnp.max(u_ref))
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('z')
    ax2.set_title(f'Predicted $u(x, y, z)$', fontsize=15)

    # absolute error
    ax3 = fig.add_subplot(133, projection='3d')
    im = ax3.scatter(xm, ym, zm, c=jnp.abs(u_ref-u_pred), cmap = 'seismic', s=0.5, vmin=jnp.min(u_ref), vmax=jnp.max(u_ref))
    ax3.set_xlabel('x')
    ax3.set_ylabel('y')
    ax3.set_zlabel('z')
    ax3.set_title(f'Absolute error', fontsize=15)

    cbar_ax = fig.add_axes([0.95, 0.3, 0.01, 0.4])
    fig.colorbar(im, cax=cbar_ax)

    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()


def _klein_gordon3d(args, apply_fn, params, result_dir, e, resol):
    print("visualizing solution...")

    t = jnp.linspace(0., 10., resol)
    x = jnp.linspace(-1., 1., resol)
    y = jnp.linspace(-1., 1., resol)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    if args.model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)

    u_ref = klein_gordon3d_exact_u(tm, xm, ym, args.k)

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)
    u_pred = apply_fn(params, t, x, y)
    if args.model == 'pinn':
        u_pred = u_pred.reshape(resol, resol, resol)
        u_ref = u_ref.reshape(resol, resol, resol)

    fig = plt.figure(figsize=(14, 5))

    # reference solution
    ax1 = fig.add_subplot(131, projection='3d')
    im = ax1.scatter(tm, xm, ym, c=u_ref, cmap = 'seismic', s=0.5)
    ax1.set_xlabel('t')
    ax1.set_ylabel('x')
    ax1.set_zlabel('y')
    ax1.set_title(f'Reference $u(t, x, y)$', fontsize=15)

    # predicted solution
    ax2 = fig.add_subplot(132, projection='3d')
    im = ax2.scatter(tm, xm, ym, c=u_pred, cmap = 'seismic', s=0.5, vmin=jnp.min(u_ref), vmax=jnp.max(u_ref))
    ax2.set_xlabel('t')
    ax2.set_ylabel('x')
    ax2.set_zlabel('y')
    ax2.set_title(f'Predicted $u(t, x, y)$', fontsize=15)

    # absolute error
    ax3 = fig.add_subplot(133, projection='3d')
    im = ax3.scatter(tm, xm, ym, c=jnp.abs(u_ref-u_pred), cmap = 'seismic', s=0.5, vmin=jnp.min(u_ref), vmax=jnp.max(u_ref))
    ax3.set_xlabel('t')
    ax3.set_ylabel('x')
    ax3.set_zlabel('y')
    ax3.set_title(f'Absolute error', fontsize=15)

    cbar_ax = fig.add_axes([0.95, 0.3, 0.01, 0.4])
    fig.colorbar(im, cax=cbar_ax)

    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()


def _navier_stokes3d(apply_fn, params, test_data, result_dir, e):
    print("visualizing solution...")

    nt, nx, ny = test_data[0].shape[0], test_data[1].shape[0], test_data[2].shape[0]

    t = test_data[0][-1]
    t = jnp.expand_dims(t, axis=1)

    w_pred = velocity_to_vorticity_fwd(apply_fn, params, t, test_data[1], test_data[2])
    w_pred = w_pred.reshape(-1, nx, ny)
    w_ref = test_data[-1][-1]

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)

    fig = plt.figure(figsize=(14, 5))

    # reference solution
    ax1 = fig.add_subplot(131)
    im = ax1.imshow(w_ref, cmap='jet', extent=[0, 2*jnp.pi, 0, 2*jnp.pi], vmin=jnp.min(w_ref), vmax=jnp.max(w_ref))
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax1.set_title(f'Reference $\omega(t={jnp.round(t[0][0], 1):.2f}, x, y)$', fontsize=15)

    # predicted solution
    ax1 = fig.add_subplot(132)
    im = ax1.imshow(w_pred[0], cmap='jet', extent=[0, 2*jnp.pi, 0, 2*jnp.pi], vmin=jnp.min(w_ref), vmax=jnp.max(w_ref))
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax1.set_title(f'Predicted $\omega(t={jnp.round(t[0][0], 1):.2f}, x, y)$', fontsize=15)

    # absolute error
    ax1 = fig.add_subplot(133)
    im = ax1.imshow(jnp.abs(w_ref - w_pred[0]), cmap='jet', extent=[0, 2*jnp.pi, 0, 2*jnp.pi], vmin=jnp.min(w_ref), vmax=jnp.max(w_ref))
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax1.set_title(f'Asolute error', fontsize=15)

    cbar_ax = fig.add_axes([0.95, 0.3, 0.01, 0.4])
    fig.colorbar(im, cax=cbar_ax)
    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()


def _navier_stokes4d(apply_fn, params, test_data, result_dir, e):
    print("visualizing solution...")

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)

    fig = plt.figure(figsize=(30, 5))
    for t, sub in zip([0, 1, 2, 3, 4, 5], [161, 162, 163, 164, 165, 166]):
        t = jnp.array([[t]])
        x = jnp.linspace(0, 2*jnp.pi, 4).reshape(-1, 1)
        y = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        z = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)

        # c = jnp.sqrt(u_x**2 + u_y**2 + u_z**2)   # magnitude
        c = jnp.arctan2(wy, wz)    # zenith angle
        c = (c.ravel() - c.min()) / c.ptp()
        c = jnp.concatenate((c, jnp.repeat(c, 2)))
        c = plt.cm.plasma(c)

        x, y, z = jnp.meshgrid(jnp.squeeze(x), jnp.squeeze(y), jnp.squeeze(z), indexing='ij')

        ax = fig.add_subplot(sub, projection='3d')
        ax.quiver(x, y, z, jnp.squeeze(wx), jnp.squeeze(wy), jnp.squeeze(wz), length=0.1, colors=c, alpha=1, linewidth=0.7)
        plt.title(f't={jnp.squeeze(t)}')
    
    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()


def show_solution(args, apply_fn, params, test_data, result_dir, e, resol=50):
    if args.equation == 'diffusion3d':
        _diffusion3d(args, apply_fn, params, test_data, result_dir, e, resol)
    elif args.equation == 'helmholtz3d':
        _helmholtz3d(args, apply_fn, params, result_dir, e, resol)
    elif args.equation == 'klein_gordon3d':
        _klein_gordon3d(args, apply_fn, params, result_dir, e, resol)
    elif args.equation == 'navier_stokes3d':
        _navier_stokes3d(apply_fn, params, test_data, result_dir, e)
    elif args.equation == 'navier_stokes4d':
        _navier_stokes4d(apply_fn, params, test_data, result_dir, e)
    else:
        raise NotImplementedError

In [None]:


def _navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vortcity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z

def _navier_stokes4d_exact_u(t, x, y, z, nu=0.05):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z


class NS_exact(nn.Module):
    @nn.compact
    def __call__(self, t, x, y, z):
        # pdb.set_trace()
        if jnp.ndim(t) > 1:
            t = jnp.squeeze(t, axis=1)
        if jnp.ndim(x) > 1:
            x = jnp.squeeze(x, axis=1)
        if jnp.ndim(y) > 1:
            y = jnp.squeeze(y, axis=1)
        if jnp.ndim(z) > 1:
            z = jnp.squeeze(z, axis=1)
        t, x, y, z = jnp.meshgrid(t, x, y, z, indexing='ij')
        u_x, u_y, u_z = _navier_stokes4d_exact_u(t, x, y, z)
        # pdb.set_trace()
        return u_x, u_y, u_z



class PINN2d(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x, y):
        X = jnp.concatenate([x, y], axis=1)
        init = nn.initializers.glorot_normal()
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = nn.activation.tanh(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        return X
    

class PINN3d(nn.Module):
    features: Sequence[int]
    out_dim: int
    pos_enc: int

    @nn.compact
    def __call__(self, x, y, z):
        if self.pos_enc != 0:
            # freq = jnp.array([[2**k for k in range(int(-(self.pos_enc-1)/2), int((self.pos_enc+1)/2))]]) * jnp.pi
            freq = jnp.array([[2**k for k in range(int(-(self.pos_enc-1)/2), int((self.pos_enc+1)/2))]])
            x = jnp.concatenate((jnp.sin(x@freq), jnp.cos(x@freq)), 1)
            y = jnp.concatenate((jnp.sin(y@freq), jnp.cos(y@freq)), 1)
            z = jnp.concatenate((jnp.sin(z@freq), jnp.cos(z@freq)), 1)
        X = jnp.concatenate([x, y, z], axis=1)
        
        init = nn.initializers.glorot_normal()
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = nn.activation.tanh(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)

        return X


class PINN4d(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, t, x, y, z):
        X = jnp.concatenate([t, x, y, z], axis=1)
        init = nn.initializers.glorot_normal()
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = nn.activation.tanh(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        return X


class SPINN2d(nn.Module):
    features: Sequence[int]
    r: int
    mlp: str

    @nn.compact
    def __call__(self, x, y):
        inputs, outputs = [x, y], []
        init = nn.initializers.glorot_normal()
        if self.mlp == 'mlp':
            for X in inputs:
                for fs in self.features[:-1]:
                    X = nn.Dense(fs, kernel_init=init)(X)
                    X = nn.activation.tanh(X)
                X = nn.Dense(self.r, kernel_init=init)(X)
                outputs += [X]
        else:
            for X in inputs:
                U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                for fs in self.features[:-1]:
                    Z = nn.Dense(fs, kernel_init=init)(H)
                    Z = nn.activation.tanh(Z)
                    H = (jnp.ones_like(Z)-Z)*U + Z*V
                H = nn.Dense(self.r, kernel_init=init)(H)
                outputs += [H]

        return jnp.dot(outputs[0], outputs[-1].T)


class SPINN3d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    pos_enc: int
    mlp: str

    @nn.compact
    def __call__(self, x, y, z):
        '''
        inputs: input factorized coordinates
        outputs: feature output of each body network
        xy: intermediate tensor for feature merge btw. x and y axis
        pred: final model prediction (e.g. for 2d output, pred=[u, v])
        '''
        if self.pos_enc != 0:
            # positional encoding only to spatial coordinates
            freq = jnp.expand_dims(jnp.arange(1, self.pos_enc+1, 1), 0)
            y = jnp.concatenate((jnp.ones((y.shape[0], 1)), jnp.sin(y@freq), jnp.cos(y@freq)), 1)
            z = jnp.concatenate((jnp.ones((z.shape[0], 1)), jnp.sin(z@freq), jnp.cos(z@freq)), 1)

            # causal PINN version (also on time axis)
            #  freq_x = jnp.expand_dims(jnp.power(10.0, jnp.arange(0, 3)), 0)
            # x = x@freq_x
            
        inputs, outputs, xy, pred = [x, y, z], [], [], []
        init = nn.initializers.glorot_normal()

        if self.mlp == 'mlp':
            for X in inputs:
                for fs in self.features[:-1]:
                    X = nn.Dense(fs, kernel_init=init)(X)
                    X = nn.activation.tanh(X)
                X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
                outputs += [jnp.transpose(X, (1, 0))]

        elif self.mlp == 'modified_mlp':
            for X in inputs:
                U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                for fs in self.features[:-1]:
                    Z = nn.Dense(fs, kernel_init=init)(H)
                    Z = nn.activation.tanh(Z)
                    H = (jnp.ones_like(Z)-Z)*U + Z*V
                H = nn.Dense(self.r*self.out_dim, kernel_init=init)(H)
                outputs += [jnp.transpose(H, (1, 0))]
        
        for i in range(self.out_dim):
            xy += [jnp.einsum('fx, fy->fxy', outputs[0][self.r*i:self.r*(i+1)], outputs[1][self.r*i:self.r*(i+1)])]
            pred += [jnp.einsum('fxy, fz->xyz', xy[i], outputs[-1][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred


class SPINN4d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    mlp: str

    @nn.compact
    def __call__(self, t, x, y, z):
        inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []
        # inputs, outputs = [t, x, y, z], []
        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        for i in range(self.out_dim):
            tx += [jnp.einsum('ft, fx->ftx', 
            outputs[0][self.r*i:self.r*(i+1)], 
            outputs[1][self.r*i:self.r*(i+1)])]

            txy += [jnp.einsum('ftx, fy->ftxy', 
            tx[i], 
            outputs[2][self.r*i:self.r*(i+1)])]

            pred += [jnp.einsum('ftxy, fz->txyz', 
            txy[i], 
            outputs[3][self.r*i:self.r*(i+1)])]


        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

class SPINNnd(nn.Module):
    features: Sequence[int]
    r: int

    @nn.compact
    def __call__(self, t, *x):
        inputs = [t, *x]
        dim = len(inputs)
        # inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []
        # inputs, outputs = [t, x, y, z], []
        outputs = []
        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        # einsum(a,b->c)
        a = 'za'
        b = 'zb'
        c = 'zab'
        pred = jnp.einsum(f'{a}, {b}->{c}', outputs[0], outputs[1])
        for i in range(dim-2):
            a = c
            b = f'z{chr(97+i+2)}'
            c = c+chr(97+i+2)
            if i == dim-3:
                c = c[1:]
            pred = jnp.einsum(f'{a}, {b}->{c}', pred, outputs[i+2])
            # pred = jnp.einsum('fab, fc->fabc', pred, outputs[i+2])

        return pred

In [None]:



def setup_networks(args, key):
    # build network
    dim = args.equation[-2:]
    if args.model == 'pinn':
        # feature sizes
        feat_sizes = tuple([args.features for _ in range(args.n_layers - 1)] + [args.out_dim])
        if dim == '2d':
            model = PINN2d(feat_sizes)
        elif dim == '3d':
            model = PINN3d(feat_sizes, args.out_dim, args.pos_enc)
        elif dim == '4d':
            model = PINN4d(feat_sizes)
        else:
            raise NotImplementedError
    else: # SPINN
        # feature sizes
        feat_sizes = tuple([args.features for _ in range(args.n_layers)])
        if dim == '2d':
            model = SPINN2d(feat_sizes, args.r, args.mlp)
        elif dim == '3d':
            model = SPINN3d(feat_sizes, args.r, args.out_dim, args.pos_enc, args.mlp)
        elif dim == '4d':
            model = SPINN4d(feat_sizes, args.r, args.out_dim, args.mlp)
        else:
            raise NotImplementedError
    # initialize params
    # dummy inputs must be given
    if dim == '2d':
        params = model.init(
            key,
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1))
        )
    elif dim == '3d':
        if args.equation == 'navier_stokes3d':
            params = model.init(
                key,
                jnp.ones((args.nt, 1)),
                jnp.ones((args.nxy, 1)),
                jnp.ones((args.nxy, 1))
            )
        else:
            params = model.init(
                key,
                jnp.ones((args.nc, 1)),
                jnp.ones((args.nc, 1)),
                jnp.ones((args.nc, 1))
            )
    elif dim == '4d':
        params = model.init(
            key,
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1))
        )
    else:
        raise NotImplementedError

    return jax.jit(model.apply), params


def name_model(args):
    name = [
        f'nl{args.n_layers}',
        f'fs{args.features}',
        f'lr{args.lr}',
        f's{args.seed}',
        f'r{args.r}'
    ]
    if args.model != 'spinn':
        del name[-1]
    if args.equation != 'navier_stokes3d':
        name.insert(0, f'nc{args.nc}')
    if args.equation == 'navier_stokes3d':
        name.insert(0, f'nxy{args.nxy}')
        name.insert(0, f'nt{args.nt}')
        name.append(f'on{args.offset_num}')
        name.append(f'oi{args.offset_iter}')
        name.append(f'lc{args.lbda_c}')
        name.append(f'lic{args.lbda_ic}')
    if args.equation == 'navier_stokes4d':
        name.append(f'lc{args.lbda_c}')
        name.append(f'li{args.lbda_ic}')
    if args.equation == 'helmholtz3d':
        name.append(f'a{args.a1}{args.a2}{args.a3}')
    if args.equation == 'klein_gordon3d':
        name.append(f'k{args.k}')
    
    name.append(f'{args.mlp}')
        
    return '_'.join(name)


def save_config(args, result_dir):
    with open(os.path.join(result_dir, 'configs.txt'), 'w') as f:
        for arg in vars(args):
            f.write(f'{arg}: {getattr(args, arg)}\n')


# single update function
#@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state


# save next initial condition for time-marching
def save_next_IC(root_dir, name, apply_fn, params, test_data, step_idx, e):
    os.makedirs(os.path.join(root_dir, name, 'IC_pred'), exist_ok=True)

    w_pred = velocity_to_vorticity_fwd(apply_fn, params, jnp.expand_dims(test_data[0][-1], axis=1), test_data[1], test_data[2])
    w_pred = w_pred.reshape(-1, test_data[1].shape[0], test_data[2].shape[0])[0]
    u0_pred, v0_pred = apply_fn(params, jnp.expand_dims(test_data[0][-1], axis=1), test_data[1], test_data[2])
    u0_pred, v0_pred = jnp.squeeze(u0_pred), jnp.squeeze(v0_pred)
    
    scipy.io.savemat(os.path.join(root_dir, name, f'IC_pred/w0_{step_idx+1}.mat'), mdict={'w0': w_pred, 'u0': u0_pred, 'v0': v0_pred, 't': jnp.expand_dims(test_data[0][-1], axis=1)})

In [None]:



def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

def mse(u, u_gt):
    return jnp.mean((u-u_gt)**2)

#@partial(jax.jit, static_argnums=(0,))
def _eval2d(apply_fn, params, *test_data):
    x, y, u_gt = test_data
    return relative_l2(apply_fn(params, x, y), u_gt)


#@partial(jax.jit, static_argnums=(0,))
def _eval3d(apply_fn, params, *test_data):
    x, y, z, u_gt = test_data
    pred = apply_fn(params, x, y, z)
    return relative_l2(pred, u_gt)


#@partial(jax.jit, static_argnums=(0,))
def _eval3d_ns_pinn(apply_fn, params, *test_data):
    x, y, z, u_gt = test_data
    pred = velocity_to_vorticity_rev(apply_fn, params, x, y, z)
    return relative_l2(pred, u_gt)


#@partial(jax.jit, static_argnums=(0,))
def _eval3d_ns_spinn(apply_fn, params, *test_data):
    x, y, z, u_gt = test_data
    pred = velocity_to_vorticity_fwd(apply_fn, params, x, y, z)
    return relative_l2(pred, u_gt)


#@partial(jax.jit, static_argnums=(0,))
def _eval4d(apply_fn, params, *test_data):
    t, x, y, z, u_gt = test_data
    return relative_l2(apply_fn(params, t, x, y, z), u_gt)

#@partial(jax.jit, static_argnums=(0,))
def _eval_ns4d(apply_fn, params, *test_data):
    t, x, y, z, w_gt = test_data
    error = 0
    wx = vorx(apply_fn, params, t, x, y, z)
    wy = vory(apply_fn, params, t, x, y, z)
    wz = vorz(apply_fn, params, t, x, y, z)
    error = relative_l2(wx, w_gt[0]) + relative_l2(wy, w_gt[1]) + relative_l2(wz, w_gt[2])
    return error / 3


# temporary code
def _batch_eval4d(apply_fn, params, *test_data):
    t, x, y, z, u_gt = test_data
    error, batch_size = 0., 100000
    n_iters = len(u_gt) // batch_size
    for i in range(n_iters):
        begin, end = i*batch_size, (i+1)*batch_size
        u = apply_fn(params, t[begin:end], x[begin:end], y[begin:end], z[begin:end])
        error += jnp.sum((u - u_gt[begin:end])**2)
    error = jnp.sqrt(error) / jnp.linalg.norm(u_gt)
    return error

#@partial(jax.jit, static_argnums=(0,))
def _evalnd(apply_fn, params, *test_data):
    t, x_list, u_gt = test_data
    return relative_l2(apply_fn(params, t, *x_list), u_gt)


def setup_eval_function(model, equation):
    dim = equation[-2:]
    if dim == '2d':
        fn = _eval2d
    elif dim == '3d':
        if model == 'pinn' and equation == 'navier_stokes3d':
            fn = _eval3d_ns_pinn
        elif model == 'spinn' and equation == 'navier_stokes3d':
            fn = _eval3d_ns_spinn
        else:
            fn = _eval3d
    elif dim == '4d':
        if model == 'pinn':
            fn = _batch_eval4d
        if model == 'spinn' and equation == 'navier_stokes4d':
            fn = _eval_ns4d
        else:
            fn = _eval4d
    elif dim == 'nd':
        if model == 'spinn':
            fn = _evalnd
    else:
        raise NotImplementedError
    return fn