This notebook allows you to simply run the [JAX](https://jax.readthedocs.org) implementation of DeepRTE.

### Setup

The cell below downloads the code from Github and install necessary dependencies.

In [1]:
# ![ -d deeprte] || git clone --depth=1 https://github.com/mazhengcn/deeprte.git
# !cd deeprte && git pull
# !pip install -qr deeprte/requirements.txt

### Import packages

In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4,6"
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"

In [2]:
%load_ext autoreload
%autoreload 2

import jax
from matplotlib import pyplot as plt
import numpy as np
from jaxline import utils as jl_utils
import haiku as hk
from deeprte.train import Trainer
from deeprte.config import get_config

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
jax.local_devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

### Setup train config

In [4]:
# Path to the dataset
DATA_DIR = "/workspaces/deeprte/rte_data/matlab/train-scattering-kernel-0309"
DATA_NAMES = ["train_random_kernel_1.mat","train_random_kernel_2.mat"]

# define train params
num_epoch = 1
batch_size = 2
base_lr = 0.001
optimizer = "adam"
schedule_type = "exponential"
decay_rate = 0.96
transition_steps = 60000

In [5]:
# setup train config
config = get_config()
config = config.experiment_kwargs.config
config.dataset.source_dir = DATA_DIR
config.dataset.data_name_list = DATA_NAMES
config.dataset.train.batch_size = batch_size
config.training.num_epoch = num_epoch
config.training.optimizer.base_lr = base_lr
config.training.optimizer.schedule_type = schedule_type
config.training.optimizer.optimizer = optimizer
config.training.optimizer.decay_kwargs.decay_rate = decay_rate
config.training.optimizer.decay_kwargs.transition_steps = transition_steps

In [6]:
config.training

num_epoch: 1
num_epochs: 8000
optimizer:
  adam_kwargs: {}
  base_lr: 0.001
  decay_kwargs:
    decay_rate: 0.96
    transition_steps: 60000
  optimizer: adam
  scale_by_batch: false
  schedule_type: exponential

### Build model trainer

In [7]:
init_rng = 0
mode = "train"

e = Trainer(mode=mode, init_rng=init_rng, config=config)

### Train

In [9]:
rng = hk.PRNGSequence(42)
init_rng = jl_utils.bcast_local_devices(next(rng))
e._initialize_training()
# e._build_train_input()

In [None]:
scalars = []
for i in range(num_epoch):
    scalar = e.step(jl_utils.bcast_local_devices(i), jl_utils.bcast_local_devices(next(rng)))
    scalars["global_step"] = i
    scalars.append(scalar)

In [None]:
# total loss
total_loss = [scalar["train_total_loss"] for scalar in scalars]

# plt.plot(total_loss)

In [29]:
# restore params
params = e._params

### Evaluation

In [50]:
# define eval model
eval_model = jax.jit(lambda params, inputs: e.model.apply(params, None, next(rng), inputs, is_training=False, compute_loss=False,compute_metrics=True,))

In [47]:
from test import utils
from deeprte.model.tf.input_pipeline import load_tf_data
import jax.numpy as jnp

In [48]:
tf_data = load_tf_data(DATA_DIR, DATA_NAMES, normalization=False)
features = jax.tree_map(lambda x: jnp.array(x), tf_data)
data_feature = features[0]
batch = utils.slice_batch(0, data_feature)

# visualize shape
jax.tree_util.tree_map(lambda x: x.shape, batch)

{'boundary': (1, 1920),
 'boundary_coords': (1920, 4),
 'boundary_scattering_kernel': (1, 1920, 24),
 'boundary_weights': (1920,),
 'phase_coords': (38400, 4),
 'position_coords': (1600, 2),
 'psi_label': (1, 38400),
 'scattering_kernel': (1, 38400, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1600, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [51]:
ret = eval_model(params, batch)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <lambda> at /tmp/ipykernel_2539128/4288000890.py:2 traced for pmap.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_2539128/4288000890.py:2 (<lambda>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3194 (run_cell_async)
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3373 (run_ast_nodes)
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3433 (run_code)
/tmp/ipykernel_2539128/2715311739.py:1 (<module>)
/tmp/ipykernel_2539128/4288000890.py:2 (<lambda>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [39]:
jax.tree_util.tree_map(lambda x: x.shape, ret)

({'metrics': {'mse': (2, 1), 'rmspe': (2, 1)},
  'predicted_solution': (2, 1, 120)},
 {})

In [43]:
ret[0]["metrics"]["rmspe"]

ShardedDeviceArray([[8.989382 ],
                    [0.9979299]], dtype=float32)