In [None]:
# !wget https://hinode.isee.nagoya-u.ac.jp/nlfff_database/v12/11158/20110213/11158_20110213_120000.nc

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

import jax 
import jax.numpy as jnp
import optax
import time 
import pickle
from tqdm import trange

from utils.spinn import SPINN3d, generate_train_data, apply_model_spinn, update_model
from utils.data_load import load_nc

In [None]:
base_path = 'spinn'
os.makedirs(base_path, exist_ok=True)
save_path = os.path.join(base_path, 'params.pickle')

features = 256
n_layers = 8 
r = 256
out_dim = 3 
pos_enc = 0 
mlp = 'modified_mlp'
b_norm = 2500
height = 257

ncx = 32
ncy = 32
ncz = 32

iterations = 100000
log_iter = 10000
random_iter = 100
lr = 1e-4

In [None]:
key = jax.random.PRNGKey(0)

# make & init model forward function
key, subkey = jax.random.split(key, 2)

feat_sizes = tuple([features for _ in range(n_layers)]) 

model = SPINN3d(feat_sizes, r, out_dim, pos_enc, mlp)

params = model.init(
                key,
                jnp.ones((ncx, 1)),
                jnp.ones((ncy, 1)),
                jnp.ones((ncz, 1))
            )

apply_fn, params = jax.jit(model.apply), params

# count total params
total_params = sum(x.size for x in jax.tree_util.tree_leaves(params))

# optimizer
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
b_true = load_nc('11158_20110213_120000.nc')

b_bottom = b_true[:, :, 0, :]
b_true.shape, b_bottom.shape

((513, 257, 257, 3), (513, 257, 3))

In [None]:
nx, ny, _ = b_bottom.shape
nz = height

nx, ny, nz

(513, 257, 257)

In [None]:
b_bottom_normalized = b_bottom / b_norm

In [None]:
# dataset
key, subkey = jax.random.split(key, 2)
xc, yc, zc, xb, yb, zb = generate_train_data(ncx, ncy, ncz, nx, ny, nz, subkey)

# start training
for e in trange(1, iterations + 1):
    if e == 2:
        # exclude compiling time
        start = time.time()

    if e % random_iter == 0:
        key, subkey = jax.random.split(key, 2)
        xc, yc, zc, xb, yb, zb = generate_train_data(ncx, ncy, ncz, nx, ny, nz, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, xc, yc, zc, xb, yb, zb, b_bottom_normalized)
    params, state = update_model(optim, gradient, params, state)

    # log
    if e % log_iter == 0:
        print(f'Iteration: {e}/{iterations} --> total loss: {loss:.8f}')

# training done
runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(iterations-1)*1000):.2f}ms/iter.)')
with open(save_path, "wb") as f:
    pickle.dump(params, f)

 10%|█         | 10015/100000 [02:01<12:05, 124.09it/s]

Iteration: 10000/100000 --> total loss: 0.01025009


 20%|██        | 20023/100000 [03:22<10:46, 123.80it/s]

Iteration: 20000/100000 --> total loss: 0.06579896


 30%|███       | 30024/100000 [04:43<09:25, 123.73it/s]

Iteration: 30000/100000 --> total loss: 0.01897338


 40%|████      | 40023/100000 [06:03<08:04, 123.69it/s]

Iteration: 40000/100000 --> total loss: 0.00432796


 50%|█████     | 50021/100000 [07:24<06:43, 123.85it/s]

Iteration: 50000/100000 --> total loss: 0.00745599


 60%|██████    | 60020/100000 [08:44<05:23, 123.73it/s]

Iteration: 60000/100000 --> total loss: 0.00411018


 70%|███████   | 70020/100000 [10:05<04:01, 124.03it/s]

Iteration: 70000/100000 --> total loss: 0.07731351


 80%|████████  | 80020/100000 [11:25<02:41, 123.87it/s]

Iteration: 80000/100000 --> total loss: 0.06275281


 90%|█████████ | 90019/100000 [12:46<01:20, 124.03it/s]

Iteration: 90000/100000 --> total loss: 0.00272743


100%|██████████| 100000/100000 [14:06<00:00, 118.11it/s]


Iteration: 100000/100000 --> total loss: 0.03043940
Runtime --> total: 805.34sec (8.05ms/iter.)
