In [15]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all] e3x

# Sampling Path Candidates with Machine Learning

This notebook aims at being a tutorial to reproduce the results presented in the paper
*Generative Flow Sets-based $\mathrm{E}(3)$-Invariant Ray Path Sampling for Faster Point-to-Point Ray Tracing: Introduction and Concept*,
and assumes you are familiar with its content.

**You can run it locally or with Google Colab** by cliking on the rocket
at the top of this page!

:::{tip}
On Google Colab, make sure to select a GPU or TPU runtime for a faster experience.
:::

If you find this tutorial useful and plan on using this tool for your publications,
please cite our work, see {ref}`citing`.

## Summary

In our work, we present a Machine Learning (ML) model that aims at reducing the computational complexity
of exhaustive Point-to-Point (P2P) Ray Tracing (ML) by learning how to sample path candidates.
For further details, please refer to the paper.

## Imports

We need to import quite a few Python modules, but all of them should be installed with `differt[all]`.

In [29]:
from collections.abc import Iterator
from functools import partial
from typing import Any

import e3x
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from beartype import beartype as typechecker
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, jaxtyped
from tqdm.notebook import trange

from differt.plotting import reuse, set_defaults
from differt.scene.sionna import get_sionna_scene
from differt.scene.triangle_scene import TriangleScene
from differt.utils import sample_points_in_bounding_box

In [30]:
set_defaults(
    "plotly"
)  # Our scene is simple, and Plotly is the best backend for online interactive plots :-)

file = get_sionna_scene("simple_street_canyon")
base_scene = TriangleScene.load_xml(file)
base_scene = base_scene.set_assume_quads(False)
# Set the above to 'True' of you want to sample quadrilaterals instead of triangles
base_scene.plot()

In [18]:
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def random_tx_rx(
    base_scene: TriangleScene, *, key: PRNGKeyArray
) -> TriangleScene:
    """
    Return a random scene with one TX and one RX, at random positions.

    TX is placed in the "upper" part of the scene, and RX is the "lower" part.
    This is just to increase the chances of having possible ray paths in between,
    while maintaining relatively realistic positions.
    
    Args:
        base_scene: The base scene from which the random scene
            is derived.
        key: The random key to be used.

    Returns:
        A new scene.
    """
    scene = base_scene
    bounding_box = scene.mesh.bounding_box
    min_z, max_z = bounding_box[:, 2]
    avg_z = 0.5 * (min_z + max_z)
    rx_bounding_box = bounding_box.at[:, 2].set([min_z + 1.5, avg_z])
    tx_bounding_box = bounding_box.at[:, 2].set([avg_z + 5, max_z + 5])
    key_tx, key_rx = jax.random.split(key, 2)
    scene = eqx.tree_at(
        lambda s: s.transmitters,
        scene,
        sample_points_in_bounding_box(tx_bounding_box, key=key_tx),
    )
    return eqx.tree_at(
        lambda s: s.receivers,
        scene,
        sample_points_in_bounding_box(rx_bounding_box, key=key_rx),
    )


@jaxtyped(typechecker=typechecker)
def random_scene(
    base_scene: TriangleScene, *, key: PRNGKeyArray
) -> TriangleScene:
    """
    Return a random scene with one TX and one RX, at random positions, and a random number of objects.

    The number of objects can be anywhere between 1 and the total number of objects in the scene.
    
    Args:
        base_scene: The base scene from which the random scene
            is derived.
        key: The random key to be used.

    Returns:
        A new scene.
    """
    scene = base_scene
    key_tx_rx, key_num_objects, key_sample_triangles = jax.random.split(
        key, 3
    )
    scene = random_tx_rx(scene, key=key_tx_rx)
    num_objects = scene.mesh.num_objects
    num_objects = jax.random.randint(
        key_num_objects, (), 1, num_objects + 1
    )
    return eqx.tree_at(
        lambda s: s.mesh,
        scene,
        scene.mesh.sample(int(num_objects), key=key_sample_triangles),
    )


def random_scenes(
    base_scene: TriangleScene, *, key: PRNGKeyArray
) -> Iterator[TriangleScene]:
    """
    Return an (infinite) interator over random scenes.
    
    Args:
        base_scene: The base scene from which random scenes
            are derived.
        key: The random key to be used.

            For each new scene, a new key will be derived from the
            one provided.

    Yields:
        An infinite number of random scenes.
    """
    while True:
        key, key_to_use = jax.random.split(key, 2)

        yield random_scene(base_scene, key=key_to_use)

In [19]:
key_example = jax.random.key(1234)

with reuse() as fig:
    scene = random_tx_rx(base_scene, key=key_example)
    scene.plot(showlegend=False)

    for i in range(3):
        paths = scene.compute_paths(order=i)
        print(
            f"(order = {i}) Found {int(paths.mask.sum()):2d} valid paths our of {paths.mask.size:4d} path candidates."
        )
        paths.plot(showlegend=False)

fig

(order = 0) Found  1 valid paths our of    1 path candidates.
(order = 1) Found  1 valid paths our of   74 path candidates.
(order = 2) Found  0 valid paths our of 5402 path candidates.


In [20]:
key_samples = jax.random.key(1234)

samples = random_scenes(base_scene, key=key_samples)

In [21]:
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def reward(
    path_candidate: Int[Array, "order"],
    scene: TriangleScene,
) -> Float[Array, " "]:
    """
    Reward a predicted path candidate depending on if it
    produces a valid path in the given scene.

    If the scene contains multiple TXs and RXs, then
    the maximum number of valid paths, i.e., 'max_valid_paths'
    is the number of TXs times the number of RXs.

    Args:
        path_candidate: The path candidate to evaluate.
        scene: The scene on which to evaluate the path candidate.

    Return:
        A possible reward, between 0 and 'max_valid_paths'.
    """
    paths = scene.compute_paths(path_candidates=path_candidate.reshape(1, -1))

    return paths.mask.astype(jnp.float32).sum()

In [67]:
class FlowModel(eqx.Module):
    """The flow model that returns flows between two states."""
    
    # Layers
    cell: eqx.nn.GRUCell
    features_2_flow: eqx.nn.MLP

    def __init__(
        self,
        # Hyperparameters
        num_features: int = 100,
        *,
        key: PRNGKeyArray,
    ):
        """
        Construct a GFlowNet model.

        Args:
            num_features: The size of the vector that represents each object.
            key: The random key to be used.
        """
        key_cell, key_flow = jax.random.split(key, 2)

        del key
        
        # 1 - Combine arbitrary many features from an (ordered) state and create a new feature vector
        self.cell = eqx.nn.GRUCell(
            input_size=num_features,
            hidden_size=num_features,
            key=key_cell,
        )
        
        del key_cell

        # 2 - Features to flow
        self.features_2_flow = eqx.nn.MLP(
            # Takes scene_features + obj_features[i] + state_features
            in_size=num_features * 3,
            out_size="scalar",
            width_size=500,
            depth=3,
            final_activation=jnp.exp,  # Positive flow only
            key=key_flow,
        )

        del key_flow

    @eqx.filter_jit
    @jaxtyped(typechecker=typechecker)
    def __call__(
        self,
        scene_features: Float[Array, " num_features"],
        object_features: Float[Array, "num_objects num_features"],
        state: Bool[Array, "num_objects order"],
    ) -> Float[Array, " num_objects"]:
        """
        Call this model in order to generate a new flow from a given state,
        and an abstract representation of the scene and its objects.

        Invariance with respect to E(3) should be guaranteed by the callee,
        i.e., the arguments passed to this function are treated regardless of
        any geometrical property.

        Args:
            scene_features: The vector of scene features.
            object_features: The array of each object's features.
            state: The current state, a one-hot encoding of the path candidate
                in construction. Only one element per column can be True.

        Returns:
            The array of flows, one per object.
        """
        num_objects, order = state.shape
        
        # [order]
        # note: 'fill_value=num_objects' is important as we need to generate 'out of bounds'
        #       indices for missing values (only current_order <= order are non zero)
        object_indices, _ = jnp.nonzero(
            state, size=order, fill_value=num_objects,
        )
        # [order num_features]
        # note: we tell JAX to replace out of bounds indices with zeros,
        #       as this will have no impact on the sum (see next step),
        #       and this contains information about the objects we already visited,
        #       as a sum of corresponding features (one object can appear multiple times)
        state_features = jnp.take(
            object_features, object_indices, axis=0, fill_value=0
        )

        # [order num_features]
        # note: we transform 'order' vectors of 'num_features' values into
        #       one vector of 'num_features' values.
        #       We could have used the 'sum' operation (over the leading axis),
        #       but then we loose information about the order in which features
        #       appeared. A cell can keep some information about ordering since
        #       cell(cell(a, b), c) isn't the same as, e.g., cell(cell(b, a), c).
        state_features = jax.lax.scan(
            lambda state, features: (self.cell(features, state), None),
            jnp.zeros_like(scene_features),
            state_features,
        )[0]

        # [num_objects]
        # note: the input (per object) looks as follows
        #       # [scene_features, state_features, object_features[i]]
        flow = jax.vmap(
            lambda *features: self.features_2_flow(
                jnp.concatenate(features)
            ),
            in_axes=(None, 0, None),
        )(
            scene_features,
            object_features,
            state_features,
        )

        return flow


class Model(eqx.Module):
    """The generative model that samples a path candidate from flows."""

    xyz_2_features: eqx.nn.MLP
    """The MLP that maps xyz coordinates to a larger feature vectors."""
    flow: FlowModel
    """The flow model."""
    inference: bool
    """Whether this model is used for training or inference."""

    def __init__(
        self,
        # Hyperparameters
        num_features: int = 100,
        hidden_size: int = 500,
        num_hidden_layers: int  = 3,
        *,
        inference: bool = False,
        key: PRNGKeyArray,
    ):
        """
        Construct a model.

        Args:
            num_features: The size of the vector that will represent each xyz triplet.
            hidden_size: The size hidden layers to use in the MLP.
            num_hidden_layers: The number of hidden layer to use in the MLP.
            key: The random key to be used.
        """
        self.inference = inference
        
        key_embeds, key_flow = jax.random.split(key, 2)

        del key

        # Layers

        # 1 - World coordinates to features
        self.xyz_2_features = eqx.nn.MLP(
            in_size=3,
            out_size=num_features,
            width_size=hidden_size,
            depth=num_hidden_layers,
            key=key_embeds,
        )
        
        del key_embeds

        # 2 - Generate flow
        self.flow = FlowModel(
            num_features=num_features,
            key=key_flow,
        )

        del key_flow

    @eqx.filter_jit
    @jaxtyped(typechecker=typechecker)
    def __call__(
        self,
        scene: TriangleScene,
        *,
        order: int,
        key: PRNGKeyArray,
    ) -> Int[Array, "{order}"] | tuple[Int[Array, "{order}"], Float[Array, "{order}+1 num_objects"]]:
        """
        Call this model to generate a path candidate of the given order.

        Args:
            scene: The triangle scene.
            order: The order of the path candidate.
            key: The random key to be used.

        Returns:
            A path candidate.

            If training mode, it also returns the array of flows.
        """
         # 1 - Extracting data

        num_objects = scene.mesh.num_objects
        num_triangles = scene.mesh.num_triangles

        tx = scene.transmitters.reshape(-1, 3)
        num_tx = tx.shape[0]
        rx = scene.receivers.reshape(-1, 3)
        num_rx = rx.shape[0]
        flat_tri = scene.mesh.triangle_vertices.reshape(-1, 3)

        # [num_tx+num_rx+3*num_triangles 3]
        xyz = jnp.concatenate((tx, rx, flat_tri), axis=0)

        # 2 -  Data normalization (for translation/scaling invariance)
        
        eps = jnp.finfo(xyz.dtype).eps  # Not sure this is needed in practice, but avoids div by zero
        mean = jnp.mean(xyz, axis=0, keepdims=True)
        std = jnp.std(xyz, axis=0, keepdims=True)

        # [num_tx+num_rx+3*num_triangles 3]
        xyz = (xyz - mean) / (std + eps)

        # 3 - Computing features

        # TODO: I would like to use the e3x module for E(3) invariance,
        # but it doesn't use the same ML framework as I use here, so I need
        # to check how to implement this.
        
        # [num_tx+num_rx+3*num_triangles 1 (max_degree+1)**2 num_features]
        # xyz_features = e3x.nn.basis(
        #     xyz,
        #     num=num_features,
        #     max_degree=2,
        #     radial_fn=partial(e3x.nn.triangular_window, limit=2.0),
        # )

        # [num_tx+num_rx+num_triangles num_features]
        xyz_features = jax.vmap(self.xyz_2_features)(xyz)

        # [num_features]
        scene_features = xyz_features.sum(axis=0)

        # [num_triangles num_features]
        tri_features = xyz_features[num_tx+num_rx:, ...].reshape(num_triangles, 3, *xyz_features.shape[1:]).sum(axis=1)

        # [num_objects 1 (max_degree+1)**2 num_features]
        # note: this is a no-op is 'scene.mesh.assume_quads' is False
        object_features = tri_features.reshape(num_objects, -1, *tri_features.shape[1:]).sum(axis=1)

        # 4 - Generating path candidates

        ScanR = Int[Array, " "]
        Flow = Float[Array, f" {num_objects}"]
        State = Bool[Array, f"{num_objects} {order}"]
        ScanC = tuple[
            Flow,
            State,
        ]

        if not self.inference:
            ScanR = tuple[ScanR, Flow]

        @jaxtyped(typechecker=typechecker)
        def scan_fn(
            carry: ScanC,
            key_and_current_order: tuple[PRNGKeyArray, Int[Array, " "]],
        ) -> tuple[ScanC, ScanR]:
            parent_edge_flow_prediction, state = carry
            key, current_order = key_and_current_order

            # Compute probability to flow to each child state
            p = parent_edge_flow_prediction / jnp.sum(
                parent_edge_flow_prediction
            )

            # Randomly choose a child state (i.e., triangle), based on 'p'
            object_index = jax.random.categorical(key=key, logits=jnp.log(p))

            # Update 'state' accordingly
            state = state.at[object_index, current_order].set(True)

            # Compute the new flow
            edge_flow_prediction = self.flow(scene_features, object_features, state)

            # Set flow[object_index] to zero to prevent consecutive duplicate indices
            # A flow of zero means that there is a zero probability to pick
            # objects[object_index] for the next state.
            edge_flow_prediction = edge_flow_prediction.at[object_index].set(
                0.0
            )  # out of bounds indices are ignored

            if self.inference:
                return (edge_flow_prediction, state), object_index
            else:
                return (edge_flow_prediction, state), (object_index, edge_flow_prediction)

        # Let's initialize the variables
        state = jnp.zeros((num_objects, order), dtype=bool)
        parent_edge_flow_prediction = self.flow(scene_features, object_features, state)

        # We use scan to efficiently loop over multiple states
        init = parent_edge_flow_prediction, state
        _, path_candidate = jax.lax.scan(
            scan_fn,
            init,
            xs=(jax.random.split(key, order), jnp.arange(order)),
        )

        if not self.inference:
            path_candidate, flows = path_candidate
            flows = jnp.vstack((parent_edge_flow_prediction, flows))

        if scene.mesh.assume_quads:
            return path_candidate * 2  # We map quad. indices to triangles indices

        if not self.inference:
            return path_candidate, flows
            
        return path_candidate

    @eqx.filter_jit
    @jaxtyped(typechecker=typechecker)
    def sample_path_candidates(
        self,
        *args: Any,
        num_path_candidates: int,
        order: int,
        key: PRNGKeyArray,
        **kwargs: Any,
    ) -> Int[Array, "{num_path_candidates} {order}"]:
        """
        Sample 'num_path_candidates' in a given scene.

        This is a convenient wrapper samples arbitrary many path candidates,
        and has the same output in both training and inference modes.
        """
        @jaxtyped(typechecker=typechecker)
        def scan_fun(_: None, key: PRNGKeyArray) -> tuple[None, Int[Array, f" {order}"]]:
            path_candidate = self(*args, order=order, key=key, **kwargs)

            if self.inference:
                return None, path_candidate

            return None, path_candidate[0]

        return jax.lax.scan(
            scan_fun,
            init=None,
            xs=jax.random.split(key, num_path_candidates),
        )[1]

In [68]:
key_model = jax.random.key(1234)

untrained_model = Model(key=key_model)

In [69]:
key_sample_untrained = jax.random.key(1234)

with reuse() as fig:
    scene = next(samples)
    scene.plot()

    path_candidates = untrained_model.sample_path_candidates(
        scene, order=1, num_path_candidates=1, key=key_sample_untrained
    )
    paths = scene.compute_paths(path_candidates=path_candidates)
    paths.plot()

fig

In [13]:
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def loss(
    model: Model,
    scene: TriangleScene,
    num_path_candidates: int = 20,
    num_tx_rx_pairs: int = 50,
    *,
    order: int,
    key: PRNGKeyArray,
) -> Float[Array, " "]:
    """
    Compute the loss of the model on a specific input scene.

    The loss is accumulated over the generation of many path candidates,
    repeated on for multiple TX / RX positions.
    """
    assert not model.inference, "Model cannot estimate loss in 'inference' mode."

    if scene.mesh.assume_quads:
        num_objects = scene.mesh.num_quads
    else:
        num_objects = scene.mesh.num_triangles

    ScanR = Int[Array, " "]
    ScanC = tuple[
        Float[Array, " "],
        Float[Array, f" {num_objects}"],
        Float[Array, f"{num_objects} {order}"],
    ]

    @jaxtyped
    def compute_flow_mismatch(model: Model, key: PRNGKeyArray, order: int) -> :
        path_candidate, flows = model(order=order, key=key)

        parent_flows = flows[:-1, :]
        edge_flows = flows[1:, :]

        flow_mismatch = jnp.where(

        )
        

    @jaxtyped(typechecker=typechecker)
    def scan_fn(
        carry: ScanC,
        key_and_current_order: tuple[PRNGKeyArray, Int[Array, " "]],
    ) -> tuple[ScanC, ScanR]:
        # We carry, for the currently generated path candidate:
        # - the flow mismatch (parent edges flow - current flow)
        # - the parent edge flow (only one parent leads to the current choice)
        # - the state (path candidate using one-hot encoding)
        flow_mismatch, parent_edge_flow_prediction, state = carry
        key, current_order = key_and_current_order

        # Turn positive flow into normalized probability in [0, 1]
        p = parent_edge_flow_prediction / jnp.sum(parent_edge_flow_prediction)

        object_index = jax.random.categorical(
            key=key, logits=jnp.log(p)
        )  # The triangle to choose

        # Indicate we have chosen triangles[object_index] as a candidate at 'current_order'
        state = state.at[object_index, current_order].set(1.0)

        edge_flow_prediction = flow_model(state, object_index, scene)

        pred_path_candidate, _ = jnp.nonzero(
            state, size=order, fill_value=num_objects
        )

        flow_mismatch += jnp.where(
            current_order == order - 1,  # Check whether we reached final state
            (  # Reached last state so (next) edge_flow_prediction is ignored
                parent_edge_flow_prediction[
                    object_index
                ]  # Each state s' has only one possible parent state s
                - reward(pred_path_candidate, scene)
            )
            ** 2,
            (  # Didn't reach last state so no reward
                parent_edge_flow_prediction[
                    object_index
                ]  # Each state s' has only one possible parent state s
                - jnp.sum(edge_flow_prediction)
            )
            ** 2,
        )

        return (flow_mismatch, edge_flow_prediction, state), object_index

    BatchC = Float[Array, " "]
    BatchR = Int[Array, f" {order}"]

    @jaxtyped(typechecker=typechecker)
    def batch_path_candidates(
        batch_loss: BatchC, key: PRNGKeyArray,
    ) -> tuple[BatchC, BatchR]:
        key_pred, key_scene = jax.random.split(key, 2)
        del key
        new_scene = random_tx_rx(scene, key=key_scene)
        flow_mismatch = jnp.array(0.0)
        object_index = jnp.array(
            num_objects
        )  # We didn't select any wall yet: out of bounds index
        state = jnp.zeros(
            (num_objects, order)
        )  # Same, but represented in one-hot encoding
        parent_edge_flow_prediction = flow_model(
            state, object_index, new_scene
        )  # Initial state's flow
        init = flow_mismatch, parent_edge_flow_prediction, state
        # Scan accumulates the mismatch and generated the path candidate
        (flow_mismatch, *_), pred_path_candidate = jax.lax.scan(
            scan_fn,
            init,
            xs=(jax.random.split(key_pred, order), jnp.arange(order)),
        )

        return batch_loss + flow_mismatch, pred_path_candidate

    @jaxtyped(typechecker=typechecker)
    def batch_tx_rx_positions(
            batch_loss: BatchC, key: PRNGKeyArray
        ) -> tuple[BatchC, None]:
        batch_loss, _ = jax.lax.scan(
            batch_path_candidates, batch_loss, xs=jax.random.split(key, num_path_candidates),
        )
        return batch_loss, None
        

    batch_loss = jnp.array(0.0)
    batch_loss, _ = jax.lax.scan(
        batch_tx_rx_positions, batch_loss, xs=jax.random.split(key, num_tx_rx_pairs)
    )

    return batch_loss

In [14]:
key_untrained_loss = jax.random.key(1234)

loss(untrained_model, scene, order=1, key=key_untrained_loss)

TypeCheckError: Type-check error whilst checking the parameters of __main__.FlowModel.__call__.
The problem arose whilst typechecking parameter 'scene_features'.
Actual value: f32[17,1]
Expected type: <class 'Float[Array, 'num_features']'>.
----------------------
Called with parameters: {
  'self':
  FlowModel(...),
  'scene_features':
  f32[17,1],
  'object_features':
  i32[],
  'state':
  TriangleScene(
    transmitters=f32[3],
    receivers=f32[3],
    mesh=TriangleMesh(
      vertices=f32[88,3],
      triangles=i32[17,3],
      face_colors=f32[17,3],
      face_materials=i32[17],
      material_names=(
        'mat-itu_wood',
        'mat-itu_glass',
        'mat-itu_marble',
        'mat-itu_brick',
        'mat-itu_concrete'
      ),
      object_bounds=None,
      assume_quads=False
    )
  )
}
Parameter annotations: (self, scene_features: Float[Array, 'num_features'], object_features: Float[Array, 'num_objects num_features'], state: Float[Array, 'num_objects order']) -> Any.


In [None]:
@jaxtyped(typechecker=typechecker)
def train(
    model: Model,
    train_samples: Iterator[TriangleScene],
    optim: optax.GradientTransformation,
    steps: int = 1_000,
    print_every: int = 100,
    *,
    order: int,
    key: PRNGKeyArray,
) -> tuple[
    Model,
    Float[Array, "{steps}"]
]:
    """
    Train a model on a sequence of training samples and returns the training loss average of samples.

    :param model: The model to train.
    :param train_samples: The training samples.
    :param val_samples: The validation samples.
    :param optim: The optimizer to use.
    :param steps: The number of optimization steps.
    :param print_every: The frequency at which the average loss is computed.
    :param order: The order of the paths to be trained on.
    :param key: The random key to be used.
    :return: The trained model, the steps,
        the train losses and the validation losses.
    """
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def make_step(
        model: Model,
        opt_state: optax.OptState,
        scene: TriangleScene,
        *,
        key: PRNGKeyArray,
    ) -> tuple[Model, optax.OptState, Float[Array, " "]]:
        loss_value, grads = eqx.filter_value_and_grad(loss)(
            model, scene, order=order, key=key,
        )
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    train_losses = []

    print_width = len(str(steps))

    with trange(steps, desc="", unit=" steps", leave=True) as bar:
        for (
            step,
            scene,
        ) in zip(bar,train_samples):
            key, key_step = jax.random.split(key, 2)

            model, opt_state, train_loss = make_step(
                model, opt_state, scene, key=key_step
            )
            train_losses.append(train_loss)

            if (step % print_every) == 0 or (step == steps - 1):

                bar.set_description(
                    f"Train loss @ iter. #{step:0{print_width}d} is {float(train_loss):.1f}"
                )

    return model, jnp.array(train_losses)

In [None]:
key_train = jax.random.key(1234)
key_train_model, key_train_samples = jax.random.split(key_train, 2)
train_samples = random_scenes(base_scene, key=key_train_samples)
optim = optax.adam(learning_rate=3e-5)

trained_model, train_losses = train(untrained_model, train_samples, optim, order=1, key=key_train_model)

In [None]:
plt.semilogy(train_losses)

In [None]:
key_eval_trained = jax.random.key(1234)

key_sample_trained, key_eval_scene = jax.random.split(key_eval_trained, 2)

with reuse() as fig:
    scene = random_scene(base_scene, key=key_eval_scene)
    scene.plot()

    path_candidates = trained_model.sample_path_candidates(
        scene, order=1, num_path_candidates=30, key=key_sample_trained
    )
    paths = scene.compute_paths(path_candidates=path_candidates)
    # paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))
    paths.plot()

fig

In [None]:
print("Base scene prediction:\n", trained_model.sample_path_candidates(
        base_scene, order=2, num_path_candidates=10, key=key_sample_untrained
).tolist())
key_rotation = jax.random.key(1234)
rotation = e3x.so3.random_rotation(key_rotation)
rotated_scene = base_scene.rotate(rotation)
print("Rotated scene prediction:\n", trained_model.sample_path_candidates(
        rotated_scene, order=2, num_path_candidates=10, key=key_sample_untrained
).tolist())