## Setup

The cell below downloads the code from Github and install necessary dependencies.

### Import Packages

In [None]:
%load_ext autoreload
%autoreload 2

import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"


In [None]:
import json
import os
import time
from typing import Any

import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from absl import logging
from matplotlib.colors import ListedColormap
from deeprte.model.engine import RteEngine

from deeprte.data import pipeline

logging.set_verbosity(logging.INFO)

jax.local_devices()

### Utility functions

In [None]:
def rmse(pred, target):
    return np.sqrt(np.mean((pred - target) ** 2) / np.mean(target**2))


def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


def jnp_to_np(output: dict[str, Any]) -> dict[str, Any]:
    """Recursively changes jax arrays to numpy arrays."""
    for k, v in output.items():
        if isinstance(v, dict):
            output[k] = jnp_to_np(v)
        elif isinstance(v, jnp.ndarray):
            output[k] = np.array(v)
    return output


def plot_phi(r, phi_pre, phi_label):
    fig, _axs = plt.subplots(nrows=1, ncols=3, figsize=(20, 6))
    fig.subplots_adjust(hspace=0.3)
    axs = _axs.flatten()

    viridis = matplotlib.colormaps["viridis"](np.linspace(0, 1.2, 128))
    cs_1 = axs[0].contourf(
        r[..., 0], r[..., 1], phi_label, cmap=ListedColormap(viridis)
    )
    axs[0].set_title(r"Exact $f(x,v)$", fontsize=20)
    axs[0].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_1)
    cbar.ax.tick_params(labelsize=16)

    cs_2 = axs[1].contourf(r[..., 0], r[..., 1], phi_pre, cmap=ListedColormap(viridis))
    axs[1].set_title(r"Predict $f(x,v)$", fontsize=20)
    axs[1].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_2)
    cbar.ax.tick_params(labelsize=16)

    cs_3 = axs[2].contourf(
        r[..., 0],
        r[..., 1],
        abs(phi_pre - phi_label),
        cmap=ListedColormap(viridis),
    )
    axs[2].set_title(r"Absolute error", fontsize=20)
    axs[2].tick_params(axis="both", labelsize=15)
    cbar = fig.colorbar(cs_3)
    cbar.ax.tick_params(labelsize=16)

    plt.tight_layout()

## Load Dataset

In [None]:
# Path to the dataset
DATA_DIR = "/workspaces/deeprte/data/raw_data/test/bc1-g0.5"
DATA_NAME = ["bc1-g0.5.mat"]

data_pipeline = pipeline.DataPipeline(DATA_DIR, DATA_NAME)
raw_feature_dict = data_pipeline.process()
num_examples = raw_feature_dict["shape"]["num_examples"]

del data_pipeline

jax.tree.map(lambda x: x.shape, raw_feature_dict["functions"])

## Import Pre-trained model

#### Load model config

In [None]:
import imp

MODEL_DIR = "/workspaces/deeprte/ckpts/g0.5/infer/params"
CONFIG_PATH = "/workspaces/deeprte/deeprte/configs/default.py"

config_module = imp.load_source("config", CONFIG_PATH)
config = config_module.get_config()
config.load_parameters_path = MODEL_DIR

config

#### Load model parameters

In [None]:
rte_engine = RteEngine(config)

#### Construct model runner

In [None]:

def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


normalization_dict = {"psi_min": "7.593978e-11", "psi_range": "0.14345199", "boundary_min": "0.0", "boundary_range": "0.19116795"}
normalization_ratio = get_normalization_ratio(
    normalization_dict["psi_range"],
    normalization_dict["boundary_range"],
)

normalization_ratio

## Predict and Evaluate

#### Predict and Evaluate i-th example in the dataset

In [None]:
# i = 1
i = np.random.randint(100)

#### Run prediction

In [None]:
logging.info("Predicting example %d/%d", i + 1, num_examples)

benchmark = True

feature_dict = {
    "functions": jax.tree.map(lambda x: x[i : i + 1], raw_feature_dict["functions"]),
    "grid": raw_feature_dict["grid"],
    "shape": raw_feature_dict["shape"],
}

# Run the model.
logging.info("Running rte engine...")
processed_feature_dict = rte_engine.process_features(feature_dict)


t_0 = time.time()
prediction = rte_engine.predict(processed_feature_dict)
t_diff = time.time() - t_0

if benchmark:
    t_0 = time.time()
    rte_engine.predict(processed_feature_dict)
    t_diff = time.time() - t_0
    logging.info(
        "Total JAX model predict time " "(excludes compilation time): %.6fs",
        t_diff,
    )

psi_shape = feature_dict["functions"]["psi_label"].shape
t_0 = time.time()
predicted_psi = (
    prediction
    .reshape(1, -1)  # reshape multi_devices to single device
    .reshape(psi_shape)
)
if normalization_ratio:
    predicted_psi = predicted_psi * normalization_ratio

predicted_phi = jnp.sum(
    predicted_psi * feature_dict["grid"]["velocity_weights"],
    axis=-1,
)
t_diff = time.time() - t_0

# Compute metrics.
metrics = {}
psi_label = feature_dict["functions"]["psi_label"]
phi_label = np.sum(psi_label * feature_dict["grid"]["velocity_weights"], axis=-1)
psi_rmse = rmse(predicted_psi, psi_label)
phi_rmse = rmse(predicted_phi, phi_label)
logging.info("RMSE of psi: %f, RMSE of phi: %f\n", psi_rmse, phi_rmse)

### Visualize result

In [None]:
plot_phi(
    feature_dict["grid"]["position_coords"].reshape(*psi_shape[1:-1], -1),
    predicted_phi[0],
    phi_label[0],
)