In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

In [2]:
from deeprte.config import make_config
config = make_config()

In [3]:
data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"

from deeprte.data.pipeline import DataPipeline
data_pipeline = DataPipeline(data_path)
data = data_pipeline.process(pre_shuffle=True, is_split_test_samples=True, num_test_samples= 2)

from deeprte.model.tf.rte_dataset import np_to_tensor_dict

tf_data = np_to_tensor_dict(data)

from deeprte.model.tf.rte_dataset import divide_batch_feat
batched_data, unbatched_data = divide_batch_feat(tf_data)

from deeprte.model.tf.rte_features import _COLLOCATION_FEATURE_NAMES

In [4]:
import tensorflow as tf
import haiku as hk
_batched_data = tf.nest.map_structure(lambda x: x[0], batched_data)
tf_data = {**_batched_data, **unbatched_data}

In [5]:
tf.nest.map_structure(lambda x: x.shape, tf_data)

{'sigma': TensorShape([1681, 2]),
 'psi_label': TensorShape([40344]),
 'scattering_kernel': TensorShape([40344, 24]),
 'self_scattering_kernel': TensorShape([24, 24]),
 'boundary': TensorShape([1968]),
 'position_coords': TensorShape([1681, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([40344, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

In [6]:
import jax
import jax.numpy as jnp

In [7]:
jnp_data = jax.tree_util.tree_map(lambda x: jnp.array(x), tf_data)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [8]:
config.model.green_function.scatter_model.res_block_depth = 2

In [9]:
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES

In [10]:
dummy_batch = {k: v[0]  if k in _COLLOCATION_FEATURE_NAMES else v for k, v in jnp_data.items()}
dummy_batch = {k: jnp.stack([v,v])  if k in _BATCH_FEATURE_NAMES else v for k, v in dummy_batch.items()}

axis_dict = {k:0 if k in _BATCH_FEATURE_NAMES else None for k in jnp_data}

In [11]:
jax.tree_util.tree_map(lambda x: x.shape, dummy_batch)

{'boundary': (2, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (4,),
 'position_coords': (1681, 2),
 'psi_label': (2,),
 'scattering_kernel': (2, 24),
 'self_scattering_kernel': (2, 24, 24),
 'sigma': (2, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [12]:
# axis_dict['phase_coords'] = {0}
# axis_dict['psi_label'] = {0,1}
# axis_dict['scattering_kernel'] = {0,1}
jax.tree_util.tree_map(lambda x: x, axis_dict)

{'boundary': 0,
 'boundary_coords': None,
 'boundary_weights': None,
 'phase_coords': None,
 'position_coords': None,
 'psi_label': 0,
 'scattering_kernel': 0,
 'self_scattering_kernel': 0,
 'sigma': 0,
 'velocity_coords': None,
 'velocity_weights': None}

In [13]:
from deeprte.model.modules import DeepRTE

In [14]:
def forward(batch,):
    out = DeepRTE(config.model)(batch, True, True, False)
    return out

In [15]:
dummy_batch = {k: v[:3]  if k in _COLLOCATION_FEATURE_NAMES else v for k, v in jnp_data.items()}
dummy_batch = {k: jnp.stack([v,v])  if k in _BATCH_FEATURE_NAMES else v for k, v in dummy_batch.items()}

In [16]:
f_transformed = hk.transform(forward)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = jax.jit(f_transformed.init)(next(rng), dummy_batch)
apply = jax.jit(f_transformed.apply)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'DeepRTE/green_function/mlp/linear': {'bias': (1,), 'weights': (128, 1)},
 'DeepRTE/green_function/scatter_model/__layer_stack_with_state/mlp/linear': {'bias': (1,
   128),
  'weights': (1, 128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'DeepRTE/green_function/scatter_model/transport_model/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_function/scatter_model/transport_model/transport_block_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 'DeepRTE/green_f

In [17]:
logits = apply(params, next(rng), dummy_batch)

In [18]:
jax.tree_util.tree_map(lambda x: x.shape, logits)

{'loss': {'mse': (), 'rmspe': ()}, 'rte_predictions': (2, 3)}

In [19]:
batch_data = {k: v[jnp.newaxis,...]  if k in _BATCH_FEATURE_NAMES else v for k, v in jnp_data.items()}

In [20]:
jax.tree_util.tree_map(lambda x: x.shape, batch_data)

{'boundary': (1, 1968),
 'boundary_coords': (1968, 4),
 'boundary_weights': (1968,),
 'phase_coords': (40344, 4),
 'position_coords': (1681, 2),
 'psi_label': (1, 40344),
 'scattering_kernel': (1, 40344, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1681, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [21]:
logits = apply(params, next(rng), batch_data)
# OOM

2022-12-04 13:38:43.418063: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 16.18GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 17378252800 bytes.

In [22]:
jax.tree_util.tree_map(lambda x: x.shape, logits)

{'loss': {'mse': (), 'rmspe': ()}, 'rte_predictions': (2, 3)}

In [23]:
def forward(batch,):
    out = DeepRTE(config.model)(batch, False, True, False)
    return out

In [24]:
f_transformed = hk.transform(forward)
apply = jax.jit(f_transformed.apply)
logits = apply(params, next(rng), dummy_batch)

In [25]:
logits

{'loss': {'mse': DeviceArray(0.0425264, dtype=float32),
  'rmspe': DeviceArray(294.1997, dtype=float32)},
 'rte_predictions': DeviceArray([[0.2120242 , 0.19744486, 0.21019216],
              [0.2120242 , 0.19744486, 0.21019216]], dtype=float32)}

In [26]:
logits = apply(params, next(rng), batch_data)

In [27]:
logits

{'loss': {'mse': DeviceArray(0.03196618, dtype=float32),
  'rmspe': DeviceArray(50.43785, dtype=float32)},
 'rte_predictions': DeviceArray([[0.21202493, 0.19744399, 0.21019182, ..., 0.1465061 ,
               0.13761234, 0.13230112]], dtype=float32)}