In [None]:
# !pip install -q git+https://github.com/RobertJaro/NF2.git
# !wget https://hinode.isee.nagoya-u.ac.jp/nlfff_database/v12/11158/20110213/11158_20110213_120000.nc

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LambdaCallback

from nf2.train.module import NF2Module, save
from nf2.evaluation.unpack import load_cube

from utils.data_load import load_nc
from utils.data_loader import ArrayDataModule

In [None]:
base_path = 'pinn'
os.makedirs(base_path, exist_ok=True)
save_path = os.path.join(base_path, 'extrapolation_result.nf2')

data_args = {
                "bin": 1,
                "height_mapping": {"z":  [0.0]},
                "Mm_per_pixel": 1,
                "boundary": {"type":  "open"},
                "height": 257,
                "b_norm": 2500,
                "spatial_norm": 320,
                "batch_size": {"boundary":  1e4, "random":  2e4},
                "iterations": 100000,
                "work_directory": "tmp/isee_11158",
                "num_workers": 8
             }

model_args = {
                "dim": 256, 
                "use_vector_potential": False
             }

training_args = {
                    "max_epochs": 1,
                    "lambda_b": {"start": 1e3, "end": 1, "iterations" : 5e4},
                    "lambda_div": 1e-1,
                    "lambda_ff": 1e-1,
                    "lambda_height_reg": 1e-3,
                    "validation_interval": 10000,
                    "lr_params": {"start": 5e-4, "end": 5e-5, "decay_iterations": 1e5}
                }

config = {'data': data_args, 'model': model_args, 'training': training_args}

In [None]:
b_true = load_nc('11158_20110213_120000.nc')

b_bottom = b_true[:, :, 0, :][:, :, None, :]
b_true.shape, b_bottom.shape

((513, 257, 257, 3), (513, 257, 1, 3))

In [None]:
wandb_logger = WandbLogger(project="nf2", name="11158_20110213_120000", offline=False, 
                           entity="mgjeon", id=None, dir=base_path, log_model="all")

In [None]:
data_module = ArrayDataModule(b_bottom, **data_args)

In [None]:
validation_settings = {'cube_shape': data_module.cube_dataset.coords_shape,
                       'gauss_per_dB': data_args['b_norm'],
                       'Mm_per_ds': data_module.Mm_per_pixel * data_args['spatial_norm']}

nf2 = NF2Module(validation_settings, **model_args, **training_args)

In [None]:
save_callback = LambdaCallback(on_validation_end=lambda *args: 
                              save(save_path, nf2.model, data_module, config, nf2.height_mapping_model))
checkpoint_callback = ModelCheckpoint(dirpath=base_path,
                                      every_n_train_steps=training_args['validation_interval'],
                                      save_last=True)

In [None]:
n_gpus = torch.cuda.device_count()
trainer = Trainer(max_epochs=1,
                  logger=wandb_logger,
                  devices=n_gpus if n_gpus >= 1 else None,
                  accelerator='gpu' if n_gpus >= 1 else None,
                  strategy='dp' if n_gpus > 1 else None,
                  num_sanity_val_steps=0,
                  val_check_interval=training_args['validation_interval'],
                  gradient_clip_val=0.1,
                  callbacks=[checkpoint_callback, save_callback], )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Runtime --> total: 8913.14sec
start = time.time()

trainer.fit(nf2, data_module, ckpt_path='last')

runtime = time.time() - start
print(f'Runtime --> total: {runtime:.2f}sec')

In [None]:
save(save_path, nf2.model, data_module, config, nf2.height_mapping_model)