## Setup

The cell below downloads the code from Github and install necessary dependencies.

In [None]:
# ![ -d deeprte] || git clone --depth=1 https://github.com/mazhengcn/deeprte.git
# !cd deeprte && git pull
# !pip install -e ".[dev]"

### Import Packages

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"

%load_ext autoreload
%autoreload 2

import json
import os
import time
from typing import Any

import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import ml_collections
import numpy as np
from absl import logging
from matplotlib.colors import ListedColormap

from deeprte.data import pipeline
from deeprte.model import model
from deeprte.model.data import flat_params_to_haiku

logging.set_verbosity(logging.INFO)

jax.local_devices()

### Utility functions

In [None]:
def rmse(pred, target):
    return np.sqrt(np.mean((pred - target) ** 2) / np.mean(target**2))


def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


def jnp_to_np(output: dict[str, Any]) -> dict[str, Any]:
    """Recursively changes jax arrays to numpy arrays."""
    for k, v in output.items():
        if isinstance(v, dict):
            output[k] = jnp_to_np(v)
        elif isinstance(v, jnp.ndarray):
            output[k] = np.array(v)
    return output


def plot_phi(r, phi_pre, phi_label):
    fig, _axs = plt.subplots(nrows=1, ncols=3, figsize=(20, 6))
    fig.subplots_adjust(hspace=0.3)
    axs = _axs.flatten()

    viridis = matplotlib.colormaps["viridis"](np.linspace(0, 1.2, 128))
    cs_1 = axs[0].contourf(
        r[..., 0], r[..., 1], phi_label, cmap=ListedColormap(viridis)
    )
    axs[0].set_title(r"Exact $f(x,v)$", fontsize=20)
    axs[0].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_1)
    cbar.ax.tick_params(labelsize=16)

    cs_2 = axs[1].contourf(
        r[..., 0], r[..., 1], phi_pre, cmap=ListedColormap(viridis)
    )
    axs[1].set_title(r"Predict $f(x,v)$", fontsize=20)
    axs[1].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_2)
    cbar.ax.tick_params(labelsize=16)

    cs_3 = axs[2].contourf(
        r[..., 0],
        r[..., 1],
        abs(phi_pre - phi_label),
        cmap=ListedColormap(viridis),
    )
    axs[2].set_title(r"Absolute error", fontsize=20)
    axs[2].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_3)
    cbar.ax.tick_params(labelsize=16)

    plt.tight_layout()

## Load Dataset

In [None]:
# Path to the dataset
DATA_DIR = "/workspaces/deeprte/data/raw_data/eval_data/0311"
DATA_NAME = "test_random_kernel_0311.mat"

data_pipeline = pipeline.DataPipeline(DATA_DIR, [DATA_NAME])
raw_feature_dict = data_pipeline.process()
num_examples = raw_feature_dict["shape"]["num_examples"]

jax.tree_map(lambda x: x.shape, raw_feature_dict["functions"])

## Import Pre-trained model

#### Load model config

In [None]:
MODEL_DIR = "/workspaces/deeprte/ckpts/saved_model"

model_config_path = os.path.join(MODEL_DIR, "config.json")
with open(model_config_path) as f:
    str = f.read()
    config = ml_collections.ConfigDict(json.loads(str))
model_config = config.experiment_kwargs.config.model

model_config

#### Load model parameters

In [None]:
params_path = os.path.join(MODEL_DIR, "params.npz")

np_params = np.load(params_path, allow_pickle=True)
params = flat_params_to_haiku(np_params)

jax.tree_map(lambda x: x.shape, params)

#### Construct model runner

In [None]:
model_runner = model.RunModel(model_config, params, multi_devices=True)


def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


if model_config.data.is_normalization:
    normalization_dict = model_config.data.normalization_dict
    normalization_ratio = get_normalization_ratio(
        normalization_dict["psi_range"],
        normalization_dict["boundary_range"],
    )
else:
    normalization_ratio = None

## Predict and Evaluate

#### Predict and Evaluate i-th example in the dataset

In [None]:
i = 1

#### Run prediction

In [None]:
logging.info("Predicting example %d/%d", i + 1, num_examples)

random_seed = 1
benchmark = True

feature_dict = {
    "functions": jax.tree_map(
        lambda x: x[i : i + 1], raw_feature_dict["functions"]
    ),
    "grid": raw_feature_dict["grid"],
    "shape": raw_feature_dict["shape"],
}

# Run the model.
logging.info("Running model...")
processed_feature_dict = model_runner.process_features(feature_dict)
t_0 = time.time()
prediction_result = model_runner.predict(
    processed_feature_dict, random_seed=random_seed
)
t_diff = time.time() - t_0
logging.info(
    "Total JAX model predict time "
    "(includes compilation time, see --benchmark): %.1fs",
    t_diff,
)

if benchmark:
    t_0 = time.time()
    model_runner.predict(processed_feature_dict, random_seed=random_seed)
    t_diff = time.time() - t_0
    logging.info(
        "Total JAX model predict time " "(excludes compilation time): %.1fs",
        t_diff,
    )

psi_shape = feature_dict["functions"]["psi_label"].shape
t_0 = time.time()
predicted_psi = (
    prediction_result["predicted_psi"]
    .reshape(1, -1)  # reshape multi_devices to single device
    .reshape(psi_shape)
)
if normalization_ratio:
    predicted_psi = predicted_psi * normalization_ratio

predicted_phi = jnp.sum(
    predicted_psi * feature_dict["grid"]["velocity_weights"],
    axis=-1,
)
t_diff = time.time() - t_0

# Remove jax dependency from results.
np_prediction_result = jnp_to_np(dict(prediction_result))

# Compute metrics.
metrics = {}
psi_label = feature_dict["functions"]["psi_label"]
phi_label = np.sum(
    psi_label * feature_dict["grid"]["velocity_weights"], axis=-1
)
psi_rmse = rmse(predicted_psi, psi_label)
phi_rmse = rmse(predicted_phi, phi_label)
logging.info("RMSE of psi: %f, RMSE of phi: %f\n", psi_rmse, phi_rmse)

### Visualize result

In [None]:
plot_phi(
    feature_dict["grid"]["position_coords"].reshape(*psi_shape[1:-1], -1),
    predicted_phi[0],
    phi_label[0],
)