# Quickstart

In [None]:
from pathlib import Path

import jax.numpy as jnp
import numpy as np

import differt.plotting as dplt
from differt.geometry import TriangleMesh
from differt.geometry.triangle_mesh import (
    triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt.image_method import (
    consecutive_vertices_are_on_same_side_of_mirrors,
    image_method,
)
from differt.rt.utils import generate_all_path_candidates, rays_intersect_triangles

In [None]:
dplt.use("plotly")  # Let's use the Plotly backend

In [None]:
mesh_file = Path("two_buildings.obj")
mesh = TriangleMesh.load_obj(mesh_file)

In [None]:
fig = mesh.plot(opacity=0.5)
fig

In [None]:
tx = jnp.array([0.0, 4.9352, 22.0])
rx = jnp.array([0.0, 10.034, 1.50])

dplt.draw_markers(np.array([tx, rx]), labels=["tx", "rx"], figure=fig, name="nodes")

In [None]:
select = [8, 9, 22, 23]

# differt.plotting (dplt) works with NumPy arrays, not JAX arrays
vertices = np.asarray(mesh.vertices)
triangles = np.asarray(mesh.triangles[select, :])

select = jnp.array(
    select[::2], dtype=int
)  # We actually only need one triangle per plane, so [8, 22]

dplt.draw_mesh(vertices, triangles, figure=fig, color="red")

In [None]:
color = ["black", "green", "orange", "yellow", "blue"]

for path_candidate in [select[:0], select[:1], select[:2]]:
    from_vertices = tx
    to_vertices = rx
    mirror_vertices = mesh.vertices[mesh.triangles[path_candidate, 0], :]
    mirror_normals = mesh.normals[path_candidate, :]
    paths = image_method(from_vertices, to_vertices, mirror_vertices, mirror_normals)

    full_paths = jnp.concatenate(
        (
            from_vertices[None, :],
            paths,
            to_vertices[None, :],
        )
    )

    dplt.draw_paths(
        full_paths,
        figure=fig,
        marker=dict(
            size=0,
            color="black",
        ),
        line=dict(color=color[len(path_candidate)], width=3),
        name=f"Order {len(path_candidate)}",
    )

fig

In [None]:
fig.data = fig.data[:2]  # Keep only first 2 traces: geometry and TX/RX

all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)

num_triangles = mesh.triangles.shape[0]

for order in range(1, 5):
    # 1- Prepare input arrays

    # [num_path_candidates order]
    path_candidates = generate_all_path_candidates(num_triangles, order)
    num_path_candidates = path_candidates.shape[0]

    # [num_path_candidates 3]
    from_vertices = jnp.tile(tx, (num_path_candidates, 1))
    to_vertices = jnp.tile(rx, (num_path_candidates, 1))

    # [num_path_candidates order 3]
    triangles = jnp.take(mesh.triangles, path_candidates, axis=0)

    # [num_path_candidates order 3 3]
    triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)

    # [num_path_candidates order 3]
    mirror_vertices = triangle_vertices[
        ..., 0, :
    ]  # Only one vertex per triangle is needed
    # [num_path_candidates order 3]
    mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

    # 2 - Trace paths

    # [num_path_candidates order 3]
    paths = image_method(from_vertices, to_vertices, mirror_vertices, mirror_normals)

    # 3 - Remove invalid paths

    # Remove paths with vertices outside triangles
    # [num_path_candidates order]
    mask = triangles_contain_vertices_assuming_inside_same_plane(
        triangle_vertices,
        paths,
    )
    # [num_path_candidates]
    mask = jnp.all(mask, axis=-1)

    # [num_paths_inter order+2 3]
    full_paths = jnp.concatenate(
        (
            jnp.expand_dims(from_vertices[mask, ...], axis=-2),
            paths[mask, ...],
            jnp.expand_dims(to_vertices[mask, ...], axis=-2),
        ),
        axis=-2,
    )

    # Remove paths with vertices not on the same side of mirrors
    # [num_paths_inter order]
    mask = consecutive_vertices_are_on_same_side_of_mirrors(
        full_paths, mirror_vertices[mask, ...], mirror_normals[mask, ...]
    )

    # [num_paths_inter]
    mask = jnp.all(mask, axis=-1)

    # [num_paths_inter order+1 3]
    ray_origins = full_paths[..., :-1, :]
    # [num_paths_inter order+1 3]
    ray_directions = jnp.diff(full_paths, axis=-2)

    # [num_paths_inter order+1 num_triangles 3]
    ray_origins = jnp.repeat(
        jnp.expand_dims(ray_origins, axis=-2), num_triangles, axis=-2
    )
    # [num_paths_inter order+1 num_triangles 3]
    ray_directions = jnp.repeat(
        jnp.expand_dims(ray_directions, axis=-2), num_triangles, axis=-2
    )

    # [num_paths_inter order+1 num_triangles], [num_paths_inter order+1 num_triangles]
    t, hit = rays_intersect_triangles(
        ray_origins,
        ray_directions,
        jnp.broadcast_to(all_triangle_vertices, (*ray_origins.shape, 3)),
    )
    # [num_paths_inter order+1 num_triangles]
    intersect = (t < 0.999) & hit
    #  [num_paths_inter]
    intersect = jnp.any(intersect, axis=(-1, -2))
    #  [num_paths_inter]
    mask = mask & ~intersect

    # 4 - Obtain final valid paths and plot

    #  [num_paths_final]
    full_paths = full_paths[mask, ...]

    dplt.draw_paths(
        full_paths,
        figure=fig,
        marker=dict(
            size=0,
            color="black",
        ),
        line=dict(color=color[order], width=3),
        name=f"Order {order}",
    )

fig