In [22]:
import os

import functools

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import haiku as hk
import jax
import numpy as np

from deeprte.data.pipeline import DataPipeline
from deeprte.model.tf.rte_dataset import np_to_tensor_dict, divide_batch_feat
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES
from deeprte.config import make_config
from deeprte.model.modules import DeepRTE

## Load Dataset

In [13]:
data_path = "/workspaces/deeprte/rte_data/matlab/eval-data/test_shape.mat"

data_pipeline = DataPipeline(data_path)
raw_features = data_pipeline.process(
    pre_shuffle=True, is_split_test_samples=True, num_test_samples=2
)

tensor_dict = np_to_tensor_dict(raw_features)
features = jax.tree_map(lambda x: np.asarray(x), tensor_dict)
batched_features, unbatched_features = divide_batch_feat(features)

jax.tree_map(lambda x: x.shape, features)

{'boundary': (8, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (8, 40344),
 'scattering_kernel': (8, 40344, 24),
 'self_scattering_kernel': (8, 24, 24),
 'sigma': (8, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

## Build Model

In [36]:
config = make_config()
config.model.green_function.scatter_model.res_block_depth = 2


def forward_fn(batch, is_training):
    model = DeepRTE(config.model)
    return model(batch, is_training, compute_loss=True, compute_metrics=True)


rte_op = hk.transform(forward_fn)

## Initialize Parameters

In [39]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

init_net = jax.jit(functools.partial(rte_op.init, is_training=True))
params = init_net(next(rng), features)

jax.tree_map(lambda x: x.shape, params)

{'DeepRTE/green_function/mlp/linear': {'bias': (1,), 'weights': (128, 1)},
 'DeepRTE/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_f

## Apply RTE Operator

In [49]:
rte_apply = jax.jit(functools.partial(rte_op.apply, is_training=False))


def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i : i + 1] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }

In [54]:
example = slice_batch(1, features)

%time out = rte_apply(params, next(rng), example)

CPU times: user 26.7 s, sys: 25.3 s, total: 52 s
Wall time: 52 s
