In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [None]:
import haiku as hk
from haiku.testing import transform_and_run
import jax

from deeprte.model.config import model_config
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics
from deeprte.model.tf import input_pipeline
from deeprte.model.tf import rte_features

c = model_config()
gc = c.global_config

rng = hk.PRNGSequence(42)

In [None]:
split = input_pipeline.Split.TRAIN
ds = input_pipeline.load(
    split=split,
    split_ratio=0.8,
    is_training=True,
    batch_sizes=[1],
    collocation_sizes=[120],
)
batch = next(ds)
jax.tree_map(lambda x: x.shape, batch)

## Test DeepRTE

In [None]:
def deeprte(*args, **kwargs):
    return modules.DeepRTE(c)(*args, **kwargs)


transformed_deeprte = hk.transform(deeprte)
deeprte = transform_and_run(deeprte)

### Params

In [None]:
params = transformed_deeprte.init(
    next(rng),
    batch,
    is_training=True,
    compute_loss=False,
    compute_metrics=False,
)
jax.tree_map(lambda x: x.shape, params)

### Apply

#### Case 1: train mode

In [None]:
total_loss, outputs = deeprte(
    batch, is_training=True, compute_loss=True, compute_metrics=True
)
total_loss, outputs["predicted_psi"].shape, outputs["loss"], outputs["metrics"]

#### Case 2: evaluation mode

In [None]:
outputs = deeprte(batch, is_training=False, compute_loss=False, compute_metrics=True)
outputs["predicted_psi"].shape, outputs["metrics"]

#### Case 3: inference mode

In [None]:
outputs = deeprte(batch, is_training=False, compute_loss=False, compute_metrics=False)
outputs["predicted_psi"].shape

## Test Green's function

In [None]:
one_point_example = {}

single_example = jax.tree_map(lambda x: x[0], batch)
for k, v in single_example.items():
    if k in rte_features.PHASE_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree_map(lambda x: x.shape, one_point_example)

In [None]:
@transform_and_run
def green_fn(*args):
    return modules.GreenFunction(c.green_function, gc)(*args, is_training=False)


green_fn(
    one_point_example["phase_coords"],
    one_point_example["phase_coords"],
    one_point_example,
)

## Test Attenuation module

In [None]:
@transform_and_run
def attenuation_fn(*args):
    return modules.Attenuation(c.green_function.attenuation, gc)(*args)


coord1 = coord2 = jax.random.uniform(next(rng), [4])
att_coeff = jax.random.uniform(next(rng), [50, 3])

grid = jax.random.uniform(next(rng), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation_fn(coord1, coord2, att_coeff, char).shape

# Test Attention module

In [None]:
@transform_and_run
def attn_fn(*args):
    return modules.Attention(c.green_function.attenuation.attention, gc)(*args)


q = jax.random.uniform(next(rng), [3])
k = jax.random.uniform(next(rng), [4, 3])
v = jax.random.uniform(next(rng), [4, 2])

attn_fn(q, k, v, None).shape