# SPINN - (3+1)-d Navier Stoke
> cleanup

The vorticity $\mathbf{w} = \nabla \times \mathbf{u}$ form of (3+1)-d Navier-Stokes equation is given as:

$$
\frac{\partial \mathbf{w}}{\partial t} + (\mathbf{u} \cdot \nabla)\mathbf{w} = (\mathbf{w} \cdot \nabla)\mathbf{u} + \nu\nabla^2 \mathbf{w} + \mathbf{F}
$$

$$
\nabla \cdot \mathbf{u} = 0
$$

$$
\mathbf{w}(\mathbf{r}, 0) = \mathbf{w}_0 (\mathbf{r})
$$



$$
\mathbf{r} \in \Omega = [0, 2\pi]^3
$$

$$
t \in \Gamma = [0, 5]
$$

In [None]:
import jax 
import jax.numpy as jnp
from jax import jvp
import optax
from flax import linen as nn 

from typing import Sequence
from functools import partial

import time
from tqdm import trange

In [None]:
class SPINN4d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int

    @nn.compact
    def __call__(self, t, x, y, z):
        inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []

        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        for i in range(self.out_dim):
            tx += [jnp.einsum('ft, fx->ftx', 
            outputs[0][self.r*i:self.r*(i+1)], 
            outputs[1][self.r*i:self.r*(i+1)])]

            txy += [jnp.einsum('ftx, fy->ftxy', 
            tx[i], 
            outputs[2][self.r*i:self.r*(i+1)])]

            pred += [jnp.einsum('ftxy, fz->txyz', 
            txy[i], 
            outputs[3][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

In [None]:
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out
    
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

In [None]:
seed = 111
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)

In [None]:
nc = 32 # number of collocation points
nu = 0.05 # parameter of PDE

features = 64 # feature size of each layer
n_layers = 5 # the number of layer
feat_sizes = tuple([features for _ in range(n_layers)]) # feature sizes
r = 128 # rank of a approximated tensor
out_dim = 3 # size of model output

lr = 1e-3 # learning rate
epochs = 1000
log_iter = 100

# Loss weight
lbda_c = 100
lbda_ic = 10

In [None]:
def vorx(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_x = uz_y - uy_z
    vec_z = jnp.ones(z.shape)
    vec_y = jnp.ones(y.shape)
    uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec_z,))[1]
    uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec_y,))[1]
    wx = uz_y - uy_z
    return wx


def vory(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_y = ux_z - uz_x
    vec_z = jnp.ones(z.shape)
    vec_x = jnp.ones(x.shape)
    ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec_z,))[1]
    uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec_x,))[1]
    wy = ux_z - uz_x
    return wy

def vorz(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_z = uy_x - ux_y
    vec_y = jnp.ones(y.shape)
    vec_x = jnp.ones(x.shape)
    ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec_y,))[1]
    uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec_x,))[1]
    wz = uy_x - ux_y
    return wz

# 3d time-dependent navier-stokes forcing term
def navier_stokes4d_forcing_term(t, x, y, z, nu):
    # forcing terms in the PDE
    f_x = -6*jnp.exp(-18*nu*t)*jnp.sin(4*y)*jnp.sin(2*z)
    f_y = -6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(2*z)
    f_z = 6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(4*y)
    return f_x, f_y, f_z

# 3d time-dependent navier-stokes exact vorticity
def navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vorticity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z


# 3d time-dependent navier-stokes exact velocity
def navier_stokes4d_exact_u(t, x, y, z, nu):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z

In [None]:
@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_navier_stokes4d(nc, nu, key):
    keys = jax.random.split(key, 4)
    
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=5.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=0., maxval=2.*jnp.pi)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=0., maxval=2.*jnp.pi)
    zc = jax.random.uniform(keys[3], (nc, 1), minval=0., maxval=2.*jnp.pi)

    tcm, xcm, ycm, zcm = jnp.meshgrid(
        tc.ravel(), xc.ravel(), yc.ravel(), zc.ravel(), indexing='ij'
    )
    fc = navier_stokes4d_forcing_term(tcm, xcm, ycm, zcm, nu)

    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    zi = zc
    tim, xim, yim, zim = jnp.meshgrid(
        ti.ravel(), xi.ravel(), yi.ravel(), zi.ravel(), indexing='ij'
    )
    wi = navier_stokes4d_exact_w(tim, xim, yim, zim, nu)
    ui = navier_stokes4d_exact_u(tim, xim, yim, zim, nu)
    
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[-1.]]), jnp.array([[1.]])]
    wb = []
    for i in range(6):
        tbm, xbm, ybm, zbm = jnp.meshgrid(
            tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), zb[i].ravel(), indexing='ij'
        )
        wb += [navier_stokes4d_exact_w(tbm, xbm, ybm, zbm, nu)]
    return tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb

In [None]:
@partial(jax.jit, static_argnums=(0,))
def apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data):
    def residual_loss(params, t, x, y, z, f):
        # calculate u
        ux, uy, uz = apply_fn(params, t, x, y, z)
        
        # calculate w (3D vorticity vector)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        
        # tangent vector dx/dx
        # assumes t, x, y, z have same shape (very important)
        vec = jnp.ones(t.shape)

        # x-component
        wx_t = jvp(lambda t: vorx(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wx_x, wx_xx = hvp_fwdfwd(lambda x: vorx(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wx_y, wx_yy = hvp_fwdfwd(lambda y: vorx(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wx_z, wx_zz = hvp_fwdfwd(lambda z: vorx(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        ux_x = jvp(lambda x: apply_fn(params, t, x, y, z)[0], (x,), (vec,))[1]
        ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec,))[1]
        ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec,))[1]

        loss_x = jnp.mean((wx_t + ux*wx_x + uy*wx_y + uz*wx_z - \
             (wx*ux_x + wy*ux_y + wz*ux_z) - \
                nu*(wx_xx + wx_yy + wx_zz) - \
                    f[0])**2)

        # y-component
        wy_t = jvp(lambda t: vory(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wy_x, wy_xx = hvp_fwdfwd(lambda x: vory(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wy_y, wy_yy = hvp_fwdfwd(lambda y: vory(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wy_z, wy_zz = hvp_fwdfwd(lambda z: vory(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec,))[1]
        uy_y = jvp(lambda y: apply_fn(params, t, x, y, z)[1], (y,), (vec,))[1]
        uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec,))[1]

        loss_y = jnp.mean((wy_t + ux*wy_x + uy*wy_y + uz*wy_z - \
             (wx*uy_x + wy*uy_y + wz*uy_z) - \
                nu*(wy_xx + wy_yy + wy_zz) - \
                    f[1])**2)

        # z-component
        wz_t = jvp(lambda t: vorz(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wz_x, wz_xx = hvp_fwdfwd(lambda x: vorz(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wz_y, wz_yy = hvp_fwdfwd(lambda y: vorz(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wz_z, wz_zz = hvp_fwdfwd(lambda z: vorz(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec,))[1]
        uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec,))[1]
        uz_z = jvp(lambda z: apply_fn(params, t, x, y, z)[2], (z,), (vec,))[1]

        loss_z = jnp.mean((wz_t + ux*wz_x + uy*wz_y + uz*wz_z - \
             (wx*uz_x + wy*uz_y + wz*uz_z) - \
                nu*(wz_xx + wz_yy + wz_zz) - \
                    f[2])**2)

        loss_c = jnp.mean((ux_x + uy_y + uz_z)**2)

        return loss_x + loss_y + loss_z + lbda_c*loss_c

    def initial_loss(params, t, x, y, z, w, u):
        ux, uy, uz = apply_fn(params, t, x, y, z)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        loss = jnp.mean((wx - w[0])**2) + jnp.mean((wy - w[1])**2) + jnp.mean((wz - w[2])**2)
        loss += jnp.mean((ux - u[0])**2) + jnp.mean((uy - u[1])**2) + jnp.mean((uz - u[2])**2)
        return loss

    def boundary_loss(params, t, x, y, z, w):
        loss = 0.
        for i in range(6):
            wx = vorx(apply_fn, params, t[i], x[i], y[i], z[i])
            wy = vory(apply_fn, params, t[i], x[i], y[i], z[i])
            wz = vorz(apply_fn, params, t[i], x[i], y[i], z[i])
            loss += (1/6.) * jnp.mean((wx - w[i][0])**2) + jnp.mean((wy - w[i][1])**2) + jnp.mean((wz - w[i][2])**2)
        return loss

    # unpack data
    tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb = train_data

    # isolate loss func from redundant arguments
    loss_fn = lambda params: residual_loss(params, tc, xc, yc, zc, fc) + \
                        lbda_ic*initial_loss(params, ti, xi, yi, zi, wi, ui) + \
                        boundary_loss(params, tb, xb, yb, zb, wb)

    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [None]:
model = SPINN4d(feat_sizes, r, out_dim)
params = model.init(
            key,
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1))
        )
apply_fn = jax.jit(model.apply)
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
train_data = _spinn_train_generator_navier_stokes4d(nc, nu, subkey)

In [None]:
# complie
loss, gradient = apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data)
params, state = update_model(optim, gradient, params, state)

In [None]:
for e in trange(1, epochs + 1):
    start = time.time()

    if e % 100 == 0:
        # sample new input data
        key, subkey = jax.random.split(key, 2)
        train_data = _spinn_train_generator_navier_stokes4d(nc, nu, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data)
    params, state = update_model(optim, gradient, params, state)
    
    if e % log_iter == 0:
        print(f'Epoch: {e}/{epochs} --> total loss: {loss:.8f}')

runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(epochs-1)*1000):.2f}ms/iter.)')

 10%|█         | 101/1000 [00:10<01:38,  9.11it/s]

Epoch: 100/1000 --> total loss: 93.84567261


 20%|██        | 200/1000 [00:21<01:26,  9.20it/s]

Epoch: 200/1000 --> total loss: 20.04462051


 30%|███       | 301/1000 [00:31<01:16,  9.18it/s]

Epoch: 300/1000 --> total loss: 17.55069923


 40%|████      | 400/1000 [00:42<01:06,  9.05it/s]

Epoch: 400/1000 --> total loss: 8.64530373


 50%|█████     | 500/1000 [00:52<00:54,  9.11it/s]

Epoch: 500/1000 --> total loss: 1.58680403


 60%|██████    | 601/1000 [01:03<00:44,  9.07it/s]

Epoch: 600/1000 --> total loss: 0.90065020


 70%|███████   | 700/1000 [01:13<00:33,  8.91it/s]

Epoch: 700/1000 --> total loss: 1.93615973


 80%|████████  | 801/1000 [01:23<00:22,  9.00it/s]

Epoch: 800/1000 --> total loss: 1.39351535


 90%|█████████ | 902/1000 [01:34<00:09,  9.83it/s]

Epoch: 900/1000 --> total loss: 0.90451860


100%|██████████| 1000/1000 [01:44<00:00,  9.56it/s]

Epoch: 1000/1000 --> total loss: 0.56178963
Runtime --> total: 0.11sec (0.11ms/iter.)





# Save model

In [None]:
u_test, _, _ = apply_fn(params, train_data[0], train_data[1], train_data[2], train_data[3])
u_test[0][0][0][0]

Array(-0.82522225, dtype=float32)

In [None]:
import orbax

In [None]:
checkpoint = orbax.checkpoint.PyTreeCheckpointer()

In [None]:
ckpt = {'params':params, 'train_data':train_data}

In [None]:
checkpoint.save('ckpt', ckpt, force=True)

In [None]:
ckptt = checkpoint.restore('ckpt')
ppa = ckptt['params']
ttd = ckptt['train_data']

In [None]:
u_test, _, _ = apply_fn(ppa, ttd[0], ttd[1], ttd[2], ttd[3])
u_test[0][0][0][0]

Array(-0.82522225, dtype=float32)