# SPINN Navier Stokes cleanup
> SPINN navier stokes cleanup

In [None]:
import jax

In [None]:
seed = 111
key = jax.random.PRNGKey(seed)
key

Array([  0, 111], dtype=uint32)

In [None]:
key, subkey

(Array([2623418170, 2720574883], dtype=uint32),
 Array([2399137871, 1034659656], dtype=uint32))

In [None]:
from typing import Sequence
from flax import linen as nn

In [None]:
class SPINN4d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    mlp: str

    @nn.compact
    def __call__(self, t, x, y, z):
        inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []
        # inputs, outputs = [t, x, y, z], []
        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        for i in range(self.out_dim):
            tx += [jnp.einsum('ft, fx->ftx', 
            outputs[0][self.r*i:self.r*(i+1)], 
            outputs[1][self.r*i:self.r*(i+1)])]

            txy += [jnp.einsum('ftx, fy->ftxy', 
            tx[i], 
            outputs[2][self.r*i:self.r*(i+1)])]

            pred += [jnp.einsum('ftxy, fz->txyz', 
            txy[i], 
            outputs[3][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

In [None]:
# feature size of each layer
features = 64

# the number of layer
n_layers = 5

# feature sizes
feat_sizes = tuple([features for _ in range(n_layers)])

feat_sizes

(64, 64, 64, 64, 64)

In [None]:
# rank of a approximated tensor
r = 128

# sizee of model output
out_dim = 3

# type of mlp
mlp = 'modified_mlp'

In [None]:
model = SPINN4d(feat_sizes, r, out_dim, mlp)
model

SPINN4d(
    # attributes
    features = (64, 64, 64, 64, 64)
    r = 128
    out_dim = 3
    mlp = 'modified_mlp'
)

In [None]:
import jax.numpy as jnp

In [None]:
# the number of collocation points
nc = 32

In [None]:
params = model.init(
            key,
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1))
        )

In [None]:
apply_fn = jax.jit(model.apply)
apply_fn

<PjitFunction of <bound method Module.apply of SPINN4d(
    # attributes
    features = (64, 64, 64, 64, 64)
    r = 128
    out_dim = 3
    mlp = 'modified_mlp'
)>>

In [None]:
# count total params
total_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
total_params

150272

In [None]:
# learning rate
lr = 1e-3

# lbda_c
lbda_c = 100

# lbda_ic
lbda_ic = 10

In [None]:
# name model
name = [
    f'nc{nc}',
    f'nl{n_layers}',
    f'fs{features}',
    f'lr{lr}',
    f's{seed}',
    f'r{r}',
    f'lc{lbda_c}',
    f'li{lbda_ic}',
    f'{mlp}'
]
name = '_'.join(name)
name

'nl5_fs64_lr0.001_s111_r128_lc100_li10_modified_mlp'

In [None]:
import os

In [None]:
model_type = 'spinn'
equation = 'navier_stokes4d'
root_dir = os.path.join(os.getcwd(), 'results', equation, model_type)
result_dir = os.path.join(root_dir, name)
result_dir

'/home/tensor/workspace/Python_stuff/zpinn/nbs/results/navier_stokes4d/spinn/nl5_fs64_lr0.001_s111_r128_lc100_li10_modified_mlp'

In [None]:
os.makedirs(result_dir, exist_ok=True)

In [None]:
import optax

In [None]:
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
key, subkey = jax.random.split(key, 2)
key, subkey

(Array([ 246877504, 3749814588], dtype=uint32),
 Array([4235481140,  694578851], dtype=uint32))

In [None]:
nu = 0.05

In [None]:
from functools import partial

In [None]:
# 3d time-dependent navier-stokes forcing term
def navier_stokes4d_forcing_term(t, x, y, z, nu):
    # forcing terms in the PDE
    # f_x = -24*jnp.exp(-18*nu*t)*jnp.sin(2*y)*jnp.cos(2*y)*jnp.sin(z)*jnp.cos(z)
    f_x = -6*jnp.exp(-18*nu*t)*jnp.sin(4*y)*jnp.sin(2*z)
    # f_y = -24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(z)*jnp.cos(z)
    f_y = -6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(2*z)
    # f_z = 24*jnp.exp(-18*nu*t)*jnp.sin(2*x)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(2*y)
    f_z = 6*jnp.exp(-18*nu*t)*jnp.sin(4*x)*jnp.sin(4*y)
    return f_x, f_y, f_z

# 3d time-dependent navier-stokes exact vorticity
def navier_stokes4d_exact_w(t, x, y, z, nu):
    # analytic form of vorticity
    w_x = -3*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.cos(z)
    w_y = 6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.cos(z)
    w_z = -6*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.cos(2*y)*jnp.sin(z)
    return w_x, w_y, w_z


# 3d time-dependent navier-stokes exact velocity
def navier_stokes4d_exact_u(t, x, y, z, nu):
    # analytic form of velocity
    u_x = 2*jnp.exp(-9*nu*t)*jnp.cos(2*x)*jnp.sin(2*y)*jnp.sin(z)
    u_y = -1*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.cos(2*y)*jnp.sin(z)
    u_z = -2*jnp.exp(-9*nu*t)*jnp.sin(2*x)*jnp.sin(2*y)*jnp.cos(z)
    return u_x, u_y, u_z

In [None]:
#======================== Navier-Stokes equation 4-d ========================#
#---------------------------------- SPINN -----------------------------------#
@partial(jax.jit, static_argnums=(0,))
def _spinn_train_generator_navier_stokes4d(nc, nu, key):
    keys = jax.random.split(key, 4)
    
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=5.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=0., maxval=2.*jnp.pi)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=0., maxval=2.*jnp.pi)
    zc = jax.random.uniform(keys[3], (nc, 1), minval=0., maxval=2.*jnp.pi)

    tcm, xcm, ycm, zcm = jnp.meshgrid(
        tc.ravel(), xc.ravel(), yc.ravel(), zc.ravel(), indexing='ij'
    )
    fc = navier_stokes4d_forcing_term(tcm, xcm, ycm, zcm, nu)

    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    zi = zc
    tim, xim, yim, zim = jnp.meshgrid(
        ti.ravel(), xi.ravel(), yi.ravel(), zi.ravel(), indexing='ij'
    )
    wi = navier_stokes4d_exact_w(tim, xim, yim, zim, nu)
    ui = navier_stokes4d_exact_u(tim, xim, yim, zim, nu)
    
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[-1.]]), jnp.array([[1.]])]
    wb = []
    for i in range(6):
        tbm, xbm, ybm, zbm = jnp.meshgrid(
            tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), zb[i].ravel(), indexing='ij'
        )
        wb += [navier_stokes4d_exact_w(tbm, xbm, ybm, zbm, nu)]
    return tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb

In [None]:
train_data = _spinn_train_generator_navier_stokes4d(nc, nu, subkey)

In [None]:
# the number of collocation points
nc_test = 20

In [None]:
#----------------------- Navier-Stokes equation 4-d -------------------------#
@partial(jax.jit, static_argnums=(0, 1,))
def _test_generator_navier_stokes4d(nc_test, nu):
    t = jnp.linspace(0, 5, nc_test)
    x = jnp.linspace(0, 2*jnp.pi, nc_test)
    y = jnp.linspace(0, 2*jnp.pi, nc_test)
    z = jnp.linspace(0, 2*jnp.pi, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    tm, xm, ym, zm = jnp.meshgrid(
        t, x, y, z, indexing='ij'
    )
    w_gt = navier_stokes4d_exact_w(tm, xm, ym, zm, nu)
    t = t.reshape(-1, 1)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    z = z.reshape(-1, 1)
    return t, x, y, z, w_gt

In [None]:
test_data = _test_generator_navier_stokes4d(nc_test, nu)
test_data[1].shape

(20, 1)

In [None]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

def vorx(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_x = uz_y - uy_z
    vec_z = jnp.ones(z.shape)
    vec_y = jnp.ones(y.shape)
    uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec_z,))[1]
    uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec_y,))[1]
    wx = uz_y - uy_z
    return wx


def vory(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_y = ux_z - uz_x
    vec_z = jnp.ones(z.shape)
    vec_x = jnp.ones(x.shape)
    ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec_z,))[1]
    uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec_x,))[1]
    wy = ux_z - uz_x
    return wy

def vorz(apply_fn, params, t, x, y, z):
    # vorticity vector w/ forward-mode AD
    # w_z = uy_x - ux_y
    vec_y = jnp.ones(y.shape)
    vec_x = jnp.ones(x.shape)
    ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec_y,))[1]
    uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec_x,))[1]
    wz = uy_x - ux_y
    return wz
    
@partial(jax.jit, static_argnums=(0,))
def _eval_ns4d(apply_fn, params, *test_data):
    t, x, y, z, w_gt = test_data
    error = 0
    wx = vorx(apply_fn, params, t, x, y, z)
    wy = vory(apply_fn, params, t, x, y, z)
    wz = vorz(apply_fn, params, t, x, y, z)
    error = relative_l2(wx, w_gt[0]) + relative_l2(wy, w_gt[1]) + relative_l2(wz, w_gt[2])
    return error / 3

In [None]:
# evaluation function
eval_fn = _eval_ns4d

In [None]:
from jax import jvp

def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out

In [None]:
@partial(jax.jit, static_argnums=(0,))
def apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data):
    def residual_loss(params, t, x, y, z, f):
        # calculate u
        ux, uy, uz = apply_fn(params, t, x, y, z)
        
        # calculate w (3D vorticity vector)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        
        # tangent vector dx/dx
        # assumes t, x, y, z have same shape (very important)
        vec = jnp.ones(t.shape)

        # x-component
        wx_t = jvp(lambda t: vorx(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wx_x, wx_xx = hvp_fwdfwd(lambda x: vorx(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wx_y, wx_yy = hvp_fwdfwd(lambda y: vorx(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wx_z, wx_zz = hvp_fwdfwd(lambda z: vorx(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        ux_x = jvp(lambda x: apply_fn(params, t, x, y, z)[0], (x,), (vec,))[1]
        ux_y = jvp(lambda y: apply_fn(params, t, x, y, z)[0], (y,), (vec,))[1]
        ux_z = jvp(lambda z: apply_fn(params, t, x, y, z)[0], (z,), (vec,))[1]

        loss_x = jnp.mean((wx_t + ux*wx_x + uy*wx_y + uz*wx_z - \
             (wx*ux_x + wy*ux_y + wz*ux_z) - \
                nu*(wx_xx + wx_yy + wx_zz) - \
                    f[0])**2)

        # y-component
        wy_t = jvp(lambda t: vory(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wy_x, wy_xx = hvp_fwdfwd(lambda x: vory(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wy_y, wy_yy = hvp_fwdfwd(lambda y: vory(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wy_z, wy_zz = hvp_fwdfwd(lambda z: vory(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uy_x = jvp(lambda x: apply_fn(params, t, x, y, z)[1], (x,), (vec,))[1]
        uy_y = jvp(lambda y: apply_fn(params, t, x, y, z)[1], (y,), (vec,))[1]
        uy_z = jvp(lambda z: apply_fn(params, t, x, y, z)[1], (z,), (vec,))[1]

        loss_y = jnp.mean((wy_t + ux*wy_x + uy*wy_y + uz*wy_z - \
             (wx*uy_x + wy*uy_y + wz*uy_z) - \
                nu*(wy_xx + wy_yy + wy_zz) - \
                    f[1])**2)

        # z-component
        wz_t = jvp(lambda t: vorz(apply_fn, params, t, x, y, z), (t,), (vec,))[1]
        wz_x, wz_xx = hvp_fwdfwd(lambda x: vorz(apply_fn, params, t, x, y, z), (x,), (vec,), True)
        wz_y, wz_yy = hvp_fwdfwd(lambda y: vorz(apply_fn, params, t, x, y, z), (y,), (vec,), True)
        wz_z, wz_zz = hvp_fwdfwd(lambda z: vorz(apply_fn, params, t, x, y, z), (z,), (vec,), True)
        
        uz_x = jvp(lambda x: apply_fn(params, t, x, y, z)[2], (x,), (vec,))[1]
        uz_y = jvp(lambda y: apply_fn(params, t, x, y, z)[2], (y,), (vec,))[1]
        uz_z = jvp(lambda z: apply_fn(params, t, x, y, z)[2], (z,), (vec,))[1]

        loss_z = jnp.mean((wz_t + ux*wz_x + uy*wz_y + uz*wz_z - \
             (wx*uz_x + wy*uz_y + wz*uz_z) - \
                nu*(wz_xx + wz_yy + wz_zz) - \
                    f[2])**2)

        loss_c = jnp.mean((ux_x + uy_y + uz_z)**2)

        return loss_x + loss_y + loss_z + lbda_c*loss_c

    def initial_loss(params, t, x, y, z, w, u):
        ux, uy, uz = apply_fn(params, t, x, y, z)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)
        loss = jnp.mean((wx - w[0])**2) + jnp.mean((wy - w[1])**2) + jnp.mean((wz - w[2])**2)
        loss += jnp.mean((ux - u[0])**2) + jnp.mean((uy - u[1])**2) + jnp.mean((uz - u[2])**2)
        return loss

    def boundary_loss(params, t, x, y, z, w):
        loss = 0.
        for i in range(6):
            wx = vorx(apply_fn, params, t[i], x[i], y[i], z[i])
            wy = vory(apply_fn, params, t[i], x[i], y[i], z[i])
            wz = vorz(apply_fn, params, t[i], x[i], y[i], z[i])
            loss += (1/6.) * jnp.mean((wx - w[i][0])**2) + jnp.mean((wy - w[i][1])**2) + jnp.mean((wz - w[i][2])**2)
        return loss

    # unpack data
    tc, xc, yc, zc, fc, ti, xi, yi, zi, wi, ui, tb, xb, yb, zb, wb = train_data

    # isolate loss func from redundant arguments
    loss_fn = lambda params: residual_loss(params, tc, xc, yc, zc, fc) + \
                        lbda_ic*initial_loss(params, ti, xi, yi, zi, wi, ui) + \
                        boundary_loss(params, tb, xb, yb, zb, wb)

    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [None]:
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

In [None]:
loss, gradient = apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data)

In [None]:
params, state = update_model(optim, gradient, params, state)

In [None]:
error = eval_fn(apply_fn, params, *test_data)

In [None]:
print(f'total loss: {loss:.8f}, error: {error:.8f}')

total loss: 134.97689819, error: 1.00003636


In [None]:
import matplotlib.pyplot as plt

In [None]:
def _navier_stokes4d(apply_fn, params, test_data, result_dir, e):
    print("visualizing solution...")

    os.makedirs(os.path.join(result_dir, f'vis/{e:05d}'), exist_ok=True)

    fig = plt.figure(figsize=(30, 5))
    for t, sub in zip([0, 1, 2, 3, 4, 5], [161, 162, 163, 164, 165, 166]):
        t = jnp.array([[t]])
        x = jnp.linspace(0, 2*jnp.pi, 4).reshape(-1, 1)
        y = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        z = jnp.linspace(0, 2*jnp.pi, 30).reshape(-1, 1)
        wx = vorx(apply_fn, params, t, x, y, z)
        wy = vory(apply_fn, params, t, x, y, z)
        wz = vorz(apply_fn, params, t, x, y, z)

        # c = jnp.sqrt(u_x**2 + u_y**2 + u_z**2)   # magnitude
        c = jnp.arctan2(wy, wz)    # zenith angle
        c = (c.ravel() - c.min()) / c.ptp()
        c = jnp.concatenate((c, jnp.repeat(c, 2)))
        c = plt.cm.plasma(c)

        x, y, z = jnp.meshgrid(jnp.squeeze(x), jnp.squeeze(y), jnp.squeeze(z), indexing='ij')

        ax = fig.add_subplot(sub, projection='3d')
        ax.quiver(x, y, z, jnp.squeeze(wx), jnp.squeeze(wy), jnp.squeeze(wz), length=0.1, colors=c, alpha=1, linewidth=0.7)
        plt.title(f't={jnp.squeeze(t)}')
    
    plt.savefig(os.path.join(result_dir, f'vis/{e:05d}/pred.png'))
    plt.close()

In [None]:
_navier_stokes4d(apply_fn, params, test_data, result_dir, e=1)

visualizing solution...


In [None]:
epochs = 1000

In [None]:
import time
from tqdm import trange

In [None]:
log_iter = 100
plot_iter = 100
best = 100000. # best error

In [None]:
for e in trange(1, epochs + 1):
    if e == 2:
        # exclude compiling time
        start = time.time()
    if e % 100 == 0:
        # sample new input data
        key, subkey = jax.random.split(key, 2)
        train_data = _spinn_train_generator_navier_stokes4d(nc, nu, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, nu, lbda_c, lbda_ic, *train_data)
    params, state = update_model(optim, gradient, params, state)

    if e % 10 == 0:
        if loss < best:
            best = loss
            best_error = eval_fn(apply_fn, params, *test_data)

    # log
    if e % log_iter == 0:
        error = eval_fn(apply_fn, params, *test_data)
        print(f'Epoch: {e}/{epochs} --> total loss: {loss:.8f}, error: {error:.8f}, best error {best_error:.8f}')
        with open(os.path.join(result_dir, 'log (loss, error).csv'), 'a') as f:
            f.write(f'{loss}, {error}, {best_error}\n')

    # visualization
    if e % plot_iter == 0:
        _navier_stokes4d(apply_fn, params, test_data, result_dir, e)

 10%|▉         | 99/1000 [00:10<01:31,  9.89it/s]

Epoch: 100/1000 --> total loss: 95.65531158, error: 0.90177166, best error 0.91455835
visualizing solution...


 20%|█▉        | 199/1000 [00:23<01:20,  9.98it/s]

Epoch: 200/1000 --> total loss: 41.20937347, error: 0.70664871, best error 0.74622047
visualizing solution...


 30%|██▉       | 299/1000 [00:36<01:10,  9.92it/s]

Epoch: 300/1000 --> total loss: 26.06181145, error: 0.51984829, best error 0.50875151
visualizing solution...


 40%|███▉      | 399/1000 [00:49<01:00,  9.93it/s]

Epoch: 400/1000 --> total loss: 3.93560553, error: 0.26011056, best error 0.26011056
visualizing solution...


 50%|████▉     | 499/1000 [01:02<00:53,  9.30it/s]

Epoch: 500/1000 --> total loss: 1.66154373, error: 0.17931287, best error 0.19282588
visualizing solution...


 60%|█████▉    | 598/1000 [01:14<00:42,  9.53it/s]

Epoch: 600/1000 --> total loss: 2.47873926, error: 0.17698327, best error 0.19282588
visualizing solution...


 70%|██████▉   | 699/1000 [01:27<00:31,  9.62it/s]

Epoch: 700/1000 --> total loss: 1.57854688, error: 0.13691574, best error 0.19282588
visualizing solution...


 80%|███████▉  | 799/1000 [01:40<00:20,  9.97it/s]

Epoch: 800/1000 --> total loss: 1.02764773, error: 0.13225016, best error 0.13567518
visualizing solution...


 90%|████████▉ | 899/1000 [01:53<00:10,  9.98it/s]

Epoch: 900/1000 --> total loss: 0.87361717, error: 0.11081986, best error 0.11430246
visualizing solution...


100%|█████████▉| 999/1000 [02:05<00:00, 10.06it/s]

Epoch: 1000/1000 --> total loss: 0.47723839, error: 0.09555665, best error 0.09825733
visualizing solution...


100%|██████████| 1000/1000 [02:07<00:00,  7.81it/s]
