# $\mathbb{W}_2$ Distance Map
In this notebook we the compute distance map in $\mathbb{W}_2$ with the $\mathbb{R}^2$ and $\mathbb{M}_2$ cost functions, corresponding to Figure 4b and 4d in ["Crossing-Preserving Geodesic Tracking on Spherical Images"](https://arxiv.org/abs/2504.03388v1).

In [None]:
import numpy as np
import taichi as ti
ti.init(arch=ti.gpu, debug=False, device_memory_GB=3.5) # Use less than the VRAM on your device as to not mix RAM and VRAM
import eikivp
from eikivp.R2.vesselness import import_vesselness as import_vesselness_R2
from eikivp.M2.vesselness import import_vesselness as import_vesselness_M2
from eikivp.utils import cost_function
from eikivp.W2.costfunction import cost
from eikivp.W2.plus.distancemap import export_W
from copy import deepcopy

## Parameters

In [2]:
cost_domain = "R2"
image_name = "E46_OD_best"
image_file_name = f"data\{image_name}.tif"
match cost_domain:
    case "M2":
        σ_s_list = np.array((0.5**3, 0.5)) # np.array((1.5, 2.))
        σ_o = 0.5 * 0.75**2
        σ_s_ext = 1.
        σ_o_ext = 0.01
        V_params = {
            "σ_s_list": σ_s_list,
            "σ_o": σ_o,
            "σ_s_ext": σ_s_ext,
            "σ_o_ext": σ_o_ext,
            "image_name": image_name 
        }
        V = import_vesselness_M2(V_params, "storage\\vesselness")
        dim_I, dim_J, dim_K = V.shape
    case "R2":
        dim_K = 32
        scales = np.array((0.125, 0.5), dtype=float)
        α = 0.5/np.sqrt(2)
        γ = 3/4
        ε = np.sqrt(0.2)
        V_params = {
            "scales": scales,
            "α": α,
            "γ": γ,
            "ε": ε,
            "image_name": image_name 
        }
        V = import_vesselness_R2(V_params, "storage\\vesselness")
        dim_I, dim_J = V.shape
        V = np.array(dim_K * [V]).transpose(1, 2, 0)

In [3]:
Is, Js, Ks = np.indices((dim_I, dim_J, dim_K))
a = 13 / 21
α_min, α_max = -0.837758, 0.837758
β_min, β_max = -0.962727, 0.962727
φ_min, φ_max = 0, 2 * np.pi
dα = (α_max - α_min) / (dim_I - 1)
dβ = (β_max - β_min) / (dim_J - 1)
dφ = (φ_max - φ_min) / dim_K
Is, Js, Ks = np.indices((dim_I, dim_J, dim_K))
αs, βs, φs = eikivp.W2.utils.coordinate_array_to_real(Is, Js, Ks, α_min, β_min, φ_min, dα, dβ, dφ)

a = 13 / 21
c = np.cos(np.pi/3)
x_min, x_max = -0.866025, 0.866025
y_min, y_max = -0.866025, 0.866025
θ_min, θ_max = 0., 2 * np.pi
dxy = (x_max - x_min) / (dim_I - 1)
dθ = (θ_max - θ_min) / dim_K

In [4]:
λ = 500
p = 2
ξ = 6.
source_point_real = (0.177528, 0.159588, 2.37002)
source_point = eikivp.W2.utils.coordinate_real_to_array(*source_point_real, α_min, β_min, φ_min, dα, dβ, dφ)
W_params = deepcopy(V_params)
W_params["λ"] = λ
W_params["p"] = p
W_params["ξ"] = ξ
W_params["source_point"] = source_point
W_params["target_point"] = "default"
W_params["cost_domain"] = cost_domain

In [5]:
target_point_real = (-0.721357, 0.218753, 2.65495)
target_point = eikivp.W2.utils.coordinate_real_to_array(*target_point_real, α_min, β_min, φ_min, dα, dβ, dφ)

In [6]:
C_M2 = cost_function(V, λ, p)
C = cost(C_M2, αs, βs, φs, a, c, x_min, y_min, θ_min, dxy, dθ)

In [None]:
fig, ax, _ = eikivp.visualisations.plot_image_array_W2(C.min(-1), α_min, α_max, β_min, β_max)
ax.scatter(*source_point_real[1::-1], label="Source")
ax.arrow(*source_point_real[1::-1], 0.1 * np.sin(φs[source_point]), 0.1 * np.cos(φs[source_point]), width=0.01)
ax.scatter(*target_point_real[1::-1], label="Target")
ax.arrow(*target_point_real[1::-1], 0.1 * np.sin(φs[target_point]), 0.1 * np.cos(φs[target_point]), width=0.01)
ax.legend();

## Compute Distance Map

In [None]:
W, grad_W = eikivp.eikonal_solver_W2_plus(C, source_point, ξ, dα, dβ, dφ, αs, φs, target_point=target_point, n_max=1e4, n_max_initialisation=1e4, n_check=2e3, n_check_initialisation=2e3, tol=1e-3, initial_condition=200.)

In [None]:
fig, ax, _ = eikivp.visualisations.plot_image_array_W2(C.min(-1), α_min, α_max, β_min, β_max, figsize=(12, 10))
max_distance = W[target_point] * 2.5
_, _, contour = eikivp.visualisations.plot_contour_W2(W.min(-1), αs[..., 0], βs[..., 0], fig=fig, ax=ax, levels=np.linspace(0., max_distance, 5))
ax.scatter(*source_point_real[1::-1], label="Source")
ax.scatter(*target_point_real[1::-1], label="Target")
ax.set_aspect("equal")
fig.colorbar(contour, label="$\\min_φ W(α, β, φ)$")
ax.legend();

In [10]:
export_W(W, grad_W, W_params, "storage\\distance")