In [1]:
import os

import functools
import jax.numpy as jnp

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import haiku as hk
import jax
import numpy as np

from deeprte.data.pipeline import DataPipeline
from deeprte.model.tf.rte_dataset import np_to_tensor_dict, divide_batch_feat
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES, _COLLOCATION_FEATURE_NAMES
from deeprte.config import get_config

In [2]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

data_pipeline = DataPipeline(data_path)
raw_features = data_pipeline.process(
    pre_shuffle=True, is_split_test_samples=True, num_test_samples=2
)

tensor_dict = np_to_tensor_dict(raw_features)
features = jax.tree_map(lambda x: np.asarray(x), tensor_dict)

jax.tree_map(lambda x: x.shape, features)

{'boundary': (8, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (8, 40344),
 'scattering_kernel': (8, 40344, 24),
 'self_scattering_kernel': (8, 24, 24),
 'sigma': (8, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [3]:
batched_feat, unbatched_feat = divide_batch_feat(tensor_dict)
batched_feat = jax.tree_util.tree_map(lambda x: x[0], batched_feat)
tf_dict = {**batched_feat, **unbatched_feat}

In [4]:
def sample_colloctaion_points(indices: int, feat: dict):
    ret = {}
    for k in feat:
        if k in _COLLOCATION_FEATURE_NAMES:
            ret[k] = feat[k][indices]
        else:
            ret[k] = feat[k]
            
    return ret

In [5]:
jnp_dict = sample_colloctaion_points(0,tf_dict)
jnp_dict = jax.tree_map(lambda x: jnp.array(x), jnp_dict)
jax.tree_map(lambda x: x.shape, jnp_dict)

{'boundary': (1968,),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (4,),
 'position_coords': (1681, 2),
 'psi_label': (),
 'scattering_kernel': (24,),
 'self_scattering_kernel': (24, 24),
 'sigma': (1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [6]:
config = get_config()
config = config.experiment_kwargs.config
config.model.model_structure.green_function.scattering_module.res_block_depth = 2

In [7]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [8]:
from deeprte.model.modules_v2 import AttenuationModule
from deeprte.model.geometry.characteristics import Characteristics

In [9]:
def forward_atte(*arg):
    out = AttenuationModule(config.model.green_function.scattering_module.attenuation_module)(*arg)
    return out

forward_fn = hk.transform(forward_atte)

In [10]:
_dummy_inputs = jnp_dict.copy()
# _dummy_inputs["boundary_coords"] = _dummy_inputs["boundary_coords"][0]
charc = Characteristics.from_tensor(_dummy_inputs["position_coords"])
dummy_inputs = (_dummy_inputs["phase_coords"], _dummy_inputs["boundary_coords"][0], _dummy_inputs["sigma"], charc)

In [11]:
from deeprte.model.modules_v2 import ScatteringLayer

In [12]:
A = jax.random.uniform(next(rng),[24,64])
B = jax.random.uniform(next(rng),[24,])
C = jax.random.uniform(next(rng),[24,64])

In [13]:
def forward(*arg):
    out = ScatteringLayer(config.model.model_structure.green_function.scattering_module, config.model.global_config)(*arg)
    return out

forward_fn = hk.transform(forward)

In [14]:
dummy_inputs = (A,B)

In [15]:
init_net = jax.jit(forward_fn.init)
params = init_net(next(rng), *dummy_inputs)
apply = jax.jit(forward_fn.apply)
jax.tree_map(lambda x: x.shape, params)

{'scattering_layer': {'scattering_bias': (64,),
  'scattering_weights': (64, 64)}}

In [16]:
logits = apply(params, next(rng), *dummy_inputs)
logits.shape

(64,)

In [17]:
from deeprte.model.modules_v2 import ScatteringModule

In [18]:
config.model.global_config.deterministic = True

In [19]:
def forward(*arg):
    out = ScatteringModule(config.model.model_structure.green_function.scattering_module, config.model.global_config)(*arg)
    return out

forward_fn = hk.transform(forward)

In [20]:
A = jax.random.uniform(next(rng),[64,])
B = jax.random.uniform(next(rng),[24,64])
C = jax.random.uniform(next(rng),[24,])
D = jax.random.uniform(next(rng),[24,24])

In [21]:
dummy_inputs = (A,B,C,D)

In [22]:
init_net = jax.jit(forward_fn.init)
params = init_net(next(rng), *dummy_inputs)
apply = jax.jit(forward_fn.apply)
jax.tree_map(lambda x: x.shape, params)

{'scattering_module/__layer_stack_no_per_layer/scattering_layer': {'scattering_bias': (2,
   64),
  'scattering_weights': (2, 64, 64)}}

In [23]:
logits = apply(params, next(rng), *dummy_inputs)
jax.tree_util.tree_map(lambda x: x.shape, logits)

((64,), (24, 64))

In [24]:
from deeprte.model.modules_v2 import GreenFunction

In [25]:
def forward(*arg):
    out = GreenFunction(config.model.model_structure.green_function, config.model.global_config)(*arg)
    return out

forward_fn = hk.transform(forward)

In [26]:
_dummy_inputs = jnp_dict.copy()
dummy_inputs = (_dummy_inputs["phase_coords"], _dummy_inputs["boundary_coords"][0], jnp_dict)

In [27]:
init_net = jax.jit(forward_fn.init)
params = init_net(next(rng), *dummy_inputs)
apply = jax.jit(forward_fn.apply)
jax.tree_map(lambda x: x.shape, params)

{'green_function': {'proj_weights': (64,)},
 'green_function/attenuation_module/attention': {'b': (64,),
  'key_w': (2, 64),
  'proj_w': (64,),
  'query_w': (4, 64)},
 'green_function/attenuation_module/attenuation_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'green_function/attenuation_module/attenuation_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/attenuation_module/attenuation_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/attenuation_module/attenuation_mlp/linear_3': {'bias': (64,),
  'weights': (128, 64)},
 'green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer': {'scattering_bias': (2,
   64),
  'scattering_weights': (2, 64, 64)}}

In [28]:
from deeprte.model.modules_v2 import DeepRTE

In [29]:
def forward(*arg):
    out = DeepRTE(config.model.model_structure, config.model.global_config)(compute_metrics = True,compute_loss = True,is_training=True, *arg, )
    return out

forward_fn = hk.transform_with_state(forward)

In [30]:
def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i : i + 2] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }


def sample_colloctaion_points(indices: int, feat: dict):
    ret = {}
    for k in feat:
        if k in _COLLOCATION_FEATURE_NAMES:
            if k == "phase_coords":
                ret[k] = feat[k][indices]
            else:
                ret[k] = feat[k][:, indices]
        else:
            ret[k] = feat[k]
            
    return ret

In [31]:
batch = slice_batch(2, features)
dummy_inputs = sample_colloctaion_points(np.arange(2), batch)
jax.tree_map(lambda x: x.shape, dummy_inputs)

{'boundary': (2, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (2, 4),
 'position_coords': (1681, 2),
 'psi_label': (2, 2),
 'scattering_kernel': (2, 2, 24),
 'self_scattering_kernel': (2, 24, 24),
 'sigma': (2, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [32]:
init_net = jax.jit(functools.partial(forward_fn.init,))

init_net = jax.jit(forward_fn.init)
params, states = init_net(next(rng), dummy_inputs)
apply = jax.jit(forward_fn.apply)
jax.tree_map(lambda x: x.shape, params)

{'deeprte/green_function': {'proj_weights': (64,)},
 'deeprte/green_function/attenuation_module/attention': {'b': (64,),
  'key_w': (2, 64),
  'proj_w': (64,),
  'query_w': (4, 64)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'deeprte/green_function/attenuation_module/attenuation_mlp/linear_3': {'bias': (64,),
  'weights': (128, 64)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer': {'scattering_bias': (2,
   64),
  'scattering_weights': (2, 64, 64)}}

In [33]:
logits = apply(params, states, next(rng), dummy_inputs)
jax.tree_util.tree_map(lambda x: x.shape, logits)

(((),
  {'loss': {'mse': (), 'rmspe': ()},
   'metrics': {'mse': (2,), 'rmspe': (2,)},
   'rte_predictions': (2, 2)}),
 {})