# Understanding `_move_with_push` in PushWorld

This notebook walks through the wave-front push mechanics step by step.


In [2]:
# Uncomment to run notebook

# import sys
# import os

# # Add project root to path for imports
# project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "..", ".."))
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)

# import jax
# import jax.numpy as jnp
# import jax.lax as lax
# import numpy as np

# # Now import from the package
# from jaxued.environments.pushworld.level import Level, prefabs, GRID_SIZE, MAX_PIXELS
# from jaxued.environments.pushworld.env import PushWorld, EnvState, Actions, DISPLACEMENTS

# print("Imports successful!")


Imports successful!


## 1. Create a Simple Puzzle

Let's use the `TrivialPush` puzzle:

```
.  .  .  .  .
.  A  M1  G1  .
.  .  .  .  .
```

Agent (A) is at position (1, 1), M1 is at (2, 1), G1 is at (3, 1).
If agent moves RIGHT, it should push M1 to the right.


In [3]:
# Load the puzzle
level = Level.from_str(prefabs["TrivialPush"])

print("Puzzle layout:")
print(level.to_str())
print()
print(f"Agent position: {level.agent_pos}")
print(f"M1 position: {level.m1_pos}")
print(f"G1 position: {level.g1_pos}")
print(f"Wall map shape: {level.wall_map.shape}")


Puzzle layout:
W  W  W  W  W  W  W  W  W  W
W  W  W  W  W  W  W  W  W  W
W  W  .  .  .  .  .  W  W  W
W  W  .  A  M1  G1  .  W  W  W
W  W  .  .  .  .  .  W  W  W
W  W  .  .  .  .  .  W  W  W
W  W  .  .  .  .  .  W  W  W
W  W  W  W  W  W  W  W  W  W
W  W  W  W  W  W  W  W  W  W
W  W  W  W  W  W  W  W  W  W

Agent position: [[ 3  3]
 [-1 -1]
 [-1 -1]]
M1 position: [[ 4  3]
 [-1 -1]
 [-1 -1]]
G1 position: [[ 5  3]
 [-1 -1]
 [-1 -1]]
Wall map shape: (10, 10)


In [4]:
# Create environment and initialize state
env = PushWorld()
state = env.init_state_from_level(level)

print("Initial State:")
print(f"  agent_pos: {state.agent_pos}")
print(f"  m1_pos: {state.m1_pos}")
print(f"  m2_pos: {state.m2_pos}")
print(f"  g1_pos: {state.g1_pos}")


Initial State:
  agent_pos: [[ 3  3]
 [-1 -1]
 [-1 -1]]
  m1_pos: [[ 4  3]
 [-1 -1]
 [-1 -1]]
  m2_pos: [[-1 -1]
 [-1 -1]
 [-1 -1]]
  g1_pos: [[ 5  3]
 [-1 -1]
 [-1 -1]]


## 2. Understanding the Coordinate System

- Each object can have up to `MAX_PIXELS=3` pixels (polyominoes)
- Coordinates are stored as `(MAX_PIXELS, 2)` arrays: `[[x0, y0], [x1, y1], [x2, y2]]`
- Invalid/unused pixels are marked with `-1`


In [5]:
print(f"MAX_PIXELS = {MAX_PIXELS}")
print(f"GRID_SIZE = {GRID_SIZE}")
print()
print("Agent position array (shape {}):".format(state.agent_pos.shape))
print(state.agent_pos)
print(
    "  -> First pixel at (x={}, y={})".format(
        state.agent_pos[0, 0], state.agent_pos[0, 1]
    )
)
print("  -> Other pixels are -1 (unused)")


MAX_PIXELS = 3
GRID_SIZE = 10

Agent position array (shape (3, 2)):
[[ 3  3]
 [-1 -1]
 [-1 -1]]
  -> First pixel at (x=3, y=3)
  -> Other pixels are -1 (unused)


## 3. Action and Displacement

When agent takes an action, we compute a displacement vector:

- UP (0): (0, -1)
- RIGHT (1): (1, 0)
- DOWN (2): (0, 1)
- LEFT (3): (-1, 0)


In [6]:
print("DISPLACEMENTS array:")
print(DISPLACEMENTS)
print()

# Let's move RIGHT
action = Actions.right
displacement = DISPLACEMENTS[action]
print(f"Action: {action.name} (index {action})")
print(
    f"Displacement: {displacement}  (move x by {displacement[0]}, y by {displacement[1]})"
)


DISPLACEMENTS array:
[[ 0 -1]
 [ 1  0]
 [ 0  1]
 [-1  0]]

Action: right (index 1)
Displacement: [1 0]  (move x by 1, y by 0)


## 4. Step-by-Step: `_move_with_push`

Now let's break down the method into individual steps.

### Step 4.1: Stack all object coordinates

We stack agent + 4 movable objects into a single array for easier processing.


In [15]:
# Stack coordinates: [agent, m1, m2, m3, m4]
coords = jnp.stack(
    [
        state.agent_pos,
        state.m1_pos,
        state.m2_pos,
        state.m3_pos,
        state.m4_pos,
    ],
    axis=0,
)

print(
    f"Stacked coords shape: {coords.shape}  (N=5 objects, MAX_PIXELS=3 pixels each, 2 coords)"
)
print()
print("Object 0 (Agent):")
print(coords[0])
print()
print("Object 1 (M1):")
print(coords[1])
print()
print("Object 2 (M2) - empty/unused:")
print(coords[2])


Stacked coords shape: (5, 3, 2)  (N=5 objects, MAX_PIXELS=3 pixels each, 2 coords)

Object 0 (Agent):
[[ 3  3]
 [-1 -1]
 [-1 -1]]

Object 1 (M1):
[[ 4  3]
 [-1 -1]
 [-1 -1]]

Object 2 (M2) - empty/unused:
[[-1 -1]
 [-1 -1]
 [-1 -1]]


### Step 4.2: Compute displaced positions

For each pixel of each object, compute where it would be after applying the displacement.


In [8]:
def masked_displacement(coords, displacement):
    """
    Compute displaced coordinates, but only for valid pixels (not -1).

    Args:
        coords: (MAX_PIXELS, 2) - coordinates of one object's pixels
        displacement: (2,) - displacement vector

    Returns:
        disp: (MAX_PIXELS, 2) - displaced coordinates (unchanged for invalid pixels)
        valid: (MAX_PIXELS,) - mask of which pixels are valid
    """
    # A pixel is valid if both x and y are >= 0
    valid = (coords[:, 0] >= 0) & (coords[:, 1] >= 0)

    # Compute displacement for all pixels
    disp_all = coords + displacement

    # Only apply displacement to valid pixels
    disp = jnp.where(valid[:, None], disp_all, coords)

    return disp, valid


# Apply to agent's coordinates
agent_disp, agent_valid = masked_displacement(state.agent_pos, displacement)
print("Agent displacement:")
print(f"  Original: {state.agent_pos[0]}")
print(f"  Displaced: {agent_disp[0]}")
print(f"  Valid mask: {agent_valid}")


Agent displacement:
  Original: [3 3]
  Displaced: [4 3]
  Valid mask: [ True False False]


In [16]:
# Apply to ALL objects using vmap
all_disp, all_valid = jax.vmap(lambda c: masked_displacement(c, displacement))(coords)

print(f"all_disp shape: {all_disp.shape}  (5 objects, 3 pixels, 2 coords)")
print(f"all_valid shape: {all_valid.shape}  (5 objects, 3 pixels)")
print()
print("Displaced positions for each object:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {coords[i, 0]} -> {all_disp[i, 0]}  (valid: {all_valid[i, 0]})")


all_disp shape: (5, 3, 2)  (5 objects, 3 pixels, 2 coords)
all_valid shape: (5, 3)  (5 objects, 3 pixels)

Displaced positions for each object:
  Agent: [3 3] -> [4 3]  (valid: True)
  M1: [4 3] -> [5 3]  (valid: True)
  M2: [-1 -1] -> [-1 -1]  (valid: False)
  M3: [-1 -1] -> [-1 -1]  (valid: False)
  M4: [-1 -1] -> [-1 -1]  (valid: False)


### Step 4.3: Initialize the wave-front

The "wave-front" algorithm starts with the agent and propagates through objects it collides with.

- `frontier`: Objects currently being checked for collisions
- `pushed`: Objects that are part of the push chain
- `broken`: Whether the push chain hit a wall/obstacle


In [17]:
N = 5  # Number of objects (agent + 4 movables)

# Start with only agent in the frontier
frontier = jnp.array([True, False, False, False, False])

# Agent is already "pushed" (part of the chain)
pushed = jnp.zeros((N,), dtype=jnp.bool_).at[0].set(True)

# Not broken yet
broken = jnp.array(False)

print("Initial state:")
print(f"  frontier: {frontier}  <- Agent is in the frontier")
print(f"  pushed:   {pushed}    <- Agent is pushed")
print(f"  broken:   {broken}")


Initial state:
  frontier: [ True False False False False]  <- Agent is in the frontier
  pushed:   [ True False False False False]    <- Agent is pushed
  broken:   False


### Step 4.4: Check for wall/OOB collisions

Before propagating, check if any frontier object hits a wall or goes out of bounds.


In [19]:
# Let's unpack check_blocked step by step

grid_size = GRID_SIZE
wall_map = state.wall_map

print("=== Step 4.4a: Extract x and y coordinates ===")
print()

# all_disp has shape (5, 3, 2) = (N objects, MAX_PIXELS, 2 coords)
# We extract x coords (index 0) and y coords (index 1)
xs = all_disp[..., 0]  # Shape: (5, 3) - x coordinate for each pixel of each object
ys = all_disp[..., 1]  # Shape: (5, 3) - y coordinate for each pixel of each object

print(f"xs shape: {xs.shape}")
print(f"ys shape: {ys.shape}")
print()
print("X coordinates (displaced) for each object:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {xs[i]}")
print()
print("Y coordinates (displaced) for each object:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {ys[i]}")


=== Step 4.4a: Extract x and y coordinates ===

xs shape: (5, 3)
ys shape: (5, 3)

X coordinates (displaced) for each object:
  Agent: [ 4 -1 -1]
  M1: [ 5 -1 -1]
  M2: [-1 -1 -1]
  M3: [-1 -1 -1]
  M4: [-1 -1 -1]

Y coordinates (displaced) for each object:
  Agent: [ 3 -1 -1]
  M1: [ 3 -1 -1]
  M2: [-1 -1 -1]
  M3: [-1 -1 -1]
  M4: [-1 -1 -1]


In [20]:
print("=== Step 4.4b: Clamp coordinates for safe array indexing ===")
print()

# Problem: Some coordinates are -1 (invalid/unused pixels)
# We can't index wall_map[-1, -1] - it would wrap around!
# Solution: Clamp to valid range [0, grid_size-1] for the lookup

xs_clamped = jnp.clip(xs, 0, grid_size - 1)
ys_clamped = jnp.clip(ys, 0, grid_size - 1)

print(f"grid_size = {grid_size}")
print()
print("Before clamping (M2 has -1 values):")
print(f"  M2 xs: {xs[2]}")
print()
print("After clamping (now safe for indexing):")
print(f"  M2 xs_clamped: {xs_clamped[2]}")
print()
print(
    "Note: Clamping doesn't change valid coordinates, only fixes invalid ones for safe indexing"
)


=== Step 4.4b: Clamp coordinates for safe array indexing ===

grid_size = 10

Before clamping (M2 has -1 values):
  M2 xs: [-1 -1 -1]

After clamping (now safe for indexing):
  M2 xs_clamped: [0 0 0]

Note: Clamping doesn't change valid coordinates, only fixes invalid ones for safe indexing


In [21]:
print("=== Step 4.4c: Look up wall values at displaced positions ===")
print()

# wall_map is a 2D boolean array where True = wall
# We look up: wall_map[y, x] for each displaced pixel
# Note: y comes first because arrays are indexed [row, col] = [y, x]

raw_vals = wall_map[ys_clamped, xs_clamped]

print(f"wall_map shape: {wall_map.shape}")
print(f"raw_vals shape: {raw_vals.shape}  (same as xs/ys: 5 objects, 3 pixels)")
print()
print("Wall values at displaced positions:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name} at ({xs_clamped[i]}, {ys_clamped[i]}): wall={raw_vals[i]}")
print()
print("Interpretation: True means there's a wall at that position")


=== Step 4.4c: Look up wall values at displaced positions ===

wall_map shape: (10, 10)
raw_vals shape: (5, 3)  (same as xs/ys: 5 objects, 3 pixels)

Wall values at displaced positions:
  Agent at ([4 0 0], [3 0 0]): wall=[False  True  True]
  M1 at ([5 0 0], [3 0 0]): wall=[False  True  True]
  M2 at ([0 0 0], [0 0 0]): wall=[ True  True  True]
  M3 at ([0 0 0], [0 0 0]): wall=[ True  True  True]
  M4 at ([0 0 0], [0 0 0]): wall=[ True  True  True]

Interpretation: True means there's a wall at that position


In [22]:
print("=== Step 4.4d: Mask out invalid pixels ===")
print()

# Problem: We looked up wall values for ALL pixels, including invalid ones
# Invalid pixels (like M2's -1,-1) got clamped to (0,0) and looked up
# But we should ignore those results!

# all_valid tells us which pixels are real (True) vs padding (False)
print("all_valid (which pixels are real):")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {all_valid[i]}")

print()

# Only count wall hits for valid pixels
wall_vals = jnp.where(all_valid, raw_vals, False)

print("wall_vals after masking (False for invalid pixels):")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {wall_vals[i]}")


=== Step 4.4d: Mask out invalid pixels ===

all_valid (which pixels are real):
  Agent: [ True False False]
  M1: [ True False False]
  M2: [False False False]
  M3: [False False False]
  M4: [False False False]

wall_vals after masking (False for invalid pixels):
  Agent: [False False False]
  M1: [False False False]
  M2: [False False False]
  M3: [False False False]
  M4: [False False False]


In [23]:
print("=== Step 4.4e: Check if any pixel hits a wall ===")
print()

# For each object, check if ANY of its pixels hit a wall
# jnp.any(..., axis=1) reduces across pixels (axis 1) for each object
hit_wall = jnp.any(wall_vals, axis=1)

print(f"hit_wall shape: {hit_wall.shape}  (one value per object)")
print()
print("Does each object hit a wall?")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {hit_wall[i]}")


=== Step 4.4e: Check if any pixel hits a wall ===

hit_wall shape: (5,)  (one value per object)

Does each object hit a wall?
  Agent: False
  M1: False
  M2: False
  M3: False
  M4: False


In [24]:
print("=== Step 4.4f: Check for out-of-bounds ===")
print()

# Check if coordinates are within [0, grid_size)
in_bounds = (xs >= 0) & (xs < grid_size) & (ys >= 0) & (ys < grid_size)

print("in_bounds for each pixel:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {in_bounds[i]}")

print()
print("Note: Invalid pixels (-1, -1) show as out of bounds, but that's expected")


=== Step 4.4f: Check for out-of-bounds ===

in_bounds for each pixel:
  Agent: [ True False False]
  M1: [ True False False]
  M2: [False False False]
  M3: [False False False]
  M4: [False False False]

Note: Invalid pixels (-1, -1) show as out of bounds, but that's expected


In [25]:
print("=== Step 4.4g: Mask and combine OOB check ===")
print()

# For invalid pixels, treat them as "in bounds" (don't count as OOB)
valid_in_bounds = jnp.where(all_valid, in_bounds, True)

print("valid_in_bounds (invalid pixels treated as True):")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {valid_in_bounds[i]}")

print()

# Object is OOB if ANY valid pixel is out of bounds
# ~jnp.all(...) = NOT(all pixels in bounds) = at least one is OOB
oob = ~jnp.all(valid_in_bounds, axis=1)

print("Is each object out of bounds?")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {oob[i]}")


=== Step 4.4g: Mask and combine OOB check ===

valid_in_bounds (invalid pixels treated as True):
  Agent: [ True  True  True]
  M1: [ True  True  True]
  M2: [ True  True  True]
  M3: [ True  True  True]
  M4: [ True  True  True]

Is each object out of bounds?
  Agent: False
  M1: False
  M2: False
  M3: False
  M4: False


In [26]:
print("=== Step 4.4h: Final blocked status ===")
print()

# An object is blocked if it hits a wall OR goes out of bounds
blocked = hit_wall | oob

print("Final blocked status (hit_wall OR oob):")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: blocked={blocked[i]}  (hit_wall={hit_wall[i]}, oob={oob[i]})")

print()
print("Summary: Neither Agent nor M1 is blocked, so the push can proceed!")


=== Step 4.4h: Final blocked status ===

Final blocked status (hit_wall OR oob):
  Agent: blocked=False  (hit_wall=False, oob=False)
  M1: blocked=False  (hit_wall=False, oob=False)
  M2: blocked=False  (hit_wall=False, oob=False)
  M3: blocked=False  (hit_wall=False, oob=False)
  M4: blocked=False  (hit_wall=False, oob=False)

Summary: Neither Agent nor M1 is blocked, so the push can proceed!


### Step 4.5: Compute collision matrix

Check if any displaced pixel of object `i` overlaps with any stationary pixel of object `j`.


In [29]:
print("=== Step 4.5a: Understanding the goal ===")
print()
print(
    "We want to check: if object i moves, does it collide with object j (staying still)?"
)
print()
print("Data we have:")
print(f"  all_disp shape: {all_disp.shape}  <- displaced positions for each object")
print(f"  coords shape: {coords.shape}      <- original positions for each object")
print(f"  all_valid shape: {all_valid.shape}  <- which pixels are real vs padding")
print()
print("We need to compare EVERY pixel of displaced object i")
print("with EVERY pixel of stationary object j")
print()
print("Example: Does displaced Agent overlap with stationary M1?")
print(f"  Agent displaced to: {all_disp[0, 0]}  (only 1 valid pixel)")
print(f"  M1 is at: {coords[1, 0]}  (only 1 valid pixel)")
print(f"  Are they the same? {jnp.array_equal(all_disp[0, 0], coords[1, 0])}")


=== Step 4.5a: Understanding the goal ===

We want to check: if object i moves, does it collide with object j (staying still)?

Data we have:
  all_disp shape: (5, 3, 2)  <- displaced positions for each object
  coords shape: (5, 3, 2)      <- original positions for each object
  all_valid shape: (5, 3)  <- which pixels are real vs padding

We need to compare EVERY pixel of displaced object i
with EVERY pixel of stationary object j

Example: Does displaced Agent overlap with stationary M1?
  Agent displaced to: [4 3]  (only 1 valid pixel)
  M1 is at: [4 3]  (only 1 valid pixel)
  Are they the same? True


In [30]:
print("=== Step 4.5b: Broadcasting magic - compare all pairs at once ===")
print()

# The key trick: use broadcasting to compare ALL pixel combinations
# We reshape arrays to enable broadcasting:
#   all_disp: (N, MAX_PIXELS, 2) -> (N, MAX_PIXELS, 1, 1, 2)
#   coords:   (N, MAX_PIXELS, 2) -> (1, 1, N, MAX_PIXELS, 2)
# Then == broadcasts to: (N, MAX_PIXELS, N, MAX_PIXELS, 2)

print("Step-by-step broadcasting:")
print()
print("Original shapes:")
print(f"  all_disp: {all_disp.shape}  (5 objects, 3 pixels, 2 coords)")
print(f"  coords:   {coords.shape}  (5 objects, 3 pixels, 2 coords)")
print()

# Reshape for broadcasting
all_disp_expanded = all_disp[:, :, None, None, :]  # (N, MAX_PIXELS, 1, 1, 2)
coords_expanded = coords[None, None, :, :, :]  # (1, 1, N, MAX_PIXELS, 2)

print("After adding dimensions:")
print(f"  all_disp_expanded: {all_disp_expanded.shape}  (N, MAX_PIXELS, 1, 1, 2)")
print(f"  coords_expanded:   {coords_expanded.shape}  (1, 1, N, MAX_PIXELS, 2)")
print()

# Compare - broadcasts to (N, MAX_PIXELS, N, MAX_PIXELS, 2)
eq = all_disp_expanded == coords_expanded

print("After broadcasting ==:")
print(f"  eq shape: {eq.shape}")
print()
print("Interpretation of axes:")
print("  eq[i, p_i, j, p_j, c] = True if:")
print("    object i's pixel p_i (displaced)")
print("    has same coordinate c (0=x, 1=y)")
print("    as object j's pixel p_j (stationary)")


=== Step 4.5b: Broadcasting magic - compare all pairs at once ===

Step-by-step broadcasting:

Original shapes:
  all_disp: (5, 3, 2)  (5 objects, 3 pixels, 2 coords)
  coords:   (5, 3, 2)  (5 objects, 3 pixels, 2 coords)

After adding dimensions:
  all_disp_expanded: (5, 3, 1, 1, 2)  (N, MAX_PIXELS, 1, 1, 2)
  coords_expanded:   (1, 1, 5, 3, 2)  (1, 1, N, MAX_PIXELS, 2)

After broadcasting ==:
  eq shape: (5, 3, 5, 3, 2)

Interpretation of axes:
  eq[i, p_i, j, p_j, c] = True if:
    object i's pixel p_i (displaced)
    has same coordinate c (0=x, 1=y)
    as object j's pixel p_j (stationary)


In [31]:
print("=== Step 4.5c: Check if BOTH x AND y match ===")
print()

# For two pixels to collide, BOTH x AND y must match
# eq has shape (N, MAX_PIXELS, N, MAX_PIXELS, 2) where last axis is [x_match, y_match]
# We reduce with jnp.all along the last axis

pixel_eq = jnp.all(eq, axis=-1)

print(f"pixel_eq shape: {pixel_eq.shape}  (5, 3, 5, 3)")
print()
print("Interpretation:")
print("  pixel_eq[i, p_i, j, p_j] = True if pixel p_i of displaced object i")
print(
    "                            is at the SAME position as pixel p_j of stationary object j"
)
print()

# Let's look at a specific example: Agent (i=0) vs M1 (j=1)
print("Example: Agent (i=0) pixel 0 vs M1 (j=1) pixel 0")
print(f"  Agent displaced pixel 0: {all_disp[0, 0]}")
print(f"  M1 stationary pixel 0:   {coords[1, 0]}")
print(f"  pixel_eq[0, 0, 1, 0] = {pixel_eq[0, 0, 1, 0]}")
print()
print(
    "  They match! Agent's displaced position (4, 3) == M1's original position (4, 3)"
)


=== Step 4.5c: Check if BOTH x AND y match ===

pixel_eq shape: (5, 3, 5, 3)  (5, 3, 5, 3)

Interpretation:
  pixel_eq[i, p_i, j, p_j] = True if pixel p_i of displaced object i
                            is at the SAME position as pixel p_j of stationary object j

Example: Agent (i=0) pixel 0 vs M1 (j=1) pixel 0
  Agent displaced pixel 0: [4 3]
  M1 stationary pixel 0:   [4 3]
  pixel_eq[0, 0, 1, 0] = True

  They match! Agent's displaced position (4, 3) == M1's original position (4, 3)


In [32]:
print("=== Step 4.5d: The problem with invalid pixels ===")
print()

# Problem: Invalid pixels (-1, -1) can also "match" each other!
print("Recall all_valid - which pixels are real:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {all_valid[i]}")
print()

# Look at M2 vs M3 - both are entirely invalid
print("Problem case: M2 (i=2) vs M3 (j=3)")
print(f"  M2 displaced pixel 0: {all_disp[2, 0]}  (invalid: -1, -1)")
print(f"  M3 stationary pixel 0: {coords[3, 0]}  (invalid: -1, -1)")
print(f"  pixel_eq[2, 0, 3, 0] = {pixel_eq[2, 0, 3, 0]}  <- OOPS! False positive!")
print()
print("Actually they don't match here because -1 != -1 after displacement...")
print("But even so, we should mask invalid pixels to be safe and explicit.")


=== Step 4.5d: The problem with invalid pixels ===

Recall all_valid - which pixels are real:
  Agent: [ True False False]
  M1: [ True False False]
  M2: [False False False]
  M3: [False False False]
  M4: [False False False]

Problem case: M2 (i=2) vs M3 (j=3)
  M2 displaced pixel 0: [-1 -1]  (invalid: -1, -1)
  M3 stationary pixel 0: [-1 -1]  (invalid: -1, -1)
  pixel_eq[2, 0, 3, 0] = True  <- OOPS! False positive!

Actually they don't match here because -1 != -1 after displacement...
But even so, we should mask invalid pixels to be safe and explicit.


In [33]:
print("=== Step 4.5e: Mask out invalid pixel comparisons ===")
print()

# A collision only counts if BOTH pixels are valid
# We need to expand all_valid to match pixel_eq's shape

# all_valid has shape (5, 3) - one bool per pixel per object
# We need (5, 3, 5, 3) for the masking

# For the displaced object (i, p_i): expand to (N, MAX_PIXELS, 1, 1)
valid_i = all_valid[:, :, None, None]
print(f"valid_i shape: {valid_i.shape}  <- validity of displaced object's pixels")

# For the stationary object (j, p_j): expand to (1, 1, N, MAX_PIXELS)
valid_j = all_valid[None, None, :, :]
print(f"valid_j shape: {valid_j.shape}  <- validity of stationary object's pixels")

print()

# Combine: collision only valid if BOTH pixels are valid
valid_collision = pixel_eq & valid_i & valid_j

print(f"valid_collision shape: {valid_collision.shape}")
print()
print("valid_collision[i, p_i, j, p_j] is True only if:")
print("  1. pixel_eq[i, p_i, j, p_j] is True (same position)")
print("  2. all_valid[i, p_i] is True (displaced pixel is real)")
print("  3. all_valid[j, p_j] is True (stationary pixel is real)")


=== Step 4.5e: Mask out invalid pixel comparisons ===

valid_i shape: (5, 3, 1, 1)  <- validity of displaced object's pixels
valid_j shape: (1, 1, 5, 3)  <- validity of stationary object's pixels

valid_collision shape: (5, 3, 5, 3)

valid_collision[i, p_i, j, p_j] is True only if:
  1. pixel_eq[i, p_i, j, p_j] is True (same position)
  2. all_valid[i, p_i] is True (displaced pixel is real)
  3. all_valid[j, p_j] is True (stationary pixel is real)


In [34]:
print("=== Step 4.5f: Reduce to object-level collision matrix ===")
print()

# valid_collision has shape (5, 3, 5, 3) - pixel-level comparisons
# We want coll_mat with shape (5, 5) - object-level collisions

# Object i collides with object j if ANY of their pixels collide
# Reduce along pixel axes (1 and 3)

coll_mat = jnp.any(valid_collision, axis=(1, 3))

print(f"coll_mat shape: {coll_mat.shape}  (5 objects x 5 objects)")
print()
print("Collision matrix (rows = displaced obj, cols = stationary obj):")
print()
print("        Agent  M1     M2     M3     M4")
for i, name in enumerate(["Agent", "M1   ", "M2   ", "M3   ", "M4   "]):
    row = "  ".join([str(bool(coll_mat[i, j])).ljust(5) for j in range(5)])
    print(f"  {name}  {row}")


=== Step 4.5f: Reduce to object-level collision matrix ===

coll_mat shape: (5, 5)  (5 objects x 5 objects)

Collision matrix (rows = displaced obj, cols = stationary obj):

        Agent  M1     M2     M3     M4
  Agent  False  True   False  False  False
  M1     False  False  False  False  False
  M2     False  False  False  False  False
  M3     False  False  False  False  False
  M4     False  False  False  False  False


In [35]:
print("=== Step 4.5g: Interpret the collision matrix ===")
print()

print("Reading the matrix:")
print("  coll_mat[i, j] = True means:")
print("    'If object i moves in the direction, it will occupy a cell")
print("     where object j currently is (before j moves)'")
print()

# Find all collisions
print("All collisions found:")
for i in range(5):
    for j in range(5):
        if coll_mat[i, j]:
            names = ["Agent", "M1", "M2", "M3", "M4"]
            print(f"  {names[i]} (displaced) -> {names[j]} (stationary)")

print()
print("In our TrivialPush puzzle:")
print("  Agent at (3,3) moves right to (4,3)")
print("  M1 is at (4,3)")
print("  -> Agent's new position overlaps M1's current position")
print("  -> Therefore Agent 'pushes' M1!")
print()
print("Note: coll_mat[0, 0] would be True if Agent collided with itself,")
print("but that doesn't happen because Agent's displaced position (4,3)")
print("is different from Agent's original position (3,3).")


=== Step 4.5g: Interpret the collision matrix ===

Reading the matrix:
  coll_mat[i, j] = True means:
    'If object i moves in the direction, it will occupy a cell
     where object j currently is (before j moves)'

All collisions found:
  Agent (displaced) -> M1 (stationary)

In our TrivialPush puzzle:
  Agent at (3,3) moves right to (4,3)
  M1 is at (4,3)
  -> Agent's new position overlaps M1's current position
  -> Therefore Agent 'pushes' M1!

Note: coll_mat[0, 0] would be True if Agent collided with itself,
but that doesn't happen because Agent's displaced position (4,3)
is different from Agent's original position (3,3).


### Step 4.6: Propagate the wave-front

Find which objects are hit by the current frontier, add them to the push chain.


In [37]:
print("=== Step 4.6a: Recall the current state ===")
print()

print("We have these variables tracking the wave-front propagation:")
print()
print(f"frontier = {frontier}")
print("  ^ Objects we're currently processing (only Agent at start)")
print()
print(f"pushed = {pushed}")
print("  ^ Objects that are part of the push chain so far")
print()
print(f"blocked = {blocked}")
print("  ^ Which objects would hit a wall/OOB if they moved")
print()
print("Goal: Find which objects the frontier collides with, add them to the chain")


=== Step 4.6a: Recall the current state ===

We have these variables tracking the wave-front propagation:

frontier = [ True False False False False]
  ^ Objects we're currently processing (only Agent at start)

pushed = [ True False False False False]
  ^ Objects that are part of the push chain so far

blocked = [False False False False False]
  ^ Which objects would hit a wall/OOB if they moved

Goal: Find which objects the frontier collides with, add them to the chain


In [38]:
print("=== Step 4.6b: Check if any frontier object is blocked ===")
print()

# If a frontier object hits a wall, the entire push chain breaks
# We check: is ANY object both in the frontier AND blocked?

print("blocked & frontier:")
print(f"  blocked:  {blocked}")
print(f"  frontier: {frontier}")
print(f"  AND:      {blocked & frontier}")
print()

blocked_any = jnp.any(blocked & frontier)

print(f"blocked_any = jnp.any(blocked & frontier) = {blocked_any}")
print()
if blocked_any:
    print("STOP! A frontier object hit a wall. Push chain is broken.")
else:
    print("No frontier object is blocked. We can continue propagating.")


=== Step 4.6b: Check if any frontier object is blocked ===

blocked & frontier:
  blocked:  [False False False False False]
  frontier: [ True False False False False]
  AND:      [False False False False False]

blocked_any = jnp.any(blocked & frontier) = False

No frontier object is blocked. We can continue propagating.


In [39]:
print("=== Step 4.6c: Find neighbors - the broadcasting setup ===")
print()

# We want to find: which objects are hit by ANY object in the frontier?
# We have coll_mat[i, j] = True if displaced object i hits stationary object j

print("Recall the collision matrix:")
print()
print("        Agent  M1     M2     M3     M4")
for i, name in enumerate(["Agent", "M1   ", "M2   ", "M3   ", "M4   "]):
    row = "  ".join([str(bool(coll_mat[i, j])).ljust(5) for j in range(5)])
    print(f"  {name}  {row}")

print()
print("We want to select only rows where frontier[i] = True")
print(f"frontier = {frontier}  <- only Agent (index 0) is True")
print()

# Reshape frontier for broadcasting
frontier_mat = frontier[:, None]  # (N, 1)
print("frontier_mat = frontier[:, None]")
print(f"  Shape: {frontier_mat.shape}  (column vector)")
print(f"  Value:\n{frontier_mat}")


=== Step 4.6c: Find neighbors - the broadcasting setup ===

Recall the collision matrix:

        Agent  M1     M2     M3     M4
  Agent  False  True   False  False  False
  M1     False  False  False  False  False
  M2     False  False  False  False  False
  M3     False  False  False  False  False
  M4     False  False  False  False  False

We want to select only rows where frontier[i] = True
frontier = [ True False False False False]  <- only Agent (index 0) is True

frontier_mat = frontier[:, None]
  Shape: (5, 1)  (column vector)
  Value:
[[ True]
 [False]
 [False]
 [False]
 [False]]


In [40]:
print("=== Step 4.6d: Mask and reduce to find neighbors ===")
print()

# Multiply coll_mat by frontier_mat to zero out non-frontier rows
masked_coll = coll_mat * frontier_mat

print("coll_mat * frontier_mat:")
print("  This zeros out rows where frontier[i] = False")
print()
print("        Agent  M1     M2     M3     M4")
for i, name in enumerate(["Agent", "M1   ", "M2   ", "M3   ", "M4   "]):
    row = "  ".join([str(bool(masked_coll[i, j])).ljust(5) for j in range(5)])
    frontier_str = "<- frontier" if frontier[i] else "<- not frontier"
    print(f"  {name}  {row}  {frontier_str}")

print()

# Now reduce: for each column j, is there ANY row i where masked_coll[i,j] = True?
neighbors = jnp.any(masked_coll, axis=0)

print("neighbors = jnp.any(masked_coll, axis=0)")
print(f"  Reduce down columns -> {neighbors}")
print()
print("Interpretation: neighbors[j] = True if object j is hit by any frontier object")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    print(f"  {name}: {neighbors[i]}")


=== Step 4.6d: Mask and reduce to find neighbors ===

coll_mat * frontier_mat:
  This zeros out rows where frontier[i] = False

        Agent  M1     M2     M3     M4
  Agent  False  True   False  False  False  <- frontier
  M1     False  False  False  False  False  <- not frontier
  M2     False  False  False  False  False  <- not frontier
  M3     False  False  False  False  False  <- not frontier
  M4     False  False  False  False  False  <- not frontier

neighbors = jnp.any(masked_coll, axis=0)
  Reduce down columns -> [False  True False False False]

Interpretation: neighbors[j] = True if object j is hit by any frontier object
  Agent: False
  M1: True
  M2: False
  M3: False
  M4: False


In [41]:
print("=== Step 4.6e: Compute new frontier (unpushed neighbors) ===")
print()

# The new frontier is: neighbors that haven't been pushed yet
# This prevents infinite loops if objects form a cycle

print("neighbors (objects hit by frontier):")
print(f"  {neighbors}")
print()
print("pushed (objects already in push chain):")
print(f"  {pushed}")
print()
print("~pushed (NOT pushed = not yet in chain):")
print(f"  {~pushed}")
print()

new_frontier = neighbors & (~pushed)

print("new_frontier = neighbors & (~pushed):")
print(f"  {new_frontier}")
print()
print("Interpretation:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    if new_frontier[i]:
        print(f"  {name}: Added to frontier (was hit, not yet pushed)")
    elif neighbors[i] and pushed[i]:
        print(f"  {name}: Already pushed (hit but skip)")
    else:
        print(f"  {name}: Not hit by frontier")


=== Step 4.6e: Compute new frontier (unpushed neighbors) ===

neighbors (objects hit by frontier):
  [False  True False False False]

pushed (objects already in push chain):
  [ True False False False False]

~pushed (NOT pushed = not yet in chain):
  [False  True  True  True  True]

new_frontier = neighbors & (~pushed):
  [False  True False False False]

Interpretation:
  Agent: Not hit by frontier
  M1: Added to frontier (was hit, not yet pushed)
  M2: Not hit by frontier
  M3: Not hit by frontier
  M4: Not hit by frontier


In [42]:
print("=== Step 4.6f: Update the pushed set ===")
print()

# Add new frontier objects to the pushed set
new_pushed = pushed | new_frontier

print("pushed | new_frontier:")
print(f"  pushed:       {pushed}")
print(f"  new_frontier: {new_frontier}")
print(f"  OR result:    {new_pushed}")
print()

print("After this iteration:")
print("  Objects in push chain: ", end="")
names = ["Agent", "M1", "M2", "M3", "M4"]
pushed_names = [names[i] for i in range(5) if new_pushed[i]]
print(", ".join(pushed_names))
print()

print("What happens next in the loop:")
print("  1. frontier becomes new_frontier (now M1)")
print("  2. Check if M1 is blocked")
print("  3. Find what M1 collides with")
print("  4. Repeat until frontier is empty")
print()
print("In TrivialPush: M1 doesn't collide with anything else,")
print("so the next iteration will have empty new_frontier and the loop ends.")


=== Step 4.6f: Update the pushed set ===

pushed | new_frontier:
  pushed:       [ True False False False False]
  new_frontier: [False  True False False False]
  OR result:    [ True  True False False False]

After this iteration:
  Objects in push chain: Agent, M1

What happens next in the loop:
  1. frontier becomes new_frontier (now M1)
  2. Check if M1 is blocked
  3. Find what M1 collides with
  4. Repeat until frontier is empty

In TrivialPush: M1 doesn't collide with anything else,
so the next iteration will have empty new_frontier and the loop ends.


In [43]:
print("=== Step 4.6g: Summary - Wave-front propagation ===")
print()

print("Visual representation of one iteration:")
print()
print("  BEFORE:           AFTER:")
print("  frontier: [A]     frontier: [M1]")
print("  pushed:   [A]     pushed:   [A, M1]")
print()
print("  ┌───┬───┬───┐     ┌───┬───┬───┐")
print("  │   │ A │M1 │     │   │ A │M1 │")
print("  │   │ → │   │  →  │   │   │ → │")
print("  └───┴───┴───┘     └───┴───┴───┘")
print("       ↑                    ↑")
print("   A in frontier       M1 now in frontier")
print("   pushes M1           (will check what M1 hits)")
print()
print("The wave-front 'ripples' through the chain of objects!")
print("Each iteration, we check what the current frontier hits,")
print("and those become the next frontier.")


=== Step 4.6g: Summary - Wave-front propagation ===

Visual representation of one iteration:

  BEFORE:           AFTER:
  frontier: [A]     frontier: [M1]
  pushed:   [A]     pushed:   [A, M1]

  ┌───┬───┬───┐     ┌───┬───┬───┐
  │   │ A │M1 │     │   │ A │M1 │
  │   │ → │   │  →  │   │   │ → │
  └───┴───┴───┘     └───┴───┴───┘
       ↑                    ↑
   A in frontier       M1 now in frontier
   pushes M1           (will check what M1 hits)

The wave-front 'ripples' through the chain of objects!
Each iteration, we check what the current frontier hits,
and those become the next frontier.


### Step 4.7: Apply the push

If no blocking occurred, apply the displacement to all pushed objects.


In [44]:
final_pushed = new_pushed
final_broken = blocked_any  # In this case, False

# Agent moved OK if it was pushed and we didn't break
moved_ok = final_pushed[0] & (~final_broken)
print(f"Move successful? {moved_ok}")
print()

# Apply displacement to all pushed objects
should_move = final_pushed[:, None, None]  # (5, 1, 1) for broadcasting
moved_coords = jnp.where(should_move, all_disp, coords)

print("Final positions:")
for i, name in enumerate(["Agent", "M1", "M2", "M3", "M4"]):
    was_pushed = final_pushed[i]
    print(f"  {name}: {coords[i, 0]} -> {moved_coords[i, 0]}  (pushed: {was_pushed})")


Move successful? True

Final positions:
  Agent: [3 3] -> [4 3]  (pushed: True)
  M1: [4 3] -> [5 3]  (pushed: True)
  M2: [-1 -1] -> [-1 -1]  (pushed: False)
  M3: [-1 -1] -> [-1 -1]  (pushed: False)
  M4: [-1 -1] -> [-1 -1]  (pushed: False)


## 5. Full Test: Execute the action using the environment


In [45]:
# Reset and step
rng = jax.random.PRNGKey(0)
obs, state = env.reset_env_to_level(rng, level, env.default_params)

print("Before action:")
print(f"  Agent: {state.agent_pos[0]}")
print(f"  M1: {state.m1_pos[0]}")
print()

# Take RIGHT action
obs, new_state, reward, done, info = env.step_env(
    rng, state, Actions.right, env.default_params
)

print("After RIGHT action:")
print(f"  Agent: {new_state.agent_pos[0]}")
print(f"  M1: {new_state.m1_pos[0]}")
print(f"  Reward: {reward}")
print(f"  Done: {done}")


Before action:
  Agent: [3 3]
  M1: [4 3]

After RIGHT action:
  Agent: [4 3]
  M1: [5 3]
  Reward: 9.989999771118164
  Done: True


## 6. Test Chain Push

Let's try a puzzle where agent pushes M1 into M2.


In [46]:
chain_level = Level.from_str(prefabs["ChainPush"])
print("ChainPush puzzle:")
print(chain_level.to_str())
print()
print(f"Agent: {chain_level.agent_pos[0]}")
print(f"M1: {chain_level.m1_pos[0]}")
print(f"M2: {chain_level.m2_pos[0]}")


ChainPush puzzle:
W  W  W  W  W  W  W  W  W  W
W  W  W  W  W  W  W  W  W  W
W  W  .  .  .  .  .  .  W  W
W  W  .  A  M1  M2  .  .  W  W
W  W  .  .  .  G1  .  .  W  W
W  W  .  .  .  .  .  .  W  W
W  W  .  .  .  .  .  .  W  W
W  W  .  .  .  .  .  .  W  W
W  W  W  W  W  W  W  W  W  W
W  W  W  W  W  W  W  W  W  W

Agent: [3 3]
M1: [4 3]
M2: [5 3]


In [47]:
obs, chain_state = env.reset_env_to_level(rng, chain_level, env.default_params)

print("Before:")
print(f"  Agent: {chain_state.agent_pos[0]}")
print(f"  M1: {chain_state.m1_pos[0]}")
print(f"  M2: {chain_state.m2_pos[0]}")
print()

# Push RIGHT - should push both M1 and M2
obs, new_chain_state, reward, done, info = env.step_env(
    rng, chain_state, Actions.right, env.default_params
)

print("After RIGHT (chain push):")
print(f"  Agent: {new_chain_state.agent_pos[0]}")
print(f"  M1: {new_chain_state.m1_pos[0]}")
print(f"  M2: {new_chain_state.m2_pos[0]}")


Before:
  Agent: [3 3]
  M1: [4 3]
  M2: [5 3]

After RIGHT (chain push):
  Agent: [4 3]
  M1: [5 3]
  M2: [6 3]


## Summary

The `_move_with_push` method uses a **wave-front algorithm**:

1. **Stack** all object coordinates into a single array
2. **Compute displaced positions** for all objects
3. **Initialize frontier** with just the agent
4. **Loop until frontier is empty**:
   - Check if any frontier object hits a wall → break
   - Find objects that collide with displaced frontier objects
   - Add collided objects to frontier (if not already pushed)
5. **If not broken**, apply displacement to all pushed objects
