In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import os
import json

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import haiku as hk
import functools
from deeprte.config import get_config
from deeprte.model.modules import DeepRTE
from deeprte.model.tf.input_pipeline import load_tf_data
from deeprte.model.data import flat_params_to_haiku

from deeprte.model.tf.rte_features import BATCH_FEATURE_NAMES

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i:i+1] if k in BATCH_FEATURE_NAMES else feat[k] for k in feat
    }

rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [3]:
source_dir = "/workspaces/deeprte/rte_data/matlab/eval-data/scattering-kernel/0311"
# source_dir = "/workspaces/deeprte/rte_data/matlab/train-scattering-kernel-0309"
# data_name_list = ["test_random_kernel_0311.mat"]
# data_name_list = ["train_random_kernel_1.mat"]
data_name_list_bc1 = ["test_bc1.mat"]
data_name_list_v1 = ["test_deltax_fixsk_v1.mat"]

PARAMS_FILE = "/workspaces/deeprte/ckpts/train_scattering_kernel_2023-03-28T04:14:34/models/latest/step_570000_2023-03-30T05:30:18/params.npz"
# PARAMS_FILE = "/workspaces/deeprte/ckpts/train_scattering_kernel_2023-03-09T07:12:20/models/latest/step_555000_2023-03-11T06:44:08/params.npz"
CONFIG_PATH = "/workspaces/deeprte/ckpts/train_scattering_kernel_2023-03-09T07:12:20/config.json"

In [4]:
idx = 0

In [5]:
from deeprte.data.pipeline import DataPipeline

data_pipeline = DataPipeline(source_dir, data_name_list_bc1)

data = data_pipeline.process(
    pre_shuffle=False,
    pre_shuffle_seed=0,
    is_split_test_samples=False,
    num_test_samples=None,
    normalization=False,
    save_path=None,
)
tf_data = load_tf_data(source_dir, data_name_list_bc1, normalization=False)
features = jax.tree_map(lambda x: jnp.array(x), tf_data)
# jax.tree_util.tree_map(lambda x: x.shape, features)

data_feature = features[0]
batch_bc1 = slice_batch(idx, data_feature)
jax.tree_util.tree_map(lambda x: x.shape, batch_bc1)

{'boundary': (1, 1920),
 'boundary_coords': (1920, 4),
 'boundary_scattering_kernel': (1, 1920, 24),
 'boundary_weights': (1920,),
 'phase_coords': (38400, 4),
 'position_coords': (1600, 2),
 'psi_label': (1, 38400),
 'scattering_kernel': (1, 38400, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1600, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [6]:
data_pipeline = DataPipeline(source_dir, data_name_list_v1)

data = data_pipeline.process(
    pre_shuffle=False,
    pre_shuffle_seed=0,
    is_split_test_samples=False,
    num_test_samples=None,
    normalization=False,
    save_path=None,
)
tf_data = load_tf_data(source_dir, data_name_list_v1, normalization=False)
features = jax.tree_map(lambda x: jnp.array(x), tf_data)
# jax.tree_util.tree_map(lambda x: x.shape, features)

data_feature = features[0]
batch_v1 = slice_batch(idx, data_feature)
jax.tree_util.tree_map(lambda x: x.shape, batch_v1)

{'boundary': (1, 1920),
 'boundary_coords': (1920, 4),
 'boundary_scattering_kernel': (1, 1920, 24),
 'boundary_weights': (1920,),
 'phase_coords': (38400, 4),
 'position_coords': (1600, 2),
 'psi_label': (1, 38400),
 'scattering_kernel': (1, 38400, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1600, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [7]:
for keys in batch_v1.keys():
    print(keys, np.allclose(batch_bc1[keys], batch_v1[keys]))

boundary False
boundary_coords True
boundary_scattering_kernel False
boundary_weights True
phase_coords True
position_coords True
psi_label False
scattering_kernel False
self_scattering_kernel False
sigma True
velocity_coords True
velocity_weights True


In [8]:
batch_bc1["self_scattering_kernel"]

DeviceArray([[[1.0797689 , 1.1066376 , 1.1066376 , 1.0989072 ,
               1.1177467 , 1.0989072 , 1.0245289 , 1.043865  ,
               0.96030015, 1.0490503 , 0.98438615, 0.9231771 ,
               0.97192734, 0.9043545 , 0.9043545 , 0.88027394,
               0.86493325, 0.88027394, 1.0245289 , 0.96030015,
               1.043865  , 0.9231771 , 0.98438615, 1.0490503 ],
              [1.1040435 , 1.2325007 , 1.1556625 , 1.2792213 ,
               1.2562976 , 1.1557316 , 0.9613707 , 0.9884659 ,
               0.8404819 , 1.0193497 , 0.88928044, 0.79427195,
               0.90694857, 0.78341657, 0.8404931 , 0.75075686,
               0.76545   , 0.8361339 , 1.0427945 , 0.98848176,
               1.1556737 , 0.951418  , 1.0891303 , 1.2134529 ],
              [1.1040435 , 1.1556625 , 1.2325007 , 1.1557316 ,
               1.2562976 , 1.2792213 , 1.0427945 , 1.1556737 ,
               0.98848176, 1.2134529 , 1.0891303 , 0.951418  ,
               0.90694857, 0.8404931 , 0.78341657, 0.