In [178]:
import jax
import jax.numpy as jnp
from jax import lax
import scipy.ndimage
from jax.lax import reduce_window

In [222]:
def uniform_filter_jax(input, size, axis=None):
    input = jnp.asarray(input)
    
    if axis is None:
        size = (size,) * input.ndim if isinstance(size, int) else tuple(size)
        kernel = jnp.ones(size, dtype=input.dtype) / jnp.prod(jnp.array(size))
        
        # Reflective padding matching SciPy
        pad_width = [(s // 2, s // 2) for s in size]
        input_padded = jnp.pad(input, pad_width, mode='edge')
        
        result = lax.conv_general_dilated(
            input_padded[None, None],  # Add batch and channel dims
            kernel[None, None],  # Add batch and channel dims
            window_strides=(1,) * input.ndim,
            padding='VALID'  # Since we manually pad
        )[0, 0]  # Remove batch and channel dims
        return result
    else:
        shape = [1] * input.ndim
        shape[axis] = size
        kernel = jnp.ones(tuple(shape), dtype=input.dtype) / size
        
        # Reflective padding matching SciPy for the given axis
        pad_width = [(0, 0)] * input.ndim
        pad_width[axis] = (size // 2, size // 2)
        input_padded = jnp.pad(input, pad_width, mode='edge')
        
        result = lax.conv_general_dilated(
            input_padded[None, None],  # Add batch and channel dims
            kernel[None, None],  # Add batch and channel dims
            window_strides=(1,) * input.ndim,
            padding='VALID'  # Since we manually pad
        )[0, 0]  # Remove batch and channel dims
        return result

In [223]:
# Example input image
image = jnp.array([[1, 2, 3],
                          [4, 5, 6],
                          [7, 8, 9]], dtype=jnp.float32)

In [224]:
filtered_image = uniform_filter_jax(image, 3)

print(filtered_image)

[[2.3333333 3.        3.666667 ]
 [4.3333335 5.        5.666667 ]
 [6.333333  7.        7.6666665]]


In [225]:
# Example input image
image = jnp.array([[1, 2, 3],
                          [4, 5, 6],
                          [7, 8, 9]], dtype=jnp.float32)

In [226]:
scipy_filtered_image = scipy.ndimage.uniform_filter(image, size=3)

print(scipy_filtered_image)

[[2.3333333 3.        3.6666667]
 [4.3333335 5.        5.6666665]
 [6.3333335 7.        7.6666665]]


In [229]:
assert jnp.allclose(filtered_image, scipy_filtered_image)