In [2]:
import functools

import haiku as hk
from haiku.testing import transform_and_run
import jax
import jax.numpy as jnp
import numpy as np

from deeprte.model.config import model_config
from deeprte.model.tf.rte_features import (
    shape,
    FEATURES,
    BATCH_FEATURE_NAMES,
    COLLOCATION_FEATURE_NAMES,
)
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics

c = model_config()
gc = c.global_config

rng = hk.PRNGSequence(42)

In [3]:
get_feat_shape = functools.partial(
    shape,
    num_examples=10,
    num_position_coords=5,
    num_velocity_coords=3,
    num_phase_coords=15,
    num_boundary_coords=4,
)

test_batch = {}
for k, v in FEATURES.items():
    new_shape = get_feat_shape(k)
    test_batch[k] = jax.random.uniform(next(rng), new_shape)

jax.tree_map(lambda x: x.shape, test_batch),

({'boundary': (10, 4),
  'boundary_coords': (4, 4),
  'boundary_weights': (4,),
  'phase_coords': (15, 4),
  'position_coords': (5, 2),
  'psi_label': (10, 15),
  'scattering_kernel': (10, 15, 3),
  'self_scattering_kernel': (10, 3, 3),
  'sigma': (10, 5, 2),
  'velocity_coords': (3, 2),
  'velocity_weights': (3,)},)

## Test DeepRTE

In [4]:
def deeprte(*args, **kwargs):
    return modules.DeepRTE(c)(*args, **kwargs)


transformed_deeprte = hk.transform(deeprte)
deeprte = transform_and_run(deeprte)

In [10]:
def loss(params, rng):
    loss, _ = transformed_deeprte.apply(params, rng, test_batch,is_training=False,
compute_loss=True,
compute_metrics=False)
    return loss

In [11]:
grad_fn = jax.grad(loss)

In [12]:
logits = grad_fn(params, next(rng))

In [14]:
jax.tree_map(lambda x: x.shape, logits)

{'deeprte/green_function/attenuation/attention/key': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attention/output_projection': {'b': (2,),
  'w': (64, 2)},
 'deeprte/green_function/attenuation/attention/query': {'b': (64,),
  'w': (4, 64)},
 'deeprte/green_function/attenuation/attention/value': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attenuation_linear': {'b': (128,),
  'w': (10, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_1': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_2': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/output_projection': {'b': (32,),
  'w': (128, 32)},
 'deeprte/green_function/output_projection': {'w': (32, 1)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer/linear': {'b': (2,
   32),
  'w': (2, 32, 32)}}

### Params

In [6]:
params = transformed_deeprte.init(
    next(rng),
    test_batch,
    is_training=True,
    compute_loss=False,
    compute_metrics=False,
)
jax.tree_map(lambda x: x.shape, params)

{'deeprte/green_function/attenuation/attention/key': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attention/output_projection': {'b': (2,),
  'w': (64, 2)},
 'deeprte/green_function/attenuation/attention/query': {'b': (64,),
  'w': (4, 64)},
 'deeprte/green_function/attenuation/attention/value': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attenuation_linear': {'b': (128,),
  'w': (10, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_1': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_2': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/output_projection': {'b': (32,),
  'w': (128, 32)},
 'deeprte/green_function/output_projection': {'w': (32, 1)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer/linear': {'b': (2,
   32),
  'w': (2, 32, 32)}}

### Apply

#### Case 1: train mode

In [5]:
total_loss, outputs = deeprte(
    test_batch, is_training=True, compute_loss=True, compute_metrics=True
)
total_loss.shape, outputs["predicted_solution"].shape, outputs[
    "loss"
], outputs["metrics"]

((),
 (10, 15),
 {'mse': DeviceArray(0.40339795, dtype=float32),
  'rmspe': DeviceArray(1.0573285, dtype=float32)},
 {'mse': DeviceArray([0.13312022, 0.9188513 , 0.19137558, 1.3161011 , 0.54276913,
               0.19607523, 0.3501287 , 0.12373666, 0.1253971 , 0.13642465],            dtype=float32),
  'rmspe': DeviceArray([0.36891827, 2.5464277 , 0.5303623 , 3.6473327 , 1.5041851 ,
               0.5433865 , 0.9703175 , 0.34291345, 0.34751505, 0.37807587],            dtype=float32)})

#### Case 2: evaluation mode

In [6]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=True
)
outputs["predicted_solution"].shape, outputs["metrics"]

((10, 15),
 {'mse': DeviceArray([0.13312022, 0.9188513 , 0.19137558, 1.3161011 , 0.54276913,
               0.19607523, 0.3501287 , 0.12373666, 0.1253971 , 0.13642465],            dtype=float32),
  'rmspe': DeviceArray([0.36891827, 2.5464277 , 0.5303623 , 3.6473327 , 1.5041851 ,
               0.5433865 , 0.9703175 , 0.34291345, 0.34751505, 0.37807587],            dtype=float32)})

#### Case 3: inference mode

In [None]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=False
)
outputs["predicted_solution"].shape

## Test Green's function

In [None]:
one_point_example = {}
for k, v in test_batch.items():
    if k in BATCH_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

for k, v in one_point_example.items():
    if k in COLLOCATION_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree_map(lambda x: x.shape, one_point_example)

In [None]:
@transform_and_run
def green_fn(*args):
    return modules.GreenFunction(c.green_function, gc)(*args, is_training=False)


green_fn(
    one_point_example["phase_coords"],
    one_point_example["phase_coords"],
    one_point_example
).shape

## Test Attenuation module

In [None]:
@transform_and_run
def attenuation_fn(*args):
    return modules.Attenuation(c.green_function.attenuation, gc)(*args)


coord1 = coord2 = jax.random.uniform(next(rng), [4])
att_coeff = jax.random.uniform(next(rng), [50, 3])

grid = jax.random.uniform(next(rng), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation_fn(coord1, coord2, att_coeff, char).shape

# Test Attention module

In [None]:
@transform_and_run
def attn_fn(*args):
    return modules.Attention(c.green_function.attenuation.attention, gc)(*args)


q = jax.random.uniform(next(rng), [5, 3])
k = jax.random.uniform(next(rng), [4, 3])
v = jax.random.uniform(next(rng), [4, 2])

attn_fn(q, k, v, None).shape