In [1]:
import pathlib

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as orbax
from flax import nnx
from jax.sharding import PartitionSpec as P
from rte_dataset.builders import pipeline

from deeprte.configs import default
from deeprte.model import features
from deeprte.model.autoencoder import AutoEncoder
from deeprte.model.mapping import inference_subbatch
from deeprte.model.tf import rte_features
from deeprte.train_lib import utils

In [2]:
path = "/workspaces/deeprte/ckpts/g0.1-gaussian-source0121"

In [3]:
config = default.get_config(path + "/config.yaml")

In [4]:
config.load_parameters_path = (path + "/infer/params",)

In [5]:
devices_array = utils.create_device_mesh(config, devices=jax.local_devices())
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

replicated_sharding = jax.sharding.NamedSharding(mesh, P(None))
data_sharding = jax.sharding.NamedSharding(mesh, P(None, *config.data_sharding))

feature_sharding = {
    k: data_sharding if k in rte_features.PHASE_COORDS_FEATURES else replicated_sharding
    for k in rte_features.FEATURES
}

rngs = jax.random.key(0)

In [6]:
state, state_sharding = utils.get_abstract_state(
    constructor=lambda config, rngs: AutoEncoder(config, rngs=nnx.Rngs(params=rngs)),
    tx=None,
    config=config,
    rng=rngs,
    mesh=mesh,
    is_training=False,
)

In [7]:
config.load_parameters_path


('/workspaces/deeprte/ckpts/g0.1-gaussian-source0121/infer/params',)

In [8]:
checkpointer = orbax.PyTreeCheckpointer()
_params = {}
for path in config.load_parameters_path:
    _res_params = checkpointer.restore(path)
    _params = _params | _res_params["params"]



In [9]:
graphdef, state = utils.module_from_variables_dict(
    lambda: nnx.eval_shape(lambda: AutoEncoder(config, rngs=nnx.Rngs(params=rngs))),
    _params,
    lambda path: path[:-1] if path[-1] == "value" else path,
)


In [10]:
params = utils.init_infer_state(None, state).params

In [11]:
num_params = utils.calculate_num_params_from_pytree(params)
print(f"Number of model params={num_params}")

Number of model params=165760


In [60]:
def forward_fn(params, features, graphdef):
    model = nnx.merge(graphdef, params)

    def func(features):
        moments = model.encoder(
            features["source"], [features["source_coords"], features["source_weights"]]
        )
        print(moments.shape)
        basis = model.mlp(features["source_coords"])
        print(basis.shape)
        return basis, moments

    return jax.vmap(func)(features)

In [61]:
jit_predict_fn = jax.jit(
    forward_fn,
    in_shardings=(state_sharding.params, feature_sharding),
    out_shardings=data_sharding,
    static_argnums=2,
)

In [85]:
# params
# data_path = pathlib.Path(
#     "/workspaces/deeprte/data/raw_data/train/g0.1-q0.003/g0.1-q0.003.npz"
# )
# data_path = pathlib.Path(
#     "/workspaces/deeprte/data/raw_data/test/source-g0.1-qconstant/g0.1-qconstant.npz"
# )
data_path = pathlib.Path(
    "/workspaces/deeprte/data/raw_data/train/g0.1-gaussian-alpha5/g0.1-gaussian-alpha5.npz"
)
data_pipeline = pipeline.DataPipeline(data_path.parent, [data_path.name])
raw_feature_dict = data_pipeline.process()
del data_pipeline

In [86]:
raw_feature_dict["functions"].keys()

dict_keys(['sigma', 'psi_label', 'scattering_kernel', 'boundary_scattering_kernel', 'self_scattering_kernel', 'boundary', 'source', 'source_label'])

In [87]:
# i = 0
# a = 1
# b = 0
# bc = (
#     a * raw_feature_dict["functions"]["boundary"][i : i + 1]
#     + b * raw_feature_dict["functions"]["boundary"][i + 1 : i + 2]
# )

In [88]:
i = 0
feature_dict = {
    "functions": jax.tree.map(lambda x: x[i : i + 1], raw_feature_dict["functions"]),
    "grid": raw_feature_dict["grid"],
    "shape": raw_feature_dict["shape"],
}


In [89]:
rte_features.SOURCE_COORDS_FEATURES

['source_coords', 'source_weights', 'source']

In [90]:
jax.tree.map(
    lambda x: x.shape if isinstance(x, np.ndarray) else x, feature_dict["grid"]
)

{'boundary_coords': (40, 40, 24, 4),
 'boundary_weights': (40, 40, 24),
 'phase_coords': (40, 40, 24, 4),
 'position_coords': (40, 40, 2),
 'source_coords': (40, 40, 24, 4),
 'source_weights': (40, 40, 24),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [91]:
processed_feature_dict = features.np_data_to_features(feature_dict)
phase_feat, other_feat = features.split_feature(
    processed_feature_dict, filter=lambda x: x in rte_features.SOURCE_COORDS_FEATURES
)

In [92]:
jax.tree.map(
    lambda x: x.shape if isinstance(x, np.ndarray) else x, processed_feature_dict
)


{'boundary': (1, 38400),
 'boundary_coords': (1, 38400, 4),
 'boundary_weights': (1, 38400),
 'phase_coords': (1, 38400, 4),
 'position_coords': (1, 1600, 2),
 'scattering_kernel': (1, 38400, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1600, 2),
 'source': (1, 38400),
 'source_coords': (1, 38400, 4),
 'source_weights': (1, 38400),
 'velocity_coords': (1, 24, 2),
 'velocity_weights': (1, 24)}

In [93]:
phase_feat.keys()

dict_keys(['source', 'source_coords', 'source_weights'])

In [94]:
out = inference_subbatch(
    module=lambda x: jit_predict_fn(params, x, graphdef),
    subbatch_size=128,
    batched_args=phase_feat,
    nonbatched_args=other_feat,
    low_memory=False,
)

In [106]:
result = {}
result["moments"] = out[1]
result["grim_matrix"] = jnp.einsum(
    "...ij,...ik->...jk", phase_feat["source_weights"][..., None] * out[0], out[0]
)

In [107]:
_boundary = jnp.einsum("k,jk->j", jnp.squeeze(result["moments"]), jnp.squeeze(out[0]))
_boundary.shape

(38400,)

In [108]:
def rmse(pred, target):
    return jnp.sqrt(jnp.mean((pred - target) ** 2) / jnp.mean(target**2))


rmse(_boundary, jnp.squeeze(phase_feat["source"]))


Array(0.0603061, dtype=float32)

In [109]:
results = [result, result]


In [110]:
all_results = jax.tree.map(lambda *xs: np.concatenate(xs, axis=0), *results)

In [111]:
jax.tree.map(lambda x: x.shape, all_results)

{'grim_matrix': (2, 128, 128), 'moments': (2, 128)}