In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'

In [2]:
from deeprte.config import get_config
config = get_config()

In [3]:
source_dir = "/workspaces/deeprte/rte_data/matlab/train-scattering-kernel-0405"
data_name_list = ["train_random_kernel_1.mat"]

In [4]:
from deeprte.data.pipeline import DataPipeline
data_pipeline = DataPipeline(source_dir, data_name_list)
data = data_pipeline.process(pre_shuffle=True, is_split_test_samples=True, num_test_samples= 2)

from deeprte.model.tf.rte_dataset import np_to_tensor_dict

tf_data = np_to_tensor_dict(data)

from deeprte.model.tf.rte_dataset import divide_batch_feat
batched_feat, unbatched_feat = divide_batch_feat(tf_data)

In [5]:
import tensorflow as tf

In [7]:
ds = tf.data.Dataset.from_tensor_slices(batched_feat)
ds = ds.batch(3, drop_remainder=True)
    # construct the inputs structure
g = tf.random.Generator.from_seed(seed=0)

def cat_batch(
    batched_feat,
):
    batched_feat.update(unbatched_feat)
    return batched_feat

ds = ds.map(cat_batch)

ds = ds.batch(2, drop_remainder=True)

In [8]:
ds.element_spec

{'sigma': TensorSpec(shape=(2, 3, 1600, 2), dtype=tf.float32, name=None),
 'psi_label': TensorSpec(shape=(2, 3, 38400), dtype=tf.float32, name=None),
 'scattering_kernel': TensorSpec(shape=(2, 3, 38400, 24), dtype=tf.float32, name=None),
 'boundary_scattering_kernel': TensorSpec(shape=(2, 3, 1920, 24), dtype=tf.float32, name=None),
 'self_scattering_kernel': TensorSpec(shape=(2, 3, 24, 24), dtype=tf.float32, name=None),
 'boundary': TensorSpec(shape=(2, 3, 1920), dtype=tf.float32, name=None),
 'position_coords': TensorSpec(shape=(2, 1600, 2), dtype=tf.float32, name=None),
 'velocity_coords': TensorSpec(shape=(2, 24, 2), dtype=tf.float32, name=None),
 'phase_coords': TensorSpec(shape=(2, 38400, 4), dtype=tf.float32, name=None),
 'boundary_coords': TensorSpec(shape=(2, 1920, 4), dtype=tf.float32, name=None),
 'boundary_weights': TensorSpec(shape=(2, 1920), dtype=tf.float32, name=None),
 'velocity_weights': TensorSpec(shape=(2, 24), dtype=tf.float32, name=None)}

In [9]:
import numpy as np

In [10]:
from deeprte.model.tf.input_pipeline import split_feat,process_features, make_device_batch
import jax
train_feat = split_feat(tf_data, 0.8, True)
# ds, unbatched_feat = load_and_split_data(tf_data, data_config, True)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
from deeprte.model.tf.input_pipeline import load_tf_data, tf_data_to_generator

In [12]:
ds, normalization_dict = load_tf_data(source_dir, data_name_list, pre_shuffle=True,is_split_test_samples=True, num_test_samples = 2)

In [13]:
tf.nest.map_structure(lambda x:x.shape, ds)

{'sigma': TensorShape([498, 1600, 2]),
 'psi_label': TensorShape([498, 38400]),
 'scattering_kernel': TensorShape([498, 38400, 24]),
 'boundary_scattering_kernel': TensorShape([498, 1920, 24]),
 'self_scattering_kernel': TensorShape([498, 24, 24]),
 'boundary': TensorShape([498, 1920]),
 'position_coords': TensorShape([1600, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([38400, 4]),
 'boundary_coords': TensorShape([1920, 4]),
 'boundary_weights': TensorShape([1920]),
 'velocity_weights': TensorShape([24])}

In [14]:
normalization_dict

{'psi_min': 7.994067e-10,
 'psi_range': 0.11769425,
 'boundary_min': 0.0,
 'boundary_range': 0.13641933}

In [15]:
tfds = tf_data_to_generator(ds, True, [2,3],0.8,collocation_size=500,bc_collocation_size=1,)

In [16]:
tf.nest.map_structure(lambda x:x.shape, next(tfds))

{'sigma': (2, 3, 1600, 2),
 'psi_label': (2, 3, 500),
 'scattering_kernel': (2, 3, 500, 24),
 'boundary_scattering_kernel': (2, 3, 1920, 24),
 'self_scattering_kernel': (2, 3, 24, 24),
 'boundary': (2, 3, 1920),
 'position_coords': (2, 1600, 2),
 'velocity_coords': (2, 24, 2),
 'phase_coords': (2, 500, 4),
 'boundary_coords': (2, 1920, 4),
 'boundary_weights': (2, 1920),
 'velocity_weights': (2, 24),
 'sampled_boundary_coords': (2, 1, 4),
 'sampled_boundary': (2, 3, 1),
 'sampled_boundary_scattering_kernel': (2, 3, 1, 24)}

In [17]:
ret1 = next(tfds)

In [18]:
ret1["psi_label"].shape

(2, 3, 500)

In [24]:
np.allclose(ret1["position_coords"][0],ret1["position_coords"][1])

True

In [25]:
ret2 = next(tfds)

In [28]:
np.allclose(ret2["position_coords"][0],ret2["position_coords"][1])

True

In [29]:
np.allclose(ret1["position_coords"][0],ret2["position_coords"][1])

True