In [8]:
import jax.numpy as jnp
import jax
@jax.jit
def padded_translate(arr, shift, grid_size: int = 4):
    # Step 1: Pad array with zeros
    padded = jnp.pad(arr, ((0, grid_size-1), (0, grid_size-1)), mode='constant')

    # Step 2: Roll the padded array
    rolled = jnp.roll(padded, shift=shift, axis=(0, 1))

    # Step 3: Crop the central 4x4 region
    return rolled[:grid_size, :grid_size]


In [7]:
arr = jnp.array([[1, 0, 0, 0],
                 [1, 1, 0, 0],
                 [0, 0, 0, 0],
                 [0, 0, 0, 0]])

padded_translate(arr, shift=(0, 3))

Array([[0, 0, 0, 1],
       [0, 0, 0, 1],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=int32)

In [None]:
arr = jnp.array([[1, 0, 0, 0],
                 [0, 1, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]])

padded = jnp.pad(arr, ((0, 3), (0, 3)), mode='constant')  # Example padding

In [None]:
rolled = jnp.roll(padded, shift=(2, 1), axis=(0, 1))  # Example rolling

In [None]:
rolled[:4]  # Display the central 4x4 region