In [27]:
import numpy as np
import jax.numpy as jnp
import jax
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import haiku as hk
import functools
from deeprte.config import get_config
from deeprte.model.modules_v2 import DeepRTE
from deeprte.model.tf.input_pipeline import load_tf_data
from deeprte.model.data import flat_params_to_haiku

from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES

In [12]:
def slice_batch(i: int, feat: dict):
    return {
        k: feat[k][i:i+1] if k in _BATCH_FEATURE_NAMES else feat[k] for k in feat
    }

rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [13]:
config = get_config()
config = config.experiment_kwargs.config
config.model.model_structure.green_function.scattering_module.res_block_depth = 2

In [14]:
source_dir = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/delta/"
data_name_list = ["test_bc_1.mat"]

PARAMS_FILE = "/workspaces/deeprte/ckpts/train_delta_2022-11-15T17:50:42/models/latest/step_1000000_2022-11-17T05:40:43/params.npz"

In [15]:
tf_data = load_tf_data(source_dir, data_name_list)
features = jax.tree_map(lambda x: jnp.array(x), tf_data)
jax.tree_util.tree_map(lambda x: x.shape, features)

{'boundary': (10, 1920),
 'boundary_coords': (1920, 4),
 'boundary_weights': (1920,),
 'phase_coords': (38400, 4),
 'position_coords': (1600, 2),
 'psi_label': (10, 38400),
 'scattering_kernel': (10, 38400, 24),
 'self_scattering_kernel': (10, 24, 24),
 'sigma': (10, 1600, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [32]:
np_params = np.load(PARAMS_FILE)
params = flat_params_to_haiku(np_params)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'green_function/coefficient_net/attention_net/linear': {'bias': (64,),
  'weights': (6, 64)},
 'green_function/coefficient_net/attention_net/linear_1': {'bias': (1,),
  'weights': (64, 1)},
 'green_function/green_function_mlp/linear': {'bias': (128,),
  'weights': (10, 128)},
 'green_function/green_function_mlp/linear_1': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/green_function_mlp/linear_2': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/green_function_mlp/linear_3': {'bias': (128,),
  'weights': (128, 128)},
 'green_function/green_function_mlp/linear_4': {'bias': (1,),
  'weights': (128, 1)}}

In [28]:
def forward_fn(batch, is_training):
    out = DeepRTE(config.model.model_structure, config.model.global_config)(batch, is_training=is_training, compute_loss=True, compute_metrics=False)
    return out

forward = hk.transform(forward_fn)
apply = jax.jit(functools.partial(forward.apply, is_training = False))
init = jax.jit(functools.partial(forward.init, is_training = True))

In [29]:
idx = 0

In [30]:
batch = slice_batch(idx, features)
jax.tree_util.tree_map(lambda x: x.shape, batch)

{'boundary': (1, 1920),
 'boundary_coords': (1920, 4),
 'boundary_weights': (1920,),
 'phase_coords': (38400, 4),
 'position_coords': (1600, 2),
 'psi_label': (1, 38400),
 'scattering_kernel': (1, 38400, 24),
 'self_scattering_kernel': (1, 24, 24),
 'sigma': (1, 1600, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

In [37]:
# _params = init(next(rng), batch)
# jax.tree_util.tree_map(lambda x: x.shape, _params)

In [None]:
phi_pre = apply(params, next(rng), batch)
phi_label = jnp.dot(batch["psi_label"],batch["velocity_weights"])
print(jnp.sqrt(jnp.mean((phi_label - phi_pre)**2)/jnp.mean(phi_label**2)), idx)

In [None]:
import matplotlib.pyplot as plt
import plotly.express as px
import matplotlib as mpl
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

viridis = mpl.colormaps['viridis'](np.linspace(0, 1.2, 128))
r = batch['phase_coords']


fig, _axs = plt.subplots(nrows=1, ncols=3, figsize=(20, 6) )
fig.subplots_adjust(hspace=0.3)
axs = _axs.flatten()

# fig = px.density_contour(phi_label)
# fig.show()

cs_1 = axs[0].contourf(
    r[..., 0],
    r[..., 1],
    phi_label,
    cmap = ListedColormap(viridis)
)
axs[0].set_title(r"Exact $f(x,v)$", fontsize=20)
axs[0].tick_params(axis='both', labelsize=15)
cbar = fig.colorbar(cs_1)
cbar.ax.tick_params(labelsize=16) 


# fig = px.density_contour(phi_pre)
cs_2 = axs[1].contourf(
    r[..., 0],
    r[..., 1],
    phi_pre,
        cmap = ListedColormap(viridis)
)
axs[1].set_title(r"Predict $f(x,v)$", fontsize=20)
axs[1].tick_params(axis='both', labelsize=15)
cbar = fig.colorbar(cs_2)
cbar.ax.tick_params(labelsize=16) 

cs_3 = axs[2].contourf(
    r[..., 0],
    r[..., 1],
    abs(phi_pre-phi_label),
    cmap = ListedColormap(viridis)
)
axs[2].set_title(r"Absolute error", fontsize=20)
axs[2].tick_params(axis='both', labelsize=15)
cbar = fig.colorbar(cs_3)
cbar.ax.tick_params(labelsize=16) 

plt.tight_layout()

plt.show()
print(np.sqrt(np.mean((phi_label - phi_pre)**2)/np.mean(phi_label**2)))