In [None]:
import functools

import haiku as hk
from haiku.testing import transform_and_run
import jax
import jax.numpy as jnp
import numpy as np

from deeprte.model import config
from deeprte.model.tf.rte_features import (
    shape,
    FEATURES,
    BATCH_FEATURE_NAMES,
    COLLOCATION_FEATURE_NAMES,
)
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics

c = config.model_config()
gc = c.global_config

rng = hk.PRNGSequence(42)

In [None]:
get_feat_shape = functools.partial(
    shape,
    num_examples=10,
    num_position_coords=5,
    num_velocity_coords=3,
    num_phase_coords=15,
    num_boundary_coords=4,
)

test_batch = {}
for k, v in FEATURES.items():
    new_shape = get_feat_shape(k)
    test_batch[k] = jax.random.uniform(next(rng), new_shape)

jax.tree_map(lambda x: x.shape, test_batch),

## Test DeepRTE

In [None]:
def deeprte(*args, **kwargs):
    return modules.DeepRTE(c)(*args, **kwargs)


transformed_deeprte = hk.transform(deeprte)
deeprte = transform_and_run(deeprte)

### Params

In [None]:
params = transformed_deeprte.init(
    next(rng),
    test_batch,
    is_training=True,
    compute_loss=False,
    compute_metrics=False,
)
jax.tree_map(lambda x: x.shape, params)

### Apply

#### Case 1: train mode

In [None]:
total_loss, outputs = deeprte(
    test_batch, is_training=True, compute_loss=True, compute_metrics=True
)
total_loss.shape, outputs["predicted_solution"].shape, outputs[
    "loss"
], outputs["metrics"]

#### Case 2: evaluation mode

In [None]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=True
)
outputs["predicted_solution"].shape, outputs["metrics"]

#### Case 3: inference mode

In [None]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=False
)
outputs["predicted_solution"].shape

## Test Green's function

In [None]:
one_point_example = {}
for k, v in test_batch.items():
    if k in BATCH_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

for k, v in one_point_example.items():
    if k in COLLOCATION_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree_map(lambda x: x.shape, one_point_example)

In [None]:
@transform_and_run
def green_fn(*args):
    return modules.GreenFunction(c.green_function, gc)(*args, is_training=False)


green_fn(
    one_point_example["phase_coords"],
    one_point_example["phase_coords"],
    one_point_example
).shape

## Test Attenuation module

In [None]:
@transform_and_run
def attenuation_fn(*args):
    return modules.Attenuation(c.green_function.attenuation, gc)(*args)


coord1 = coord2 = jax.random.uniform(next(rng), [4])
att_coeff = jax.random.uniform(next(rng), [50, 3])

grid = jax.random.uniform(next(rng), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation_fn(coord1, coord2, att_coeff, char).shape

# Test Attention module

In [None]:
@transform_and_run
def attn_fn(*args):
    return modules.Attention(c.green_function.attenuation.attention, gc)(*args)


q = jax.random.uniform(next(rng), [5, 3])
k = jax.random.uniform(next(rng), [4, 3])
v = jax.random.uniform(next(rng), [4, 2])

attn_fn(q, k, v, None).shape