In [40]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage


In [60]:
import jax
import jax.numpy as jnp
from jax import jit

@jit 
def edt_jax(binary_input: jnp.ndarray) -> jnp.ndarray:
    """Compute Euclidean distance transform of binary input."""
    # Initialize distance grid
    d = jnp.where(binary_input == 1, jnp.inf, 0.0)

    # Process each axis
    for axis in range(binary_input.ndim):
        # Move current axis to front for processing
        d = jnp.moveaxis(d, axis, 0)
        # Process each 1D line along current axis
        u = jnp.arange(d.shape[0])[:, None, None]  # Make broadcastable
        
        # For each position, compute distance to every possible target
        # and take minimum
        d_line = jnp.min((u - jnp.arange(d.shape[0])[None, :, None]) ** 2 + d, axis=1)
        d = jnp.moveaxis(d_line, 0, axis)

    return jnp.sqrt(d)

# Test
x = jnp.ones((5, 5))
x = x.at[2, 2].set(0) 
print(edt_jax(x))

[[2.828427  2.2360678 2.        2.2360678 2.828427 ]
 [2.2360678 1.4142135 1.        1.4142135 2.2360678]
 [2.        1.        0.        1.        2.       ]
 [2.2360678 1.4142135 1.        1.4142135 2.2360678]
 [2.828427  2.2360678 2.        2.2360678 2.828427 ]]


In [63]:
@jit 
def edt_jax(binary_input: jnp.ndarray, sampling: tuple = None) -> jnp.ndarray:
    """Compute Euclidean distance transform of binary input with per-dimension sampling."""
    d = jnp.where(binary_input == 1, jnp.inf, 0.0)
    
    sampling = sampling or (1.0,) * binary_input.ndim

    for axis in range(binary_input.ndim):
        d = jnp.moveaxis(d, axis, 0)
        u = jnp.arange(d.shape[0])[:, None, None] * sampling[axis]
        d_line = jnp.min((u - jnp.arange(d.shape[0])[None, :, None] * sampling[axis]) ** 2 + d, axis=1)
        d = jnp.moveaxis(d_line, 0, axis)

    return jnp.sqrt(d)

# Test with different sampling per dimension
x = jnp.ones((5, 5))
x = x.at[2, 2].set(0)
print(edt_jax(x, sampling=(1.0, 2.0)))  # Y dimension stretched by 2x

[[4.4721355 2.828427  2.        2.828427  4.4721355]
 [4.1231055 2.2360678 1.        2.2360678 4.1231055]
 [4.        2.        0.        2.        4.       ]
 [4.1231055 2.2360678 1.        2.2360678 4.1231055]
 [4.4721355 2.828427  2.        2.828427  4.4721355]]


In [65]:
print(edt_jax(a, sampling=(1.0, 2.0))) 

[[0.        1.        2.2360678 3.6055512 3.       ]
 [0.        0.        2.        2.828427  2.       ]
 [0.        1.        2.2360678 2.        1.       ]
 [0.        2.        2.2360678 1.        0.       ]
 [0.        2.        2.        0.        0.       ]]


In [66]:
ndimage.distance_transform_edt(a, sampling=(1.0, 2.0))

array([[0.        , 1.        , 2.23606798, 3.60555128, 3.        ],
       [0.        , 0.        , 2.        , 2.82842712, 2.        ],
       [0.        , 1.        , 2.23606798, 2.        , 1.        ],
       [0.        , 2.        , 2.23606798, 1.        , 0.        ],
       [0.        , 2.        , 2.        , 0.        , 0.        ]])

In [67]:
# Test with 3D array
x = jnp.ones((4, 4, 4))
x = x.at[2, 2, 2].set(0)  # Single point at center
print(edt_jax(x, sampling=(1.0, 2.0, 3.0)))  # Different sampling per dimension

ValueError: axis 2 is out of bounds for array of dimension 2

In [68]:
import jax
import jax.numpy as jnp
from jax import jit

@jit 
def edt_jax(binary_input: jnp.ndarray, sampling: tuple = None) -> jnp.ndarray:
    """Compute Euclidean distance transform of binary input with per-dimension sampling."""
    d = jnp.where(binary_input == 1, jnp.inf, 0.0)
    
    sampling = sampling or (1.0,) * binary_input.ndim
    ndim = binary_input.ndim

    for axis in range(ndim):
        d = jnp.moveaxis(d, axis, 0)
        u = jnp.arange(d.shape[0])
        for _ in range(ndim - 1):
            u = u[:, None]
        u = u * sampling[axis]
        
        target = jnp.arange(d.shape[0])
        for _ in range(ndim - 2):
            target = target[:, None]
        target = target[None, :, None] * sampling[axis]
        
        d_line = jnp.min((u - target) ** 2 + d, axis=1)
        d = jnp.moveaxis(d_line, 0, axis)

    return jnp.sqrt(d)

# Test with 3D array
x = jnp.ones((4, 4, 4))
x = x.at[2, 2, 2].set(0)  # Single point at center
print(edt_jax(x, sampling=(1.0, 2.0, 3.0)))  # Different sampling per dimension

[[[0.]]]


In [70]:
import jax
import jax.numpy as jnp
from jax import jit

@jit 
def edt_jax(binary_input: jnp.ndarray, sampling: tuple = None) -> jnp.ndarray:
    """Compute Euclidean distance transform of binary input with per-dimension sampling."""
    d = jnp.where(binary_input == 1, jnp.inf, 0.0)
    sampling = sampling or (1.0,) * binary_input.ndim

    for axis in range(binary_input.ndim):
        d = jnp.moveaxis(d, axis, 0)
        shape = [1] * (binary_input.ndim + 1)
        shape[0] = d.shape[0]
        shape[1] = d.shape[0]
        
        # Create position arrays for broadcasting
        u = jnp.arange(d.shape[0]).reshape([-1] + [1] * (binary_input.ndim-1)) * sampling[axis]
        v = jnp.arange(d.shape[0]).reshape([1, -1] + [1] * (binary_input.ndim-2)) * sampling[axis]
        
        d_line = jnp.min((u - v) ** 2 + d, axis=1)
        d = jnp.moveaxis(d_line, 0, axis)

    return jnp.sqrt(d)

# Test with 3D array
x = jnp.ones((4, 4, 4))
x = x.at[2, 2, 2].set(0)  # Single point at center
print(edt_jax(x, sampling=(1.0, 2.0, 3.0)))  # Different sampling per dimension

ValueError: axis 2 is out of bounds for array of dimension 2

In [71]:
import jax
import jax.numpy as jnp
from jax import jit

@jit 
def edt_jax(binary_input: jnp.ndarray, sampling: tuple = None) -> jnp.ndarray:
    d = jnp.where(binary_input == 1, jnp.inf, 0.0)
    sampling = sampling or (1.0,) * binary_input.ndim

    for axis in range(binary_input.ndim):
        d = jnp.moveaxis(d, axis, 0)
        shape = list(d.shape)
        coords = jnp.arange(shape[0]) * sampling[axis]
        diff_sq = (coords[:, None] - coords[None, :]) ** 2
        diff_sq = diff_sq.reshape((shape[0], shape[0]) + (1,) * (len(shape) - 1))
        d_line = jnp.min(diff_sq + d, axis=1)
        d = jnp.moveaxis(d_line, 0, axis)

    return jnp.sqrt(d)

# Test
x = jnp.ones((4, 4, 4))
x = x.at[2, 2, 2].set(0)
print(edt_jax(x, sampling=(1.0, 2.0, 3.0)))

[[[7.4833145 5.3851647 4.4721355 5.3851647]
  [6.6332493 4.1231055 2.828427  4.1231055]
  [6.3245554 3.6055512 2.        3.6055512]
  [6.6332493 4.1231055 2.828427  4.1231055]]

 [[7.28011   5.0990195 4.1231055 5.0990195]
  [6.403124  3.7416573 2.2360678 3.7416573]
  [6.0827622 3.1622777 1.        3.1622777]
  [6.403124  3.7416573 2.2360678 3.7416573]]

 [[7.2111025 5.        4.        5.       ]
  [6.3245554 3.6055512 2.        3.6055512]
  [6.        3.        0.        3.       ]
  [6.3245554 3.6055512 2.        3.6055512]]

 [[7.28011   5.0990195 4.1231055 5.0990195]
  [6.403124  3.7416573 2.2360678 3.7416573]
  [6.0827622 3.1622777 1.        3.1622777]
  [6.403124  3.7416573 2.2360678 3.7416573]]]


In [72]:
ndimage.distance_transform_edt(x, sampling=(1.0, 2.0, 3.0))

array([[[7.48331477, 5.38516481, 4.47213595, 5.38516481],
        [6.63324958, 4.12310563, 2.82842712, 4.12310563],
        [6.32455532, 3.60555128, 2.        , 3.60555128],
        [6.63324958, 4.12310563, 2.82842712, 4.12310563]],

       [[7.28010989, 5.09901951, 4.12310563, 5.09901951],
        [6.40312424, 3.74165739, 2.23606798, 3.74165739],
        [6.08276253, 3.16227766, 1.        , 3.16227766],
        [6.40312424, 3.74165739, 2.23606798, 3.74165739]],

       [[7.21110255, 5.        , 4.        , 5.        ],
        [6.32455532, 3.60555128, 2.        , 3.60555128],
        [6.        , 3.        , 0.        , 3.        ],
        [6.32455532, 3.60555128, 2.        , 3.60555128]],

       [[7.28010989, 5.09901951, 4.12310563, 5.09901951],
        [6.40312424, 3.74165739, 2.23606798, 3.74165739],
        [6.08276253, 3.16227766, 1.        , 3.16227766],
        [6.40312424, 3.74165739, 2.23606798, 3.74165739]]])