In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

In [2]:
from deeprte.config import make_config
config = make_config()

In [3]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

from deeprte.data.pipeline import DataPipeline
data_pipeline = DataPipeline(data_path)
data = data_pipeline.process(pre_shuffle=True, is_split_test_samples=True, num_test_samples= 2)

from deeprte.model.tf.rte_dataset import np_to_tensor_dict

tf_data = np_to_tensor_dict(data)

from deeprte.model.tf.rte_dataset import divide_batch_feat
batched_data, unbatched_data = divide_batch_feat(tf_data)

from deeprte.model.tf.rte_features import _COLLOCATION_FEATURE_NAMES

In [4]:
import tensorflow as tf
import haiku as hk
_batched_data = tf.nest.map_structure(lambda x: x[0], batched_data)
tf_data = {**_batched_data, **unbatched_data}

In [5]:
tf.nest.map_structure(lambda x: x.shape, tf_data)

{'sigma': TensorShape([1681, 2]),
 'psi_label': TensorShape([40344]),
 'scattering_kernel': TensorShape([40344, 24]),
 'self_scattering_kernel': TensorShape([24, 24]),
 'boundary': TensorShape([1968]),
 'position_coords': TensorShape([1681, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([40344, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

In [6]:
config.model.green_function.scatter_model.transport_model.coefficient_net

attention_net:
  widths:
  - 64
  - 1

In [7]:
import jax
import jax.numpy as jnp
from deeprte.model.tf.feature_transform import select_feat

In [8]:
jnp_data = jax.tree_util.tree_map(lambda x: jnp.array(x), tf_data)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [9]:
config.model.green_function.scatter_model.res_block_depth = 2

In [10]:
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES

In [11]:
from deeprte.model.modules import GreenFunction, TransportModule, ScatterModule

In [12]:
from deeprte.model.modules import RTEModel

In [13]:
def forward_rte(batch):
    out = RTEModel(config.model)(batch)
    return out

In [14]:
feature_name_list = ['phase_coords', "scattering_kernel"]
_dummy_input = jnp_data.copy()
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
for k in _COLLOCATION_FEATURE_NAMES:
    if k in _dummy_input:
        _dummy_input[k] = _dummy_input[k][0]
# _dummy_input['boundary_coords'] = _dummy_input['boundary_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


{'boundary': (1968,), 'boundary_coords': (1968, 4), 'boundary_weights': (1968,), 'phase_coords': (4,), 'position_coords': (1681, 2), 'psi_label': (), 'scattering_kernel': (24,), 'self_scattering_kernel': (24, 24), 'sigma': (1681, 2), 'velocity_coords': (24, 2), 'velocity_weights': (24,)}


In [15]:
f_transformed = hk.transform(forward_rte)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), _dummy_input)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'RTE_model/green_function/mlp/linear': {'bias': (1,), 'weights': (128, 1)},
 'RTE_model/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 

In [16]:
logits = apply(params, next(rng), _dummy_input)
print(logits.shape)

()


In [17]:
dummy_batch = {k: v[0]  if k in _COLLOCATION_FEATURE_NAMES else v for k, v in jnp_data.items()}
dummy_batch = {k: jnp.stack([v,v])  if k in _BATCH_FEATURE_NAMES else v for k, v in dummy_batch.items()}

axis_dict = {k:0 if k in _BATCH_FEATURE_NAMES else None for k in jnp_data}

In [18]:
jax.tree_util.tree_map(lambda x: x.shape, dummy_batch)

{'boundary': (2, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (4,),
 'position_coords': (1681, 2),
 'psi_label': (2,),
 'scattering_kernel': (2, 24),
 'self_scattering_kernel': (2, 24, 24),
 'sigma': (2, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [19]:
# axis_dict['phase_coords'] = {0}
# axis_dict['psi_label'] = {0,1}
# axis_dict['scattering_kernel'] = {0,1}
jax.tree_util.tree_map(lambda x: x, axis_dict)

{'boundary': 0,
 'boundary_coords': None,
 'boundary_weights': None,
 'phase_coords': None,
 'position_coords': None,
 'psi_label': 0,
 'scattering_kernel': 0,
 'self_scattering_kernel': 0,
 'sigma': 0,
 'velocity_coords': None,
 'velocity_weights': None}

In [20]:
from deeprte.model.mapping import vmap
def forward(batch):
    out = vmap(RTEModel(config.model), use_hk=True, in_axes=(axis_dict,))(batch)
    return out

In [21]:
from deeprte.model.mapping import vmap
def forward(batch):
    out = hk.vmap(RTEModel(config.model), in_axes=(axis_dict,),split_rng=(not hk.running_init()))(batch)
    return out

In [22]:
f_transformed = hk.transform(forward)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), dummy_batch)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'RTE_model/green_function/mlp/linear': {'bias': (1,), 'weights': (128, 1)},
 'RTE_model/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 

In [23]:
logits = apply(params, next(rng), dummy_batch)
print(logits.shape)

(2,)


In [24]:
from deeprte.model.modules import DeepRTE

In [26]:
def forward(batch,):
    out = DeepRTE(config.model)(batch, True, True)
    return out

In [27]:
dummy_batch = {k: v[:3]  if k in _COLLOCATION_FEATURE_NAMES else v for k, v in jnp_data.items()}
dummy_batch = {k: jnp.stack([v,v])  if k in _BATCH_FEATURE_NAMES else v for k, v in dummy_batch.items()}

In [28]:
f_transformed = hk.transform(forward)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), dummy_batch)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'RTE_model/RTE_model/green_function/mlp/linear': {'bias': (1,),
  'weights': (128, 1)},
 'RTE_model/RTE_model/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'RTE_model/RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'RTE_model/RTE_model/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'RTE_model/RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'RTE_model/RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/RTE_model/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'RTE_model/RTE_model/green_function/scatter_model/transport

In [31]:
logits = apply(params, next(rng), dummy_batch)

In [32]:
jax.tree_util.tree_map(lambda x: x.shape, logits)

{'loss': {'mse': (), 'rmspe': ()}, 'rte_predictions': (2, 3)}

In [38]:
batch_data = {k: v[jnp.newaxis,...]  if k in _BATCH_FEATURE_NAMES else v for k, v in jnp_data.items()}

In [40]:
jax.tree_util.tree_map(lambda x: x.shape, batch_data)

{'boundary': (1, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (1, 40344),
 'scattering_kernel': (1, 40344, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [41]:
logits = apply(params, next(rng), batch_data)

In [42]:
jax.tree_util.tree_map(lambda x: x.shape, logits)

{'loss': {'mse': (), 'rmspe': ()}, 'rte_predictions': (1, 40344)}