# SPINN NF2
> NF2 + SPINN

In [None]:
import os

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

In [None]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

## Input data

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from zpinn.pinn_nf2_cleanup import create_coordinates

In [None]:
import json

In [None]:
import pickle

In [None]:
with open('config.json') as config:
    info = json.load(config)

nx = info['nx']
ny = info['ny']
nz = info['nz']
b_norm = info['b_norm']

In [None]:
# import pickle
# with open("bv.pickle","rb") as f:
#     bv = pickle.load(f)

# with open("b_bottom.pickle","rb") as f:
#     b_bottom = pickle.load(f)

# with open("bp_top.pickle","rb") as f:
#     bp_top = pickle.load(f)

# with open("bp_lateral_1.pickle","rb") as f:
#     bp_lateral_1 = pickle.load(f)

# with open("bp_lateral_2.pickle","rb") as f:
#     bp_lateral_2 = pickle.load(f)

# with open("bp_lateral_3.pickle","rb") as f:
#     bp_lateral_3 = pickle.load(f)

# with open("bp_lateral_4.pickle","rb") as f:
#     bp_lateral_4 = pickle.load(f)

In [None]:
from zpinn.lowloumag import LowLouMag

In [None]:
b = LowLouMag(resolutions=[nx, ny, nz])
b.calculate()
b_bottom = np.array(b.b_bottom)/b_norm 
bv = b.grid['B']

In [None]:
res_path = 'spinn'
os.makedirs(res_path, exist_ok=True)

In [None]:
b_path = os.path.join(res_path, "b.pickle")
bv_path = os.path.join(res_path, "bv.pickle")
b_bottom_path = os.path.join(res_path, "b_bottom.pickle")

In [None]:
with open(b_bottom_path,"wb") as f:
    pickle.dump(np.array(b_bottom), f)

In [None]:
with open(b_path,"wb") as f:
    pickle.dump(b, f)

In [None]:
with open(bv_path,"wb") as f:
    pickle.dump(np.array(bv), f)

In [None]:
import pickle
with open("bv.pickle","rb") as f:
    bv = pickle.load(f)

with open("b_bottom.pickle","rb") as f:
    b_bottom = pickle.load(f)

with open("bp_top.pickle","rb") as f:
    bp_top = pickle.load(f)

with open("bp_lateral_1.pickle","rb") as f:
    bp_lateral_1 = pickle.load(f)

with open("bp_lateral_2.pickle","rb") as f:
    bp_lateral_2 = pickle.load(f)

with open("bp_lateral_3.pickle","rb") as f:
    bp_lateral_3 = pickle.load(f)

with open("bp_lateral_4.pickle","rb") as f:
    bp_lateral_4 = pickle.load(f)

# Model

In [None]:
import jax 
import jax.numpy as jnp
from jax import jvp
import optax
from flax import linen as nn 

from typing import Sequence
from functools import partial

import time
from tqdm import trange

In [None]:
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out
    
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

In [None]:
class SPINN3d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int
    pos_enc: int
    mlp: str

    @nn.compact
    def __call__(self, x, y, z):
        '''
        inputs: input factorized coordinates
        outputs: feature output of each body network
        xy: intermediate tensor for feature merge btw. x and y axis
        pred: final model prediction (e.g. for 2d output, pred=[u, v])
        '''
        if self.pos_enc != 0:
            # positional encoding only to spatial coordinates
            freq = jnp.expand_dims(jnp.arange(1, self.pos_enc+1, 1), 0)
            y = jnp.concatenate((jnp.ones((y.shape[0], 1)), jnp.sin(y@freq), jnp.cos(y@freq)), 1)
            z = jnp.concatenate((jnp.ones((z.shape[0], 1)), jnp.sin(z@freq), jnp.cos(z@freq)), 1)

            # causal PINN version (also on time axis)
            #  freq_x = jnp.expand_dims(jnp.power(10.0, jnp.arange(0, 3)), 0)
            # x = x@freq_x
            
        inputs, outputs, xy, pred = [x, y, z], [], [], []
        init = nn.initializers.glorot_normal()

        if self.mlp == 'mlp':
            for X in inputs:
                for fs in self.features[:-1]:
                    X = nn.Dense(fs, kernel_init=init)(X)
                    X = nn.activation.tanh(X)
                X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
                outputs += [jnp.transpose(X, (1, 0))]

        elif self.mlp == 'modified_mlp':
            for X in inputs:
                U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=init)(X))
                for fs in self.features[:-1]:
                    Z = nn.Dense(fs, kernel_init=init)(H)
                    Z = nn.activation.tanh(Z)
                    H = (jnp.ones_like(Z)-Z)*U + Z*V
                H = nn.Dense(self.r*self.out_dim, kernel_init=init)(H)
                outputs += [jnp.transpose(H, (1, 0))]
        
        for i in range(self.out_dim):
            xy += [jnp.einsum('fx, fy->fxy', outputs[0][self.r*i:self.r*(i+1)], outputs[1][self.r*i:self.r*(i+1)])]
            pred += [jnp.einsum('fxy, fz->xyz', xy[i], outputs[-1][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

In [None]:
seed = 111
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)

2023-07-05 07:10:32.320651: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [None]:
features = 256 # feature size of each layer
n_layers = 8 # the number of layer
feat_sizes = tuple([features for _ in range(n_layers)]) # feature sizes
r = 128 # rank of a approximated tensor
out_dim = 3 # size of model output

lr = 5e-4 # learning rate

epochs = 2000 #10000
log_iter = 100 #1000

# Loss weight
# lbda_c = 100
# lbda_ic = 10

In [None]:
@partial(jax.jit)
def generate_train_data():
#     keys = jax.random.split(key, 4)
    
    # collocation points
#     xc = jax.random.uniform(keys[1], (nc, 1), minval=0., maxval=2.)
#     yc = jax.random.uniform(keys[2], (nc, 1), minval=0., maxval=2.)
#     zc = jax.random.uniform(keys[3], (nc, 1), minval=0., maxval=2.)
    xc = jnp.linspace(0, 2, nx).reshape(nx, 1)
    yc = jnp.linspace(0, 2, ny).reshape(ny, 1)
    zc = jnp.linspace(0, 2, nz).reshape(nz, 1)

    # # boundary points
    xb = [jnp.linspace(0, 2, nx).reshape(nx, 1), # z=0   bottom
          jnp.linspace(0, 2, nx).reshape(nx, 1), # z=2   top
          jnp.array([[0.]]),                     # x=0   lateral_1
          jnp.array([[2.]]),                     # x=2   lateral_2
          jnp.linspace(0, 2, nx).reshape(nx, 1), # y=0   lateral_3
          jnp.linspace(0, 2, nx).reshape(nx, 1)] # y=2   lateral_4

    yb = [jnp.linspace(0, 2, ny).reshape(ny, 1), 
          jnp.linspace(0, 2, ny).reshape(ny, 1), 
          jnp.linspace(0, 2, ny).reshape(ny, 1), 
          jnp.linspace(0, 2, ny).reshape(ny, 1), 
          jnp.array([[0.]]), 
          jnp.array([[2.]])]

    zb = [jnp.array([[0.]]), 
          jnp.array([[2.]]), 
          jnp.linspace(0, 2, nz).reshape(nz, 1), 
          jnp.linspace(0, 2, nz).reshape(nz, 1), 
          jnp.linspace(0, 2, nz).reshape(nz, 1), 
          jnp.linspace(0, 2, nz).reshape(nz, 1)]

    return xc, yc, zc, xb, yb, zb

In [None]:
def curlx(apply_fn, params, x, y, z):
    # curl vector w/ forward-mode AD
    # w_x = uz_y - uy_z
    vec_z = jnp.ones(z.shape)
    vec_y = jnp.ones(y.shape)
    uy_z = jvp(lambda z: apply_fn(params, x, y, z)[1], (z,), (vec_z,))[1]
    uz_y = jvp(lambda y: apply_fn(params, x, y, z)[2], (y,), (vec_y,))[1]
    wx = uz_y - uy_z
    return wx


def curly(apply_fn, params, x, y, z):
    # curl vector w/ forward-mode AD
    # w_y = ux_z - uz_x
    vec_z = jnp.ones(z.shape)
    vec_x = jnp.ones(x.shape)
    ux_z = jvp(lambda z: apply_fn(params, x, y, z)[0], (z,), (vec_z,))[1]
    uz_x = jvp(lambda x: apply_fn(params, x, y, z)[2], (x,), (vec_x,))[1]
    wy = ux_z - uz_x
    return wy

def curlz(apply_fn, params, x, y, z):
    # curl vector w/ forward-mode AD
    # w_z = uy_x - ux_y
    vec_y = jnp.ones(y.shape)
    vec_x = jnp.ones(x.shape)
    ux_y = jvp(lambda y: apply_fn(params, x, y, z)[0], (y,), (vec_y,))[1]
    uy_x = jvp(lambda x: apply_fn(params, x, y, z)[1], (x,), (vec_x,))[1]
    wz = uy_x - ux_y
    return wz


In [None]:
@partial(jax.jit, static_argnums=(0,))
def apply_model_spinn(apply_fn, params, *train_data):
    def residual_loss(params, x, y, z):
        # calculate u
        Bx, By, Bz = apply_fn(params, x, y, z)
        B = jnp.stack([Bx, By, Bz], axis=-1)
        
        # calculate J
        Jx = curlx(apply_fn, params, x, y, z)
        Jy = curly(apply_fn, params, x, y, z)
        Jz = curlz(apply_fn, params, x, y, z)
        J = jnp.stack([Jx, Jy, Jz], axis=-1)

        JxB = jnp.cross(J, B, axis=-1) 

        #-----------------------------------------------------------
        loss_ff = jnp.sum(JxB**2, axis=-1) / (jnp.sum(B**2, axis=-1) + 1e-7)
        loss_ff = jnp.mean(loss_ff)

        # loss_ff = jnp.mean(JxB**2)

        # loss_ff = jnp.sum(JxB**2, axis=-1)
        # loss_ff = jnp.mean(loss_ff)
        #-----------------------------------------------------------

        # tangent vector dx/dx
        # assumes x, y, z have same shape (very important)
        vec_x = jnp.ones(x.shape)
        vec_y = jnp.ones(y.shape)
        vec_z = jnp.ones(z.shape)
        
        Bx_x = jvp(lambda x: apply_fn(params, x, y, z)[0], (x,), (vec_x,))[1]
        # Bx_y = jvp(lambda y: apply_fn(params, x, y, z)[0], (y,), (vec,))[1]
        # Bx_z = jvp(lambda z: apply_fn(params, x, y, z)[0], (z,), (vec,))[1]

        # By_x = jvp(lambda x: apply_fn(params, x, y, z)[1], (x,), (vec,))[1]
        By_y = jvp(lambda y: apply_fn(params, x, y, z)[1], (y,), (vec_y,))[1]
        # By_z = jvp(lambda z: apply_fn(params, x, y, z)[1], (z,), (vec,))[1]

        # Bz_x = jvp(lambda x: apply_fn(params, x, y, z)[2], (x,), (vec,))[1]
        # Bz_y = jvp(lambda y: apply_fn(params, x, y, z)[2], (y,), (vec,))[1]
        Bz_z = jvp(lambda z: apply_fn(params, x, y, z)[2], (z,), (vec_z,))[1]

        divB = Bx_x + By_y + Bz_z
        
        #-----------------------------------------------------------
        # loss_div = jnp.sum((divB)**2, axis=-1)
        # loss_div = jnp.mean(loss_div)

        loss_div = jnp.mean((divB)**2)
        #-----------------------------------------------------------

        loss = loss_ff + loss_div

        return loss

    def boundary_loss(params, x, y, z):
        
        # loss = 0.
        # for i in np.arange(4):
        #     boundary_data_batched = boundary_batches[i, :, :, :]
        #     xb = boundary_data_batched[:, 0, :][:, 0].reshape(-1, 1)
        #     yb = boundary_data_batched[:, 0, :][:, 1].reshape(-1, 1)
        #     zb = boundary_data_batched[:, 0, :][:, 2].reshape(-1, 1)

        #     Bx, By, Bz = apply_fn(params, xb, yb, zb)
        #     # Bx, By, Bz = Bx.reshape(-1, 1), By.reshape(-1, 1), Bz.reshape(-1, 1)

        #     Bxb = boundary_data_batched[:, 1, :][:, 0].reshape(-1, 1)
        #     Byb = boundary_data_batched[:, 1, :][:, 1].reshape(-1, 1)
        #     Bzb = boundary_data_batched[:, 1, :][:, 2].reshape(-1, 1)

        #     Bxb_mesh, Byb_mesh, Bzb_mesh = jnp.meshgrid(Bxb.ravel(), Byb.ravel(), Bzb.ravel(), indexing='ij')
            
        #     loss += jnp.mean((Bx - Bxb_mesh)**2) + jnp.mean((By - Byb_mesh)**2) + jnp.mean((Bz - Bzb_mesh)**2)

        #0 z=0   bottom
        #1 z=2   top                  -> Only normal(Bz), Bx=0, By=0
        #2 x=0   lateral_1            -> Only tangential(By, Bz), Bx=0
        #3 x=2   lateral_2            -> Only tangential(By, Bz), Bx=0
        #4 y=0   lateral_3            -> Only tangential(Bx, Bz), By=0
        #5 y=2   lateral_4            -> Only tangential(Bx, Bz), By=0
        

        loss = 0.
        Bx, By, Bz = apply_fn(params,  x[0], y[0], z[0])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean((Bx - b_bottom[:, :, 0])**2) + jnp.mean((By - b_bottom[:, :, 1])**2) + jnp.mean((Bz - b_bottom[:, :, 2])**2)

        Bx, By, Bz = apply_fn(params,  x[1], y[1], z[1])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean(Bx**2) + jnp.mean(By**2)

        Bx, By, Bz = apply_fn(params,  x[2], y[2], z[2])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean(Bx**2)

        Bx, By, Bz = apply_fn(params,  x[3], y[3], z[3])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean(Bx**2)
        
        Bx, By, Bz = apply_fn(params,  x[4], y[4], z[4])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean(By**2)

        Bx, By, Bz = apply_fn(params,  x[5], y[5], z[5])
        Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        loss += jnp.mean(By**2)

        # Bx, By, Bz = apply_fn(params,  x[1], y[1], z[1])
        # Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        # loss += jnp.mean((Bx - bp_top[:, :, 0])**2) + jnp.mean((By - bp_top[:, :, 1])**2) + jnp.mean((Bz - bp_top[:, :, 2])**2)

        # Bx, By, Bz = apply_fn(params,  x[2], y[2], z[2])
        # Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        # loss += jnp.mean((Bx - bp_lateral_1[:, :, 0])**2) + jnp.mean((By - bp_lateral_1[:, :, 1])**2) + jnp.mean((Bz - bp_lateral_1[:, :, 2])**2)

        # Bx, By, Bz = apply_fn(params,  x[3], y[3], z[3])
        # Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        # loss += jnp.mean((Bx - bp_lateral_2[:, :, 0])**2) + jnp.mean((By - bp_lateral_2[:, :, 1])**2) + jnp.mean((Bz - bp_lateral_2[:, :, 2])**2)

        # Bx, By, Bz = apply_fn(params,  x[4], y[4], z[4])
        # Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        # loss += jnp.mean((Bx - bp_lateral_3[:, :, 0])**2) + jnp.mean((By - bp_lateral_3[:, :, 1])**2) + jnp.mean((Bz - bp_lateral_3[:, :, 2])**2)

        # Bx, By, Bz = apply_fn(params,  x[5], y[5], z[5])
        # Bx, By, Bz = jnp.squeeze(Bx), jnp.squeeze(By), jnp.squeeze(Bz)
        # loss += jnp.mean((Bx - bp_lateral_4[:, :, 0])**2) + jnp.mean((By - bp_lateral_4[:, :, 1])**2) + jnp.mean((Bz - bp_lateral_4[:, :, 2])**2)

        
        return loss

    # unpack data
    xc, yc, zc, xb, yb, zb = train_data

    # isolate loss func from redundant arguments
    loss_fn = lambda params: residual_loss(params, xc, yc, zc) + boundary_loss(params, xb, yb, zb)

    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [None]:
model = SPINN3d(feat_sizes, r, out_dim, pos_enc=0, mlp='modified_mlp')
params = model.init(
            subkey,
            jnp.ones((nx, 1)),
            jnp.ones((ny, 1)),
            jnp.ones((nz, 1))
        )
apply_fn = jax.jit(model.apply)
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
key, subkey = jax.random.split(key, 2)
train_data = generate_train_data()

complie (it takes time)

In [None]:
loss, gradient = apply_model_spinn(apply_fn, params, *train_data)
params, state = update_model(optim, gradient, params, state)

In [None]:
start = time.time()
for e in trange(1, epochs + 1):
    
    # if e % 1000 == 0:
    #     # sample new input data
    #     key, subkey = jax.random.split(key, 2)
    #     train_data = generate_train_data(nc, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, *train_data)
    params, state = update_model(optim, gradient, params, state)
    
    if e % log_iter == 0 or e == 1:
        print(f'Epoch: {e}/{epochs} --> total loss: {loss:.8f}')

runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(epochs-1)*1000):.2f}ms/iter.)')

  0%|          | 1/2000 [00:00<16:40,  2.00it/s]

Epoch: 1/2000 --> total loss: 0.22356929


  2%|▏         | 30/2000 [00:08<09:26,  3.48it/s]


KeyboardInterrupt: 

# Visualization

In [None]:
xx = jnp.linspace(0, 2, nx).reshape(-1, 1)
yy = jnp.linspace(0, 2, ny).reshape(-1, 1)
zz = jnp.linspace(0, 2, nz).reshape(-1, 1)
xx, yy, zz = jax.lax.stop_gradient(xx), jax.lax.stop_gradient(yy), jax.lax.stop_gradient(zz)
Bxx, Byy, Bzz = apply_fn(params, xx, yy, zz)
Bxx, Byy, Bzz = Bxx*b_norm, Byy*b_norm, Bzz*b_norm

In [None]:
Bb = jnp.stack([Bxx, Byy, Bzz], axis=-1)

In [None]:
Bb.shape

(256, 256, 256, 3)

In [None]:
Jxx = curlx(apply_fn, params, xx, yy, zz)
Jyy = curly(apply_fn, params, xx, yy, zz)
Jzz = curlz(apply_fn, params, xx, yy, zz)
Jj = jnp.stack([Jxx, Jyy, Jzz], axis=-1)

In [None]:
Jj.shape

(256, 256, 256, 3)

In [None]:
import pickle

with open("x.pickle","wb") as f:
    pickle.dump(np.array(xx), f)

with open("y.pickle","wb") as f:
    pickle.dump(np.array(yy), f)

with open("z.pickle","wb") as f:
    pickle.dump(np.array(zz), f)

with open("B.pickle","wb") as f:
    pickle.dump(np.array(Bb), f)

with open("J.pickle","wb") as f:
    pickle.dump(np.array(Jj), f)