# SPINN (series)

In [None]:
from setproctitle import setproctitle
setproctitle("SPINN (series)")

In [None]:
import os

In [None]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "2"

In [None]:
import jax 
import jax.numpy as jnp
import optax
import numpy as np 
import matplotlib.pyplot as plt
import time 
import pickle
from tqdm import trange

In [None]:
from cmspinn.spinn_jax import SPINN3d, generate_train_data, apply_model_spinn, update_model

In [None]:
import glob
import time
from pathlib import Path

In [None]:
b_bottom_paths = os.path.expanduser('~/workspace/_data/NOAA12673/b_bottom')
b_bottom_list = sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy')))
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
Path(spinn_output_path).mkdir(parents=True, exist_ok=True)

In [None]:
spinn_BC_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_BC')

In [None]:
class SPINN_Trainer:
    def __init__(self, output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None):
        os.makedirs(output_path, exist_ok=True)
        
        Nx, Ny, _ = b_bottom.shape

        features = 256
        n_layers = 8 
        feat_sizes = tuple([features for _ in range(n_layers)]) 
        r = 128 
        out_dim = 3 

        lr = 5e-4

        pos_enc = 0
        mlp = 'modified_mlp'

        parameters = {'feat_sizes' : feat_sizes, 
              'r' : r, 
              'out_dim' : out_dim, 
              'Nx' : Nx, 
              'Ny' : Ny, 
              'Nz' : Nz, 
              'b_norm' : b_norm,
              'pos_enc' : pos_enc,
              'mlp' : mlp,
              'lr': lr}
    
        parameters_path = os.path.join(output_path, "parameters.pickle")
        with open(parameters_path, "wb") as f:
            pickle.dump(parameters, f)


        seed = 111
        key = jax.random.PRNGKey(seed)
        key, subkey = jax.random.split(key, 2)

        model = SPINN3d(feat_sizes, r, out_dim, pos_enc=pos_enc, mlp=mlp)
        if transfer_learning_path is None:
            params = model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)
        else:
            model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            with open(transfer_learning_path, 'rb') as f:
                params = pickle.load(f)

            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)

        with open(BC_path, 'rb') as f:
            boundary_data = pickle.load(f)

        train_data = generate_train_data(subkey, Nx, Ny, Nz)
        train_boundary_data = [train_data, boundary_data]

        self.apply_fn = apply_fn
        self.params = params
        self.train_boundary_data = train_boundary_data
        self.optim = optim
        self.state = state
        self.output_path = output_path

    def train(self, total_iterations, log_iterations, loss_threshold=0.001):
        params = self.params
        state = self.state

        losses = []
        print('Complie Start')
        start = time.time()
        loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
        losses.append(loss.item())
        params, state = update_model(self.optim, gradient, params, state)
        runtime = time.time() - start
        print(f'Complie End --> total: {runtime:.2f}sec')

        start = time.time()
        for e in trange(1, total_iterations + 1):

            loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
            losses.append(loss.item())
            if loss.item() < loss_threshold:
                print(f'Epoch: {e}/{total_iterations} --> loss: {loss:.8f} < {loss_threshold}')
                break
            
            params, state = update_model(self.optim, gradient, params, state)
            
            if e % log_iterations == 0:
                print(f'Epoch: {e}/{total_iterations} --> total loss: {loss:.8f}')
                params_path = os.path.join(self.output_path, f"params_{e}.pickle")
                with open(params_path, "wb") as f:
                    pickle.dump(params, f)

        final_params_path = os.path.join(self.output_path, f"final_params.pickle")
        with open(final_params_path, "wb") as f:
            pickle.dump(params, f)

        np.save(os.path.join(self.output_path, 'losses.npy'), losses)

        runtime = time.time() - start
        print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(total_iterations-1)*1000):.2f}ms/iter.)')

In [None]:
Nz = 160
b_norm = 2500

epochs = 50000
log_iter = 1000

series_epochs = 2000
series_log_iter = 100

loss_threshold = 0.006

In [None]:
transfer_learning_path = None

start_time = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):

    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]
    
    with open(b_bottom_path, 'rb') as f:
        b_bottom = np.load(f)

    output_path = os.path.join(spinn_output_path, b_bottom_date)

    BC_path = os.path.join(spinn_BC_path, f'b_BC_{b_bottom_date}.pickle')

    final_params_path = os.path.join(output_path, 'final_params.pickle')
    if os.path.exists(final_params_path):
        transfer_learning_path = final_params_path
        continue

    if transfer_learning_path is None:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None)
        start = time.time()
        tranier.train(epochs, log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        print(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    else:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=transfer_learning_path)
        start = time.time()
        tranier.train(series_epochs, series_log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        print(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    
print(f'Total Runtime: {time.time() - start_time:.2f} sec')

Complie Start
Complie End --> total: 63.27sec


  1%|          | 270/50000 [00:13<39:01, 21.23it/s]

In [None]:
from cmspinn.spinn_jax_viz import spinn_cube

In [None]:
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
vtk_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_vtk')
os.makedirs(vtk_output_path, exist_ok=True)

In [None]:
start = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):
    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]

    target_path = os.path.join(spinn_output_path, b_bottom_date)

    final_model_path = os.path.join(target_path, 'final_params.pickle')
    parameters_path = os.path.join(target_path, 'parameters.pickle')
    
    vtk_path = os.path.join(vtk_output_path, f'B_spinn_{b_bottom_date}.vtk')

    spinn = spinn_cube(final_model_path, parameters_path)
    spinn.calculate_magnetic_fields()
    spinn.grid.save(vtk_path)
    
    print(vtk_path)

runtime = time.time() - start
print(f'Runtime: {runtime:.2f} sec')

/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_083600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_084800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_090000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_091200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_092400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_093600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_094800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_100000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_101200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_102400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_103600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_104800.vtk
/userhome/jeon_m