# SPINN (series)

In [None]:
#| default_exp spinn_jax_trainer

In [None]:
from setproctitle import setproctitle
setproctitle("SPINN (series)")

In [None]:
#| export
import os

In [None]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "2"

In [None]:
#| export
import jax 
import jax.numpy as jnp
import optax
import numpy as np 
import matplotlib.pyplot as plt
import time 
import pickle
from tqdm import trange

In [None]:
#| export
from cmspinn.spinn_jax import SPINN3d, generate_train_data, apply_model_spinn, update_model

In [None]:
#| export
import glob
import time
from pathlib import Path

In [None]:
b_bottom_paths = os.path.expanduser('~/workspace/_data/NOAA12673/b_bottom')
b_bottom_list = sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy')))
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
Path(spinn_output_path).mkdir(parents=True, exist_ok=True)

In [None]:
spinn_BC_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_BC')

In [None]:
#| export
class SPINN_Trainer:
    def __init__(self, output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None):
        os.makedirs(output_path, exist_ok=True)
        
        Nx, Ny, _ = b_bottom.shape

        features = 128
        n_layers = 3
        feat_sizes = tuple([features for _ in range(n_layers)]) 
        r = 128 
        out_dim = 3 

        lr = 5e-4

        pos_enc = 0
        mlp = 'modified_mlp'

        n_max_x = 2*(Nx/Nz)
        n_max_y = 2*(Ny/Nz)
        n_max_z = 2*(Nz/Nz)

        parameters = {'feat_sizes' : feat_sizes, 
              'r' : r, 
              'out_dim' : out_dim, 
              'Nx' : Nx, 
              'Ny' : Ny, 
              'Nz' : Nz, 
              'b_norm' : b_norm,
              'pos_enc' : pos_enc,
              'mlp' : mlp,
              'lr': lr,
              'n_max_x': n_max_x,
              'n_max_y': n_max_y,
              'n_max_z': n_max_z,}

        
        parameters_path = os.path.join(output_path, "parameters.pickle")
        with open(parameters_path, "wb") as f:
            pickle.dump(parameters, f)


        seed = 111
        key = jax.random.PRNGKey(seed)
        key, subkey = jax.random.split(key, 2)

        model = SPINN3d(feat_sizes, r, out_dim, pos_enc=pos_enc, mlp=mlp)
        if transfer_learning_path is None:
            params = model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)
        else:
            model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            with open(transfer_learning_path, 'rb') as f:
                params = pickle.load(f)

            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)

        with open(BC_path, 'rb') as f:
            boundary_data = pickle.load(f)

        train_data = generate_train_data(subkey, Nx, Ny, Nz, n_max_x, n_max_y, n_max_z)
        train_boundary_data = [train_data, boundary_data]

        self.apply_fn = apply_fn
        self.params = params
        self.train_boundary_data = train_boundary_data
        self.optim = optim
        self.state = state
        self.output_path = output_path

    def train(self, total_iterations, log_iterations, loss_threshold=0.001):
        params = self.params
        state = self.state

        losses = []
        print('Complie Start')
        start = time.time()
        loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
        losses.append(loss.item())
        params, state = update_model(self.optim, gradient, params, state)
        runtime = time.time() - start
        print(f'Complie End --> total: {runtime:.2f}sec')

        start = time.time()
        for e in trange(1, total_iterations + 1):

            loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
            losses.append(loss.item())
            if loss.item() < loss_threshold:
                print(f'Epoch: {e}/{total_iterations} --> loss: {loss:.8f} < {loss_threshold}')
                break
            
            params, state = update_model(self.optim, gradient, params, state)
            
            if e % log_iterations == 0:
                print(f'Epoch: {e}/{total_iterations} --> total loss: {loss:.8f}')
                params_path = os.path.join(self.output_path, f"params_{e}.pickle")
                with open(params_path, "wb") as f:
                    pickle.dump(params, f)

        final_params_path = os.path.join(self.output_path, f"final_params.pickle")
        with open(final_params_path, "wb") as f:
            pickle.dump(params, f)

        np.save(os.path.join(self.output_path, 'losses.npy'), losses)

        runtime = time.time() - start
        print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(total_iterations-1)*1000):.2f}ms/iter.)')

In [None]:
Nz = 160
b_norm = 2500

epochs = 50000
log_iter = 1000

series_epochs = 2000
series_log_iter = 100

loss_threshold = 0.01

In [None]:
transfer_learning_path = None

start_time = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):

    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]
    
    with open(b_bottom_path, 'rb') as f:
        b_bottom = np.load(f)

    output_path = os.path.join(spinn_output_path, b_bottom_date)

    BC_path = os.path.join(spinn_BC_path, f'b_BC_{b_bottom_date}.pickle')

    final_params_path = os.path.join(output_path, 'final_params.pickle')
    if os.path.exists(final_params_path):
        transfer_learning_path = final_params_path
        continue

    if transfer_learning_path is None:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None)
        start = time.time()
        tranier.train(epochs, log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        print(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    else:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=transfer_learning_path)
        start = time.time()
        tranier.train(series_epochs, series_log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        print(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    
print(f'Total Runtime: {time.time() - start_time:.2f} sec')

Complie Start
Complie End --> total: 18.21sec


  2%|▏         | 1003/50000 [00:32<26:33, 30.74it/s]

Epoch: 1000/50000 --> total loss: 0.02508836


  4%|▍         | 2003/50000 [01:04<25:51, 30.93it/s]

Epoch: 2000/50000 --> total loss: 0.01509597


  6%|▌         | 3003/50000 [01:36<25:24, 30.83it/s]

Epoch: 3000/50000 --> total loss: 0.01353082


  8%|▊         | 4003/50000 [02:08<24:47, 30.92it/s]

Epoch: 4000/50000 --> total loss: 0.01216163


 10%|█         | 5003/50000 [02:40<24:22, 30.76it/s]

Epoch: 5000/50000 --> total loss: 0.01112062


 12%|█▏        | 6003/50000 [03:13<23:44, 30.89it/s]

Epoch: 6000/50000 --> total loss: 0.01043476


 13%|█▎        | 6539/50000 [03:30<23:18, 31.08it/s]


Epoch: 6540/50000 --> loss: 0.00999985 < 0.01
Runtime --> total: 210.42sec (4.21ms/iter.)
Runtime: 228.64 sec
Complie Start
Complie End --> total: 17.00sec


  5%|▌         | 104/2000 [00:03<01:01, 30.59it/s]

Epoch: 100/2000 --> total loss: 0.01025323


 10%|█         | 204/2000 [00:06<00:58, 30.80it/s]

Epoch: 200/2000 --> total loss: 0.01003259


 12%|█▏        | 247/2000 [00:08<00:56, 30.85it/s]


Epoch: 248/2000 --> loss: 0.00999980 < 0.01
Runtime --> total: 8.01sec (4.01ms/iter.)
Runtime: 25.02 sec
Complie Start
Complie End --> total: 17.37sec


  5%|▌         | 104/2000 [00:03<01:01, 30.78it/s]

Epoch: 100/2000 --> total loss: 0.01049765


 10%|█         | 204/2000 [00:06<00:58, 30.64it/s]

Epoch: 200/2000 --> total loss: 0.01038593


 15%|█▌        | 304/2000 [00:09<00:54, 30.91it/s]

Epoch: 300/2000 --> total loss: 0.01031721


 20%|██        | 404/2000 [00:13<00:51, 30.94it/s]

Epoch: 400/2000 --> total loss: 0.01023902


 25%|██▌       | 504/2000 [00:16<00:48, 30.92it/s]

Epoch: 500/2000 --> total loss: 0.01034957


 30%|███       | 604/2000 [00:19<00:45, 30.83it/s]

Epoch: 600/2000 --> total loss: 0.01029011


 35%|███▌      | 704/2000 [00:22<00:42, 30.59it/s]

Epoch: 700/2000 --> total loss: 0.01002608


 38%|███▊      | 752/2000 [00:24<00:40, 30.90it/s]


Epoch: 753/2000 --> loss: 0.00999780 < 0.01
Runtime --> total: 24.34sec (12.18ms/iter.)
Runtime: 41.72 sec
Complie Start
Complie End --> total: 17.52sec


  5%|▌         | 104/2000 [00:03<01:01, 30.75it/s]

Epoch: 100/2000 --> total loss: 0.01074501


 10%|█         | 204/2000 [00:06<00:58, 30.50it/s]

Epoch: 200/2000 --> total loss: 0.01056560


 15%|█▌        | 304/2000 [00:09<00:55, 30.79it/s]

Epoch: 300/2000 --> total loss: 0.01048623


 20%|██        | 404/2000 [00:13<00:52, 30.61it/s]

Epoch: 400/2000 --> total loss: 0.01048418


 25%|██▌       | 504/2000 [00:16<00:48, 30.93it/s]

Epoch: 500/2000 --> total loss: 0.01050294


 30%|███       | 604/2000 [00:19<00:45, 30.43it/s]

Epoch: 600/2000 --> total loss: 0.01034918


 35%|███▌      | 704/2000 [00:22<00:41, 30.93it/s]

Epoch: 700/2000 --> total loss: 0.01023860


 40%|████      | 804/2000 [00:26<00:39, 30.50it/s]

Epoch: 800/2000 --> total loss: 0.01017779


 45%|████▌     | 904/2000 [00:29<00:35, 30.92it/s]

Epoch: 900/2000 --> total loss: 0.01064013


 50%|█████     | 1004/2000 [00:32<00:32, 30.62it/s]

Epoch: 1000/2000 --> total loss: 0.01005933


 55%|█████▌    | 1100/2000 [00:35<00:29, 30.85it/s]


Epoch: 1100/2000 --> total loss: 0.01002633
Epoch: 1101/2000 --> loss: 0.00999764 < 0.01
Runtime --> total: 35.66sec (17.84ms/iter.)
Runtime: 53.18 sec
Complie Start
Complie End --> total: 18.06sec


  3%|▎         | 60/2000 [00:01<01:04, 30.25it/s]


Epoch: 61/2000 --> loss: 0.00998374 < 0.01
Runtime --> total: 1.99sec (1.00ms/iter.)
Runtime: 20.05 sec
Complie Start
Complie End --> total: 18.29sec


  2%|▏         | 43/2000 [00:01<01:04, 30.13it/s]


Epoch: 44/2000 --> loss: 0.00998162 < 0.01
Runtime --> total: 1.43sec (0.72ms/iter.)
Runtime: 19.72 sec
Complie Start
Complie End --> total: 18.28sec


  2%|▏         | 43/2000 [00:01<01:05, 30.05it/s]


Epoch: 44/2000 --> loss: 0.00999071 < 0.01
Runtime --> total: 1.44sec (0.72ms/iter.)
Runtime: 19.72 sec
Complie Start
Complie End --> total: 18.29sec


  2%|▏         | 41/2000 [00:01<01:05, 29.89it/s]


Epoch: 42/2000 --> loss: 0.00998686 < 0.01
Runtime --> total: 1.38sec (0.69ms/iter.)
Runtime: 19.67 sec
Complie Start
Complie End --> total: 18.06sec


  2%|▏         | 35/2000 [00:01<01:05, 29.94it/s]


Epoch: 36/2000 --> loss: 0.00995998 < 0.01
Runtime --> total: 1.18sec (0.59ms/iter.)
Runtime: 19.24 sec
Complie Start
Complie End --> total: 18.91sec


  2%|▏         | 37/2000 [00:01<01:05, 29.83it/s]


Epoch: 38/2000 --> loss: 0.00998883 < 0.01
Runtime --> total: 1.25sec (0.62ms/iter.)
Runtime: 20.16 sec
Complie Start
Complie End --> total: 18.74sec


  2%|▏         | 46/2000 [00:01<01:06, 29.57it/s]


Epoch: 47/2000 --> loss: 0.00998967 < 0.01
Runtime --> total: 1.57sec (0.78ms/iter.)
Runtime: 20.31 sec
Complie Start
Complie End --> total: 19.18sec


  3%|▎         | 56/2000 [00:01<01:04, 30.14it/s]


Epoch: 57/2000 --> loss: 0.00998930 < 0.01
Runtime --> total: 1.87sec (0.93ms/iter.)
Runtime: 21.05 sec
Complie Start
Complie End --> total: 19.59sec


  2%|▏         | 42/2000 [00:01<01:05, 30.12it/s]


Epoch: 43/2000 --> loss: 0.00999844 < 0.01
Runtime --> total: 1.40sec (0.70ms/iter.)
Runtime: 20.99 sec
Complie Start
Complie End --> total: 19.02sec


  3%|▎         | 56/2000 [00:01<01:04, 30.29it/s]


Epoch: 57/2000 --> loss: 0.00999411 < 0.01
Runtime --> total: 1.86sec (0.93ms/iter.)
Runtime: 20.88 sec
Complie Start
Complie End --> total: 18.72sec


  2%|▏         | 44/2000 [00:01<01:05, 30.02it/s]


Epoch: 45/2000 --> loss: 0.00999483 < 0.01
Runtime --> total: 1.48sec (0.74ms/iter.)
Runtime: 20.20 sec
Complie Start
Complie End --> total: 18.62sec


  3%|▎         | 65/2000 [00:02<01:03, 30.25it/s]


Epoch: 66/2000 --> loss: 0.00999522 < 0.01
Runtime --> total: 2.16sec (1.08ms/iter.)
Runtime: 20.77 sec
Complie Start
Complie End --> total: 19.12sec


  4%|▍         | 81/2000 [00:02<01:02, 30.51it/s]


Epoch: 82/2000 --> loss: 0.00999606 < 0.01
Runtime --> total: 2.66sec (1.33ms/iter.)
Runtime: 21.78 sec
Complie Start
Complie End --> total: 18.67sec


  5%|▌         | 104/2000 [00:03<01:01, 30.78it/s]

Epoch: 100/2000 --> total loss: 0.01123935


 10%|█         | 204/2000 [00:06<00:57, 31.00it/s]

Epoch: 200/2000 --> total loss: 0.01109009


 15%|█▌        | 304/2000 [00:09<00:55, 30.38it/s]

Epoch: 300/2000 --> total loss: 0.01077992


 20%|██        | 404/2000 [00:13<00:51, 30.86it/s]

Epoch: 400/2000 --> total loss: 0.01066878


 25%|██▌       | 504/2000 [00:16<00:49, 30.34it/s]

Epoch: 500/2000 --> total loss: 0.01050695


 30%|███       | 604/2000 [00:19<00:46, 30.21it/s]

Epoch: 600/2000 --> total loss: 0.01038185


 35%|███▌      | 704/2000 [00:22<00:42, 30.47it/s]

Epoch: 700/2000 --> total loss: 0.01033161


 40%|████      | 804/2000 [00:26<00:39, 30.43it/s]

Epoch: 800/2000 --> total loss: 0.01021169


 45%|████▌     | 904/2000 [00:29<00:35, 30.56it/s]

Epoch: 900/2000 --> total loss: 0.01007387


 49%|████▉     | 976/2000 [00:31<00:33, 30.67it/s]


Epoch: 977/2000 --> loss: 0.00999301 < 0.01
Runtime --> total: 31.83sec (15.92ms/iter.)
Runtime: 50.50 sec
Complie Start
Complie End --> total: 18.25sec


  3%|▎         | 51/2000 [00:01<01:04, 30.41it/s]


Epoch: 52/2000 --> loss: 0.00996890 < 0.01
Runtime --> total: 1.68sec (0.84ms/iter.)
Runtime: 19.94 sec
Complie Start
Complie End --> total: 17.92sec


  2%|▏         | 40/2000 [00:01<01:05, 29.99it/s]


Epoch: 41/2000 --> loss: 0.00998980 < 0.01
Runtime --> total: 1.34sec (0.67ms/iter.)
Runtime: 19.27 sec
Complie Start
Complie End --> total: 19.86sec


  1%|▏         | 27/2000 [00:00<01:06, 29.73it/s]


Epoch: 28/2000 --> loss: 0.00993900 < 0.01
Runtime --> total: 0.92sec (0.46ms/iter.)
Runtime: 20.78 sec
Complie Start
Complie End --> total: 18.23sec


  1%|▏         | 29/2000 [00:00<01:07, 29.33it/s]


Epoch: 30/2000 --> loss: 0.00995449 < 0.01
Runtime --> total: 1.00sec (0.50ms/iter.)
Runtime: 19.22 sec
Complie Start
Complie End --> total: 18.53sec


  1%|▏         | 27/2000 [00:00<01:06, 29.66it/s]

Epoch: 28/2000 --> loss: 0.00999631 < 0.01
Runtime --> total: 0.92sec (0.46ms/iter.)
Runtime: 19.45 sec
Total Runtime: 771.21 sec





In [None]:
from cmspinn.spinn_jax_viz import spinn_cube

In [None]:
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
vtk_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_vtk')
os.makedirs(vtk_output_path, exist_ok=True)

In [None]:
start = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):
    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]

    target_path = os.path.join(spinn_output_path, b_bottom_date)

    final_model_path = os.path.join(target_path, 'final_params.pickle')
    parameters_path = os.path.join(target_path, 'parameters.pickle')
    
    vtk_path = os.path.join(vtk_output_path, f'B_spinn_{b_bottom_date}.vtk')

    spinn = spinn_cube(final_model_path, parameters_path)
    spinn.calculate_magnetic_fields()
    spinn.grid.save(vtk_path)
    
    print(vtk_path)

runtime = time.time() - start
print(f'Runtime: {runtime:.2f} sec')

/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_083600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_084800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_090000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_091200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_092400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_093600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_094800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_100000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_101200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_102400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_103600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_104800.vtk
/userhome/jeon_m