In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import os
import json

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import haiku as hk
import tensorflow as tf
from deeprte.config import get_config
from deeprte.model.modules import DeepRTE
from deeprte.model.tf.input_pipeline import load_tf_data
from deeprte.model.data import flat_params_to_haiku

from deeprte.model.tf.rte_features import BATCH_FEATURE_NAMES,COLLOCATION_FEATURE_NAMES,BOUNDARY_FEATURE_NAMES

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from deeprte.model.tf.rte_dataset import (
    TensorDict,
    divide_batch_feat,
    make_collocation_axis,
    make_boudary_sample_axis,
    np_to_tensor_dict,
)

In [3]:
from deeprte.model.tf.data_transforms import construct_batch

In [4]:
source_dir = "/workspaces/deeprte/rte_data/matlab/eval-data/scattering-kernel/0311/"
data_name_list = ["test_bc1.mat"]
# data_name_list = ["test_bc1_fixedvar_scattering_kernel.mat"]
# data_name_list = ["test_itr.mat"]

PARAMS_FILE = "/workspaces/deeprte/ckpts/train_scattering_kernel_2023-02-16T14:04:06/models/latest/step_1200000_2023-02-19T22:37:17/params.npz"
CONFIG_PATH = "/workspaces/deeprte/ckpts/train_scattering_kernel_2023-02-16T14:04:06/config.json"

In [5]:
from deeprte.data.pipeline import DataPipeline

data_pipeline = DataPipeline(source_dir, data_name_list)

data = data_pipeline.process(
    pre_shuffle=False,
    pre_shuffle_seed=0,
    is_split_test_samples=False,
    num_test_samples=None,
    normalization=False,
    save_path=None,
)

In [6]:
tf_data = load_tf_data(source_dir, data_name_list, normalization=False)
features = jax.tree_map(lambda x: np.array(x), tf_data)
jax.tree_util.tree_map(lambda x: x.shape, features)

({'boundary': (10, 1920),
  'boundary_coords': (1920, 4),
  'boundary_scattering_kernel': (10, 1920, 24),
  'boundary_weights': (1920,),
  'phase_coords': (38400, 4),
  'position_coords': (1600, 2),
  'psi_label': (10, 38400),
  'scattering_kernel': (10, 38400, 24),
  'self_scattering_kernel': (10, 24, 24),
  'sigma': (10, 1600, 2),
  'velocity_coords': (24, 2),
  'velocity_weights': (24,)},
 {})

In [7]:
data_feature = features[0]

In [8]:
sum(data_feature["boundary_weights"][:12])

0.01250000053551048

In [9]:
data_feature["velocity_coords"][:,1]

array([ 0.2554414 ,  0.28708965,  0.69309574,  0.2513426 ,  0.68668073,
        0.9380233 ,  0.2554414 ,  0.69309574,  0.28708965,  0.9380233 ,
        0.68668073,  0.2513426 , -0.2554414 , -0.28708965, -0.69309574,
       -0.2513426 , -0.68668073, -0.9380233 , -0.2554414 , -0.69309574,
       -0.28708965, -0.9380233 , -0.68668073, -0.2513426 ], dtype=float32)

In [10]:
batched_data, unbatched_data = divide_batch_feat(data_feature)

In [11]:
ds = tf.data.Dataset.from_tensor_slices(batched_data)

In [12]:
ds = ds.batch(1, drop_remainder=True)

In [13]:
func = construct_batch(unbatched_feat=unbatched_data,
            collocation_sizes=(12,13),
            collocation_features=(make_collocation_axis(),make_boudary_sample_axis(),),
            total_grid_sizes=(40*40*24,1920) ,
            generator=tf.random.Generator.from_seed(seed=0),
            is_training=True,
            is_replacing=(True, False))

In [14]:
ds = ds.map(func)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [15]:
ds.element_spec

{'boundary': TensorSpec(shape=(1, 1920), dtype=tf.float32, name=None),
 'boundary_scattering_kernel': TensorSpec(shape=(1, 1920, 24), dtype=tf.float32, name=None),
 'psi_label': TensorSpec(shape=(1, 12), dtype=tf.float32, name=None),
 'scattering_kernel': TensorSpec(shape=(1, 12, 24), dtype=tf.float32, name=None),
 'self_scattering_kernel': TensorSpec(shape=(1, 24, 24), dtype=tf.float32, name=None),
 'sigma': TensorSpec(shape=(1, 1600, 2), dtype=tf.float32, name=None),
 'boundary_coords': TensorSpec(shape=(1920, 4), dtype=tf.float32, name=None),
 'boundary_weights': TensorSpec(shape=(1920,), dtype=tf.float32, name=None),
 'phase_coords': TensorSpec(shape=(12, 4), dtype=tf.float32, name=None),
 'position_coords': TensorSpec(shape=(1600, 2), dtype=tf.float32, name=None),
 'velocity_coords': TensorSpec(shape=(24, 2), dtype=tf.float32, name=None),
 'velocity_weights': TensorSpec(shape=(24,), dtype=tf.float32, name=None),
 'sampled_boundary_coords': TensorSpec(shape=(13, 4), dtype=tf.float3

In [16]:
ds = ds.batch(4, drop_remainder=True)

In [17]:
ds.element_spec

{'boundary': TensorSpec(shape=(4, 1, 1920), dtype=tf.float32, name=None),
 'boundary_scattering_kernel': TensorSpec(shape=(4, 1, 1920, 24), dtype=tf.float32, name=None),
 'psi_label': TensorSpec(shape=(4, 1, 12), dtype=tf.float32, name=None),
 'scattering_kernel': TensorSpec(shape=(4, 1, 12, 24), dtype=tf.float32, name=None),
 'self_scattering_kernel': TensorSpec(shape=(4, 1, 24, 24), dtype=tf.float32, name=None),
 'sigma': TensorSpec(shape=(4, 1, 1600, 2), dtype=tf.float32, name=None),
 'boundary_coords': TensorSpec(shape=(4, 1920, 4), dtype=tf.float32, name=None),
 'boundary_weights': TensorSpec(shape=(4, 1920), dtype=tf.float32, name=None),
 'phase_coords': TensorSpec(shape=(4, 12, 4), dtype=tf.float32, name=None),
 'position_coords': TensorSpec(shape=(4, 1600, 2), dtype=tf.float32, name=None),
 'velocity_coords': TensorSpec(shape=(4, 24, 2), dtype=tf.float32, name=None),
 'velocity_weights': TensorSpec(shape=(4, 24), dtype=tf.float32, name=None),
 'sampled_boundary_coords': TensorS

In [18]:
from deeprte.model.tf.input_pipeline import tf_data_to_generator

In [19]:
g = tf_data_to_generator(tf_data[0], is_training=True, batch_sizes=[1,4], collocation_size=12, bc_collocation_size=13)

In [20]:
tf.nest.map_structure(lambda x: x.shape, next(g))

{'sigma': (1, 4, 1600, 2),
 'psi_label': (1, 4, 12),
 'scattering_kernel': (1, 4, 12, 24),
 'boundary_scattering_kernel': (1, 4, 1920, 24),
 'self_scattering_kernel': (1, 4, 24, 24),
 'boundary': (1, 4, 1920),
 'position_coords': (1, 1600, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 12, 4),
 'boundary_coords': (1, 1920, 4),
 'boundary_weights': (1, 1920),
 'velocity_weights': (1, 24),
 'sampled_boundary_coords': (1, 13, 4),
 'sampled_boundary': (1, 4, 13),
 'sampled_boundary_scattering_kernel': (1, 4, 13, 24)}

In [21]:
dummy_inputs = next(g)

In [22]:
import functools
from deeprte.config import get_config

In [23]:
config = get_config()
config = config.experiment_kwargs.config

In [24]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [25]:
def forward_fn(batch, is_training):
    out = DeepRTE(config.model)(batch, is_training=is_training, compute_loss=False, compute_metrics=False)
    return out

forward = hk.transform(forward_fn)
apply = jax.jit(functools.partial(forward.apply, is_training = False))
init = jax.jit(functools.partial(forward.init, is_training = False))

In [26]:
params = init(next(rng), dummy_inputs)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 4, e.g. axis 0 of args[0][0]['psi_label'] of type float32[4,12];
  * one axis had size 1: axis 0 of args[0][0]['phase_coords'] of type float32[1,12,4]