In [1]:
import os

import functools

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import haiku as hk
import jax
import numpy as np

from deeprte.data.pipeline import DataPipeline
from deeprte.model.tf.rte_dataset import np_to_tensor_dict
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES, _COLLOCATION_FEATURE_NAMES
from deeprte.config import get_config
from deeprte.model.modules import DeepRTE

## Load Dataset

In [2]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

data_pipeline = DataPipeline(data_path)
raw_features = data_pipeline.process(
    pre_shuffle=True, is_split_test_samples=True, num_test_samples=2
)

tensor_dict = np_to_tensor_dict(raw_features)
features = jax.tree_map(lambda x: np.asarray(x), tensor_dict)

jax.tree_map(lambda x: x.shape, features)

{'boundary': (8, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (8, 40344),
 'scattering_kernel': (8, 40344, 24),
 'self_scattering_kernel': (8, 24, 24),
 'sigma': (8, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [3]:
from deeprte.model.modules import TransportModule, ScatterModule

In [16]:
from typing import Optional

import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections

from deeprte.data.pipeline import FeatureDict
from deeprte.model.integrate import quad
from deeprte.model.layer_stack import layer_stack
from deeprte.model.mapping import vmap
from deeprte.model.networks import MLP
from deeprte.model.tf.rte_features import (
    _BATCH_FEATURE_NAMES,
    _COLLOCATION_FEATURE_NAMES,
    NUM_DIM,
)
def glorot_uniform():
    return hk.initializers.VarianceScaling(
        scale=1.0, mode="fan_avg", distribution="uniform"
    )
from deeprte.model.tf.rte_features import (
    _BATCH_FEATURE_NAMES,
    _COLLOCATION_FEATURE_NAMES,
    NUM_DIM,
)
class GreenFunction(hk.Module):
    def __init__(
        self,
        config,
        name: Optional[str] = "green_function",
    ):
        super().__init__(name=name)

        self.config = config

    def __call__(
        self,
        coords: jax.Array,
        coords_prime: jax.Array,
        scattering_kernel: jax.Array,
        batch: FeatureDict,
    ) -> jax.Array:

        c = self.config

        x, v = coords[:NUM_DIM], coords[NUM_DIM:]
        x_prime, v_prime = (
            coords_prime[:NUM_DIM],
            coords_prime[NUM_DIM:],
        )
        width = c.scatter_model.transport_model.transport_block_mlp.widths[-1]
        trans_module = TransportModule(c.scatter_model.transport_model)

        green_fn_output = trans_module(
            x,
            v,
            x_prime,
            v_prime,
            batch["position_coords"],
            batch["sigma"],
        )
        out_layer_weights = hk.get_parameter(
            name="out_layer_weights",
            shape=[
                width,
            ],
            init=glorot_uniform(),
        )

        if c.scatter_model.res_block_depth == 0:
            return jnp.einsum("i,i", green_fn_output, out_layer_weights)

        # prepare inputs
        res_weights_vstar = (1 - batch["self_scattering_kernel"]) * batch[
            "velocity_weights"
        ]
        res_weights_v = (1 - scattering_kernel) * batch["velocity_weights"]

        output_v = green_fn_output
        output_vstar = vmap(trans_module, argnums=frozenset([1]), use_hk=True,)(
            x,
            batch["velocity_coords"],
            x_prime,
            v_prime,
            batch["position_coords"],
            batch["sigma"],
        )  # shape: [N_v*, N_latent]

        # stack layer
        if c.scatter_model.res_block_depth > 1:

            def block(x):
                output_v, output_vstar = x
                output_v, output_vstar = ScatterModule(c.scatter_model)(
                    output_v,
                    output_vstar,
                    res_weights_v,
                    res_weights_vstar,
                )
                return output_v, output_vstar

            res_stack = layer_stack(c.scatter_model.res_block_depth - 1)(block)

            output_v, output_vstar = res_stack((output_v, output_vstar))

        weights = hk.get_parameter(
            name="weights", shape=[width, width], init=glorot_uniform()
        )
        bias = hk.get_parameter(
            name="bias",
            shape=[
                width,
            ],
            init=glorot_uniform(),
        )

        res_v = jnp.einsum("j,jk->k", res_weights_v, output_vstar)
        res_v = jax.nn.tanh(jnp.einsum("ik,k->i", weights, res_v) + bias)

        green_fn_output = output_v + res_v

        return jnp.einsum("i,i", green_fn_output, out_layer_weights)


In [17]:
config = get_config()
config = config.experiment_kwargs.config

In [18]:
config.model.green_function.scatter_model.res_block_depth = 2

In [19]:
def forward_fn(*args):
    model = GreenFunction(config.model.green_function)
    return model(*args)

forward = hk.transform(forward_fn)

In [8]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [9]:
def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }


def sample_colloctaion_points(indices: int, feat: dict):
    ret = {}
    for k in feat:
        if k in _COLLOCATION_FEATURE_NAMES:
            if k == "phase_coords":
                ret[k] = feat[k][indices]
            else:
                ret[k] = feat[k][:, indices]
        else:
            ret[k] = feat[k]
            
    return ret

In [10]:
_dummy_input = sample_colloctaion_points(0, features)
_dummy_input = slice_batch(0,_dummy_input)
_dummy_input["boundary_coords"] = _dummy_input["boundary_coords"][0]
dummy_input = (_dummy_input["phase_coords"], _dummy_input["boundary_coords"], _dummy_input["scattering_kernel"], _dummy_input)

In [20]:
params = jax.jit(forward.init)(next(rng),*dummy_input)
apply = jax.jit(forward.apply)

In [21]:
jax.tree_util.tree_map(lambda x:x.shape, params)

{'green_function': {'bias': (64,),
  'out_layer_weights': (64,),
  'weights': (64, 64)},
 'green_function/__layer_stack_no_state/scatter_model': {'bias': (1, 64),
  'weights': (1, 64, 64)},
 'green_function/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'green_function/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'green_function/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'green_function/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/transport_model/transport_block_mlp/linear_3': {'bias': (64,),
  'weights': (128, 64)}}

In [22]:
out = apply(params, next(rng), *dummy_input)

In [23]:
jax.tree_util.tree_map(lambda x:x.shape, out)

()

## Build Model

In [None]:
config = get_config()
config = config.experiment_kwargs.config
config.model.green_function.scatter_model.res_block_depth = 2


def forward_fn(batch, is_training):
    model = DeepRTE(config.model)
    return model(batch, is_training, compute_loss=False, compute_metrics=False)


rte_op = hk.transform(forward_fn)

## Initialize Parameters

In [None]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

init_net = jax.jit(functools.partial(rte_op.init, is_training=True))
params = init_net(next(rng), features)

jax.tree_map(lambda x: x.shape, params)

{'DeepRTE/green_function/mlp/linear': {'bias': (1,), 'weights': (32, 1)},
 'DeepRTE/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   32),
  'weights': (1, 32, 32)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (32,),
  'weights': (128, 32)},
 'DeepRTE/green_functio

## Apply RTE Operator

In [None]:
rte_apply = jax.jit(functools.partial(rte_op.apply, is_training=False))

def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i : i + 1] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }


def sample_colloctaion_points(indices: int, feat: dict):
    ret = {}
    for k in feat:
        if k in _COLLOCATION_FEATURE_NAMES:
            if k == "phase_coords":
                ret[k] = feat[k][indices]
            else:
                ret[k] = feat[k][:, indices]
        else:
            ret[k] = feat[k]
            
    return ret

In [None]:
batch = slice_batch(2, features)
batch = sample_colloctaion_points(np.arange(300), batch)

jax.tree_map(lambda x: x.shape, batch)

{'boundary': (1, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (300, 4),
 'position_coords': (1681, 2),
 'psi_label': (1, 300),
 'scattering_kernel': (1, 300, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [None]:
%time out = rte_apply(params, next(rng), batch)

CPU times: user 1min 18s, sys: 1.95 s, total: 1min 20s
Wall time: 1min 12s
