In [1]:
import jax
import jax.numpy as jnp
from flax import nnx

%load_ext autoreload
%autoreload 2

from deeprte.model.autoencoder import AutoEncoder
from deeprte.model.characteristics import Characteristics
from deeprte.model.config import DeepRTEConfig
from deeprte.model.deeprte import (
    Attenuation,
    DeepRTE,
    GreenFunction,
    MultiHeadAttention,
    Scattering,
)
from deeprte.model.modules import constructor

In [2]:
cfg = DeepRTEConfig()
cfg

rng = nnx.Rngs(jax.random.PRNGKey(0))

In [3]:
# test Characteristics
x, y = jnp.meshgrid(jnp.linspace(-1, 1, 10), jnp.linspace(-1, 1, 10))
# v = jnp.zeros((10, 10, 2))
print(x.shape, y.shape)
phase_coords = jnp.stack((x, y), axis=-1)
# phase_coords = jnp.concatenate([phase_coords, v], axis=-2)
phase_coords.shape
phase_coords = phase_coords.reshape(-1, 2)
charac = Characteristics.from_tensor(phase_coords)
# charac.apply_to_point(phase_coords)


(10, 10) (10, 10)


In [4]:
# test MultiHeadAttention
attn = MultiHeadAttention(
    cfg.num_heads, [4, 2, 2], cfg.qkv_dim, cfg.optical_depth_dim, rngs=rng
)
coord1 = jax.random.normal(rng.params(), (4,))
# nnx.display(attn)
out = attn(
    coord1,
    phase_coords,
    jnp.ones((100, 2)),
)


In [5]:
out.shape

(16,)

In [6]:
# test Attenuation
att = Attenuation(cfg, rngs=rng)
nnx.display(att)
out = att(coord1, jnp.ones((100, 2)), charac)


In [7]:
out.shape

(128,)

In [8]:
# test Scattering
scat = Scattering(cfg, rngs=rng)
nnx.display(scat)
coord = jax.random.normal(rng.params(), (100, 4))
self_act = jax.vmap(att, in_axes=(0, None, None))(coord, jnp.ones((100, 2)), charac)
out = scat(
    act=out,
    self_act=self_act,
    kernel=jnp.ones((100)),
    self_kernel=jnp.ones((100, 100)),
)


In [24]:
# test BoundaryBasisRepresentation
ate = AutoEncoder(cfg, rngs=rng)
nnx.display(ate)
inputs = {
    "source_coords": jnp.ones((4, 100, 4)),
    "source": jnp.ones((4, 100)),
    "source_weights": jnp.ones((4, 100)),
    "phase_coords": jnp.ones((4, 80, 4)),
}
quadratures = (
    inputs["source_coords"][0],
    inputs["source_weights"][0],
)
out = ate.encoder(inputs["source"][0], quadratures)
# out = basis(inputs)

# points, weights = quadratures
# out = jax.vmap(basis.basis_function_encoder, in_axes=(0,), out_axes=-1)(points)
# print(out.shape, weights.shape)


In [25]:
out.shape

(64,)

In [26]:
_out = ate.decoder(out, inputs["phase_coords"][0])

In [27]:
_out.shape

(80,)

In [28]:
out = ate(inputs)

In [29]:
out.shape

(4, 80)

: 

In [13]:
model = GreenFunction(cfg, rngs=rng)

inputs.update(
    {
        "position_coords": phase_coords,
        "velocity_coords": jax.random.normal(rng.params(), (24, 2)),
        "velocity_weights": jnp.ones((24)),
        "self_scattering_kernel": jnp.ones((24, 24)),
        "scattering_kernel": jnp.ones((10, 24)),
        "sigma": jnp.ones((100, 2)),
    }
)
out = model(coord1, inputs)

out.shape


(10, 128)

In [14]:
model = DeepRTE(cfg, rngs=rng)

inputs.update(
    {
        "moments": jnp.ones((128,)),
        "basis_inner_product": jnp.ones((128, 128)),
        "phase_coords": jnp.ones((10, 4)),
    }
)
expend_axes_inputs = jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), inputs)
out = model(expend_axes_inputs)


  expend_axes_inputs = jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), inputs)


In [15]:
out.shape

(1, 10)