In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

In [2]:
from deeprte.config import make_config
config = make_config()

In [3]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

from deeprte.data.pipeline import DataPipeline
data_pipeline = DataPipeline(data_path)
data = data_pipeline.process(pre_shuffle=True, is_split_test_samples=True, num_test_samples= 2)

from deeprte.model.tf.rte_dataset import np_to_tensor_dict

tf_data = np_to_tensor_dict(data)

from deeprte.model.tf.rte_dataset import divide_batch_feat
batched_data, unbatched_data = divide_batch_feat(tf_data)

from deeprte.model.tf.rte_features import _COLLOCATION_FEATURE_NAMES

In [4]:
import tensorflow as tf
import haiku as hk
batched_data = tf.nest.map_structure(lambda x: x[0], batched_data)
tf_data = {**batched_data, **unbatched_data}

In [5]:
tf.nest.map_structure(lambda x: x.shape, tf_data)

{'sigma': TensorShape([1681, 2]),
 'psi_label': TensorShape([40344]),
 'scattering_kernel': TensorShape([40344, 24]),
 'boundary': TensorShape([1968]),
 'position_coords': TensorShape([1681, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([40344, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

In [6]:
from deeprte.model.modules import CoefficientNet

In [7]:
config.model.green_function.coefficient_net

attention_net:
  widths:
  - 64
  - 1
pointwise_mlp:
  widths:
  - 64
  - 2

In [8]:
import jax
import jax.numpy as jnp
from deeprte.model.tf.feature_transform import select_feat

In [9]:
jnp_data = jax.tree_util.tree_map(lambda x: jnp.array(x), tf_data)

In [10]:
def forward_coeff(position, velocity, coeff_position, coeff_values,):
        out = CoefficientNet(config.model.green_function.coefficient_net)(position, velocity, coeff_position, coeff_values,)

        return out

In [11]:
feature_name_list = ['phase_coords','position_coords', 'sigma']
_dummy_input = select_feat(feature_name_list)(jnp_data)
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
_dummy_input['phase_coords'] = _dummy_input['phase_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


dummy_input = (_dummy_input['phase_coords'][:2], _dummy_input['phase_coords'][2:], _dummy_input['position_coords'], _dummy_input['sigma'])
# dummy_input = jax.tree_util.tree_map(lambda x: jnp.array(x), dummy_input)

{'phase_coords': (4,), 'position_coords': (1681, 2), 'sigma': (1681, 2)}


In [12]:
f_transformed = hk.transform(forward_coeff)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

params = f_transformed.init(next(rng), *dummy_input)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'coefficient_net/attention_net/linear': {'bias': (64,), 'weights': (6, 64)},
 'coefficient_net/attention_net/linear_1': {'bias': (1,), 'weights': (64, 1)}}

In [13]:
logits = f_transformed.apply(params, next(rng), *dummy_input)
print(logits)

[5.491967e-05 7.553206e-03]


In [14]:
from deeprte.model.modules import GreenFunctionBlock

In [15]:
config.model.green_function

coefficient_net:
  attention_net:
    widths:
    - 64
    - 1
  pointwise_mlp:
    widths:
    - 64
    - 2
green_function_mlp:
  widths:
  - 128
  - 128
  - 128
  - 128
green_res_block:
  depth: 2

In [16]:
def forward_green(position, velocity, position_prime, velocity_prime, coeff_position, coeff_values):
        out = GreenFunctionBlock(config.model.green_function)(position, velocity, position_prime, velocity_prime,coeff_position, coeff_values)

        return out

In [18]:
feature_name_list = ['phase_coords','position_coords', 'sigma', 'boundary_coords']
_dummy_input = select_feat(feature_name_list)(jnp_data)
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
_dummy_input['phase_coords'] = _dummy_input['phase_coords'][0]
_dummy_input['boundary_coords'] = _dummy_input['boundary_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


dummy_input = (_dummy_input['phase_coords'][:2], _dummy_input['phase_coords'][2:], _dummy_input['boundary_coords'][:2], _dummy_input['boundary_coords'][2:], _dummy_input['position_coords'], _dummy_input['sigma'],)

{'boundary_coords': (4,), 'phase_coords': (4,), 'position_coords': (1681, 2), 'sigma': (1681, 2)}


In [19]:
f_transformed = hk.transform(forward_green)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

params = f_transformed.init(next(rng), *dummy_input)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'green_function/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'green_function/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'green_function/green_function_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'green_function/green_function_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/green_function_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/green_function_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)}}

In [20]:
logits = f_transformed.apply(params, next(rng), *dummy_input)
print(logits.shape)

(128,)


In [21]:
from deeprte.model.modules import GreenFunctionNet

In [224]:
import ml_collections
from typing import Optional
from deeprte.model.tf.rte_dataset import TensorDict
from deeprte.model.tf.rte_features import NUM_DIM
from deeprte.model.integrate import quad
from deeprte.model.mapping import vmap
from deeprte.model.networks import MLP, Linear
from deeprte.model.layer_stack import layer_stack
class GreenFunctionNet(hk.Module):
    """Green function net."""

    def __init__(
        self,
        config: ml_collections.ConfigDict,
        name: Optional[str] = "green_function_res_block",
    ):
        super().__init__(name=name)

        self.config = config

    def __call__(
        self,
        phase_coords: jax.Array,
        boundary_coords: jax.Array,
        # scattering_kernel_coeff: jnp.ndarray,  # [Nv*,]
        coeff_position: jax.Array,
        coeff_values: jax.Array,
        scattering_kernel_coords: jax.Array,  # ((u,u*):[Nv*,4], (1-P(u,u*))*omega:[Nv*,])
        scattering_kernel: jax.Array,
        batch: TensorDict,
    ) -> jax.Array:

        c = self.config

        green_func_module = GreenFunctionBlock(c)

        inputs = (
            phase_coords[:NUM_DIM],
            phase_coords[NUM_DIM:],
            boundary_coords[:NUM_DIM],
            boundary_coords[NUM_DIM:],
            coeff_position,
            coeff_values,
        )

        green_fn_block_output = green_func_module(*inputs)

        if c.green_res_block.depth > 0:

            coords_star, weights = (
                scattering_kernel_coords,
                (1 - scattering_kernel) * batch["velocity_weights"],
            )
            res_block_inputs = (*inputs[:3], coeff_position, coeff_values)

            def _green_res_fn(green_fn):
                def func(*inputs):
                    green_fn_kernel_quad = quad(
                        green_fn,
                        (coords_star, weights),
                        argnum=3,
                        use_hk=True,
                    )(*inputs[:3], coeff_position, coeff_values)  # shape: [N_latent]
                    green_fn_res = MLP(
                        c.green_function_mlp.widths[-1:],
                        activate_final=True,
                    )(green_fn_kernel_quad)
                    return green_fn_res

                return func

            def _green_res_block(block_output, green_func_module):
                func = _green_res_fn(green_func_module)
                block_output += func(
                    *res_block_inputs,
                )
                green_func_module = func

                return block_output, green_func_module
                
            block_output = green_fn_block_output
            # for _ in range(c.green_res_block.depth):
                
            #     block_output, green_func_module = _green_res_block(block_output, green_func_module)
            block_output, green_func_module = layer_stack(3, with_state=False)(_green_res_block)(block_output, green_func_module)

        green_fn_block_output = MLP([1])(
            green_fn_block_output,
        )
        return green_fn_block_output  # shape: [1]


In [225]:
def forward_green_net(phase_coords, boundary_coords, coeff_coords, coeff_values, scattering_kernel_coords, scattering_kernel, batch):
        out = GreenFunctionNet(config.model.green_function)(phase_coords, boundary_coords, coeff_coords, coeff_values, scattering_kernel_coords, scattering_kernel, batch)

        return out

In [226]:
feature_name_list = ['phase_coords','position_coords', 'sigma', 'boundary_coords', "velocity_coords", "scattering_kernel"]
_dummy_input = select_feat(feature_name_list)(jnp_data)
# dummy_input = jax.tree_util.tree_map(lambda x: x[0], dummy_input)
for k in _COLLOCATION_FEATURE_NAMES:
    if k in _dummy_input:
        _dummy_input[k] = _dummy_input[k][0]
_dummy_input['boundary_coords'] = _dummy_input['boundary_coords'][0]
print(jax.tree_util.tree_map(lambda x: x.shape, _dummy_input))


dummy_input = (_dummy_input['phase_coords'], _dummy_input['boundary_coords'], _dummy_input['position_coords'], _dummy_input['sigma'],_dummy_input['velocity_coords'], _dummy_input['scattering_kernel'], jnp_data)

{'boundary_coords': (4,), 'phase_coords': (4,), 'position_coords': (1681, 2), 'scattering_kernel': (24,), 'sigma': (1681, 2), 'velocity_coords': (24, 2)}


In [227]:
config.model.green_function.green_res_block.depth = 1

In [228]:
f_transformed = hk.transform(forward_green_net)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

params = f_transformed.init(next(rng), *dummy_input)
jax.tree_util.tree_map(lambda x: x.shape, params)

ValueError: 'green_function_res_block/green_function/coefficient_net/attention_net/linear/weights' with retrieved shape (6,) does not match shape=(6, 64) dtype=dtype('float32')

In [223]:
logits = f_transformed.apply(params, next(rng), *dummy_input)
print(logits)

-0.7100657


In [97]:
def func(x):
    return MLP(config.model.green_function.green_function_mlp.widths[-1:])(x)

In [98]:
def forward(x):
    x = layer_stack(3, with_state=False)(func)(x)
    return x

In [104]:
x = jax.random.uniform(next(rng), (128,))

In [106]:
f_transformed = hk.transform(forward)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

params = f_transformed.init(next(rng), x)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'__layer_stack_no_state/mlp/linear': {'bias': (3, 128),
  'weights': (3, 128, 128)}}

In [108]:
logits = f_transformed.apply(params, next(rng), x)
print(logits.shape)

(128,)
