# Quickstart

In [None]:
from itertools import combinations
from pathlib import Path

import jax.numpy as jnp
import plotly.graph_objects as go

from differt.geometry import TriangleMesh
from differt.geometry.triangle_mesh import (
    triangles_contain_vertices_assuming_inside_same_plane,
)
from differt.rt.image_method import image_method

In [None]:
mesh_file = Path("two_buildings.obj")
mesh = TriangleMesh.load_obj(mesh_file)

In [None]:
fig = mesh.plot(opacity=0.5)
fig

In [None]:
tx = jnp.array([0.0, 4.9352, 22.0])
rx = jnp.array([0.0, 10.034, 1.5])

fig.add_traces(
    go.Scatter3d(
        x=[tx[0], rx[0]],
        y=[tx[1], rx[1]],
        z=[tx[2], rx[2]],
        marker=dict(
            size=7,
            color="red",
        ),
        mode="markers+text",
        text=["tx", "rx"],
    )
)

In [None]:
all_vertices = jnp.take(mesh.vertices, mesh.triangles, axis=0)
n_vertices = all_vertices.shape[0]

indices = list(range(n_vertices))  # [0, 1, ..., n_vertices - 1]

for order in range(1, 3):  # This does not work above order = 2
    # [order num_paths]
    path_candidates = jnp.array(list(combinations(indices, order)), dtype=int).T
    from_vertices = jnp.tile(tx, (path_candidates.shape[1], 1))
    to_vertices = jnp.tile(rx, (path_candidates.shape[1], 1))
    mirror_vertices = jnp.take(all_vertices[..., 0], path_candidates, axis=0)
    mirror_normals = jnp.take(mesh.normals, path_candidates, axis=0)

    paths = image_method(from_vertices, to_vertices, mirror_vertices, mirror_normals)

    # [order num_paths]
    mask = triangles_contain_vertices_assuming_inside_same_plane(
        jnp.take(all_vertices, path_candidates, axis=0),
        paths,
    )
    mask = jnp.all(mask, axis=0)
    mask = jnp.ones_like(mask, dtype=bool)

    full_paths = jnp.concatenate(
        (
            jnp.expand_dims(from_vertices[mask, :], axis=0),
            paths[:, mask, :],
            jnp.expand_dims(to_vertices[mask, :], axis=0),
        )
    )

    fig.add_traces(
        [
            go.Scatter3d(
                x=full_paths[:, i, 0],
                y=full_paths[:, i, 1],
                z=full_paths[:, i, 2],
                marker=dict(
                    size=0,
                    color="darkgreen",
                ),
                line=dict(color="green", width=3),
            )
            for i in range(full_paths.shape[1])
        ]
    )

fig.update_layout(
    scene=dict(
        xaxis=dict(
            range=[-10, 10],
        ),
        yaxis=dict(
            range=[-10, 30],
        ),
        zaxis=dict(
            range=[0, 25],
        ),
    ),
)

fig