In [None]:
import jax
from flax import nnx
from jax.sharding import Mesh

from deeprte.configs import default
from deeprte.input_pipeline import input_pipeline_interface
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics
from deeprte.train_lib import utils
from deeprte.train_lib.multihost_dataloading import prefetch_to_device

In [None]:
# config = modules.DeepRTEConfig()
rng = nnx.Rngs(42)
config = default.get_config()

device_array = utils.create_device_mesh(config)
global_mesh = Mesh(device_array, config.mesh_axes)
global_mesh

In [None]:
train_iter, eval_iter = input_pipeline_interface.create_data_iterator(
    config, global_mesh
)

if config.prefetch_to_device:
    train_iter = prefetch_to_device(train_iter, 4)

jax.tree.map(lambda x: (x.shape, x.sharding), next(train_iter))

## Test DeepRTE

In [None]:
deeprte = modules.DeepRTE(config, rngs=rng)

nnx.display(deeprte)

### Params

In [None]:
graphdef, params = nnx.split(deeprte)
jax.tree.reduce(lambda x, y: x + y, jax.tree.map(lambda x: x.size, params))

#### Case 1: train mode

In [None]:
@jax.jit
def train_step(params, batch):
    module = nnx.merge(graphdef, params)
    module.set_attributes(low_memory=False)
    return module(batch)

In [None]:
%time train_step(params, next(train_iter))

#### Case 2: evaluation / inference mode

In [None]:
@jax.jit
def eval_step(params, batch):
    module = nnx.merge(graphdef, params)
    module.set_attributes(low_memory=True)
    return module(batch)

In [None]:
%%time

prediction = eval_step(params, batch)

print(prediction.shape)

## Test Green's function

In [None]:
one_point_example = {}

single_example = jax.tree.map(lambda x: x[0], batch)
for k, v in single_example.items():
    if k in ["phase_coords", "scattering_kernel"]:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree.map(lambda x: x.shape, one_point_example)

In [None]:
green_fn = modules.GreenFunction(config=config, rngs=nnx.Rngs(0))

y = jax.vmap(green_fn, in_axes=(None, 0, None), out_axes=-1)(
    one_point_example["phase_coords"],
    one_point_example["boundary_coords"],
    one_point_example,
)

y.shape

## Test Attenuation module

In [None]:
coord1 = jax.random.uniform(rng(), [5, 4])
coord2 = jax.random.uniform(rng(), [5, 4])
att_coeff = jax.random.uniform(rng(), [50, 2])

grid = jax.random.uniform(rng(), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation = modules.Attenuation(config=config, rngs=nnx.Rngs(0))
jax.vmap(attenuation, in_axes=(0, 0, None, None))(coord1, coord2, att_coeff, char).shape