In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [3]:
import functools

import haiku as hk
from haiku.testing import transform_and_run
import jax
import jax.numpy as jnp
import numpy as np

from deeprte.model.config import model_config
from deeprte.model.tf.rte_features import (
    shape,
    FEATURES,
    BATCH_FEATURE_NAMES,
    COLLOCATION_FEATURE_NAMES,
)
from deeprte.model import modules
from deeprte.model.characteristics import Characteristics

c = model_config()
gc = c.global_config

rng = hk.PRNGSequence(42)

In [4]:
get_feat_shape = functools.partial(
    shape,
    num_examples=10,
    num_position_coords=5,
    num_velocity_coords=3,
    num_phase_coords=128*20,
    num_boundary_coords=4,
)

test_batch = {}
for k, v in FEATURES.items():
    new_shape = get_feat_shape(k)
    test_batch[k] = jax.random.uniform(next(rng), new_shape)

jax.tree_map(lambda x: x.shape, test_batch),

({'boundary': (10, 4),
  'boundary_coords': (4, 4),
  'boundary_weights': (4,),
  'phase_coords': (2560, 4),
  'position_coords': (5, 2),
  'psi_label': (10, 2560),
  'scattering_kernel': (10, 2560, 3),
  'self_scattering_kernel': (10, 3, 3),
  'sigma': (10, 5, 2),
  'velocity_coords': (3, 2),
  'velocity_weights': (3,)},)

## Test DeepRTE

In [5]:
def deeprte(*args, **kwargs):
    return modules.DeepRTE(c)(*args, **kwargs)


transformed_deeprte = hk.transform(deeprte)
deeprte = transform_and_run(deeprte)

### Params

In [6]:
params = transformed_deeprte.init(
    next(rng),
    test_batch,
    is_training=True,
    compute_loss=False,
    compute_metrics=False,
)
jax.tree_map(lambda x: x.shape, params)

{'deeprte/green_function/attenuation/attention/key': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attention/output_projection': {'b': (2,),
  'w': (64, 2)},
 'deeprte/green_function/attenuation/attention/query': {'b': (64,),
  'w': (4, 64)},
 'deeprte/green_function/attenuation/attention/value': {'b': (64,),
  'w': (2, 64)},
 'deeprte/green_function/attenuation/attenuation_linear': {'b': (128,),
  'w': (10, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_1': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/attenuation_linear_2': {'b': (128,),
  'w': (128, 128)},
 'deeprte/green_function/attenuation/output_projection': {'b': (32,),
  'w': (128, 32)},
 'deeprte/green_function/output_projection': {'w': (32, 1)},
 'deeprte/green_function/scattering_module/__layer_stack_no_per_layer/scattering_layer/linear': {'b': (2,
   32),
  'w': (2, 32, 32)}}

### Apply

#### Case 1: train mode

In [7]:
total_loss, outputs = deeprte(
    test_batch, is_training=True, compute_loss=True, compute_metrics=True
)
total_loss.shape, outputs["predicted_solution"].shape, outputs[
    "loss"
], outputs["metrics"]

((),
 (10, 2560),
 {'mse': Array(0.45419142, dtype=float32),
  'rmspe': Array(1.1661558, dtype=float32)},
 {'mse': Array([0.09425422, 0.74106085, 0.3474791 , 1.5594501 , 0.6533639 ,
         0.30188802, 0.47934285, 0.09252547, 0.09372695, 0.17882262],      dtype=float32),
  'rmspe': Array([0.28221175, 2.218851  , 1.0404063 , 4.6692357 , 1.9562728 ,
         0.90389955, 1.4352269 , 0.27703562, 0.28063303, 0.5354227 ],      dtype=float32)})

#### Case 2: evaluation mode

In [8]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=True
)
outputs["predicted_solution"].shape, outputs["metrics"]

((10, 2560),
 {'mse': Array([0.09425423, 0.74105996, 0.34747884, 1.5594487 , 0.6533642 ,
         0.30188814, 0.47934175, 0.09252533, 0.09372683, 0.17882255],      dtype=float32),
  'rmspe': Array([0.2822118 , 2.2188485 , 1.0404055 , 4.6692314 , 1.9562738 ,
         0.9038999 , 1.4352236 , 0.2770352 , 0.28063267, 0.53542244],      dtype=float32)})

#### Case 3: inference mode

In [9]:
outputs = deeprte(
    test_batch, is_training=False, compute_loss=False, compute_metrics=False
)
outputs["predicted_solution"].shape

(10, 2560)

## Test Green's function

In [10]:
one_point_example = {}
for k, v in test_batch.items():
    if k in BATCH_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

for k, v in one_point_example.items():
    if k in COLLOCATION_FEATURE_NAMES:
        one_point_example[k] = v[0]
    else:
        one_point_example[k] = v

jax.tree_map(lambda x: x.shape, one_point_example)

{'boundary': (4,),
 'boundary_coords': (4, 4),
 'boundary_weights': (4,),
 'phase_coords': (4,),
 'position_coords': (5, 2),
 'psi_label': (),
 'scattering_kernel': (3,),
 'self_scattering_kernel': (3, 3),
 'sigma': (5, 2),
 'velocity_coords': (3, 2),
 'velocity_weights': (3,)}

In [11]:
@transform_and_run
def green_fn(*args):
    return modules.GreenFunction(c.green_function, gc)(*args, is_training=False)


green_fn(
    one_point_example["phase_coords"],
    one_point_example["phase_coords"],
    one_point_example
).shape

()

## Test Attenuation module

In [12]:
@transform_and_run
def attenuation_fn(*args):
    return modules.Attenuation(c.green_function.attenuation, gc)(*args)


coord1 = coord2 = jax.random.uniform(next(rng), [4])
att_coeff = jax.random.uniform(next(rng), [50, 3])

grid = jax.random.uniform(next(rng), [50, 2])
char = Characteristics.from_tensor(grid)

attenuation_fn(coord1, coord2, att_coeff, char).shape

(32,)

# Test Attention module

In [13]:
@transform_and_run
def attn_fn(*args):
    return modules.Attention(c.green_function.attenuation.attention, gc)(*args)


q = jax.random.uniform(next(rng), [3])
k = jax.random.uniform(next(rng), [4, 3])
v = jax.random.uniform(next(rng), [4, 2])

attn_fn(q, k, v, None).shape

(2,)