In [2]:
import jax
from flax import nnx
from jax.sharding import Mesh

from deeprte.configs import default
from deeprte.input_pipeline import input_pipeline_interface
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics
from deeprte.train_lib import utils
from deeprte.train_lib.multihost_dataloading import prefetch_to_device

In [3]:
# config = modules.DeepRTEConfig()
rng = nnx.Rngs(42)
config = default.get_config()

device_array = utils.create_device_mesh(config)
global_mesh = Mesh(device_array, config.mesh_axes)
global_mesh

Mesh(device_ids=array([[[0],
        [1]]]), axis_names=('data', 'fsdp', 'tensor'))

In [5]:
train_iter, eval_iter = input_pipeline_interface.create_data_iterator(
    config, global_mesh
)

if config.prefetch_to_device:
    train_iter = prefetch_to_device(train_iter, 4)

jax.tree.map(lambda x: (x.shape, x.sharding), next(train_iter))

{'boundary': ((8, 1920),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'boundary_coords': ((8, 1920, 4),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'boundary_scattering_kernel': ((8, 1920, 24),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'boundary_weights': ((8, 1920),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'phase_coords': ((8, 128, 4),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'position_coords': ((8, 1600, 2),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', 'fsdp', 'tensor'),))),
 'psi_label': ((8, 128),
  NamedSharding(mesh=Mesh('data': 1, 'fsdp': 2, 'tensor': 1), spec=PartitionSpec(('data', '

## Test DeepRTE

In [6]:
deeprte = modules.DeepRTE(config, rngs=rng)

nnx.display(deeprte)

### Params

In [7]:
graphdef, params = nnx.split(deeprte)
jax.tree.reduce(lambda x, y: x + y, jax.tree.map(lambda x: x.size, params))

37954

#### Case 1: train mode

In [8]:
@jax.jit
def train_step(params, batch):
    module = nnx.merge(graphdef, params)
    module.set_attributes(low_memory=False)
    return module(batch)

In [19]:
%time train_step(params, next(train_iter))

CPU times: user 2.25 s, sys: 249 ms, total: 2.5 s
Wall time: 184 ms


Array([[0.00025377, 0.00032031, 0.00033375, ..., 0.00025622, 0.00030578,
        0.00028832],
       [0.00024331, 0.00028391, 0.00029136, ..., 0.0002352 , 0.00028226,
        0.00027476],
       [0.00023557, 0.00028888, 0.00029844, ..., 0.00023773, 0.00028315,
        0.00026535],
       ...,
       [0.00022531, 0.00027221, 0.00028031, ..., 0.00022271, 0.00026713,
        0.00025431],
       [0.00025352, 0.00030149, 0.00030755, ..., 0.00025014, 0.0002999 ,
        0.00028768],
       [0.00023094, 0.00027108, 0.00027972, ..., 0.00022046, 0.00026263,
        0.00025774]], dtype=float32)

#### Case 2: evaluation / inference mode

In [None]:
@jax.jit
def eval_step(params, batch):
    module = nnx.merge(graphdef, params)
    module.set_attributes(low_memory=True)
    return module(batch)

In [None]:
%%time

prediction = eval_step(params, batch)

print(prediction.shape)

## Test Green's function

In [None]:
one_point_example = {}

single_example = jax.tree.map(lambda x: x[0], batch)
for k, v in single_example.items():
    if k in ["phase_coords", "scattering_kernel"]:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree.map(lambda x: x.shape, one_point_example)

In [None]:
green_fn = modules.GreenFunction(config=config, rngs=nnx.Rngs(0))

y = jax.vmap(green_fn, in_axes=(None, 0, None), out_axes=-1)(
    one_point_example["phase_coords"],
    one_point_example["boundary_coords"],
    one_point_example,
)

y.shape

## Test Attenuation module

In [None]:
coord1 = jax.random.uniform(rng(), [5, 4])
coord2 = jax.random.uniform(rng(), [5, 4])
att_coeff = jax.random.uniform(rng(), [50, 2])

grid = jax.random.uniform(rng(), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation = modules.Attenuation(config=config, rngs=nnx.Rngs(0))
jax.vmap(attenuation, in_axes=(0, 0, None, None))(coord1, coord2, att_coeff, char).shape