# IterativeEikonal
Approximate the viscosity solution of the  Eikonal equation 
$$\begin{cases} \Vert \nabla_G W(g) \Vert = 1, & g \in G \setminus S, \\
W(g) = 0, & g \in S \subset G. \end{cases}$$
This method is based on the paper ["A PDE Approach to Data-Driven Sub-Riemannian Geodesics in $SE(2)$" (2015) by E. J. Bekkers, R. Duits, A. Mashtakov, and G. R. Sanguinetti](https://doi.org/10.1137/15M1018460).

In [None]:
import iterativeeikonal as eik
import taichi as ti
ti.init(arch=ti.gpu, debug=False)
import numpy as np
import scipy as sp
import diplib as dip
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib widget

## $\mathbb{R}^2$

### Flat space
In fact, we will solve the Eikonal equation on some subset $G \subset \mathbb{R}^2$, namely $G = [-1, 1] \times [-1, 1]$. The source set will consist of a single point: $S = \{(0, 0)\}$. Then we know that the viscosity solution is simply the Euclidean norm. 

There are numerous ways to numerically solve the Eikonal equation. We could use [Fast Marching](https://en.wikipedia.org/wiki/Fast_marching_method). This is an efficient method, but it requires a lot of work to make it compatible with non-Euclidean domains. 

In [None]:
def fast_marching_R2(f):
    N = f.shape[0]
    f = pad_array(f, 1.)
    dxy = 2. / N
    point_stages = 2 * np.ones((N, N), dtype=int)
    W = np.full(shape=(N, N), fill_value=100.)
    point_stages = pad_array(point_stages, 0)
    W = pad_array(W, 100.)
    i_0, j_0 = ((N + 1) // 2, (N + 1) // 2)
    point_stages[i_0, j_0] = 0
    W[i_0, j_0] = 0
    i_star, j_star = i_0, j_0
    while np.any(point_stages != 0):
        update_neighbours(i_star, j_star, point_stages, W, dxy, f)
        Trial = point_stages == 1
        index = np.argmin(np.where(Trial, W, np.inf))
        i_star, j_star = np.unravel_index(index, W.shape)
        point_stages[i_star, j_star] = 0
    return unpad_array(W)

def pad_array(array, fill_value):
    padded_shape = (array.shape[0] + 2, array.shape[1] + 2)
    padded_array = np.ones(padded_shape, dtype=array.dtype) * fill_value
    padded_array[1:-1, 1:-1] = array
    return padded_array

def unpad_array(padded_array):
    return padded_array[1:-1, 1:-1]

def update_neighbours(i_star, j_star, point_stages, W, dxy, f):
    neighbours = sees_point(i_star, j_star)
    for i, j in neighbours:
        if point_stages[i, j] != 0:
            point_stages[i, j] = 1
            update_W(i, j, W, dxy, f)

def update_W(i, j, W, dxy, f):
    Wx, Wy = gradient_W(i, j, W)
    if np.abs(Wx - Wy) >= dxy:
        W[i, j] = np.min((Wx + dxy / f[i, j], Wy + dxy / f[i, j]))
    else:
        W[i, j] = (Wx + Wy + np.sqrt((Wx + Wy) ** 2 - 2 * (Wx ** 2 + Wy ** 2 - (dxy / f[i, j]) ** 2))) / 2

def gradient_W(i, j, W):
    Wx = min(W[i + 1, j], W[i - 1, j])
    Wy = min(W[i, j + 1], W[i, j - 1])
    return Wx, Wy

def sees_point(i_star, j_star):
    return ((i_star + 1, j_star), (i_star - 1, j_star), (i_star, j_star + 1), (i_star, j_star - 1))

We can alternatively apply an iterative method, developed by [Bekkers et al.](https://doi.org/10.1137/15M1018460)

In [None]:
def iterative_method_flat_R2(N, n):
    """Discretise [-1, 1] x [-1, 1] into `N` points in each direction, and apply the iterative solution method `n` times."""
    dxy = 2. / (N + 1)
    eps = dxy / 4

    W = get_initial_W(N)
    boundarypoints, boundaryvalues = get_boundary_conditions(N)

    eik.cleanarrays.apply_boundary_conditions(W, boundarypoints, boundaryvalues)

    dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy = get_initial_derivatives(W)
    for _ in tqdm(range(n)):
        step_W(W, dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy, dxy, eps)
        eik.cleanarrays.apply_boundary_conditions(W, boundarypoints, boundaryvalues)

    W_np = W.to_numpy()
    return eik.cleanarrays.unpad_array(W_np)

def get_initial_W(N, initial_condition=100.):
    W_unpadded = np.full(shape=(N, N), fill_value=initial_condition)
    W_np = eik.cleanarrays.pad_array(W_unpadded, pad_value=initial_condition, pad_shape=1)
    W = ti.field(dtype=ti.f32, shape=W_np.shape)
    W.from_numpy(W_np)
    return W

def get_boundary_conditions(N):
    i_0, j_0 = (N + 1) // 2, (N + 1) // 2
    boundarypoints_np = np.array([[i_0, j_0]], dtype=int)
    boundaryvalues_np = np.array([0.], dtype=float)
    boundarypoints = ti.Vector.field(n=2, dtype=ti.i32, shape=1)
    boundarypoints.from_numpy(boundarypoints_np)
    boundaryvalues = ti.field(shape=1, dtype=ti.f32)
    boundaryvalues.from_numpy(boundaryvalues_np)
    return boundarypoints, boundaryvalues

def get_initial_derivatives(W):
    dx_forward = ti.field(dtype=ti.f32, shape=W.shape)
    dx_backward = ti.field(dtype=ti.f32, shape=W.shape)
    dy_forward = ti.field(dtype=ti.f32, shape=W.shape)
    dy_backward = ti.field(dtype=ti.f32, shape=W.shape)
    abs_dx = ti.field(dtype=ti.f32, shape=W.shape)
    abs_dy = ti.field(dtype=ti.f32, shape=W.shape)
    return dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy

@ti.kernel
def step_W(
    W:ti.template(), 
    dx_forward: ti.template(), 
    dx_backward: ti.template(), 
    dy_forward: ti.template(), 
    dy_backward: ti.template(), 
    abs_dx: ti.template(), 
    abs_dy: ti.template(), 
    dxy: ti.f32, 
    eps: ti.f32
):
    eik.derivativesR2.abs_derivatives(W, dxy, dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy)
    for I in ti.grouped(W):
        W[I] += (1 - ti.math.sqrt(abs_dx[I] ** 2 + abs_dy[I] ** 2)) * eps

In [None]:
N = 51
n = 250
xs, ys = np.meshgrid(np.linspace(-1, 1, N), np.linspace(-1, 1, N))
W_exact = np.sqrt(xs ** 2 + ys ** 2)
W_iterative_big = iterative_method_flat_R2(N, n)
f = np.ones((N, N), dtype=float)
W_fast_marching = fast_marching_R2(f)

In [None]:

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 5))
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.resizable = False
# contour = ax.contour(xs_small, ys_small, W_iterative_small, linestyles="dotted")
contour = ax.contour(xs, ys, W_fast_marching, linestyles="dotted")
contour = ax.contour(xs, ys, W_iterative_big, linestyles="dashed")
ax.contour(xs, ys, W_exact)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_xlabel("x")
ax.set_ylabel("y")
fig.colorbar(contour, label="W(x, y)");

### Retinal image
Computing the cost function on flat $\mathbb{R}^2$ is not very impressive. We would like to be able to do it also when the space is not flat, for instance in the case of a retinal image. To track the vessels in a retinal image, it makes sense to assign a low cost to those areas of the image that are vessels, and a high cost to those that aren't. We must therefore first compute such a cost function.

To do this, we use Frangi vesselness filtering.

#### Frangi vesselness filtering

In [None]:
ds = 8
retinal_image = Image.open("E46_OD_best.tif")
width, height = retinal_image.size
retinal_image_gray_ds = retinal_image.resize((width // ds, height // ds)).convert("L")
retinal_array_unnormalised = np.array(retinal_image_gray_ds).astype(np.float64)
retinal_array = retinal_array_unnormalised / retinal_array_unnormalised.max()

Frangi vesselness filtering uses Gaussian derivatives. It turns out that the standard function for computing Gaussian derivatives in SciPy - `scipy.ndimage.gaussian_filter()` - is pretty inaccurate, especially at small scales (the derivative is much sharper). We therefore make use of the Gaussian derivatives defined in another package: DIPlib. 

In [None]:
Lxx_sp = sp.ndimage.gaussian_filter(-retinal_array, sigma=0.25, order=(0, 2))
print(f"Lxx[SciPy] in [{Lxx_sp.min()}, {Lxx_sp.max()}]")
Lxx_sp_shift = Lxx_sp - Lxx_sp.min()
Lxx_sp = Lxx_sp_shift / Lxx_sp_shift.max()
Lxx_dip = np.array(dip.Gauss(-retinal_array, (0.25, 0.25), (2, 0)))
print(f"Lxx[DIPlib] in [{Lxx_dip.min()}, {Lxx_dip.max()}]")
Lxx_dip_shift = Lxx_dip - Lxx_dip.min()
Lxx_dip = Lxx_dip_shift / Lxx_dip_shift.max()

In [None]:
_, fig = eik.cleanarrays.view_image_arrays_side_by_side((Lxx_sp, Lxx_dip))

In [None]:
scales = np.array((0.25, 0.5, 1, 2, 3, 4), dtype=float) # / ds
vesselness = eik.costfunctions.multiscale_frangi_filter_R2(-retinal_array, 0.3, 0.3, scales)
mask = (retinal_array > 0) # Remove boundary
vesselness *= sp.ndimage.binary_erosion(mask, iterations=4)
print(f"Vesselness in [{vesselness.min()}, {vesselness.max()}]")
vesselness /= vesselness.max()

In [None]:
image_vesselness = eik.cleanarrays.convert_array_to_image(vesselness)
fig, ax = plt.subplots()
ax.imshow(image_vesselness, cmap="gray", origin="lower")
ax.scatter(278, 217, label="Source")
ax.scatter(118, 211, label="Target") # Coordinates are confusing: come up with 1 good set.
ax.legend();

Compared to the result produced by the Frangi filter in the Mathematica notebook "code A - Vesselness in SE(2).nb", this vesselness score is much sharper.

#### Cost function
Given the vesselness function, we compute the cost function as
$$\mathrm{cost}_{\lambda, p}(x, y) = \frac{1}{1 + \lambda * \vert \mathrm{vesselness} \vert^p}$$

In [None]:
cost = eik.costfunctions.cost_function_R2(vesselness, 1000, 5)
_, fig = eik.cleanarrays.view_image_array(1 / cost) # Should look like vesselness
# cost /= cost.min()

#### Distance map
We are now able to compute the distance map, which is the viscosity solution of the Eikonal equation
$$\begin{cases} \Vert \nabla_{\mathrm{cost}} W(x, y) \Vert = 1, & (x, y) \in \mathbb{R}^2 \setminus \{(x_0, y_0)\}, \\
W(x, y) = 0, & (x, y) = (x_0, y_0), \end{cases}$$
where $\nabla_C$ is a datadriven derivative:
$$\Vert \nabla_{\mathrm{cost}} W(x, y) \Vert = \sqrt{\mathrm{cost}_{\lambda, p}^{-2}(x, y) (\vert \partial_x W(x, y) \vert^2 + \vert \partial_y W(x, y) \vert^2)}.$$

In [None]:
def iterative_method_retinal_R2(cost_np, source_point, target_point):
    """Discretise [-1, 1] x [-1, 1] into `N` points in each direction, and apply the iterative solution method `n` times."""
    N = cost_np.shape[0]
    # n = 200000 # Maybe select n based on cost, or stop when W does not change much.
    dxy = 2. / (N + 1) # dxy is meaningless: how big is 1 pixel?
    eps = dxy / 500 # 500 # How to pick a good time step? It must depend on the data.

    cost = get_padded_cost(cost_np)
    W = get_initial_W(N)
    boundarypoints, boundaryvalues = get_boundary_conditions(source_point)

    eik.cleanarrays.apply_boundary_conditions(W, boundarypoints, boundaryvalues)

    dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy, dW_dt = get_initial_derivatives(W)
    dW_dt_target = 100.
    tol = 1e-4 # For some reason, it can never go below about 2.2e-4
    n = 0
    n_max = 3e4
    while (np.abs(dW_dt_target) > tol) and (n <= n_max):
    # for _ in tqdm(range(30000)):
        dW_dt_target = step_W(W, cost, dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy, dxy, eps, dW_dt, target_point[0] + 1, target_point[1] + 1)
        eik.cleanarrays.apply_boundary_conditions(W, boundarypoints, boundaryvalues)
        n += 1
    print(f"Converged after {n - 1} steps!")
    print(dW_dt_target)
    
    W_np = W.to_numpy()
    return eik.cleanarrays.unpad_array(W_np)

def get_padded_cost(cost_unpadded):
    cost_np = eik.cleanarrays.pad_array(cost_unpadded, pad_value=1., pad_shape=1)
    cost = ti.field(dtype=ti.f32, shape=cost_np.shape)
    cost.from_numpy(cost_np)
    return cost

def get_initial_W(N, initial_condition=100.):
    W_unpadded = np.full(shape=(N, N), fill_value=initial_condition)
    W_np = eik.cleanarrays.pad_array(W_unpadded, pad_value=initial_condition, pad_shape=1)
    W = ti.field(dtype=ti.f32, shape=W_np.shape)
    W.from_numpy(W_np)
    return W

def get_boundary_conditions(source_point):
    i_0, j_0 = source_point
    boundarypoints_np = np.array([[i_0, j_0]], dtype=int)
    boundaryvalues_np = np.array([0.], dtype=float)
    boundarypoints = ti.Vector.field(n=2, dtype=ti.i32, shape=1)
    boundarypoints.from_numpy(boundarypoints_np)
    boundaryvalues = ti.field(shape=1, dtype=ti.f32)
    boundaryvalues.from_numpy(boundaryvalues_np)
    return boundarypoints, boundaryvalues

def get_initial_derivatives(W):
    dx_forward = ti.field(dtype=ti.f32, shape=W.shape)
    dx_backward = ti.field(dtype=ti.f32, shape=W.shape)
    dy_forward = ti.field(dtype=ti.f32, shape=W.shape)
    dy_backward = ti.field(dtype=ti.f32, shape=W.shape)
    abs_dx = ti.field(dtype=ti.f32, shape=W.shape)
    abs_dy = ti.field(dtype=ti.f32, shape=W.shape)
    dW_dt = ti.field(dtype=ti.f32, shape=W.shape)
    return dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy, dW_dt

@ti.kernel
def step_W(
    W:ti.template(), 
    cost:ti.template(),
    dx_forward: ti.template(), 
    dx_backward: ti.template(), 
    dy_forward: ti.template(), 
    dy_backward: ti.template(), 
    abs_dx: ti.template(), 
    abs_dy: ti.template(), 
    dxy: ti.f32, 
    eps: ti.f32,
    dW_dt: ti.template(),
    i_target: ti.i32,
    j_target: ti.i32
) -> ti.f32:
    eik.derivativesR2.abs_derivatives(W, dxy, dx_forward, dx_backward, dy_forward, dy_backward, abs_dx, abs_dy)
    for I in ti.grouped(W):
        # It seems like TaiChi does not allow negative exponents.
        dW_dt[I] = (1 - (ti.math.sqrt((abs_dx[I] ** 2 + abs_dy[I] ** 2)) / cost[I]))
        W[I] += dW_dt[I] * eps
    return dW_dt[i_target, j_target] # adding this in makes it a bit slower.

The quality of the approximation spreads out from the source point. Hence, if we are interested only in the distance map to points that are not that far away (in terms of the true distance map), then we do not have to perform as many iterations.

In [None]:
source_point = (217, 278)
target_point = (211, 118)
W_iterative = iterative_method_retinal_R2(cost, source_point, target_point) # , 30000) # For distance map on entire image, use 200000
xs_retinal, ys_retinal = np.meshgrid(np.linspace(-1, 1, cost.shape[0]), np.linspace(-1, 1, cost.shape[1]))
# W_fast_marching = fast_marching_R2(cost) # Too slow...
W_iterative[source_point[0]-3:source_point[0]+2, source_point[1]-3:source_point[1]+2]

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 5))
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.resizable = False
ax.imshow(vesselness, cmap="gray", extent=[-1, 1, -1, 1], origin="lower")
contour = ax.contour(xs_retinal, ys_retinal, W_iterative, levels=np.linspace(0., .15, 5)) # Doesn't look great...
ax.scatter(xs_retinal[source_point[0], source_point[1]], ys_retinal[source_point[0], source_point[1]], label="Source")
ax.scatter(xs_retinal[target_point[0], target_point[1]], ys_retinal[target_point[0], target_point[1]], label="Target")
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.legend()
fig.colorbar(contour, label="$W(x, y)$");

### Geodesic tracking
Now that we have the distance map (from a certain source point), we are able to compute the geodesics between the source point and any other point $(x^*, y^*)$ using backtracking:
$$\begin{dcases} \dot{\gamma}(t) = -\left(\frac{\partial_x W(\gamma(t))}{\mathrm{cost}_{\lambda, p}^2(\gamma(t))}, \frac{\partial_y W(\gamma(t))}{\mathrm{cost}_{\lambda, p}^2(\gamma(t))}\right), & t > 0, \\
\gamma(0) = (x^*, y^*). & \end{dcases}$$

target = (211, 118) {-> (118, 211) cuz coordinates}