# Prepares

In [None]:
import os 

In [None]:
input_path = 'spinn_hmi'
input_original = os.path.join(input_path, 'original')
os.makedirs(input_original, exist_ok=True)

In [None]:
bin = 2
b_norm = 2500
spatial_norm = 160
nz = 160

## SPINN

In [None]:
import pickle

In [None]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

In [None]:
import jax 
import jax.numpy as jnp
from jax import jvp
import optax
from flax import linen as nn 

from typing import Sequence
from functools import partial

import numpy as np

import time
from tqdm import trange

In [None]:
input_original = os.path.join(input_path, 'original')

b_bottom_path = os.path.join(input_path, "b_bottom.pickle")
bp_top_path = os.path.join(input_path, "bp_top.pickle")
bp_lateral_1_path = os.path.join(input_path, "bp_lateral_1.pickle")
bp_lateral_2_path = os.path.join(input_path, "bp_lateral_2.pickle")
bp_lateral_3_path = os.path.join(input_path, "bp_lateral_3.pickle")
bp_lateral_4_path = os.path.join(input_path, "bp_lateral_4.pickle")

In [None]:
with open(b_bottom_path,"rb") as f:
    b_bottom = pickle.load(f)

with open(bp_top_path,"rb") as f:
    bp_top = pickle.load(f)

with open(bp_lateral_1_path,"rb") as f:
    bp_lateral_1 = pickle.load(f)

with open(bp_lateral_2_path,"rb") as f:
    bp_lateral_2 = pickle.load(f)

with open(bp_lateral_3_path,"rb") as f:
    bp_lateral_3 = pickle.load(f)

with open(bp_lateral_4_path,"rb") as f:
    bp_lateral_4 = pickle.load(f)

In [None]:
seed = 111
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)

In [None]:
features = 256
n_layers = 8 
feat_sizes = tuple([features for _ in range(n_layers)]) 
r = 128 
out_dim = 3 

lr = 5e-4

epochs = 2000
log_iter = 100

In [None]:
from zpinn.spinn_cleanup_new import SPINN3d, generate_train_data, apply_model_spinn, update_model

In [None]:
b_bottom.shape

(344, 224, 3)

In [None]:
nx, ny, _ = b_bottom.shape

In [None]:
model = SPINN3d(feat_sizes, r, out_dim, pos_enc=0, mlp='modified_mlp')
params = model.init(
            subkey,
            jnp.ones((nx, 1)),
            jnp.ones((ny, 1)),
            jnp.ones((nz, 1))
        )
apply_fn = jax.jit(model.apply)
optim = optax.adam(learning_rate=lr)
state = optim.init(params)

In [None]:
key, subkey = jax.random.split(key, 2)
train_data = generate_train_data(nx, ny, nz, key)

In [None]:
boundary_data = (b_bottom, bp_top, bp_lateral_1, bp_lateral_2, bp_lateral_3, bp_lateral_4)

In [None]:
train_boundary_data = [train_data, boundary_data]

In [None]:
loss, gradient = apply_model_spinn(apply_fn, params, train_boundary_data)
params, state = update_model(optim, gradient, params, state)

In [None]:
result_path = 'spinn_hmi/output'
os.makedirs(result_path, exist_ok=True)

In [None]:
start = time.time()
for e in trange(1, epochs + 1):
    
    # if e % 300 == 0:
    #     # sample new input data
    #     key, subkey = jax.random.split(key, 2)
    #     train_data = generate_train_data(nx, ny, nz, subkey)

    loss, gradient = apply_model_spinn(apply_fn, params, train_boundary_data)
    params, state = update_model(optim, gradient, params, state)
    
    if e % log_iter == 0 or e == 1:
        print(f'Epoch: {e}/{epochs} --> total loss: {loss:.8f}')
        params_path = os.path.join(result_path, f"params_{e}.pickle")
        with open(params_path, "wb") as f:
            pickle.dump(params, f)

runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(epochs-1)*1000):.2f}ms/iter.)')

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 1/2000 [00:00<03:43,  8.96it/s]

Epoch: 1/2000 --> total loss: 0.04490991


  5%|▌         | 103/2000 [00:09<02:44, 11.52it/s]

Epoch: 100/2000 --> total loss: 0.03350798


 10%|█         | 203/2000 [00:18<02:40, 11.23it/s]

Epoch: 200/2000 --> total loss: 0.03296580


 15%|█▌        | 303/2000 [00:26<02:31, 11.23it/s]

Epoch: 300/2000 --> total loss: 0.03160987


 20%|██        | 403/2000 [00:35<02:21, 11.25it/s]

Epoch: 400/2000 --> total loss: 0.02852959


 25%|██▌       | 503/2000 [00:44<02:12, 11.30it/s]

Epoch: 500/2000 --> total loss: 0.02684779


 30%|███       | 603/2000 [00:53<02:03, 11.28it/s]

Epoch: 600/2000 --> total loss: 0.02492625


 35%|███▌      | 703/2000 [01:02<01:54, 11.30it/s]

Epoch: 700/2000 --> total loss: 0.02416475


 40%|████      | 803/2000 [01:10<01:46, 11.29it/s]

Epoch: 800/2000 --> total loss: 0.02267363


 45%|████▌     | 903/2000 [01:19<01:38, 11.14it/s]

Epoch: 900/2000 --> total loss: 0.02149615


 50%|█████     | 1003/2000 [01:28<01:30, 11.01it/s]

Epoch: 1000/2000 --> total loss: 0.02314830


 55%|█████▌    | 1103/2000 [01:37<01:20, 11.16it/s]

Epoch: 1100/2000 --> total loss: 0.01667921


 60%|██████    | 1203/2000 [01:46<01:10, 11.24it/s]

Epoch: 1200/2000 --> total loss: 0.01552274


 65%|██████▌   | 1303/2000 [01:55<01:01, 11.25it/s]

Epoch: 1300/2000 --> total loss: 0.01517787


 70%|███████   | 1403/2000 [02:04<00:53, 11.16it/s]

Epoch: 1400/2000 --> total loss: 0.01443515


 75%|███████▌  | 1503/2000 [02:13<00:43, 11.33it/s]

Epoch: 1500/2000 --> total loss: 0.01397124


 80%|████████  | 1603/2000 [02:22<00:35, 11.16it/s]

Epoch: 1600/2000 --> total loss: 0.01359771


 85%|████████▌ | 1703/2000 [02:30<00:26, 11.10it/s]

Epoch: 1700/2000 --> total loss: 0.01375258


 90%|█████████ | 1803/2000 [02:39<00:17, 11.04it/s]

Epoch: 1800/2000 --> total loss: 0.01283587


 95%|█████████▌| 1903/2000 [02:48<00:08, 11.11it/s]

Epoch: 1900/2000 --> total loss: 0.01246905


100%|██████████| 2000/2000 [02:57<00:00, 11.27it/s]

Epoch: 2000/2000 --> total loss: 0.01246227
Runtime --> total: 177.40sec (88.75ms/iter.)



