In [None]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all]

# Coherence Map

In [2]:
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Int, PRNGKeyArray
from tqdm.notebook import tqdm

from differt.geometry.paths import merge_cluster_ids
from differt.geometry.utils import min_distance_between_clusters, path_lengths
from differt.plotting import draw_image, draw_markers, reuse, set_defaults
from differt.scene.sionna import download_sionna_scenes, get_sionna_scene
from differt.scene.triangle_scene import TriangleScene

In [3]:
download_sionna_scenes()  # Let's download Sionna scenes (from the main branch)

In [None]:
set_defaults(
    "plotly"
)  # Our scene is simply, so Plotly is the best backend for online interactive plots :-)

file = get_sionna_scene("simple_street_canyon")
scene = TriangleScene.load_xml(file)

scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-33.0, 0.0, 32.0]))
batch = (
    200,
    200,
)  # Warning: a too large batch could easily cause OOM issues, or you need to reduce the 'chunk_size' value below.
scene_grid = scene.with_receivers_grid(*batch)
power = jnp.zeros(batch)
cluster_ids = jnp.zeros(batch, dtype=jnp.int32)
has_multipath = jnp.zeros(batch, dtype=bool)

x, y, z = jnp.unstack(scene_grid.receivers, axis=-1)


with reuse() as fig:
    scene.plot()

    for order in range(2):
        for paths in scene_grid.compute_paths(order=order, chunk_size=1_000):
            new_cluster_ids = paths.multipath_clusters()
            has_multipath |= paths.mask.any(axis=-1)
            cluster_ids = merge_cluster_ids(cluster_ids, new_cluster_ids)
            power += (paths.mask / path_lengths(paths.vertices) ** 2).sum(axis=-1)

    draw_image(
        np.asarray(power),
        x=np.asarray(x[0, :]),
        y=np.asarray(y[:, 0]),
        z0=float(z.ravel()[0]),
    )

# We set cluster ids with no multiple to -1 for easier identification
cluster_ids = jnp.where(has_multipath, cluster_ids, -1)

fig

In [None]:
def random_rgb(key: PRNGKeyArray) -> str:
    rgb_arr = jax.random.randint(key, (3,), 0, 256, dtype=jnp.uint8)
    return f"rgb({','.join(map(str, rgb_arr.tolist()))})"


def create_discrete_colorscale(
    key: PRNGKeyArray,
    cluster_ids: Int[Array, " *batch"],
    first_is_multipath_cluster: bool,
) -> list[list[float | str]]:
    unique_ids = jnp.unique(cluster_ids).tolist()
    min_id = min(unique_ids)
    max_id = max(unique_ids)
    scale_factor = 1 + max_id - min_id

    keys = jax.random.split(key, len(unique_ids))

    def scale(id_: int) -> float:
        return (id_ - min_id) / scale_factor

    colorscale = [
        [scale(id_ + offset), random_rgb(sub_key)]
        for id_, sub_key in zip(unique_ids, keys, strict=False)
        for offset in (0, 1)
    ]

    if first_is_multipath_cluster:  # Let's hide clusters with no multipath
        colorscale[0][1] = colorscale[1][1] = "rgba(0,0,0,0)"

    return colorscale


# We renumber unique indices to be between 0 and num_unique_cluster_ids (excluded)
# Because `jax.numpy.unique` sorts entries, the first cluster id will always refer to
# the 'no multipath' cluster, if it exists.
renumbered_cluster_ids = jnp.unique(cluster_ids, return_inverse=True)[1].reshape(
    cluster_ids.shape
)
# We create a discrete colorscale
key = jax.random.PRNGKey(1234)
colorscale = create_discrete_colorscale(
    key, renumbered_cluster_ids, first_is_multipath_cluster=bool(~has_multipath.all())
)

# Figure seems to be broken for order > 1, possibly due to 'colorscale' being to big
with reuse("plotly") as fig:
    scene.plot()
    draw_image(
        np.asarray(renumbered_cluster_ids),
        x=np.asarray(x[0, :]),
        y=np.asarray(y[:, 0]),
        z0=float(z.ravel()[0]),
        colorscale=colorscale,
        hovertemplate="Cluster id: %{surfacecolor}",
        showscale=False,
    )

fig

In [None]:
length_x = x.max() - x.min()
length_y = y.max() - y.min()
surface = length_x * length_y
num_points = cluster_ids.size
surface_per_point = (
    surface / num_points
)  # ~ Roughly, because RXs are not placed at centers of tiles

unique_ids, points_per_cluster = jnp.unique(cluster_ids, return_counts=True)
points_per_cluster = points_per_cluster[
    unique_ids != -1
]  # We remove cluster with no multipath
points_per_cluster

In [None]:
import plotly.express as px

surface_per_cluster = points_per_cluster * surface_per_point

labels = {
    "x": "Surface",
    "y": "Normalized number of clusters occupying a given surface",
}
counts, bins = np.histogram(surface_per_cluster, bins=30)
bins = 0.5 * (bins[:-1] + bins[1:])

px.bar(
    x=bins,
    y=counts / counts.sum(),
    labels=labels,
)

In [None]:
min_dist = min_distance_between_clusters(scene_grid.receivers, cluster_ids)

for cluster_id in jnp.unique(cluster_ids):
    same_cluster = cluster_ids == cluster_id
    mean_min_dist = jnp.mean(min_dist, where=same_cluster)
    std_min_dist = jnp.std(min_dist, where=same_cluster)

    print(  # noqa: T201
        f"Cluster id = {int(cluster_id):5d} has an average minimal distance to next cluster of {float(mean_min_dist):5.2f} (std: {float(std_min_dist):.2f})"
    )

In [None]:
# TODO: animate this slider ? https://plotly.com/python/animations/


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1,
    cols=2,
    column_widths=[0.7, 0.3],
    specs=[[{"type": "scene"}, {"type": "xy"}]],
)

with reuse(figure=fig) as fig:
    scene.plot(tx_kwargs={"visible": False}, row=1, col=1)

    x_positions = jnp.linspace(x.min(), x.max(), 100)

    steps = []

    for x_pos in tqdm(x_positions, leave=False):
        scene_grid = eqx.tree_at(
            lambda s: s.transmitters,
            scene_grid,
            scene_grid.transmitters.at[0].set(x_pos),
        )
        cluster_ids = jnp.zeros(batch, dtype=jnp.int32)
        has_multipath = jnp.zeros(batch, dtype=bool)

        for order in range(2):
            for paths in scene_grid.compute_paths(order=order, chunk_size=1_000):
                new_cluster_ids = paths.multipath_clusters()
                has_multipath |= paths.mask.any(axis=-1)
                cluster_ids = merge_cluster_ids(cluster_ids, new_cluster_ids)

        cluster_ids = jnp.where(has_multipath, cluster_ids, -1)
        renumbered_cluster_ids = jnp.unique(cluster_ids, return_inverse=True)[
            1
        ].reshape(cluster_ids.shape)
        colorscale = create_discrete_colorscale(
            key,
            renumbered_cluster_ids,
            first_is_multipath_cluster=bool(~has_multipath.all()),
        )

        draw_markers(
            np.asarray(scene_grid.transmitters.reshape(-1, 3)),
            labels=["tx"],
            showlegend=False,
            visible=False,
            row=1,
            col=1,
        )

        draw_image(
            np.asarray(renumbered_cluster_ids),
            x=np.asarray(x[0, :]),
            y=np.asarray(y[:, 0]),
            z0=float(z.ravel()[0]),
            colorscale=colorscale,
            hovertemplate="Cluster id: %{surfacecolor}",
            showscale=False,
            visible=False,
            row=1,
            col=1,
        )

        min_dist = min_distance_between_clusters(scene_grid.receivers, cluster_ids)

        unique_cluster_ids = jnp.unique(cluster_ids)
        unique_cluster_ids = unique_cluster_ids[unique_cluster_ids != -1]
        mean_min_dist = jnp.empty_like(unique_cluster_ids, dtype=jnp.float32)

        for i, cluster_id in enumerate(unique_cluster_ids):
            same_cluster = cluster_ids == cluster_id
            mean_min_dist = mean_min_dist.at[i].set(
                jnp.mean(min_dist, where=same_cluster)
            )

        counts, bins = np.histogram(mean_min_dist, bins=30)
        bins = 0.5 * (bins[:-1] + bins[1:])

        fig.add_bar(
            x=bins,
            y=counts / counts.sum(),
            showlegend=False,
            visible=False,
            row=1,
            col=2,
        )

    for i, _ in enumerate(x_positions):
        step = {
            "method": "update",
            "args": [
                {"visible": [False, True] + [False] * len(x_positions) * 3},
            ],
        }
        step["args"][0]["visible"][2 + 3 * i + 0] = True  # Show TX position
        step["args"][0]["visible"][2 + 3 * i + 1] = True  # Show coherence map
        step["args"][0]["visible"][2 + 3 * i + 2] = True  # Show histogram
        steps.append(step)

    sliders = [
        {
            "active": 0,
            "currentvalue": {"prefix": "TX index: "},
            "pad": {"t": 50},
            "steps": steps,
        }
    ]

    fig.data[2].visible = True
    fig.data[3].visible = True
    fig.data[4].visible = True

    fig.update_layout(
        height=600,
        sliders=sliders,
        xaxis={"range": [0, 10]},
        yaxis={"range": [0, 1]},
    )

fig