In [None]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all]

(advanced_path_tracing)=
# Advanced Path Tracing

Differt provides both high-level and low-level interface to Path Tracing.

This tutorial provides a quick tour of what you can do with the **lower-level** API,
and the logic used to perform Ray Tracing (RT).

## Example on a simple scene

Before diving into a complex scene, this is worth using a very simple scene first.

:::{note}
All the logic presented in this section is contained in the
{meth}`TriangleScene.compute_paths<differt.scene.TriangleScene.compute_paths>` method.

It also contains more post-processing steps to avoid degenerate solutions,
and optimized routines,
but we omitted them here.
:::

### Necessary imports

Because we are going for the lower-level way, we will need quite a few imports.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Bool, Float

import differt.plotting as dplt
from differt.geometry import (
    TriangleMesh,
    assemble_paths,
    fibonacci_lattice,
    triangles_contain_vertices_assuming_inside_same_plane,
    viewing_frustum,
)
from differt.rt import (
    consecutive_vertices_are_on_same_side_of_mirrors,
    first_triangles_hit_by_rays,
    generate_all_path_candidates,
    image_method,
    image_of_vertices_with_respect_to_mirrors,
    rays_intersect_triangles,
)

### Loading a mesh

For each type of supported meshes, we provide some utilities
to load a mesh from a file.

In [None]:
mesh_file = "two_buildings.obj"  # Very simple scene with two buildings
mesh = TriangleMesh.load_obj(mesh_file)

### Plotting your setup

Here, we will use Plotly as the plotting backend, because it renders very nicely, especially on the web.
On larger scenes, you will likely need something more performant, like Vispy, see [Choosing your plotting backend](plotting_backend.ipynb#choosing-your-plotting-backend).

In [None]:
dplt.set_defaults("plotly")  # Let's use the Plotly backend

fig = mesh.plot(opacity=0.5)
fig

Ray Tracing without start and end points is not very interesting.
Let's add one transmitter (TX) and one receiver (RX) in the scene, represented by their 3D coordinates.

In [None]:
tx = jnp.array([0.0, 4.9352, 22.0])
rx = jnp.array([0.0, 10.034, 1.50])

dplt.draw_markers(
    np.array([tx, rx]), labels=["tx", "rx"], figure=fig, name="nodes"
)

### How we trace rays

Ray Tracing can be implemented in many ways, depending on the desired performances, the level of accuracy needed,
or the representation of the geometry.

Here, we will implement exhaustive (also referred to as *deterministic* or *exact*) RT. That is, we want to generate all possible paths from TX to RX, that undergo up to a maximum number of interactions with the environment. Interactions can be reflections, diffractions, etc.

One way to generate all possible paths is to represent the problem as a graph. Then, the goal is to find all the paths from the node corresponding to TX, to the node corresponding to RX, while possibly visiting intermediate nodes in the graph, where each corresponds to a specific primitive—or object—in the scene (here, a triangle).

A graph algorithm will therefore generate a list of *path candidates*. We use the word *candidate* to emphasize that this is not a real path (i.e., not 3D coordinates), but only an ordered list of nodes to visit, for a given path.

Then, this is the role of the **path tracing** method (e.g., {func}`image_method<differt.rt.image_method>` or {func}`fermat_path_on_planar_mirrors<differt.rt.fermat_path_on_planar_mirrors>`) to determine the exact coordinates of that path.

Let's select a subset of our primitives to understand what we have just talked about.

In [None]:
select = [
    8,  # Red
    9,  # Red
    22,  # Green
    23,  # Green
]  # In practice, you will never hard-code the primitive indices yourself

vertices = mesh.vertices
triangles = mesh.triangles[select, :]

dplt.draw_mesh(vertices, triangles[:2, :], figure=fig, color="red")
dplt.draw_mesh(vertices, triangles[2:, :], figure=fig, color="green")

Looking at the above, we can clearly see that a line-of-sight (LOS) path between TX and RX exists.

With a bit of thinking, we could also imagine that a path with one or more reflections might join TX and RX.

For example, <kbd>TX -> Red surface -> RX</kbd> might probably produce a valid path.
The same logic can be applied to <kbd>TX -> Red surface -> Green surface -> RX</kbd>.

In [None]:
# A list of color to easily differentiate paths
color = ["black", "green", "orange", "yellow", "blue"]

select = jnp.array(
    select[::2],
    dtype=int,
)  # We actually only need one triangle per plane, so [8, 22]

# Iterate through path candidates
#
#                         ┌> order 0
#                         |           ┌> order 1
#                         |           |           ┌> order 2
for path_candidate in [select[:0], select[:1], select[:2]]:
    # 1 - Prepare input arrays
    mirror_vertices = mesh.vertices[mesh.triangles[path_candidate, 0], :]
    mirror_normals = mesh.normals[path_candidate, :]

    # 2 - Trace paths

    path = image_method(
        tx, rx, mirror_vertices, mirror_normals
    )

    # 3 - ??

    # 4 - Obtain final valid paths and plot

    # The full path is [tx, paths, rx]
    full_path = jnp.concatenate(
        (
            tx[None, :],
            path,
            rx[None, :],
        ),
    )

    # The we plot it
    dplt.draw_paths(
        full_path,
        figure=fig,
        marker={
            "size": 0,
            "color": "black",
        },
        line={"color": color[len(path_candidate)], "width": 3},
        name=f"Order {len(path_candidate)}",
    )

fig

Nice! Thanks to the {func}`image_method<differt.rt.image_method>`, we successfully generated the paths we just mentioned.

### Scaling on more paths and more surfaces

Manually identifying the surfaces of interest and generating all possible path candidates can rapidly become tedious as the number of surfaces or the path order increase.

For this purpose, we created the {func}`generate_all_path_candidates<differt.rt.generate_all_path_candidates>` function. Written in Rust for performance purposes, this function can generate millions of path candidates per second!

This is all nice, but there is one important side-effect of this: if you generate all possible path candidates, how to remove invalid paths that may, e.g., cross a building?

This is where our third step comes into play: we need to validate our path against a series of checks. We can usually identify three types of checks:


1. **Are path coordinates within the boundary of their respective objects?** Manytimes, the objects are assumed to be infinitely long. Then, a check is performed to verify if the solution was found within the object's boundaries;
2. **Are all interactions valid?** E.g., do all reflections occur with an angle of reflection equal to the ange of incidence? Most path tracing methods have some faillible cases where it can return degenerate solutions;
3. **Does any object in the scene obstruct the path?** Usually, the path is first computed without taking the surrounding objects into account, which produce paths that buildings.

A possible implementation of the above rules, applied to the {func}`image_method<differt.rt.image_method>`, is provided below. A lot of the code is just broadcasting arrays into the right shapes, to benefit from the vectorized computations on arrays, i.e., instead of using *slow* Python for-loops.

In [None]:
fig.data = fig.data[:2]  # Keep only first 2 traces: geometry and TX/RX

# [num_triangles 3 3]
all_triangle_vertices = mesh.triangle_vertices

num_triangles = mesh.num_triangles

for order in range(5):
    # 1 - Prepare input arrays

    # [num_path_candidates order]
    path_candidates = generate_all_path_candidates(num_triangles, order)
    num_path_candidates = path_candidates.shape[0]

    # [num_path_candidates order 3]
    triangles = jnp.take(mesh.triangles, path_candidates, axis=0)

    # [num_path_candidates order 3 3]
    triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)

    # [num_path_candidates order 3]
    mirror_vertices = triangle_vertices[
        ...,
        0,
        :,
    ]  # Only one vertex per triangle is needed
    # [num_path_candidates order 3]
    mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

    # 2 - Trace paths

    # [num_path_candidates order 3]
    paths = image_method(
        tx, rx, mirror_vertices, mirror_normals
    )

    # 3 - Remove invalid paths

    # 3.1 - Remove paths with vertices outside triangles
    # [num_path_candidates order]
    mask = triangles_contain_vertices_assuming_inside_same_plane(
        triangle_vertices,
        paths,
    )
    # [num_path_candidates]
    mask = jnp.all(mask, axis=-1)

    # [num_paths_inter order+2 3]
    full_paths = assemble_paths(
        tx[None, None, :],
        paths[mask, ...],
        rx[None, None, :],
    )
    # 3.2 - Remove paths with vertices not on the same side of mirrors
    # [num_paths_inter order]
    mask = consecutive_vertices_are_on_same_side_of_mirrors(
        full_paths,
        mirror_vertices[mask, ...],
        mirror_normals[mask, ...],
    )

    # [num_paths_inter]
    mask = jnp.all(mask, axis=-1)  # We will actually remove them later

    # 3.3 - Remove paths that are obstructed by other objects
    # [num_paths_inter order+1 3]
    ray_origins = full_paths[..., :-1, :]
    # [num_paths_inter order+1 3]
    ray_directions = jnp.diff(full_paths, axis=-2)

    # [num_paths_inter order+1 num_triangles], [num_paths_inter order+1 num_triangles]
    t, hit = rays_intersect_triangles(
        ray_origins[..., None, :],
        ray_directions[..., None, :],
        all_triangle_vertices[None, None, ...],
    )
    # In theory, we could do t < 1.0 (because t == 1.0 means we are perfectly on a surface,
    # which is probably desirable, e.g., from a reflection) but in practice numerical
    # errors accumulate and will make this check impossible.
    # [num_paths_inter order+1 num_triangles]
    intersect = (t < 0.999) & hit
    #  [num_paths_inter]
    intersect = jnp.any(intersect, axis=(-1, -2))
    #  [num_paths_inter]
    mask = mask & ~intersect

    # 4 - Obtain final valid paths and plot

    #  [num_paths_final]
    full_paths = full_paths[mask, ...]

    dplt.draw_paths(
        full_paths,
        figure=fig,
        marker={
            "size": 0,
            "color": "black",
        },
        line={"color": color[order], "width": 3},
        name=f"Order {order}",
    )

fig

Another path tracing method that is fully compatible with the above cell is {func}`fermat_path_on_planar_mirrors<differt.rt.fermat_path_on_planar_mirrors>`. You can safely use the latter instead of the former, and that should produce the same result. Note that the Fermat path tracing is much slower than the Image method, but can be applied to other types of interactions than just pure specular reflection. This is left as an exercise to the reader.

## Example on more complex scenes

Most of the code we presented so far scales pretty well on larger scenes. However, there are is notable
exception: {func}`generate_all_path_candidates<differt.rt.generate_all_path_candidates>`.

With a bit of maths[^1], we can determine that a call to `generate_all_path_candidates(num_triangles, order)` generates an array of size $\texttt{num_triangles}(\texttt{num_triangles}-1)^{\texttt{order}-1} \times \texttt{order}$.

On scenes with many triangles, this rapidly becomes too big to fit in any computer memory. To circumvent this issue, we also provide an iterator variant, {func}`generate_all_path_candidates_chunks_iter<differt.rt.generate_all_path_candidates_chunks_iter>`, that produces arrays of a smaller size, defined by the `chunk_size` argument.

While this offers a solution to the memory allocation issue, this does not reduce the number of path candidates. To reduce this number, you somehow need to prune a subset of the path candidates before you actually generate them.

Recalling the graph analogy we mentioned above, we can implement this behavior by disconnecting some primitives (i.e., triangles) in the graph.
There is no unique solution to this challenge, but we provide a small utility to estimate the visibility matrix between objects in a given scene: {func}`triangles_visible_from_vertices<differt.rt.triangles_visible_from_vertices>`.

Then, from this visibility matrix, which is actually just an adjacency matrix of the nodes in the graph,
we can instantiate a {class}`DiGraph<differt_core.rt.DiGraph>` from the {mod}`differt_core.rt` module.

[^1]: The first node to be visited can be any of the `num_triangles` nodes. For the next nodes, we will have to choose between `num_triangles - 1` nodes, because we do not allow for loops (i.e., cycles of unit length) in our graph.

### Numbers getting crazy

To illustrate what we said above, we will load a much larger scene that contains quite a few objects, i.e., triangles.

A transmitter and a receiver are placed in the scene as example positions.

In [None]:
from differt.scene import TriangleScene

mesh_file = "bruxelles.obj"
mesh = TriangleMesh.load_obj(mesh_file)

tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])

scene = TriangleScene(transmitters=tx, receivers=rx, mesh=mesh)
scene.plot(backend="vispy")

In [None]:
mesh.num_objects  # This is the number of triangles

This number isn't actually that big, and can easily reach above a million on large cities.
However, it is large enough to present serious challenges when it comes to performing exhaustive RT.

Using the core library, we can compute the exact number of path candidates one would have to try for a given
number of interactions.

In [None]:
from differt_core.rt import CompleteGraph

graph = CompleteGraph(mesh.num_objects)

from_ = graph.num_nodes  # Index of TX in the graph
to = from_ + 1  # Index of RX in the graph
order = 2  # Number of interactions
depth = order + 2  # + 2 because we add TX and RX notes

f"{len(graph.all_paths(from_, to, depth)):.3e}"

That means that there are over 200 million second order reflection paths to test... We need to reduce that number!

### Assuming quadrilaterals

In many cases, a scene is simply a collection of quadrilaterals, that are each split into
two triangles. This is not alwarys true, and probably not the case for our scene, but we
will assume it is.

Using {func}`set_assume_quads<differt.geometry.TriangleMesh.set_assume_quads>`, the
mesh will now tell all other function that it should use, when available, optimized routines for
quadrilateral facets.

In [None]:
mesh = mesh.set_assume_quads(True)
mesh.num_objects  # This is now the number of quadrilaterals, exactly half the number of triangles

Again, we can compute the number of path candidates, and see that it is reduced almost by a factor 4.

In general, the reduction factor is nearly $2^\texttt{order}$.

In [None]:
graph = CompleteGraph(mesh.num_objects)

from_ = graph.num_nodes
to = from_ + 1
order = 2

f"{len(graph.all_paths(from_, to, depth)):.3e}"  # Roughly a quarter of the preview length

### Determining TX's visibility

Another way to reduce the number of path candidates is to indicate to the graph
that TX cannot reach all objects in the scene, but only a subset of the objects.

Such information can be obtained by estimating the visibility vector of some TX,
and use it when creating the path candidates iterator.

If one knows the location of the receiving antenna, a similar logic can be used
to compute the ``to_adjacency`` vector, which is also a visibility vector, but from RX.

On the other hand, if the mesh is fixed but the TX / RX are not, it is also possible to compute
the visibility vector of each triangle in the scene, thereby constructing the visibility
matrix of the scene, and use it to construct the graph with
{meth}`DiGraph.from_adjacency_matrix<differt_core.rt.DiGraph.from_adjacency_matrix>`.
As computing such matrix can be extremily expensive, it is recommended to perform that
as a pre-precessing step and save the resulting matrix in a file.

The code below shows how to estimate[^2] the objects (i.e., triangles) seen by TX.
For this example, visible triangles are colored in red, and hidden ones in black.

[^2]: It is an estimate because a fixed number of rays (see {func}`triangles_visible_from_vertices<differt.rt.triangles_visible_from_vertices>`) is launched from TX, and increasing (resp. decreasing) this number will increase (resp. decrease) the accuracy of the method.

In [None]:
from differt.rt import triangles_visible_from_vertices

tx = jnp.array([-40.0, 75, 30.0])

default_color = jnp.array([[0.2, 0.2, 0.2]])  # Hidden, black
visible_color = jnp.array([[1.0, 0.2, 0.2]])  # Visible, red
visible_triangles = triangles_visible_from_vertices(
    tx,
    mesh.triangle_vertices,
)

mesh = mesh.set_face_colors(default_color)
mesh = mesh.set_face_colors(
    mesh.face_colors.at[visible_triangles].set(visible_color)
)

with dplt.reuse("vispy") as canvas:
    dplt.draw_markers(
        np.array([tx]), ["tx"], size=7, text_kwargs={"font_size": 2000}
    )
    mesh.plot()

canvas

A visibility vector is simply an array of boolean, each entry indicating if a corresponding
object (here, a triangle) can be seen from TX.

The number of visible triangles is then the sum of all true entries in the array.

In [None]:
visible_triangles.sum() / mesh.num_triangles  # ~ 33% of triangles are seen from TX

It is also possible to get the number of visible quadrilaterals by
counting visible triangles by pairs. If any of the two triangles
forming a quadrilateral is visible, then this quadrilateral **is considered visible**.

In [None]:
visible_quads = visible_triangles.reshape(mesh.num_quads, 2).any(axis=-1)
visible_quads.sum() / mesh.num_quads  # ~ 43% of quadrilaterals are seen from TX

We can then use this result to inform the graph about the limited number of faces
visible from TX.

As expected, the number of path candidates get reduced to about 43% of the previous value.

However, 43% visibility is probably too high to switch from a
{class}`CompleteGraph<differt_core.rt.CompleteGraph>`
to a {class}`DiGraph<differt_core.rt.DiGraph>`,
as iterating through the latter is quite slower (because the former is optimized for complete graphs).

In [None]:
from differt_core.rt import CompleteGraph, DiGraph

graph = DiGraph.from_complete_graph(CompleteGraph(mesh.num_quads))
from_, to = graph.insert_from_and_to_nodes(
    from_adjacency=np.asarray(visible_quads)
)

f"{graph.all_paths(from_, to, order + 2).count():.3e}"  # ~ 43% of the previous length

### What about Ray Launching

Eventually, all the above solutions reach a glass ceiling at one point or another,
where the number of path candidates takes over any possible optimization.

In those cases, Ray Launching (RL) can be used as an alternative to exhaustive RT,
as the number of path candidates is usually fixed, a bit like when estimating the
visibility from TX.
This is fact what tools like Sionna use for coverage map.

Currently, DiffeRT does not provide any convenient RL routine, but it is on the roadmap,
so stay tuned!

If you want to contribute to extending DiffeRT, please feel free to reach out on GitHub!

In [None]:
mesh = TriangleMesh.load_obj(mesh_file)  # Reload mesh to reset colors

# [num_triangles 3 3]
triangle_vertices = mesh.triangle_vertices

num_triangles = mesh.num_triangles

jax.config.update("jax_disable_jit", False)

with dplt.reuse("vispy") as canvas:
    dplt.draw_markers(
        np.array([tx, rx]), ["tx", "rx"], size=7, text_kwargs={"font_size": 2000}
    )
    mesh.plot()

    num_rays = int(1e4)
    max_dist = 0.1 ** 2  # squared distance (to avoid sqrt)
    max_order = 1

    # [num_path_candidates order 3]
    frustum = viewing_frustum(
        tx, triangle_vertices.reshape(-1, 3)
    )  # This avoids launching rays where there are no object

    # [num_rays 3]
    ray_origins = jnp.broadcast_to(tx, (num_rays, 3))
    ray_directions = fibonacci_lattice(num_rays, frustum=frustum)

    ScanC = tuple[
        Float[Array, f"{num_rays} 3"],  # Ray origins
        Float[Array, f"{num_rays} 3"],  # Ray directions
    ]
    ScanR = tuple[
        Float[Array, f"{num_rays} 3"],  # Path vertices
        Bool[
            Array, f"{num_rays} 3"
        ],  # Whether paths pass close (i.e., < max_dist) to RX
    ]
    
    def scan_fun(
        ray_origins_and_directions: ScanC, _: None
    ) -> tuple[ScanC, ScanR]:
        ray_origins, ray_directions = ray_origins_and_directions

        # 1 - Compute next intersection with triangles

        # [num_rays]
        triangles, t_hit = first_triangles_hit_by_rays(
            ray_origins,
            ray_directions,
            triangle_vertices,
        )  # This may generate jnp.inf values, so we will need to be careful with those

        # 2 - Check if the rays pass near RX

        # [num_rays 3]
        ray_origins_to_rx = (
            rx - ray_origins
        )

        # [num_rays]
        ray_distances_to_rx = jax.lax.integer_pow(
            jnp.cross(ray_directions, ray_origins_to_rx), 2
        ).sum(axis=-1) / jax.lax.integer_pow(ray_origins_to_rx, 2).sum(
            axis=-1
        )  # Squared distance from rays to RXs
        t_rx = (
            jnp.sum(ray_directions * ray_origins_to_rx, axis=-1)
            / jax.lax.integer_pow(ray_directions, 2).sum(axis=-1)
        )  # Distance (scaled by ray directions) from RX projected onto rays to ray origins
        masks = jnp.where(
            (t_rx < t_hit) & jnp.isfinite(t_hit), # Check if RX is before first triangle hit
            ray_distances_to_rx < max_dist,  # Check if RX is close enough
            False,  # noqa: FBT003
        )  # Whether rays pass near RX

        # 3 - Update rays

        # [num_rays 3]
        mirror_normals = jnp.take(mesh.normals, triangles, axis=0)

        ray_origins += t_hit[..., None] * ray_directions
        ray_directions = ray_directions - 2 * jnp.sum(ray_directions * mirror_normals, axis=-1, keepdims=True) * mirror_normals

        return (ray_origins, ray_directions), (
            ray_origins,
            masks,
        )

    # [max_order+1 num_rays 3], [max_order+1 num_rays]
    _, (paths, masks) = jax.lax.scan(
        scan_fun, (ray_origins, ray_directions), length=max_order + 1
    )

    paths = jnp.moveaxis(paths, 0, 1)
    masks = jnp.moveaxis(masks, 0, 1)

    #masks = jax.random.uniform(jax.random.key(1234), masks.shape) < 0.01

    for order in range(0, max_order + 1):
        full_paths = assemble_paths(
            tx[None, None, :],
            # [num_valid_rays order 3]
            paths[masks[..., order], :order + 1, :],
            #rx[None, None, :],
        )

        dplt.draw_paths(
            full_paths,
        )

canvas

In [None]:
paths

In [None]:
full_paths

In [None]:
masks[..., 1].any()

In [None]:
masks.shape

In [None]:
masks[..., 1].any()

In [None]:
triangles, t_hit = first_triangles_hit_by_rays(
            ray_origins,
            ray_directions,
            triangle_vertices,
        )

In [None]:
triangles

In [None]:
t_hit

In [None]:
(~jnp.isinf(t_hit)).sum()

In [None]:
(triangles == -1).sum()