## Setup

The cell below downloads the code from Github and install necessary dependencies.

In [1]:
# !cd .. && pip install -e "."
# !pip install matplotlib

### Import Packages

In [2]:
%load_ext autoreload
%autoreload 2

import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"

import json
import os
import time
from typing import Any

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from absl import logging
from matplotlib.colors import ListedColormap

import dill

import sys
sys.path.append("/root/projects/deeprte")

from deeprte.data import pipeline
from deeprte.model import model
from deeprte.model.utils import flat_params_to_haiku

import pathlib

logging.set_verbosity(logging.INFO)

jax.local_devices()

[cuda(id=0)]

### Utility functions

In [3]:
def rmse(pred, target):
    return np.sqrt(np.mean((pred - target) ** 2) / np.mean(target**2))


def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


def jnp_to_np(output: dict[str, Any]) -> dict[str, Any]:
    """Recursively changes jax arrays to numpy arrays."""
    for k, v in output.items():
        if isinstance(v, dict):
            output[k] = jnp_to_np(v)
        elif isinstance(v, jnp.ndarray):
            output[k] = np.array(v)
    return output

## Load Dataset

In [4]:
# Path to the dataset
DATA_DIR = "/root/projects/deeprte/data/raw_data/train/g0.1-sigma_a3-sigma_t6"
DATA_NAME = ["g0.1-sigma_a3-sigma_t6_eval.mat"]

data_pipeline = pipeline.DataPipeline(DATA_DIR, DATA_NAME)
raw_feature_dict = data_pipeline.process()
num_examples = raw_feature_dict["shape"]["num_examples"]

# del _raw_feature_dict
jax.tree_map(lambda x: x.shape, raw_feature_dict["functions"])

{'boundary': (100, 160, 12),
 'boundary_scattering_kernel': (100, 160, 12, 24),
 'psi_label': (100, 40, 40, 24),
 'scattering_kernel': (100, 40, 40, 24, 24),
 'self_scattering_kernel': (100, 24, 24),
 'sigma': (100, 40, 40, 2)}

In [5]:
jax.tree_map(lambda x: x.shape, raw_feature_dict["grid"])

{'boundary_coords': (160, 12, 4),
 'boundary_weights': (160, 12),
 'phase_coords': (40, 40, 24, 4),
 'position_coords': (40, 40, 2),
 'velocity_coords': (24, 2),
 'velocity_weights': (24,)}

## Import Pre-trained model

#### Load model config

In [6]:
MODEL_DIR = "/root/projects/deeprte/data/ckpts/g0.1-sigma_a3-sigma_t6_2023-05-14T18:50:04/models/latest/step_500000_2023-05-23T12:05:36"

model_config_path = os.path.join(MODEL_DIR, "model.json")
with open(model_config_path) as f:
    _str = f.read()
    model_config = ml_collections.ConfigDict(json.loads(_str))

model_config

data:
  is_normalization: true
  normalization_dict:
    boundary_min: '0.0'
    boundary_range: '0.19474658'
    psi_min: '3.8484396e-10'
    psi_range: '0.14552617'
global_config:
  bc_loss_weights: 1.0
  deterministic: true
  loss_weights: 5.0
  subcollocation_size: 8
  w_init: glorot_uniform
green_function:
  attenuation:
    attention:
      key_chunk_size: 128
      key_dim: 32
      num_head: 2
      output_dim: 2
      value_dim: null
    latent_dim: 128
    num_layer: 4
    output_dim: 16
  scattering:
    latent_dim: 16
    num_layer: 2

#### Load model parameters

In [7]:
params_path = os.path.join(MODEL_DIR, "params.npz")

np_params = np.load(params_path, allow_pickle=True)
params = flat_params_to_haiku(np_params)

jax.tree_map(lambda x: x.shape, params)

2024-02-21 09:07:40.168825: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{'deeprte/green_function/attenuation/attention/key': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attention/output_projection': {'b': (2,),
  'w': (64, 2)},
 'deeprte/green_function/attenuation/attention/query': {'b': (64,),
  'w': (4, 64)},
 'deeprte/green_function/attenuation/attention/value': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attenuation_linear': {'b': (128,),
  'w': (10, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_1': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_2': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/output_projection': {'b': (16,),
  'w': (128, 16)},
 'deeprte/green_function/output_projection': {'w': (16, 1)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/layer_norm': {'offset': (2,
   16),
  'scale': (2, 16)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer/linear':

In [8]:
import haiku as hk

hk.data_structures.tree_size(params)

37954

#### Construct model runner

In [9]:
model_runner = model.RunModel(model_config, params, multi_devices=False)


def get_normalization_ratio(psi_range, boundary_range):
    psi_range = float(psi_range.split(" ")[-1])
    boundary_range = float(boundary_range.split(" ")[-1])
    return psi_range / boundary_range


if model_config.data.is_normalization:
    normalization_dict = model_config.data.normalization_dict
    normalization_ratio = get_normalization_ratio(
        normalization_dict["psi_range"],
        normalization_dict["boundary_range"],
    )
else:
    normalization_ratio = None

## Predict and Evaluate

#### Predict and Evaluate i-th example in the dataset

In [10]:
# i = 1
num_eval = 2

output_dir_base = "test"

In [11]:
import datetime

now = datetime.datetime.now()
date_time = now.strftime("%Y%m%d_%H%M%S")
output_dir_base = output_dir_base + "/" + date_time


output_dir_base = pathlib.Path(output_dir_base)
if not output_dir_base.exists():
    output_dir_base.mkdir(parents=True)

print(output_dir_base)

test/20240221_090740


#### Run prediction

In [12]:
for i in range(num_examples-num_eval, num_examples):
    timings = {}

    logging.info("Predicting example %d/%d", i + 1, num_examples)

    output_dir = output_dir_base / f"example_{i}"
    if not output_dir.exists():
        output_dir.mkdir(parents=True)

    random_seed = 1
    benchmark = True

    feature_dict = {
        "functions": jax.tree_map(lambda x: x[i : i + 1], raw_feature_dict["functions"]),
        "grid": raw_feature_dict["grid"],
        "shape": raw_feature_dict["shape"],
    }

    # Write out features as a pickled dictionary.
    features_output_path = output_dir / "features.dill"
    with open(features_output_path, "wb") as f:
        dill.dump(feature_dict, f)

    # Run the model.
    logging.info("Running model...")
    t_0 = time.time()
    processed_feature_dict = model_runner.process_features(feature_dict)
    timings["process_features"] = time.time() - t_0

    t_0 = time.time()
    prediction_result = model_runner.predict(
        processed_feature_dict, random_seed=random_seed
    )
    t_diff = time.time() - t_0
    timings["predict_and_compile"] = t_diff

    logging.info(
        "Total JAX model predict time "
        "(includes compilation time, see --benchmark): %.1fs",
        t_diff,
    )

    if benchmark:
        t_0 = time.time()
        model_runner.predict(processed_feature_dict, random_seed=random_seed)
        t_diff = time.time() - t_0
        logging.info(
            "Total JAX model predict time " "(excludes compilation time): %.1fs",
            t_diff,
        )

    psi_shape = feature_dict["functions"]["psi_label"].shape
    t_0 = time.time()
    predicted_psi = (
        prediction_result["predicted_psi"]
        .reshape(1, -1)  # reshape multi_devices to single device
        .reshape(psi_shape)
    )
    if normalization_ratio:
        predicted_psi = predicted_psi * normalization_ratio

    predicted_phi = jnp.sum(
        predicted_psi * feature_dict["grid"]["velocity_weights"],
        axis=-1,
    )
    t_diff = time.time() - t_0
    timings["compute_psi_and_phi"] = t_diff
    prediction_result.update(
        {"predicted_psi": predicted_psi, "predicted_phi": predicted_phi}
    )

    # Remove jax dependency from results.
    np_prediction_result = jnp_to_np(dict(prediction_result))

    # Compute metrics.
    metrics = {}
    psi_label = feature_dict["functions"]["psi_label"]
    phi_label = np.sum(psi_label * feature_dict["grid"]["velocity_weights"], axis=-1)
    psi_rmse = rmse(predicted_psi, psi_label)
    phi_rmse = rmse(predicted_phi, phi_label)
    metrics.update({"psi_rmse": str(psi_rmse), "phi_rmse": str(phi_rmse)})
    logging.info("RMSE of psi: %f, RMSE of phi: %f\n", psi_rmse, phi_rmse)
    np_result = {
            **np_prediction_result,
            "psi_label": psi_label,
            "phi_label": phi_label,
        }
    result_output_path = output_dir / "result.dill"
    with open(result_output_path, "wb") as f:
        dill.dump(np_result, f)

    metrics_output_path = output_dir / "metrics.json"
    with open(metrics_output_path, "w") as f:
        f.write(json.dumps(metrics, indent=4))

    timings_output_path = output_dir / "timings.json"
    with open(timings_output_path, "w") as f:
        f.write(json.dumps(timings, indent=4))

INFO:absl:Predicting example 3/100
INFO:absl:Running model...


INFO:absl:Running predict with shape(feat) = {'boundary': (1, 1920), 'boundary_coords': (1, 1920, 4), 'boundary_weights': (1, 1920), 'phase_coords': (1, 38400, 4), 'position_coords': (1, 1600, 2), 'scattering_kernel': (1, 38400, 24), 'self_scattering_kernel': (1, 24, 24), 'sigma': (1, 1600, 2), 'velocity_coords': (1, 24, 2), 'velocity_weights': (1, 24)}
INFO:absl:Output shape was {'predicted_psi': (1, 38400)}
INFO:absl:Total JAX model predict time (includes compilation time, see --benchmark): 33.9s
INFO:absl:Running predict with shape(feat) = {'boundary': (1, 1920), 'boundary_coords': (1, 1920, 4), 'boundary_weights': (1, 1920), 'phase_coords': (1, 38400, 4), 'position_coords': (1, 1600, 2), 'scattering_kernel': (1, 38400, 24), 'self_scattering_kernel': (1, 24, 24), 'sigma': (1, 1600, 2), 'velocity_coords': (1, 24, 2), 'velocity_weights': (1, 24)}
INFO:absl:Output shape was {'predicted_psi': (1, 38400)}
INFO:absl:Total JAX model predict time (excludes compilation time): 21.3s
INFO:absl