# Quickstart

In [None]:
from pathlib import Path

import jax.numpy as jnp
import plotly.graph_objects as go

from differt.geometry import TriangleMesh
from differt.geometry.triangle_mesh import (
    triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt.image_method import (
    image_method,
    consecutive_vertices_are_on_same_side_of_mirrors,
)
from differt.rt.utils import generate_all_path_candidates, rays_intersect_triangles

In [None]:
mesh_file = Path("two_buildings.obj")
mesh = TriangleMesh.load_obj(mesh_file)

In [None]:
fig = mesh.plot(opacity=0.5)
fig

In [None]:
tx = jnp.array([0.0, 4.9352, 22.0])
rx = jnp.array([0.0, 10.034, 1.50])

fig.add_traces(
    go.Scatter3d(
        x=[tx[0], rx[0]],
        y=[tx[1], rx[1]],
        z=[tx[2], rx[2]],
        marker=dict(
            size=7,
            color="red",
        ),
        mode="markers+text",
        text=["tx", "rx"],
        name="nodes",
    )
)

In [None]:
select = [8, 9, 22, 23]

x, y, z = mesh.vertices.T
i, j, k = triangles = mesh.triangles[select, :].T

select = jnp.array(
    select[::2], dtype=int
)  # We actually only need one triangle per plane, so [8, 22]

fig.add_traces(go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color="red"))

In [None]:
color = ["black", "green", "orange", "yellow"]

for path_candidate in [select[:0], select[:1], select[:2]]:
    from_vertices = tx
    to_vertices = rx
    mirror_vertices = mesh.vertices[mesh.triangles[path_candidate, 0], :]
    mirror_normals = mesh.normals[path_candidate, :]
    paths = image_method(from_vertices, to_vertices, mirror_vertices, mirror_normals)

    full_paths = jnp.concatenate(
        (
            from_vertices[None, :],
            paths,
            to_vertices[None, :],
        )
    )

    fig.add_traces(
        [
            go.Scatter3d(
                x=full_paths[:, 0],
                y=full_paths[:, 1],
                z=full_paths[:, 2],
                marker=dict(
                    size=0,
                    color="black",
                ),
                line=dict(color=color[len(path_candidate)], width=3),
                name=f"Order {len(path_candidate)}",
            )
        ]
    )
fig

In [None]:
fig.data = fig.data[:2]  # Keep only first 2 traces: geometry and TX/RX

all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)

num_triangles = mesh.triangles.shape[0]

for order in range(0, 4):
    # Prepare input arrays
    path_candidates = generate_all_path_candidates(num_triangles, order)
    num_candidates = path_candidates.shape[1]
    from_vertices = jnp.tile(tx, (num_candidates, 1))
    to_vertices = jnp.tile(rx, (num_candidates, 1))
    triangles = jnp.take(mesh.triangles, path_candidates, axis=0)
    triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)
    mirror_vertices = triangle_vertices[
        ..., 0, :
    ]  # Only one vertex per triangle is needed
    mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

    # Trace paths
    paths = image_method(from_vertices, to_vertices, mirror_vertices, mirror_normals)

    # Remove paths with vertices outside triangles
    mask = triangles_contain_vertices_assuming_inside_same_plane(
        triangle_vertices,
        paths,
    )
    mask = jnp.all(mask, axis=0)

    full_paths = jnp.concatenate(
        (
            jnp.expand_dims(from_vertices[mask, :], axis=0),
            paths[:, mask, :],
            jnp.expand_dims(to_vertices[mask, :], axis=0),
        )
    )

    # Remove paths with vertices not on the same side of mirrors
    mask = consecutive_vertices_are_on_same_side_of_mirrors(
        full_paths, mirror_vertices[:, mask, ...], mirror_normals[:, mask, ...]
    )

    mask = jnp.all(mask, axis=0)

    # [order+1 num_paths 3]
    ray_origins = full_paths[:-1, ...]
    ray_directions = jnp.diff(full_paths, axis=0)

    ray_origins = jnp.repeat(
        jnp.expand_dims(ray_origins, axis=-2), num_triangles, axis=-2
    )
    ray_directions = jnp.repeat(
        jnp.expand_dims(ray_directions, axis=-2), num_triangles, axis=-2
    )

    t, hit = rays_intersect_triangles(
        ray_origins,
        ray_directions,
        jnp.broadcast_to(all_triangle_vertices, ray_origins.shape + (3,)),
    )
    intersect = (t < 0.999) & hit
    intersect = jnp.any(intersect, axis=(0, 2))
    mask = mask & ~intersect

    full_paths = full_paths[:, mask, ...]

    if order == 0:
        full_paths = jnp.concatenate(
            (
                tx[None, None, :],
                rx[None, None, :],
            )
        )

    fig.add_traces(
        [
            go.Scatter3d(
                x=full_paths[:, i, 0],
                y=full_paths[:, i, 1],
                z=full_paths[:, i, 2],
                marker=dict(
                    size=0,
                    color="black",
                ),
                line=dict(color=color[order], width=3),
                name=f"Order {order} #{i:02d}",
            )
            for i in range(full_paths.shape[1])
        ]
    )

fig