In [1]:
import os

import functools

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import haiku as hk
import jax
import numpy as np

from deeprte.data.pipeline import DataPipeline
from deeprte.model.tf.rte_dataset import np_to_tensor_dict
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES, _COLLOCATION_FEATURE_NAMES
from deeprte.config import get_config
from deeprte.model.modules_v2 import DeepRTE

## Load Dataset

In [2]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

data_pipeline = DataPipeline(data_path)
raw_features = data_pipeline.process(
    pre_shuffle=True, is_split_test_samples=True, num_test_samples=2
)

tensor_dict = np_to_tensor_dict(raw_features)
features = jax.tree_map(lambda x: np.asarray(x), tensor_dict)

jax.tree_map(lambda x: x.shape, features)

{'boundary': (8, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (8, 40344),
 'scattering_kernel': (8, 40344, 24),
 'self_scattering_kernel': (8, 24, 24),
 'sigma': (8, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

## Build Model

In [3]:
config = get_config()
config = config.experiment_kwargs.config
config.model.green_function.scattering_module.res_block_depth = 2


def forward_fn(batch, is_training):
    model = DeepRTE(config.model, config.model)
    return model(batch, is_training, compute_loss=True, compute_metrics=False)


rte_op = hk.transform_with_state(forward_fn)

## Initialize Parameters

In [4]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

init_net = jax.jit(functools.partial(rte_op.init, is_training=True))
params, state = init_net(next(rng), features)

jax.tree_map(lambda x: x.shape, params)

{'deeprte/green_function': {'proj_weights': (32,)},
 'deeprte/green_function/attenuation_module/attention': {'b': (64,),
  'key_w': (2, 64),
  'proj_w': (64,),
  'query_w': (4, 64)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_3': {'bias': (32,),
  'weights': (128, 32)},
 'deeprte/green_function/scattering_module/__layer_stack_no_state/scattering_layer': {'scattering_bias': (2,
   32),
  'scattering_weights': (2, 32, 32)}}

## Apply RTE Operator

In [8]:
rte_apply = jax.jit(functools.partial(rte_op.apply, is_training=True))

def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i : i + 1] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }


def sample_colloctaion_points(indices: int, feat: dict):
    ret = {}
    for k in feat:
        if k in _COLLOCATION_FEATURE_NAMES:
            if k == "phase_coords":
                ret[k] = feat[k][indices]
            else:
                ret[k] = feat[k][:, indices]
        else:
            ret[k] = feat[k]
            
    return ret

In [9]:
batch = slice_batch(2, features)
batch = sample_colloctaion_points(np.arange(2), batch)

jax.tree_map(lambda x: x.shape, batch)

{'boundary': (1, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (2, 4),
 'position_coords': (1681, 2),
 'psi_label': (1, 2),
 'scattering_kernel': (1, 2, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [10]:
out = rte_apply(params, state, next(rng), batch)

In [11]:
out

((DeviceArray(0.01482137, dtype=float32),
  {'loss': {'mse': DeviceArray(0.01482137, dtype=float32),
    'rmspe': DeviceArray(3161726.8, dtype=float32)},
   'rte_predictions': DeviceArray([[-0.12058373, -0.12289134]], dtype=float32)}),
 {})