In [1]:
import scipy.io
import numpy as np

import sys

sys.path.append("/root/projects/deeprte")

from deeprte.data import pipeline

In [2]:
DATA_DIR = "/root/projects/deeprte/data/test/bc1-g0.1"
DATA_NAME = ["bc1-g0.1.mat"]
OUT_NAME = "bc1-g0.1_meshdownsampling.mat"

In [3]:
data = scipy.io.loadmat(DATA_DIR + "/" + DATA_NAME[0])

In [4]:
data.keys()

dict_keys(['__header__', '__version__', '__globals__', 'config', 'ct', 'omega_prime', 'phi', 'psi_bc', 'psi_label', 'rand_params', 'rv_prime', 'scattering_kernel', 'sigma_a', 'sigma_t', 'st', 'w_angle', 'x', 'y', '__function_workspace__'])

In [5]:
data["rv_prime"].shape

(164, 12, 4)

In [6]:
n = 41

data["phi"] = data["phi"][:, ::2, ::2]
data["psi_label"] = data["psi_label"][:, ::2, ::2]
data["x"] = np.squeeze(data["x"])[::2]
data["y"] = np.squeeze(data["y"])[::2]
data["sigma_a"] = data["sigma_a"][:, ::2, ::2]
data["sigma_t"] = data["sigma_t"][:, ::2, ::2]

list_psi_bc = [data["psi_bc"][:, n * i : n * (i + 1), :][:, ::2] for i in range(4)]
data["psi_bc"] = np.concatenate(list_psi_bc, axis=-2)
# data["psi_bc"] = data["psi_bc"][:, ::2]

list_omega_prime = [data["omega_prime"][n * i : n * (i + 1),:][::2] for i in range(4)]
data["omega_prime"] = np.concatenate(list_omega_prime, axis=-1)

list_rv_prime = [data["rv_prime"][n * i : n * (i + 1),:,:][::2] for i in range(4)]
data["rv_prime"] = np.concatenate(list_rv_prime, axis=0)

In [7]:
scipy.io.savemat(DATA_DIR + "/" + OUT_NAME, data)

In [8]:
data["psi_bc"].shape

(100, 84, 12)

In [9]:
data["phi"].shape

(100, 21, 21)

In [10]:
import jax

DATA_DIR = "/root/projects/deeprte/data/test/bc1-g0.1"
DATA_NAME = ["bc1-g0.1_meshdownsampling.mat"]
OUT_NAME = "bc1-g0.1_meshdownsampling.mat"

data_pipeline = pipeline.DataPipeline(DATA_DIR, DATA_NAME)
raw_feature_dict = data_pipeline.process()
num_examples = raw_feature_dict["shape"]["num_examples"]

del data_pipeline
# del _raw_feature_dict
jax.tree_map(lambda x: x.shape, raw_feature_dict["functions"])

{'boundary': (100, 80, 12),
 'boundary_scattering_kernel': (100, 80, 12, 24),
 'psi_label': (100, 20, 20, 24),
 'scattering_kernel': (100, 20, 20, 24, 24),
 'self_scattering_kernel': (100, 24, 24),
 'sigma': (100, 20, 20, 2)}

In [11]:
jax.tree_map(lambda x: x.shape, raw_feature_dict["grid"])

{'boundary_coords': (80, 12, 4),
 'boundary_weights': (20, 48),
 'phase_coords': (20, 20, 24, 4),
 'position_coords': (20, 20, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [12]:
raw_feature_dict["shape"]

{'num_examples': 100,
 'num_position_coords': 400,
 'num_velocity_coords': 24,
 'num_phase_coords': 9600,
 'num_boundary_coords': 960}

In [13]:
raw_feature_dict["grid"]["boundary_coords"]

array([[[ 0.        ,  0.025     ,  0.2554414 , -0.2554414 ],
        [ 0.        ,  0.025     ,  0.28708965, -0.69309574],
        [ 0.        ,  0.025     ,  0.69309574, -0.28708965],
        ...,
        [ 0.        ,  0.025     ,  0.9380233 ,  0.2513426 ],
        [ 0.        ,  0.025     ,  0.68668073,  0.68668073],
        [ 0.        ,  0.025     ,  0.2513426 ,  0.9380233 ]],

       [[ 0.        ,  0.075     ,  0.2554414 , -0.2554414 ],
        [ 0.        ,  0.075     ,  0.28708965, -0.69309574],
        [ 0.        ,  0.075     ,  0.69309574, -0.28708965],
        ...,
        [ 0.        ,  0.075     ,  0.9380233 ,  0.2513426 ],
        [ 0.        ,  0.075     ,  0.68668073,  0.68668073],
        [ 0.        ,  0.075     ,  0.2513426 ,  0.9380233 ]],

       [[ 0.        ,  0.125     ,  0.2554414 , -0.2554414 ],
        [ 0.        ,  0.125     ,  0.28708965, -0.69309574],
        [ 0.        ,  0.125     ,  0.69309574, -0.28708965],
        ...,
        [ 0.        ,  0.12