In [1]:
import jax
import jax.numpy as jnp
from jax import random as jr

# If your class is in the notebook already, this will work directly:
from pobax.envs.jax.compass_world import CompassWorld, CompassWorldState

def ok(name): 
    print(f"✓ {name}")

def assert_array_equal(a, b, name):
    a, b = jnp.asarray(a), jnp.asarray(b)
    assert a.shape == b.shape and jnp.all(a == b), f"{name} failed: {a} != {b}"
    ok(name)

def assert_true(cond, name):
    assert bool(cond), f"{name} failed"
    ok(name)

def assert_false(cond, name):
    assert not bool(cond), f"{name} failed"
    ok(name)


In [2]:
env = CompassWorld(size=8, random_start=False)
params = env.default_params
key = jr.PRNGKey(0)

obs, state = env.reset_env(key, params)

# obs shape and values in {0,1}
assert_array_equal(obs.shape, (5,), "obs shape (5,)")
assert_true(jnp.all((obs == 0) | (obs == 1)), "obs is binary")

# state bounds
y, x = state.pos
assert_true(1 <= int(y) <= env.size - 2 and 1 <= int(x) <= env.size - 2, "state.pos within bounds")
assert_true(int(state.dir) in {0,1,2,3}, "state.dir in {0,1,2,3}")


✓ obs shape (5,)
✓ obs is binary
✓ state.pos within bounds
✓ state.dir in {0,1,2,3}


In [4]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

y_mid = jnp.int32((env.size - 1) // 2)
goal_pos = jnp.array([y_mid, 1], dtype=jnp.int32)
assert_array_equal(goal_pos, env._goal_pos, "goal at west-wall middle")
assert_true(int(env._goal_dir) == 3, "goal direction is West")

# Construct explicit goal state
goal_state = CompassWorldState(pos=goal_pos, dir=jnp.int32(3), t=jnp.int32(0))
done = env._done(goal_state.pos, goal_state.dir)
rew = env._reward(done)
obs_at_goal = env._obs_from_state(goal_state.pos, goal_state.dir)

assert_true(done, "done at goal")
assert_true(int(rew) == 0, "reward 0 at goal")

# Expect: facing West at west wall, and on the goal row => green=1, blue=0
# obs = [N, E, S, W, g]
expected = jnp.array([0, 0, 0, 0, 1], dtype=jnp.uint8)
assert_array_equal(obs_at_goal, expected, "obs flags at goal (green=1, blue=0)")


✓ goal at west-wall middle
✓ goal direction is West
✓ done at goal
✓ reward 0 at goal
✓ obs flags at goal (green=1, blue=0)


In [5]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

y_mid = int((env.size - 1) // 2)
y_other = y_mid + 1 if y_mid + 1 <= env.size - 2 else y_mid - 1  # choose a different valid row
pos_other = jnp.array([y_other, 1], dtype=jnp.int32)
state_other = CompassWorldState(pos=pos_other, dir=jnp.int32(3), t=jnp.int32(0))  # facing West at west wall

obs_other = env._obs_from_state(state_other.pos, state_other.dir)

# Expect [N, E, S, b, g] with blue=1, green=0
expected_other = jnp.array([0, 0, 0, 1, 0], dtype=jnp.uint8)
assert_array_equal(obs_other, expected_other, "obs flags on west wall (non-goal row => blue)")


✓ obs flags on west wall (non-goal row => blue)


In [6]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

# North border clamp
state = CompassWorldState(pos=jnp.array([1, 4], jnp.int32), dir=jnp.int32(0), t=jnp.int32(0))  # N
next_state = env.transition(state, jnp.int32(0))  # forward
assert_array_equal(next_state.pos, state.pos, "north clamp on forward")

# South border clamp
state = CompassWorldState(pos=jnp.array([env.size - 2, 4], jnp.int32), dir=jnp.int32(2), t=jnp.int32(0))  # S
next_state = env.transition(state, jnp.int32(0))
assert_array_equal(next_state.pos, state.pos, "south clamp on forward")

# East border clamp
state = CompassWorldState(pos=jnp.array([4, env.size - 2], jnp.int32), dir=jnp.int32(1), t=jnp.int32(0))  # E
next_state = env.transition(state, jnp.int32(0))
assert_array_equal(next_state.pos, state.pos, "east clamp on forward")

# West border clamp
state = CompassWorldState(pos=jnp.array([4, 1], jnp.int32), dir=jnp.int32(3), t=jnp.int32(0))  # W
next_state = env.transition(state, jnp.int32(0))
assert_array_equal(next_state.pos, state.pos, "west clamp on forward")


✓ north clamp on forward
✓ south clamp on forward
✓ east clamp on forward
✓ west clamp on forward


In [7]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

start = CompassWorldState(pos=jnp.array([3, 3], jnp.int32), dir=jnp.int32(0), t=jnp.int32(0))  # facing N

# Action 1: turn right => E
st1 = env.transition(start, jnp.int32(1))
assert_true(int(st1.dir) == 1, "turn right to East")

# Action 2: turn left from start => W
st2 = env.transition(start, jnp.int32(2))
assert_true(int(st2.dir) == 3, "turn left to West")

# Forward from start (N): y decreases by 1 (clamped inside)
st3 = env.transition(start, jnp.int32(0))
expected_pos = jnp.clip(start.pos + jnp.array([-1, 0], jnp.int32), jnp.array([1,1], jnp.int32), jnp.array([env.size-2, env.size-2], jnp.int32))
assert_array_equal(st3.pos, expected_pos, "forward moves along dir")


✓ turn right to East
✓ turn left to West
✓ forward moves along dir


In [8]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

keys = jr.split(jr.PRNGKey(123), 20)
for i, k in enumerate(keys):
    obs, st = env.reset_env(k, params)
    in_bounds = (1 <= int(st.pos[0]) <= env.size - 2) and (1 <= int(st.pos[1]) <= env.size - 2)
    assert_true(in_bounds, f"reset {i} in-bounds")
    # Avoid exact terminal (code adjusts dir if starting on goal cell)
    not_terminal = not (int(st.dir) == int(env._goal_dir) and jnp.all(st.pos == env._goal_pos))
    assert_true(not_terminal, f"reset {i} not terminal")
ok("random_start sanity")


✓ reset 0 in-bounds
✓ reset 0 not terminal
✓ reset 1 in-bounds
✓ reset 1 not terminal
✓ reset 2 in-bounds
✓ reset 2 not terminal
✓ reset 3 in-bounds
✓ reset 3 not terminal
✓ reset 4 in-bounds
✓ reset 4 not terminal
✓ reset 5 in-bounds
✓ reset 5 not terminal
✓ reset 6 in-bounds
✓ reset 6 not terminal
✓ reset 7 in-bounds
✓ reset 7 not terminal
✓ reset 8 in-bounds
✓ reset 8 not terminal
✓ reset 9 in-bounds
✓ reset 9 not terminal
✓ reset 10 in-bounds
✓ reset 10 not terminal
✓ reset 11 in-bounds
✓ reset 11 not terminal
✓ reset 12 in-bounds
✓ reset 12 not terminal
✓ reset 13 in-bounds
✓ reset 13 not terminal
✓ reset 14 in-bounds
✓ reset 14 not terminal
✓ reset 15 in-bounds
✓ reset 15 not terminal
✓ reset 16 in-bounds
✓ reset 16 not terminal
✓ reset 17 in-bounds
✓ reset 17 not terminal
✓ reset 18 in-bounds
✓ reset 18 not terminal
✓ reset 19 in-bounds
✓ reset 19 not terminal
✓ random_start sanity


In [9]:
env = CompassWorld(size=8, random_start=True)
params = env.default_params

y_mid = int((env.size - 1) // 2)

# Start one step to the east of the goal, facing West; forward should enter goal and finish.
adjacent = CompassWorldState(pos=jnp.array([y_mid, 2], jnp.int32), dir=jnp.int32(3), t=jnp.int32(0))
obs, next_state, rew, done, info = env.step_env(jr.PRNGKey(0), adjacent, 0, params)

assert_true(done, "done after stepping into goal")
assert_true(int(rew) == 0, "reward 0 upon entering goal")
expected_obs = jnp.array([0, 0, 0, 0, 1], dtype=jnp.uint8)  # at goal row, west wall, facing West => green
assert_array_equal(obs, expected_obs, "obs at goal after step")


✓ done after stepping into goal
✓ reward 0 upon entering goal
✓ obs at goal after step


In [10]:
import jax
import jax.numpy as jnp
from jax import random as jr

A_FWD, A_RIGHT, A_LEFT = 0, 1, 2
DIR_CH = {0: "N", 1: "E", 2: "S", 3: "W"}

def obs_str(obs):
    return f"N={int(obs[0])} E={int(obs[1])} S={int(obs[2])} W={int(obs[3])} G={int(obs[4])}"

def step_name(a):
    return {0:"FWD", 1:"RIGHT", 2:"LEFT"}.get(int(a), str(int(a)))

def rollout(env, params, actions, start_state=None, key=jr.PRNGKey(0)):
    if start_state is None:
        obs, state = env.reset_env(key, params)
    else:
        state = start_state
        obs = env._obs_from_state(state.pos, state.dir)
    print(f"t={int(state.t):03d}  pos={tuple(map(int, state.pos))}  dir={DIR_CH[int(state.dir)]}  obs[{obs_str(obs)}]")
    for i, a in enumerate(actions, 1):
        obs, state, rew, done, info = env.step_env(key, state, int(a), params)
        print(f"t={int(state.t):03d}  act={step_name(a):>5}  pos={tuple(map(int, state.pos))}  dir={DIR_CH[int(state.dir)]}  rew={int(rew)}  done={bool(done)}  obs[{obs_str(obs)}]")
        if bool(done):
            break
    return state

def turn_towards(cur_dir, target_dir):
    d = (target_dir - cur_dir) % 4
    if d == 0: return []
    if d == 1: return [A_RIGHT]
    if d == 2: return [A_RIGHT, A_RIGHT]
    if d == 3: return [A_LEFT]

def plan_to_goal(state, env):
    actions = []
    y, x = map(int, state.pos)
    d = int(state.dir)
    goal_row = int(env._goal_pos[0])

    if y != goal_row:
        target_dir = 0 if y > goal_row else 2
        actions += turn_towards(d, target_dir)
        d = target_dir
        actions += [A_FWD] * abs(y - goal_row)
        y = goal_row

    if d != 3:
        actions += turn_towards(d, 3)
        d = 3

    if x > 1:
        actions += [A_FWD] * (x - 1)
        x = 1

    return actions

# --- Manual sequence demo ---
env = CompassWorld(size=8, random_start=False)
params = env.default_params
actions = [A_RIGHT, A_FWD, A_FWD, A_FWD, A_LEFT, A_FWD, A_FWD, A_LEFT, A_FWD]
print("=== Manual sequence ===")
_ = rollout(env, params, actions, key=jr.PRNGKey(0))

# --- Auto-plan to goal from deterministic reset ---
print("\n=== Auto-plan to goal (deterministic reset) ===")
key = jr.PRNGKey(42)
obs0, st0 = env.reset_env(key, params)
acts_to_goal = plan_to_goal(st0, env)
print("Planned actions:", [step_name(a) for a in acts_to_goal])
_ = rollout(env, params, acts_to_goal, start_state=st0, key=key)


=== Manual sequence ===
t=000  pos=(3, 3)  dir=N  obs[N=0 E=0 S=0 W=0 G=0]
t=001  act=RIGHT  pos=(3, 3)  dir=E  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=002  act=  FWD  pos=(3, 4)  dir=E  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=003  act=  FWD  pos=(3, 5)  dir=E  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=004  act=  FWD  pos=(3, 6)  dir=E  rew=-1  done=False  obs[N=0 E=1 S=0 W=0 G=0]
t=005  act= LEFT  pos=(3, 6)  dir=N  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=006  act=  FWD  pos=(2, 6)  dir=N  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=007  act=  FWD  pos=(1, 6)  dir=N  rew=-1  done=False  obs[N=1 E=0 S=0 W=0 G=0]
t=008  act= LEFT  pos=(1, 6)  dir=W  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]
t=009  act=  FWD  pos=(1, 5)  dir=W  rew=-1  done=False  obs[N=0 E=0 S=0 W=0 G=0]

=== Auto-plan to goal (deterministic reset) ===
Planned actions: ['LEFT', 'FWD', 'FWD']
t=000  pos=(3, 3)  dir=N  obs[N=0 E=0 S=0 W=0 G=0]
t=001  act= LEFT  pos=(3, 3)  dir=W  rew=-1  do

In [17]:
import jax
import jax.numpy as jnp
from jax import random as jr

A_FWD, A_RIGHT, A_LEFT = 0, 1, 2
DIR_CH = {0: "N", 1: "E", 2: "S", 3: "W"}
DIR_SYMBOL = {0: "^", 1: ">", 2: "v", 3: "<"}

def print_grid(env, state):
    size = env.size
    grid = [["." for _ in range(size)] for _ in range(size)]

    # walls
    for i in range(size):
        grid[0][i] = "#"
        grid[size-1][i] = "#"
        grid[i][0] = "#"
        grid[i][size-1] = "#"

    # goal
    gy, gx = map(int, env._goal_pos)
    grid[gy][gx] = "G"

    # agent
    y, x = map(int, state.pos)
    grid[y][x] = DIR_SYMBOL[int(state.dir)]

    print("\n".join(" ".join(row) for row in grid))
    print()

def rollout_with_grid(env, params, actions, start_state=None, key=jr.PRNGKey(0)):
    if start_state is None:
        obs, state = env.reset_env(key, params)
    else:
        state = start_state

    print("t=0")
    print_grid(env, state)

    for t, a in enumerate(actions, 1):
        obs, state, rew, done, info = env.step_env(key, state, int(a), params)
        print(f"t={t}, act={['FWD','RIGHT','LEFT'][a]}, rew={int(rew)}, done={bool(done)}")
        print_grid(env, state)
        if bool(done):
            break
    return state

def turn_towards(cur_dir, target_dir):
    d = (target_dir - cur_dir) % 4
    if d == 0: return []
    if d == 1: return [A_RIGHT]
    if d == 2: return [A_RIGHT, A_RIGHT]
    if d == 3: return [A_LEFT]

def plan_to_goal(state, env):
    actions = []
    y, x = map(int, state.pos)
    d = int(state.dir)
    goal_row = int(env._goal_pos[0])

    if y != goal_row:
        target_dir = 0 if y > goal_row else 2
        actions += turn_towards(d, target_dir)
        d = target_dir
        actions += [A_FWD] * abs(y - goal_row)
        y = goal_row

    if d != 3:
        actions += turn_towards(d, 3)
        d = 3

    if x > 1:
        actions += [A_FWD] * (x - 1)
        x = 1

    return actions

# --- Demo ---
env = CompassWorld(size=8, random_start=True)
params = env.default_params
key = jr.PRNGKey(3)
obs0, st0 = env.reset_env(key, params)

actions = plan_to_goal(st0, env)
rollout_with_grid(env, params, actions, start_state=st0, key=key)


t=0
# # # # # # # #
# . . . . . . #
# . . . . < . #
# G . . . . . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=1, act=LEFT, rew=-1, done=False
# # # # # # # #
# . . . . . . #
# . . . . v . #
# G . . . . . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=2, act=FWD, rew=-1, done=False
# # # # # # # #
# . . . . . . #
# . . . . . . #
# G . . . v . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=3, act=RIGHT, rew=-1, done=False
# # # # # # # #
# . . . . . . #
# . . . . . . #
# G . . . < . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=4, act=FWD, rew=-1, done=False
# # # # # # # #
# . . . . . . #
# . . . . . . #
# G . . < . . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=5, act=FWD, rew=-1, done=False
# # # # # # # #
# . . . . . . #
# . . . . . . #
# G . < . . . #
# . . . . . . #
# . . . . . . #
# . . . . . . #
# # # # # # # #

t=6, act=FWD, rew=-1, done=False
# # # # # # # #
# . .

CompassWorldState(pos=Array([3, 1], dtype=int32), dir=Array(3, dtype=int32), t=Array(7, dtype=int32))