# Ray Tracing at City Scale

In [None]:
from pathlib import Path

import jax.numpy as jnp
import numpy as np

import differt.plotting as dplt
from differt.geometry import TriangleMesh
from differt.geometry.triangle_mesh import (
    triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt.image_method import (
    consecutive_vertices_are_on_same_side_of_mirrors,
    image_method,
)
from differt.rt.utils import (
    generate_all_path_candidates_chunks_iter,
    rays_intersect_triangles,
)

In [None]:
mesh_file = Path("bruxelles.obj")
mesh = TriangleMesh.load_obj(mesh_file)

canvas = mesh.plot()

tx = jnp.array([-40.0, 75, 30.0])
rx = jnp.array([+20.0, 108.034, 1.50])

dplt.draw_markers(np.array([tx, rx]), ["tx", "rx"], canvas=canvas)

color = ["black", "green", "orange", "yellow"]

all_triangle_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)

num_triangles = mesh.triangles.shape[0]

for order in range(
    0, 2
):  # You probably don't want to try order > 1 (too slow if testing all paths)
    # Prepare input arrays
    for path_candidates in generate_all_path_candidates_chunks_iter(
        num_triangles, order, chunk_size=2000
    ):
        # 1- Prepare input arrays

        # [num_path_candidates order]
        num_path_candidates = path_candidates.shape[0]

        # [num_path_candidates 3]
        from_vertices = jnp.tile(tx, (num_path_candidates, 1))
        to_vertices = jnp.tile(rx, (num_path_candidates, 1))

        # [num_path_candidates order 3]
        triangles = jnp.take(mesh.triangles, path_candidates, axis=0)

        # [num_path_candidates order 3 3]
        triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)

        # [num_path_candidates order 3]
        mirror_vertices = triangle_vertices[
            ..., 0, :
        ]  # Only one vertex per triangle is needed
        # [num_path_candidates order 3]
        mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

        # 2 - Trace paths

        # [num_path_candidates order 3]
        paths = image_method(
            from_vertices, to_vertices, mirror_vertices, mirror_normals
        )

        # 3 - Remove invalid paths

        # Remove paths with vertices outside triangles
        # [num_path_candidates order]
        mask = triangles_contain_vertices_assuming_inside_same_plane(
            triangle_vertices,
            paths,
        )
        # [num_path_candidates]
        mask = jnp.all(mask, axis=-1)

        # [num_paths_inter order+2 3]
        full_paths = jnp.concatenate(
            (
                jnp.expand_dims(from_vertices[mask, ...], axis=-2),
                paths[mask, ...],
                jnp.expand_dims(to_vertices[mask, ...], axis=-2),
            ),
            axis=-2,
        )

        # Remove paths with vertices not on the same side of mirrors
        # [num_paths_inter order]
        mask = consecutive_vertices_are_on_same_side_of_mirrors(
            full_paths, mirror_vertices[mask, ...], mirror_normals[mask, ...]
        )

        # [num_paths_inter]
        mask = jnp.all(mask, axis=-1)

        # [num_paths_inter order+1 3]
        ray_origins = full_paths[..., :-1, :]
        # [num_paths_inter order+1 3]
        ray_directions = jnp.diff(full_paths, axis=-2)

        # [num_paths_inter order+1 num_triangles 3]
        ray_origins = jnp.repeat(
            jnp.expand_dims(ray_origins, axis=-2), num_triangles, axis=-2
        )
        # [num_paths_inter order+1 num_triangles 3]
        ray_directions = jnp.repeat(
            jnp.expand_dims(ray_directions, axis=-2), num_triangles, axis=-2
        )

        # [num_paths_inter order+1 num_triangles], [num_paths_inter order+1 num_triangles]
        t, hit = rays_intersect_triangles(
            ray_origins,
            ray_directions,
            jnp.broadcast_to(all_triangle_vertices, (*ray_origins.shape, 3)),
        )
        # [num_paths_inter order+1 num_triangles]
        intersect = (t < 0.999) & hit
        #  [num_paths_inter]
        intersect = jnp.any(intersect, axis=(-1, -2))
        #  [num_paths_inter]
        mask = mask & ~intersect

        # 4 - Obtain final valid paths and plot

        #  [num_paths_final]
        full_paths = full_paths[mask, ...]

        dplt.draw_paths(full_paths, canvas=canvas)

view = dplt.view_from_canvas(canvas)
view.camera.set_state(
    {
        "scale_factor": 138.81554751457762,
        "center": (20.0, 108.034, 46.0),
        "fov": 45.0,
        "elevation": 13.0,
        "azimuth": 39.0,
        "roll": 0.0,
    }
)

canvas