In [1]:
from pathlib import Path
from absl import logging

import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4"

logging.set_verbosity("info")

DATA_DIR = Path.cwd().parent / "data/rte"


In [2]:
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
from deeprte.dataset import (
    Split,
    load,
    convert_dataset,
    get_nest_dict_shape,
    preprocess_grid,
)

_ = convert_dataset(DATA_DIR / "rte_2d_bc_delta_funcs.npy")


INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
INFO:absl:Saved converted dataset /workspaces/deeprte/data/rte/rte_2d_bc_delta_funcs_converted.


In [3]:
data1 = np.load(DATA_DIR / "rte_2d_converted.npz")

print(tf.nest.map_structure(lambda x: x.shape, dict(data1)))


{'data/sigma_t': (500, 40, 40), 'data/sigma_a': (500, 40, 40), 'data/psi_bc': (500, 40, 12), 'data/psi_label': (500, 40, 40, 24), 'data/phi': (500, 40, 40), 'grid/w_angle': (24,), 'grid/v': (24, 2), 'grid/r': (40, 40, 2), 'grid/rv_prime': (40, 12, 4), 'grid/w_prime': (40, 12)}


In [8]:
np.mean(data1["data/psi_label"])

0.12421012

In [5]:
data2 = np.load(DATA_DIR / "rte_2d_bc_delta_funcs_converted.npz")

print(tf.nest.map_structure(lambda x: x.shape, dict(data2)))

{'data/sigma_t': (500, 40, 40), 'data/sigma_a': (500, 40, 40), 'data/psi_bc': (500, 40, 12), 'data/psi_label': (500, 40, 40, 24), 'data/phi': (500, 40, 40), 'grid/w_angle': (24,), 'grid/v': (24, 2), 'grid/r': (40, 40, 2), 'grid/rv_prime': (40, 12, 4), 'grid/w_prime': (40, 12)}


In [12]:
np.sqrt(np.mean((data2["data/psi_label"] - np.mean(data2["data/psi_label"])) ** 2))

0.5580785

In [13]:
del data2["__header__"], data2["__version__"], data2["__globals__"]

In [40]:
ds = load(
    DATA_DIR / "rte_2d_converted.npz",
    Split.VALID,
    is_training=False,
    batch_sizes=[jax.local_device_count(), 3],
    collocation_sizes=None,
    repeat=None,
)

# init_inputs = load_dummy_data(
#     DATA_DIR / "rte_2d_converted.npz", Split.TRAIN_AND_VALID, device_count=5
# )

# get_nest_dict_shape(init_inputs)


In [41]:
for i, d in enumerate(ds):
    if i <= 5:
        print(get_nest_dict_shape(d))
    else:
        break


INFO:absl:Data shapes, sigma_t: (500, 40, 40), sigma_a: (500, 40, 40), psi_bc: (500, 40, 12), psi_label: (500, 40, 40, 24), phi: (500, 40, 40)
INFO:absl:Grid shapes, w_angle: (24,), v: (24, 2), r: (40, 40, 2), rv_prime: (40, 12, 4), w_prime: (40, 12)


{'inputs': ((8, 38400, 2), (8, 38400, 2), F(x=(8, 1600, 2), y=(8, 3, 1600, 2)), F(x=(8, 480, 4), y=(8, 3, 480))), 'labels': (8, 3, 38400)}
{'inputs': ((8, 38400, 2), (8, 38400, 2), F(x=(8, 1600, 2), y=(8, 3, 1600, 2)), F(x=(8, 480, 4), y=(8, 3, 480))), 'labels': (8, 3, 38400)}
{'inputs': ((8, 38400, 2), (8, 38400, 2), F(x=(8, 1600, 2), y=(8, 3, 1600, 2)), F(x=(8, 480, 4), y=(8, 3, 480))), 'labels': (8, 3, 38400)}
{'inputs': ((8, 38400, 2), (8, 38400, 2), F(x=(8, 1600, 2), y=(8, 3, 1600, 2)), F(x=(8, 480, 4), y=(8, 3, 480))), 'labels': (8, 3, 38400)}


In [None]:
from deeprte.dataset import get_nest_dict_shape

for d in ds.take(1):
    print(get_nest_dict_shape(d))
    r, v = d["interior"][0], d["interior"][1]
    print(np.sum((r[1] - r[0]) ** 2))
    print(np.sum((v[1] - v[0]) ** 2))


In [None]:
from deeprte.utils import get_model_haiku_params, flat_dict_to_rte_data

params = get_model_haiku_params("rte_2d_1", DATA_DIR)

with np.load(DATA_DIR / "rte_2d_converted.npz") as npzfile:
    rte_data = flat_dict_to_rte_data(npzfile)

grid, num_grid_points = preprocess_grid(rte_data["grid"], is_training=False)
data = rte_data["data"]


In [None]:
get_nest_dict_shape(data), get_nest_dict_shape(grid)


In [None]:
from deeprte.models.rte_op import RTEOperator
from deeprte.modules import GreenFunction
from deeprte.typing import F
import ml_collections

config = ml_collections.ConfigDict(
    {
        "green_net": [128, 128, 128, 128, 1],
        "coeffs_net": {"weights": [64, 1], "coeffs": [64, 2]},
    }
)

sol = RTEOperator(config, GreenFunction)


In [None]:
import time

i = np.random.randint(500)


data_i = tf.nest.map_structure(lambda x: x[i : i + 1].reshape(1, -1), data)

psi_bc = data_i["psi_bc"]
sigma = np.stack([data_i["sigma_t"], data_i["sigma_a"]], axis=-1)
r, v, rv_prime, w_prime, w_angle = (
    grid["r"],
    grid["v"],
    grid["rv_prime"],
    grid["w_prime"],
    grid["w_angle"],
)

t_prev = time.time()
pred_rho = sol.rho(
    params, None, r, F(r, sigma), F(rv_prime, w_prime * psi_bc), (v, w_angle)
)
pred_rho.block_until_ready()
dt = time.time() - t_prev

print(dt)


In [None]:
import matplotlib.pyplot as plt

pred_rho = pred_rho.reshape(-1, 40, 40).squeeze()
phi = data_i["phi"].reshape(40, 40)


fig, _axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 8))
fig.subplots_adjust(hspace=0.3)
axs = _axs.flatten()

cs = axs[0].contour(
    r[..., 0].reshape(40, 40),
    r[..., 1].reshape(40, 40),
    np.abs(pred_rho - phi),
)
axs[0].clabel(cs, inline=True, fontsize=10)
fig.colorbar(cs, ax=axs[0])

axs[1].plot(
    r[..., 0].reshape(40, 40)[:, 0],
    phi[:, 0],
    r[..., 0].reshape(40, 40)[:, 0],
    pred_rho[:, 0],
    "*",
)

plt.show()
