In [None]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all]

In [None]:
from collections.abc import Iterator
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped

from differt.plotting import reuse, set_defaults
from differt.scene.sionna import get_sionna_scene
from differt.scene.triangle_scene import TriangleScene
from differt.utils import sample_points_in_bounding_box

In [None]:
set_defaults(
    "plotly"
)  # Our scene is simple, and Plotly is the best backend for online interactive plots :-)

file = get_sionna_scene("simple_street_canyon")
base_scene = TriangleScene.load_xml(file)
base_scene.plot()

In [None]:
@eqx.filter_jit
def random_tx_rx(base_scene: TriangleScene, *, key: PRNGKeyArray) -> TriangleScene:
    scene = base_scene
    bounding_box = scene.mesh.bounding_box
    min_z, max_z = bounding_box[:, 2]
    avg_z = 0.5 * (min_z + max_z)
    rx_bounding_box = bounding_box.at[1, 2].set(avg_z)
    tx_bounding_box = bounding_box.at[:, 2].set([avg_z + 5, max_z + 5])
    key_tx, key_rx = jax.random.split(key, 2)
    scene = eqx.tree_at(
        lambda s: s.transmitters,
        scene,
        sample_points_in_bounding_box(tx_bounding_box, key=key_tx),
    )
    return eqx.tree_at(
        lambda s: s.receivers,
        scene,
        sample_points_in_bounding_box(rx_bounding_box, key=key_rx),
    )


def random_scene(base_scene: TriangleScene, *, key: PRNGKeyArray) -> TriangleScene:
    scene = base_scene
    key_tx_rx, key_num_triangles, key_sample_triangles = jax.random.split(key, 3)
    scene = random_tx_rx(scene, key=key_tx_rx)
    num_triangles = scene.mesh.num_triangles
    num_triangles = jax.random.randint(
        key_num_triangles, (), num_triangles // 2, num_triangles
    )
    return eqx.tree_at(
        lambda s: s.mesh,
        scene,
        scene.mesh.sample(int(num_triangles), key=key_sample_triangles),
    )


def random_scenes(
    base_scene: TriangleScene, *, key: PRNGKeyArray
) -> Iterator[TriangleScene]:
    while True:
        key, key_to_use = jax.random.split(key, 2)

        yield random_scene(base_scene, key=key_to_use)


@jax.jit
def scene_2_sample(scene: TriangleScene) -> Float[Array, "2+num_triangles*3 3"]:
    return jnp.concatenate(
        (
            scene.transmitters.reshape(-1, 3),
            scene.receivers.reshape(-1, 3),
            scene.mesh.triangle_vertices.reshape(-1, 3),
        ),
        axis=0,
    )


key = jax.random.key(1234)

with reuse() as fig:
    scene = random_scene(base_scene, key=key)
    scene.plot()

    print(scene_2_sample(scene).shape)

    for i in range(3):
        scene.compute_paths(order=i).plot()

fig

In [None]:
samples = random_scenes(base_scene, key=key)

In [None]:
@jax.jit
@jaxtyped(typechecker=typechecker)
def reward(
    path_candidate: Int[Array, "order"],
    scene: TriangleScene,
) -> Float[Array, " "]:
    """
    Rewards a predicted path candidate depending on if it
    produces a valid path in the given scene.

    :param pred_path_candidate: The predicted path candidate.
    :param scene: The scene in which the path is traced.
    :return: The (positive) reward.
    """
    paths = scene.compute_paths(path_candidates=path_candidate.reshape(1, -1))

    return paths.mask.any().astype(jnp.float32)

In [None]:
class FlowModel(eqx.Module):
    """The flow model that returns flows between two states."""

    # Layers
    triangles_2_embeddings: eqx.nn.MLP
    """MLP that is applied to each triangle in parallel and
    returns the corresponding embeddings."""
    embeddings_2_flow: eqx.nn.MLP
    """MLP that maps each possible choice to some positive flow."""

    def __init__(
        self,
        # Hyperparameters
        num_embeddings: int = 100,
        *,
        key: PRNGKeyArray,
    ):
        """
        Constructs a GFlowNet model.

        Args:
            num_embeddings: The size of the vector that will represent each triangles.
            key: The random key to be used.
        """
        key1, key2 = jax.random.split(key, 2)

        # Layers
        self.triangles_2_embeddings = eqx.nn.MLP(
            in_size=9,
            out_size=num_embeddings,
            width_size=500,
            depth=3,
            key=key1,
        )
        self.embeddings_2_flow = eqx.nn.MLP(
            in_size=6
            + 2 * num_embeddings
            + 9,  # [tx_rx, state_embeddings, scene_embeddings, wall[i]]
            out_size="scalar",
            width_size=500,
            depth=3,
            final_activation=jnp.exp,  # Positive flow only
            key=key2,
        )

    @eqx.filter_jit
    def __call__(
        self,
        state: Float[Array, "num_triangles order"],
        triangle_index: Int[Array, " "],
        xyz: Float[Array, "2+num_triangles*3 3"],
    ) -> Float[Array, "num_triangles"]:
        """
        Call this model in order to generate a new flows from a given state,
        the last selected triangle index, and some input scene.

        Args:
            state: The current state, a one-hot encoding of the path candidate
                in construction. Only one element per column can be non-zero.
            triangle_index: The index of the last triangle that was selected. A negative index
                indicates that no triangle was previously selected.
            xyz: The array of xyz-coordinates, as returned by
                :func:`scene_2_sample`.

        Returns:
            The array of flows, one per triangle in the scene.
        """
        num_triangles, order = state.shape

        # Data normalization
        eps = 1e-5
        mean = jnp.mean(xyz, axis=0, keepdims=True)
        std = jnp.std(xyz, axis=0, keepdims=True)

        xyz = (xyz - mean) / (std + eps)

        tx_rx = xyz[:2, :].reshape(6)
        triangles = xyz[2:, :].reshape(num_triangles, 9)

        # [num_triangles 6]
        # note: this input we be the same for every triangle
        tx_rx = jnp.tile(tx_rx, (num_triangles, 1))

        # [num_triangles num_embeddings]
        # note: each triangle is mapped to a vector of embeddings
        triangle_embeddings = jax.vmap(self.triangles_2_embeddings)(triangles)

        # [num_embeddings]
        # note: the scene is the sum of all embeddings
        scene_embeddings = jnp.sum(triangle_embeddings, axis=0)

        # [num_triangles num_embeddings]
        # note: this input we be the same for every triangle
        scene_embeddings = jnp.tile(scene_embeddings, (num_triangles, 1))

        # [order]
        # note: 'fill_value=num_triangles' is important as we need to generate 'out of bounds'
        #       indices for missing values (only current_order <= order are non zero)
        wall_indices, _ = jnp.nonzero(state, size=order, fill_value=num_triangles)

        # [order num_embeddings]
        # note: we tell JAX to replace out of bounds indices with zeros,
        #       as this will have no impact on the sum (see next step)
        state_embeddings = jnp.take(
            triangle_embeddings, wall_indices, axis=0, fill_value=0
        )

        # [num_embeddings]
        # note: this contains information about the triangles we already visited,
        #       as a sum of corresponding embeddings (one triangle can appear multiple times)
        state_embeddings = jnp.sum(state_embeddings, axis=0)

        # [num_triangles num_embeddings]
        # note: this input we be the same for every triangle
        state_embeddings = jnp.tile(state_embeddings, (num_triangles, 1))

        # [num_triangles]
        # note: the input (per triangle) looks as follows
        #       # [tx_rx, state_embeddings, scene_embeddings, triangles[i]]
        flow = jax.vmap(self.embeddings_2_flow)(
            jnp.hstack((tx_rx, state_embeddings, scene_embeddings, triangles))
        )

        # Set flow[triangle_index] to zero to prevent consecutive duplicate indices
        # A flow of zero means that there is a zero probability to pick
        # trianles[triangle_index] for the next state.
        flow = flow.at[triangle_index].set(0.0)  # out of bounds indices are ignored

        return flow


class Model(eqx.Module):
    """The generative model that samples a path candidate from flows."""

    flow: FlowModel
    """The learnable flow model."""

    @eqx.filter_jit
    def __call__(
        self,
        xyz: Float[Array, "2+num_triangles*3 3"],
        *,
        order: int,
        key: PRNGKeyArray,
    ) -> Int[Array, "{order}"]:
        """
        Call this model to generate a path candidate of the given order.

        Args:
            xyz: The array of xyz-coordinates, as returned by
                :func:`scene_2_sample`.
            order: The order of the path candidate.
            key: The random key to be used.

        Returns:
            A path candidate.
        """
        num_triangles = (xyz.shape[0] - 2) // 3

        ScanR = Int[Array, " "]
        ScanC = tuple[
            Float[Array, f" {num_triangles}"],
            Float[Array, f"{num_triangles} {order}"],
        ]

        @jaxtyped(typechecker=typechecker)
        def scan_fn(
            carry: ScanC, key_and_current_order: tuple[PRNGKeyArray, Int[Array, " "]]
        ) -> tuple[ScanC, ScanR]:
            parent_edge_flow_prediction, state = carry
            key, current_order = key_and_current_order

            # Compute probability to flow to each child state
            p = parent_edge_flow_prediction / jnp.sum(parent_edge_flow_prediction)

            # Randomly choose a child state (i.e., triangle), based on 'p'
            triangle_index = jax.random.categorical(key=key, logits=jnp.log(p))

            # Update 'state' accordingly
            state = state.at[triangle_index, current_order].set(1.0)

            # Compute the new flow
            edge_flow_prediction = self.flow(state, triangle_index, xyz)

            return (edge_flow_prediction, state), triangle_index

        triangle_index = jnp.array(num_triangles)  # First index is out of bound
        state = jnp.zeros((num_triangles, order))
        parent_edge_flow_prediction = self.flow(state, triangle_index, xyz)
        init = parent_edge_flow_prediction, state
        _, pred_path_candidate = jax.lax.scan(
            scan_fn,
            init,
            xs=(jax.random.split(key, order), jnp.arange(order)),
        )

        return pred_path_candidate

    @eqx.filter_jit
    def sample_path_candidates(
        self,
        *args: Any,
        num_path_candidates: int,
        key: PRNGKeyArray,
        **kwargs: Any,
    ) -> Int[Array, "{num_path_candidates} {order}"]:
        return jax.lax.scan(
            lambda _, key: (None, self(*args, key=key, **kwargs)),
            init=None,
            xs=jax.random.split(key, num_path_candidates),
        )[1]

In [None]:
model = Model(FlowModel(key=key))

In [None]:
with reuse() as fig:
    scene = next(samples)
    scene.plot()
    xyz = scene_2_sample(scene)

    path_candidates = model.sample_path_candidates(
        xyz, order=1, num_path_candidates=30, key=key
    )
    paths = scene.compute_paths(path_candidates=path_candidates)
    #paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))
    paths.plot()

fig

In [None]:
scene.mesh.num_triangles

In [None]:
model(xyz, order=3, key=key)