In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

In [2]:
from deeprte.config import make_config
config = make_config()

In [3]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

from deeprte.data.pipeline import DataPipeline
data_pipeline = DataPipeline(data_path)
data = data_pipeline.process(pre_shuffle=True, is_split_test_samples=True, num_test_samples= 2)

from deeprte.model.tf.rte_dataset import np_to_tensor_dict

tf_data = np_to_tensor_dict(data)

from deeprte.model.tf.rte_dataset import divide_batch_feat
batched_data, unbatched_data = divide_batch_feat(tf_data)

from deeprte.model.tf.rte_features import _COLLOCATION_FEATURE_NAMES

In [4]:
import tensorflow as tf
import haiku as hk
_batched_data = tf.nest.map_structure(lambda x: x[0], batched_data)
tf_data = {**_batched_data, **unbatched_data}

In [5]:
tf.nest.map_structure(lambda x: x.shape, tf_data)

{'sigma': TensorShape([1681, 2]),
 'psi_label': TensorShape([40344]),
 'scattering_kernel': TensorShape([40344, 24]),
 'self_scattering_kernel': TensorShape([24, 24]),
 'boundary': TensorShape([1968]),
 'position_coords': TensorShape([1681, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([40344, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

In [6]:
from deeprte.model.modules import CoefficientNet

In [7]:
config.model.green_function.scatter_model.transport_model.coefficient_net

attention_net:
  widths:
  - 64
  - 1

In [8]:
import jax
import jax.numpy as jnp
from deeprte.model.tf.feature_transform import select_feat

In [9]:
jnp_data = jax.tree_util.tree_map(lambda x: jnp.array(x), tf_data)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [20]:
from deeprte.model.modules import ScatterModel

In [21]:
def forward_scatter(x, x_prime, v_prime, batch):
        out = ScatterModel(config.model.green_function.scatter_model)(x, x_prime, v_prime, batch)
        return out

In [22]:
feature_name_list = ['phase_coords','boundary_coords']
_dummy_input = select_feat(feature_name_list)(jnp_data)
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
for k in _COLLOCATION_FEATURE_NAMES:
    if k in _dummy_input:
        _dummy_input[k] = _dummy_input[k][0]
_dummy_input['boundary_coords'] = _dummy_input['boundary_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


dummy_input = (_dummy_input['phase_coords'][:2], _dummy_input['boundary_coords'][:2], _dummy_input['boundary_coords'][2:], jnp_data)

{'boundary_coords': (4,), 'phase_coords': (4,)}


In [23]:
config.model.green_function.scatter_model.res_block_depth = 2

In [24]:
f_transformed = hk.transform(forward_scatter)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

params = jax.jit(f_transformed.init)(next(rng), *dummy_input)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'scatter_model/__layer_stack_no_state/mlp/linear': {'bias': (1, 128),
  'weights': (1, 128, 128)},
 'scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)}}

In [25]:
%timeit logits = apply(params, next(rng), *dummy_input).block_until_ready()
print(logits.shape)

518 µs ± 83.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
(128,)


In [26]:
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES

In [27]:
axis_dict = {k:0 if k in _BATCH_FEATURE_NAMES else None for k in jnp_data}

In [28]:
axis_dict

{'boundary': 0,
 'boundary_coords': None,
 'boundary_weights': None,
 'phase_coords': None,
 'position_coords': None,
 'psi_label': 0,
 'scattering_kernel': 0,
 'self_scattering_kernel': 0,
 'sigma': 0,
 'velocity_coords': None,
 'velocity_weights': None}

In [29]:
# _batch_dict = {**batched_data, **unbatched_data}
# _batch_dict = jax.tree_util.tree_map(lambda x: jnp.array(x), _batch_dict)
_batch_dict = {k: jnp.stack([v,v])  if k in _BATCH_FEATURE_NAMES else v for k, v in jnp_data.items()}

In [30]:
jax.tree_util.tree_map(lambda x: x.shape, _batch_dict)

{'boundary': (2, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (2, 40344),
 'scattering_kernel': (2, 40344, 24),
 'self_scattering_kernel': (2, 24, 24),
 'sigma': (2, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [31]:
from deeprte.model.mapping import vmap
def forward_scatter(x, x_prime, v_prime, batch):
        out = hk.vmap(ScatterModel(config.model.green_function.scatter_model), in_axes=(None, None,None, axis_dict),split_rng=(not hk.running_init()))(x, x_prime, v_prime, batch)
        return out

In [32]:
dummy_input = (_dummy_input['phase_coords'][:2], _dummy_input['boundary_coords'][:2], _dummy_input['boundary_coords'][2:], _batch_dict)

In [33]:
f_transformed = hk.transform(forward_scatter)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), *dummy_input)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'scatter_model/__layer_stack_no_state/mlp/linear': {'bias': (1, 128),
  'weights': (1, 128, 128)},
 'scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)}}

In [34]:
logits = apply(params, next(rng), *dummy_input)
print(logits.shape)

(2, 24, 128)


In [35]:
from deeprte.model.modules import GreenFunction, TransportModel, ScatterModel

In [36]:
def forward_green(x, v, x_prime, v_prime, scattering_kernel, batch):
    out = GreenFunction(config.model.green_function)(x, v, x_prime, v_prime, scattering_kernel, batch)
    return out

In [37]:
feature_name_list = ['phase_coords','boundary_coords', "scattering_kernel"]
_dummy_input = select_feat(feature_name_list)(jnp_data)
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
for k in _COLLOCATION_FEATURE_NAMES:
    if k in _dummy_input:
        _dummy_input[k] = _dummy_input[k][0]
_dummy_input['boundary_coords'] = _dummy_input['boundary_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


dummy_input = (_dummy_input['phase_coords'][:2],_dummy_input['phase_coords'][2:], _dummy_input['boundary_coords'][:2], _dummy_input['boundary_coords'][2:], _dummy_input['scattering_kernel'], jnp_data)

{'boundary_coords': (4,), 'phase_coords': (4,), 'scattering_kernel': (24,)}


In [38]:
f_transformed = hk.transform(forward_green)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), *dummy_input)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'green_function/mlp/linear': {'bias': (1,), 'weights': (128, 1)},
 'green_function/scatter_model/__layer_stack_no_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/transport_model/coefficient_net/attention_net/linear': {'bias': (6

In [39]:
logits = apply(params, next(rng), *dummy_input)
print(logits.shape)

()
