# SPINN (series)

In [None]:
#| default_exp spinn_jax_trainer

In [None]:
from setproctitle import setproctitle
setproctitle("SPINN (series)")

In [None]:
#| export
import os

In [None]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "2"

In [None]:
#| export
import jax 
import jax.numpy as jnp
import optax
import numpy as np 
import matplotlib.pyplot as plt
import time 
import pickle
from tqdm import trange

In [None]:
#| export
from cmspinn.spinn_jax import SPINN3d, generate_train_data, apply_model_spinn, update_model

In [None]:
#| export
import glob
import time
from pathlib import Path

In [None]:
b_bottom_paths = os.path.expanduser('~/workspace/_data/NOAA12673/b_bottom')
b_bottom_list = sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy')))
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
Path(spinn_output_path).mkdir(parents=True, exist_ok=True)

In [None]:
spinn_BC_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_BC')

In [None]:
import logging

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for hdlr in logger.handlers[:]:  # remove all old handlers
    logger.removeHandler(hdlr)
logger.addHandler(logging.FileHandler("{0}/{1}.log".format('/userhome/jeon_mg/workspace/_data', "info_series_log")))  # set the new file handler
logger.addHandler(logging.StreamHandler())  # set the new console handler

In [None]:
#| export
class SPINN_Trainer:
    def __init__(self, output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None, logger=None):
        os.makedirs(output_path, exist_ok=True)
        
        Nx, Ny, _ = b_bottom.shape

        features = 512
        n_layers = 8
        feat_sizes = tuple([features for _ in range(n_layers)]) 
        r = 512 
        out_dim = 3 

        lr = 5e-4

        pos_enc = 0
        mlp = 'modified_mlp'

        n_max_x = 2*(Nx/Nx)
        n_max_y = 2*(Ny/Nx)
        n_max_z = 2*(Nz/Nx)

        parameters = {'feat_sizes' : feat_sizes, 
              'r' : r, 
              'out_dim' : out_dim, 
              'Nx' : Nx, 
              'Ny' : Ny, 
              'Nz' : Nz, 
              'b_norm' : b_norm,
              'pos_enc' : pos_enc,
              'mlp' : mlp,
              'lr': lr,
              'n_max_x': n_max_x,
              'n_max_y': n_max_y,
              'n_max_z': n_max_z,}

        logger.info(parameters)
        
        parameters_path = os.path.join(output_path, "parameters.pickle")
        with open(parameters_path, "wb") as f:
            pickle.dump(parameters, f)


        seed = 111
        key = jax.random.PRNGKey(seed)
        key, subkey = jax.random.split(key, 2)

        model = SPINN3d(feat_sizes, r, out_dim, pos_enc=pos_enc, mlp=mlp)
        if transfer_learning_path is None:
            params = model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)
        else:
            model.init(
                        subkey,
                        jnp.ones((Nx, 1)),
                        jnp.ones((Ny, 1)),
                        jnp.ones((Nz, 1))
                    )
            apply_fn = jax.jit(model.apply)
            with open(transfer_learning_path, 'rb') as f:
                params = pickle.load(f)

            optim = optax.adam(learning_rate=lr)
            state = optim.init(params)

        with open(BC_path, 'rb') as f:
            boundary_data = pickle.load(f)

        train_data = generate_train_data(subkey, Nx, Ny, Nz, n_max_x, n_max_y, n_max_z)
        train_boundary_data = [train_data, boundary_data]

        self.apply_fn = apply_fn
        self.params = params
        self.train_boundary_data = train_boundary_data
        self.optim = optim
        self.state = state
        self.output_path = output_path

        self.logger = logger

    def train(self, total_iterations, log_iterations, loss_threshold=0.001):
        params = self.params
        state = self.state

        logger = self.logger

        losses = []
        if logger is None:
            print('Complie Start')
        else: 
            logger.info('Complie Start')
        start = time.time()
        loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
        losses.append(loss.item())
        params, state = update_model(self.optim, gradient, params, state)
        runtime = time.time() - start
        if logger is None:
            print(f'Complie End --> total: {runtime:.2f}sec')
        else:
            logger.info(f'Complie End --> total: {runtime:.2f}sec')

        start = time.time()
        for e in trange(1, total_iterations + 1):

            loss, gradient = apply_model_spinn(self.apply_fn, params, self.train_boundary_data)
            losses.append(loss.item())
            if loss.item() < loss_threshold:
                if logger is None:
                    print(f'Epoch: {e}/{total_iterations} --> loss: {loss:.8f} < {loss_threshold}')
                else:
                    logger.info(f'Epoch: {e}/{total_iterations} --> loss: {loss:.8f} < {loss_threshold}')
                break
            
            params, state = update_model(self.optim, gradient, params, state)
            
            if e % log_iterations == 0:
                if logger is None:
                    print(f'Epoch: {e}/{total_iterations} --> total loss: {loss:.8f}')
                else:
                    logger.info(f'Epoch: {e}/{total_iterations} --> total loss: {loss:.8f}')
                params_path = os.path.join(self.output_path, f"params_{e}.pickle")
                with open(params_path, "wb") as f:
                    pickle.dump(params, f)

        final_params_path = os.path.join(self.output_path, f"final_params.pickle")
        with open(final_params_path, "wb") as f:
            pickle.dump(params, f)

        np.save(os.path.join(self.output_path, 'losses.npy'), losses)

        runtime = time.time() - start
        if logger is None:
            print(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(total_iterations-1)*1000):.2f}ms/iter.)')
        else:
            logger.info(f'Runtime --> total: {runtime:.2f}sec ({(runtime/(total_iterations-1)*1000):.2f}ms/iter.)')

In [None]:
Nz = 160
b_norm = 2500

epochs = 50000
log_iter = 1000

series_epochs = 2000
series_log_iter = 100

loss_threshold = 5e-3

In [None]:
transfer_learning_path = None

start_time = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):

    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]
    
    with open(b_bottom_path, 'rb') as f:
        b_bottom = np.load(f)

    output_path = os.path.join(spinn_output_path, b_bottom_date)

    BC_path = os.path.join(spinn_BC_path, f'b_BC_{b_bottom_date}.pickle')

    final_params_path = os.path.join(output_path, 'final_params.pickle')
    if os.path.exists(final_params_path):
        transfer_learning_path = final_params_path
        continue

    if transfer_learning_path is None:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=None, logger=logger)
        start = time.time()
        tranier.train(epochs, log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        logger.info(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    else:
        tranier = SPINN_Trainer(output_path, BC_path, b_bottom, Nz, b_norm, transfer_learning_path=transfer_learning_path, logger=logger)
        start = time.time()
        tranier.train(series_epochs, series_log_iter, loss_threshold=loss_threshold)
        runtime = time.time() - start
        logger.info(f'Runtime: {runtime:.2f} sec')
        transfer_learning_path = os.path.join(output_path, 'final_params.pickle')
    
logger.info(f'Total Runtime: {time.time() - start_time:.2f} sec')

{'feat_sizes': (512, 512, 512, 512, 512, 512, 512, 512), 'r': 512, 'out_dim': 3, 'Nx': 344, 'Ny': 224, 'Nz': 160, 'b_norm': 2500, 'pos_enc': 0, 'mlp': 'modified_mlp', 'lr': 0.0005, 'n_max_x': 2.0, 'n_max_y': 1.302325581395349, 'n_max_z': 0.9302325581395349}
Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
Complie Start
Complie End --> total: 43.42sec
  2%|▏         | 998/50000 [01:15<1:01:33, 13.27it/s]Epoch: 1000/50000 --> total loss: 0.01732743
  4%|▍         | 1998/50000 [02:30<1:00:28, 13.23it/s]Epoch: 2000/50000 --> total loss: 0.01275583
  6%|▌         | 2998/50000 [03:46<59:32, 13.16it/s]  Epoch: 3000/50000 --> total loss: 0.01144158
  8%|▊         | 3998/50000 [05:02<58:10, 13.18it/s]  Epoch: 4000/50000 --> total loss: 0.00828985
 10%|▉         | 4998/50000 [06:18<56:59, 13.16it/s]  Epoch: 5000/50000 --> total loss: 0.00736446
 12%|█▏        

In [None]:
from cmspinn.spinn_jax_viz import spinn_cube

In [None]:
spinn_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN')
vtk_output_path = os.path.join(Path(b_bottom_paths).parent, 'SPINN_vtk')
os.makedirs(vtk_output_path, exist_ok=True)

In [None]:
start = time.time()

for b_bottom_path in sorted(glob.glob(os.path.join(b_bottom_paths, '*.npy'))):
    b_bottom_date = os.path.basename(b_bottom_path)[9:-4]

    target_path = os.path.join(spinn_output_path, b_bottom_date)

    final_model_path = os.path.join(target_path, 'final_params.pickle')
    parameters_path = os.path.join(target_path, 'parameters.pickle')
    
    vtk_path = os.path.join(vtk_output_path, f'B_spinn_{b_bottom_date}.vtk')

    spinn = spinn_cube(final_model_path, parameters_path)
    spinn.calculate_magnetic_fields()
    spinn.grid.save(vtk_path)
    
    print(vtk_path)

runtime = time.time() - start
print(f'Runtime: {runtime:.2f} sec')

/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_083600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_084800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_090000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_091200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_092400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_093600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_094800.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_100000.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_101200.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_102400.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_103600.vtk
/userhome/jeon_mg/workspace/_data/NOAA12673/SPINN_vtk/B_spinn_20170906_104800.vtk
/userhome/jeon_m

In [None]:
logger.info("===============================================================")

