# SPINN Navier Stokes
> Navier Stokes

The vorticity $\mathbf{w} = \nabla \times \mathbf{u}$ form of (3+1)-d Navier-Stokes equation is given as:

$$
\frac{\partial \mathbf{w}}{\partial t} + (\mathbf{u} \cdot \nabla)\mathbf{w} = (\mathbf{w} \cdot \nabla)\mathbf{u} + \nu\nabla^2 \mathbf{w} + \mathbf{F}
$$

$$
\nabla \cdot \mathbf{u} = 0
$$

$$
\mathbf{w}(\mathbf{r}, 0) = \mathbf{w}_0 (\mathbf{r})
$$



$$
\mathbf{r} \in \Omega = [0, 2\pi]^3
$$

$$
t \in \Gamma = [0, 5]
$$

In [None]:
import os
import time
import argparse

import jax
from jax import jvp, vjp
import jax.numpy as jnp
import numpy as np
import optax
from flax import linen as nn

import matplotlib.pyplot as plt

from typing import Sequence
from tqdm import trange
from functools import partial

In [None]:
def _navier_stokes4d(apply_fn, params, test_data, result_dir, e):
    print("visualizing solution...")

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)

    fig = plt.figure(figsize=(30, 5))
    for t, sub in zip([0, 1, 2, 3, 4, 5], [161, 162, 163, 164, 165, 166]):
        t = jnp.array([[t]])
        x = jnp.linspace(0, 2*jnp.pi, 4).reshape(-1, 1)
        y = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        z = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)

        # c = jnp.sqrt(u_x**2 + u_y**2 + u_z**2)   # magnitude
        c = jnp.arctan2(wy, wz)    # zenith angle
        c = (c.ravel() - c.min()) / c.ptp()
        c = jnp.concatenate((c, jnp.repeat(c, 2)))
        c = plt.cm.plasma(c)

        x, y, z = jnp.meshgrid(jnp.squeeze(x), jnp.squeeze(y), jnp.squeeze(z), indexing='ij')

        ax = fig.add_subplot(sub, projection='3d')
        ax.quiver(x, y, z, jnp.squeeze(wx), jnp.squeeze(wy), jnp.squeeze(wz), length=0.1, colors=c, alpha=1, linewidth=0.7)
        plt.title(f't={jnp.squeeze(t)}')
    
    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()

In [None]:
def show_solution(args, apply_fn, params, test_data, result_dir, e, resol=50):
    if args.equation == 'diffusion3d':
        _diffusion3d(args, apply_fn, params, test_data, result_dir, e, resol)
    elif args.equation == 'helmholtz3d':
        _helmholtz3d(args, apply_fn, params, result_dir, e, resol)
    elif args.equation == 'klein_gordon3d':
        _klein_gordon3d(args, apply_fn, params, result_dir, e, resol)
    elif args.equation == 'navier_stokes3d':
        _navier_stokes3d(apply_fn, params, test_data, result_dir, e)
    elif args.equation == 'navier_stokes4d':
        _navier_stokes4d(apply_fn, params, test_data, result_dir, e)
    else:
        raise NotImplementedError

In [None]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

def mse(u, u_gt):
    return jnp.mean((u-u_gt)**2)

In [None]:
# forward over forward
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# reverse over reverse
def hvp_revrev(f, primals, tangents, return_primals=False):
    g = lambda primals: vjp(f, primals)[1](tangents)
    primals_out, vjp_fn = vjp(g, primals)
    tangents_out = vjp_fn((tangents,))[0]
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# forward over reverse
def hvp_fwdrev(f, primals, tangents, return_primals=False):
    g = lambda primals: vjp(f, primals)[1](tangents[0])[0]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# reverse over forward
def hvp_revfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, primals, tangents)[1]
    primals_out, vjp_fn = vjp(g, primals)
    tangents_out = vjp_fn(tangents[0])[0][0]
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out

In [None]:
def vorx(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_x = uz_y - uy_z
    vec_z = jnp.ones(z.shape)
    vec_y = jnp.ones(y.shape)
    uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec_z,))[1]
    uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec_y,))[1]
    wx = uz_y - uy_z
    return wx


def vory(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_y = ux_z - uz_x
    vec_z = jnp.ones(z.shape)
    vec_x = jnp.ones(x.shape)
    ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec_z,))[1]
    uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec_x,))[1]
    wy = ux_z - uz_x
    return wy


def vorz(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_z = uy_x - ux_y
    vec_y = jnp.ones(y.shape)
    vec_x = jnp.ones(x.shape)
    ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec_y,))[1]
    uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec_x,))[1]
    wz = uy_x - ux_y
    return wz

In [None]:
@partial(jax.jit, static_argnums=(0,))
def _eval_ns4d(apply_fn, params, *test_data):
    t, x, y, z, w_gt = test_data
    error = 0
    wx = vorx(apply_fn, params, t, x, y, z)
    wy = vory(apply_fn, params, t, x, y, z)
    wz = vorz(apply_fn, params, t, x, y, z)
    error = relative_l2(wx, w_gt[0]) + relative_l2(wy, w_gt[1]) + relative_l2(wz, w_gt[2])
    return error / 3

In [None]:

def setup_eval_function(model, equation):
    dim = equation[-2:]
    if dim == '2d':
        fn = _eval2d
    elif dim == '3d':
        if model == 'pinn' and equation == 'navier_stokes3d':
            fn = _eval3d_ns_pinn
        elif model == 'spinn' and equation == 'navier_stokes3d':
            fn = _eval3d_ns_spinn
        else:
            fn = _eval3d
    elif dim == '4d':
        if model == 'pinn':
            fn = _batch_eval4d
        if model == 'spinn' and equation == 'navier_stokes4d':
            fn = _eval_ns4d
        else:
            fn = _eval4d
    elif dim == 'nd':
        if model == 'spinn':
            fn = _evalnd
    else:
        raise NotImplementedError
    return fn

In [None]:

#----------------------- Navier-Stokes equation 4-d -------------------------#
@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_navier_stokes4d(model, nc_test, nu):
    t = jnp.linspace(0, 5, nc_test)
    x = jnp.linspace(0, 2*jnp.pi, nc_test)
    y = jnp.linspace(0, 2*jnp.pi, nc_test)
    z = jnp.linspace(0, 2*jnp.pi, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    tm, xm, ym, zm = jnp.meshgrid(
        t, x, y, z, indexing='ij'
    )
    w_gt = navier_stokes4d_exact_w(tm, xm, ym, zm, nu)
    if model == 'pinn':
        t = tm.reshape(-1, 1)
        x = xm.reshape(-1, 1)
        y = ym.reshape(-1, 1)
        z = zm.reshape(-1, 1)
        w_gt = w_gt.reshape(-1, 1)
    else:
        t = t.reshape(-1, 1)
        x = x.reshape(-1, 1)
        y = y.reshape(-1, 1)
        z = z.reshape(-1, 1)
    return t, x, y, z, w_gt


In [None]:

def generate_test_data(args, result_dir):
    eqn = args.equation
    if eqn == 'diffusion3d':
        data = _test_generator_diffusion3d(
            args.model, args.data_dir
        )
    elif eqn == 'helmholtz3d':
        data = _test_generator_helmholtz3d(
            args.model, args.a1, args.a2, args.a3, args.nc_test
        )
    elif eqn == 'klein_gordon3d':
        data = _test_generator_klein_gordon3d(
            args.model, args.nc_test, args.k
        )
    elif eqn == 'klein_gordon4d':
        data = _test_generator_klein_gordon4d(
            args.model, args.nc_test, args.k
        )
    elif eqn == 'navier_stokes3d':
        data = _test_generator_navier_stokes3d(
            args.model, args.data_dir, result_dir, args.marching_steps, args.step_idx
        )
    elif eqn == 'navier_stokes4d':
        data = _test_generator_navier_stokes4d(
            args.model, args.nc_test, args.nu
        )
    else:
        raise NotImplementedError
    return data

In [None]:
# 3d time-dependent navier-stokes forcing term
def navier_stokes4d_forcing_term(t, x, y, z, nu):
    # forcing terms in the PDE
    # f_x = -24*jnp.exp(-18*nu*t)*jnp.sin(2*y)*jnp.cos(2*y)*jnp.sin(z)*jnp.cos(z)
    f_x = -6*jnp.exp(-18*nu*t)*jnp.sin(4*y)*jnp.sin(2*z)
    # f_y = -24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(z)*jnp.cos(z)
    f_y = -6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(2*z)
    # f_z = 24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(2*y)
    f_z = 6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(4*y)
    return f_x, f_y, f_z

# 3d time-dependent navier-stokes exact vorticity
def navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vorticity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z


# 3d time-dependent navier-stokes exact velocity
def navier_stokes4d_exact_u(t, x, y, z, nu):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z

In [None]:

#======================== Navier-Stokes equation 4-d ========================#
#---------------------------------- SPINN -----------------------------------#
@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_navier_stokes4d(nc, nu, key):
    keys = jax.random.split(key, 4)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=5.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=0., maxval=2.*jnp.pi)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=0., maxval=2.*jnp.pi)
    zc = jax.random.uniform(keys[3], (nc, 1), minval=0., maxval=2.*jnp.pi)

    tcm, xcm, ycm, zcm = jnp.meshgrid(
        tc.ravel(), xc.ravel(), yc.ravel(), zc.ravel(), indexing='ij'
    )
    fc = navier_stokes4d_forcing_term(tcm, xcm, ycm, zcm, nu)

    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    zi = zc
    tim, xim, yim, zim = jnp.meshgrid(
        ti.ravel(), xi.ravel(), yi.ravel(), zi.ravel(), indexing='ij'
    )
    wi = navier_stokes4d_exact_w(tim, xim, yim, zim, nu)
    ui = navier_stokes4d_exact_u(tim, xim, yim, zim, nu)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[-1.]]), jnp.array([[1.]])]
    wb = []
    for i in range(6):
        tbm, xbm, ybm, zbm = jnp.meshgrid(
            tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), zb[i].ravel(), indexing='ij'
        )
        wb += [navier_stokes4d_exact_w(tbm, xbm, ybm, zbm, nu)]
    return tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb

In [None]:

def generate_train_data(args, key, result_dir=None):
    eqn = args.equation
    if args.model == 'pinn':
        if eqn == 'diffusion3d':
            data = _pinn_train_generator_diffusion3d(
                args.nc, key
            )
        elif eqn == 'helmholtz3d':
            data = _pinn_train_generator_helmholtz3d(
                args.a1, args.a2, args.a3, args.nc, key
            )
        elif eqn == 'klein_gordon3d':
            data = _pinn_train_generator_klein_gordon3d(
                args.nc, args.k, key
            )
        elif eqn == 'klein_gordon4d':
            data = _pinn_train_generator_klein_gordon4d(
                args.nc, args.k, key
            )
        else:
            raise NotImplementedError
    elif args.model == 'spinn':
        if eqn == 'diffusion3d':
            data = _spinn_train_generator_diffusion3d(
                args.nc, key
            )
        elif eqn == 'helmholtz3d':
            data = _spinn_train_generator_helmholtz3d(
                args.a1, args.a2, args.a3, args.nc, key
            )
        elif eqn == 'klein_gordon3d':
            data = _spinn_train_generator_klein_gordon3d(
                args.nc, args.k, key
            )
        elif eqn == 'klein_gordon4d':
            data = _spinn_train_generator_klein_gordon4d(
                args.nc, args.k, key
            )
        elif eqn == 'navier_stokes3d':
            data = _spinn_train_generator_navier_stokes3d(
                args.nt, args.nxy, args.data_dir, result_dir, args.marching_steps, args.step_idx, args.offset_num, key
            )
        elif eqn == 'navier_stokes4d':
            data = _spinn_train_generator_navier_stokes4d(
                args.nc, args.nu, key
            )
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
    return data

In [None]:
def setup_networks(args, key):
    # build network
    dim = args.equation[-2:]
    if args.model == 'pinn':
        # feature sizes
        feat_sizes = tuple([args.features for _ in range(args.n_layers - 1)] + [args.out_dim])
        if dim == '2d':
            model = PINN2d(feat_sizes)
        elif dim == '3d':
            model = PINN3d(feat_sizes, args.out_dim, args.pos_enc)
        elif dim == '4d':
            model = PINN4d(feat_sizes)
        else:
            raise NotImplementedError
    else: # SPINN
        # feature sizes
        feat_sizes = tuple([args.features for _ in range(args.n_layers)])
        if dim == '2d':
            model = SPINN2d(feat_sizes, args.r, args.mlp)
        elif dim == '3d':
            model = SPINN3d(feat_sizes, args.r, args.out_dim, args.pos_enc, args.mlp)
        elif dim == '4d':
            model = SPINN4d(feat_sizes, args.r, args.out_dim, args.mlp)
        else:
            raise NotImplementedError
    # initialize params
    # dummy inputs must be given
    if dim == '2d':
        params = model.init(
            key,
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1))
        )
    elif dim == '3d':
        if args.equation == 'navier_stokes3d':
            params = model.init(
                key,
                jnp.ones((args.nt, 1)),
                jnp.ones((args.nxy, 1)),
                jnp.ones((args.nxy, 1))
            )
        else:
            params = model.init(
                key,
                jnp.ones((args.nc, 1)),
                jnp.ones((args.nc, 1)),
                jnp.ones((args.nc, 1))
            )
    elif dim == '4d':
        params = model.init(
            key,
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1)),
            jnp.ones((args.nc, 1))
        )
    else:
        raise NotImplementedError

    return jax.jit(model.apply), params


def name_model(args):
    name = [
        f'nl{args.n_layers}',
        f'fs{args.features}',
        f'lr{args.lr}',
        f's{args.seed}',
        f'r{args.r}'
    ]
    if args.model != 'spinn':
        del name[-1]
    if args.equation != 'navier_stokes3d':
        name.insert(0, f'nc{args.nc}')
    if args.equation == 'navier_stokes3d':
        name.insert(0, f'nxy{args.nxy}')
        name.insert(0, f'nt{args.nt}')
        name.append(f'on{args.offset_num}')
        name.append(f'oi{args.offset_iter}')
        name.append(f'lc{args.lbda_c}')
        name.append(f'lic{args.lbda_ic}')
    if args.equation == 'navier_stokes4d':
        name.append(f'lc{args.lbda_c}')
        name.append(f'li{args.lbda_ic}')
    if args.equation == 'helmholtz3d':
        name.append(f'a{args.a1}{args.a2}{args.a3}')
    if args.equation == 'klein_gordon3d':
        name.append(f'k{args.k}')
    
    name.append(f'{args.mlp}')
        
    return '_'.join(name)


def save_config(args, result_dir):
    with open(os.path.join(result_dir, 'configs.txt'), 'w') as f:
        for arg in vars(args):
            f.write(f'{arg}: {getattr(args, arg)}\n')


# single update function
#@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state


# save next initial condition for time-marching
def save_next_IC(root_dir, name, apply_fn, params, test_data, step_idx, e):
    os.makedirs(os.path.join(root_dir, name, 'IC_pred'), exist_ok=True)

    w_pred = velocity_to_vorticity_fwd(apply_fn, params, jnp.expand_dims(test_data[0][-1], axis=1), test_data[1], test_data[2])
    w_pred = w_pred.reshape(-1, test_data[1].shape[0], test_data[2].shape[0])[0]
    u0_pred, v0_pred = apply_fn(params, jnp.expand_dims(test_data[0][-1], axis=1), test_data[1], test_data[2])
    u0_pred, v0_pred = jnp.squeeze(u0_pred), jnp.squeeze(v0_pred)
    
    scipy.io.savemat(os.path.join(root_dir, name, f'IC_pred/w0_{step_idx+1}.mat'), mdict={'w0': w_pred, 'u0': u0_pred, 'v0': v0_pred, 't': jnp.expand_dims(test_data[0][-1], axis=1)})

In [None]:
class SPINN4d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    mlp: str

    @nn.compact
    def __call__(self, t, x, y, z):
        inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []
        # inputs, outputs = [t, x, y, z], []
        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        for i in range(self.out_dim):
            tx += [jnp.einsum('ft, fx->ftx', 
            outputs[0][self.r*i:self.r*(i+1)], 
            outputs[1][self.r*i:self.r*(i+1)])]

            txy += [jnp.einsum('ftx, fy->ftxy', 
            tx[i], 
            outputs[2][self.r*i:self.r*(i+1)])]

            pred += [jnp.einsum('ftxy, fz->txyz', 
            txy[i], 
            outputs[3][self.r*i:self.r*(i+1)])]


        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

In [None]:
@partial(jax.jit, static_argnums=(0,))
def apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data):
    def residual_loss(params, t, x, y, z, f):
        # calculate u
        ux, uy, uz = apply_fn(params, t, x, y, z)
        # pdb.set_trace()
        # calculate w (3D vorticity vector)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        # tangent vector dx/dx
        # assumes t, x, y have same shape (very important)
        vec = jnp.ones(t.shape)

        # x-component
        wx_t = jvp(lambda t: vorx(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wx_x, wx_xx = hvp_fwdfwd(lambda x: vorx(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wx_y, wx_yy = hvp_fwdfwd(lambda y: vorx(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wx_z, wx_zz = hvp_fwdfwd(lambda z: vorx(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        ux_x = jvp(lambda x: apply_fn(params, t, x, y, z)[0], (x,), (vec,))[1]
        ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec,))[1]
        ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec,))[1]

        loss_x = jnp.mean((wx_t + ux*wx_x + uy*wx_y + uz*wx_z - \
             (wx*ux_x + wy*ux_y + wz*ux_z) - \
                nu*(wx_xx + wx_yy + wx_zz) - \
                    f[0])**2)

        # y-component
        wy_t = jvp(lambda t: vory(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wy_x, wy_xx = hvp_fwdfwd(lambda x: vory(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wy_y, wy_yy = hvp_fwdfwd(lambda y: vory(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wy_z, wy_zz = hvp_fwdfwd(lambda z: vory(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec,))[1]
        uy_y = jvp(lambda y: apply_fn(params, t, x, y, z)[1], (y,), (vec,))[1]
        uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec,))[1]

        loss_y = jnp.mean((wy_t + ux*wy_x + uy*wy_y + uz*wy_z - \
             (wx*uy_x + wy*uy_y + wz*uy_z) - \
                nu*(wy_xx + wy_yy + wy_zz) - \
                    f[1])**2)

        # z-component
        wz_t = jvp(lambda t: vorz(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wz_x, wz_xx = hvp_fwdfwd(lambda x: vorz(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wz_y, wz_yy = hvp_fwdfwd(lambda y: vorz(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wz_z, wz_zz = hvp_fwdfwd(lambda z: vorz(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec,))[1]
        uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec,))[1]
        uz_z = jvp(lambda z: apply_fn(params, t, x, y, z)[2], (z,), (vec,))[1]

        loss_z = jnp.mean((wz_t + ux*wz_x + uy*wz_y + uz*wz_z - \
             (wx*uz_x + wy*uz_y + wz*uz_z) - \
                nu*(wz_xx + wz_yy + wz_zz) - \
                    f[2])**2)

        loss_c = jnp.mean((ux_x + uy_y + uz_z)**2)

        return loss_x + loss_y + loss_z + lbda_c*loss_c

    def initial_loss(params, t, x, y, z, w, u):
        ux, uy, uz = apply_fn(params, t, x, y, z)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        loss = jnp.mean((wx - w[0])**2) + jnp.mean((wy - w[1])**2) + jnp.mean((wz - w[2])**2)
        loss += jnp.mean((ux - u[0])**2) + jnp.mean((uy - u[1])**2) + jnp.mean((uz - u[2])**2)
        return loss

    def boundary_loss(params, t, x, y, z, w):
        loss = 0.
        for i in range(6):
            wx = vorx(apply_fn, params, t[i], x[i], y[i], z[i])
            wy = vory(apply_fn, params, t[i], x[i], y[i], z[i])
            wz = vorz(apply_fn, params, t[i], x[i], y[i], z[i])
            loss += (1/6.) * jnp.mean((wx - w[i][0])**2) + jnp.mean((wy - w[i][1])**2) + jnp.mean((wz - w[i][2])**2)
        return loss

    # unpack data
    tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb = train_data

    # isolate loss func from redundant arguments
    loss_fn = lambda params: residual_loss(params, tc, xc, yc, zc, fc) + \
                        lbda_ic*initial_loss(params, ti, xi, yi, zi, wi, ui) + \
                        boundary_loss(params, tb, xb, yb, zb, wb)

    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [None]:
# config
parser = argparse.ArgumentParser(description='Training configurations')

# model and equation
parser.add_argument('--model', type=str, default='spinn', choices=['spinn', 'pinn'], help='model name (pinn; spinn)')
parser.add_argument('--debug', type=str, default='true', help='debugging purpose')
parser.add_argument('--equation', type=str, default='navier_stokes4d', help='equation to solve')

# pde settings
parser.add_argument('--nc', type=int, default=32, help='the number of collocation points')
parser.add_argument('--nc_test', type=int, default=20, help = 'the number of collocation points')

# training settings
parser.add_argument('--seed', type=int, default=111, help='random seed')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--epochs', type=int, default=1000, help='training epochs')
parser.add_argument('--mlp', type=str, default='modified_mlp', help='type of mlp')
parser.add_argument('--n_layers', type=int, default=5, help='the number of layer')
parser.add_argument('--features', type=int, default=64, help='feature size of each layer')
parser.add_argument('--r', type=int, default=128, help='rank of a approximated tensor')
parser.add_argument('--out_dim', type=int, default=3, help='size of model output')
parser.add_argument('--nu', type=float, default=0.05, help='viscosity')
parser.add_argument('--lbda_c', type=int, default=100, help='None')
parser.add_argument('--lbda_ic', type=int, default=10, help='None')

# log settings
parser.add_argument('--log_iter', type=int, default=100, help='print log every...')
parser.add_argument('--plot_iter', type=int, default=100, help='plot result every...')

args = parser.parse_args(args=[])

# random key
key = jax.random.PRNGKey(args.seed)

# make & init model forward function
key, subkey = jax.random.split(key, 2)
apply_fn, params = setup_networks(args, subkey)

# count total params
args.total_params = sum(x.size for x in jax.tree_util.tree_leaves(params))

# name model
name = name_model(args)

# result dir
root_dir = os.path.join(os.getcwd(), 'results', args.equation, args.model)
result_dir = os.path.join(root_dir, name)

# make dir
os.makedirs(result_dir, exist_ok=True)

# optimizer
optim = optax.adam(learning_rate=args.lr)
state = optim.init(params)

# dataset
key, subkey = jax.random.split(key, 2)
train_data = generate_train_data(args, subkey)
test_data = generate_test_data(args, result_dir)

# evaluation function
eval_fn = setup_eval_function(args.model, args.equation)

# save training configuration
save_config(args, result_dir)

# log
logs = []
if os.path.exists(os.path.join(result_dir, 'log (loss, error).csv')):
    os.remove(os.path.join(result_dir, 'log (loss, error).csv'))
if os.path.exists(os.path.join(result_dir, 'best_error.csv')):
    os.remove(os.path.join(result_dir, 'best_error.csv'))
best = 100000.

In [None]:
print("compiling...")

# start training
for e in trange(1, args.epochs + 1):
    if e == 2:
        # exclude compiling time
        start = time.time()
    if e % 100 == 0:
        # sample new input data
        key, subkey = jax.random.split(key, 2)
        train_data = generate_train_data(args, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, args.nu, args.lbda_c, args.lbda_ic, *train_data)
    params, state = update_model(optim, gradient, params, state)

    if e % 10 == 0:
        if loss < best:
            best = loss
            best_error = eval_fn(apply_fn, params, *test_data)

    # log
    if e % args.log_iter == 0:
        error = eval_fn(apply_fn, params, *test_data)
        print(f'Epoch: {e}/{args.epochs} --> total loss: {loss:.8f}, error: {error:.8f}, best error {best_error:.8f}')
        with open(os.path.join(result_dir, 'log (loss, error).csv'), 'a') as f:
            f.write(f'{loss}, {error}, {best_error}\n')

    # visualization
    if e % args.plot_iter == 0:
        show_solution(args, apply_fn, params, test_data, result_dir, e)

# training done
runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(args.epochs-1)*1000):.2f}ms/iter.)')
jnp.save(os.path.join(result_dir, 'params.npy'), params)
    
# save runtime
runtime = np.array([runtime])
np.savetxt(os.path.join(result_dir, 'total runtime (sec).csv'), runtime, delimiter=',')

# save total error
with open(os.path.join(result_dir, 'best_error.csv'), 'a') as f:
    f.write(f'best error: {best_error}\n')

compiling...


 10%|██████████                                                                                            | 99/1000 [00:20<03:06,  4.83it/s]

Epoch: 100/1000 --> total loss: 64.57202911, error: 0.82605666, best error 0.87254614
visualizing solution...


 20%|████████████████████                                                                                 | 199/1000 [00:42<02:33,  5.23it/s]

Epoch: 200/1000 --> total loss: 59.46010971, error: 0.71152961, best error 0.67018068
visualizing solution...


 30%|██████████████████████████████▏                                                                      | 299/1000 [01:02<02:14,  5.19it/s]

Epoch: 300/1000 --> total loss: 5.82178497, error: 0.33523923, best error 0.36991227
visualizing solution...


 40%|████████████████████████████████████████▎                                                            | 399/1000 [01:24<02:02,  4.91it/s]

Epoch: 400/1000 --> total loss: 1.71244895, error: 0.20455322, best error 0.20763847
visualizing solution...


 50%|██████████████████████████████████████████████████▍                                                  | 499/1000 [01:46<01:37,  5.15it/s]

Epoch: 500/1000 --> total loss: 2.40122461, error: 0.17550710, best error 0.16460443
visualizing solution...


 53%|█████████████████████████████████████████████████████▍                                               | 529/1000 [01:54<01:41,  4.62it/s]


KeyboardInterrupt: 