In [1]:
# Notebook to test the simulation (version 12)

In [146]:
import jax.random as random
import jax.numpy as jnp
from structs import *
from functions import *
from sim_v12 import *


# ---- simulated environment parameter values ----
params_input = Params(content={
    "x_max": 10.0,
    "y_max": 10.0,
    "energy_begin_max": 100.0,
    "eat_rate": 0.2,
    "radius": 2.0,
    "mass_begin": 5.0
})

key = random.PRNGKey(0)

In [147]:
# Sheep class: create_agent()
# active sheep

sheep_active = Sheep.create_agent(
    type="sheep",
    params=params_input,
    id=1,
    active_state=True,
    key=key
)
state = sheep_active.state.content

assert isinstance(sheep_active, Sheep)

# Position must be inside bounds
assert jnp.all(state["x"] >= -10.0) and jnp.all(state["x"] <= 10.0)
assert jnp.all(state["y"] >= -10.0) and jnp.all(state["y"] <= 10.0)

# Angle in valid range
assert jnp.all(state["ang"] >= -jnp.pi) and jnp.all(state["ang"] <= jnp.pi)

# Velocities must be zero
assert jnp.all(state["x_dot"] == 0)
assert jnp.all(state["y_dot"] == 0)
assert jnp.all(state["ang_dot"] == 0)

# Energy initialization
assert state["energy"] >= 50.0 and state["energy"] <= 100.0

# Energy offer must equal energy * eat_rate
assert jnp.isclose(state["energy_offer"], state["energy"] * params_input.content["eat_rate"])

In [148]:
# Create inactive sheep
sheep_inactive = Sheep.create_agent(
    type="sheep",
    params=params_input,
    id=2,
    active_state=False,
    key=key
)

inactive_state = sheep_inactive.state.content

# Check forced inactive placeholder values
assert jnp.all(inactive_state["x"] == -1.0)
assert jnp.all(inactive_state["y"] == -1.0)
assert jnp.all(inactive_state["energy"] == -1.0)
assert jnp.all(inactive_state["energy_offer"] == -1.0)

# Velocities should remain zero
assert jnp.all(inactive_state["x_dot"] == 0)
assert jnp.all(inactive_state["y_dot"] == 0)
assert jnp.all(inactive_state["ang_dot"] == 0)

In [149]:
# Sheep class: step_agent()

key = random.PRNGKey(42)

# --- fake agent for testing ---
agent_state = State(content={
    "x": jnp.array([0.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([1.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([10.0]),
    "energy_offer": jnp.array([2.0])
})

agent_params = Params(content={
    "eat_rate": 0.2
})

agent = Sheep(
    id=1,
    state=agent_state,
    params=agent_params,
    active_state=1,
    agent_type=1,
    age=0.0,
    key=key,
    policy=None
)

# --- inputs to step_agent ---
step_input = Params(content={"energy_intake": jnp.array([5.0])})

step_params = Params(content={
    "dt": 1.0,
    "damping": 0.1,
    "x_max_arena": 10.0,
    "y_max_arena": 10.0
})

In [150]:
def step_agent(agent, input, step_params):
    def step_active_agent():
        # input
        energy_intake = input.content["energy_intake"] # also handles energy output (if eaten by wolves)

        # current agent state
        energy = agent.state.content["energy"]

        x = agent.state.content["x"]
        y = agent.state.content["y"]
        ang = agent.state.content["ang"]
        x_dot = agent.state.content["x_dot"] # current x_velocity
        y_dot = agent.state.content["y_dot"] # current y_velocity
        ang_dot = agent.state.content["ang_dot"]

        dt = step_params.content["dt"]
        damping = step_params.content["damping"]
        x_max_arena = step_params.content["x_max_arena"]
        y_max_arena = step_params.content["y_max_arena"]

        eat_rate = agent.params.content["eat_rate"]

        key, *subkeys = random.split(agent.key, 5)

        # sample random movement
        forward_action = jax.random.uniform(subkeys[0], (), minval=0.0, maxval=1.0)
        angular_action = jax.random.uniform(subkeys[1], (), minval=-1.0, maxval=1.0)

        # fixed base speed (with noise)
        speed = ((LINEAR_ACTION_OFFSET + SHEEP_LINEAR_ACTION_SCALE * forward_action) *
                 (1 + NOISE_SCALE * jax.random.normal(subkeys[2], ())))
        ang_speed = SHEEP_ANGULAR_SPEED_SCALE * angular_action * (1 + NOISE_SCALE * jax.random.normal(subkeys[3], ()))

        # updated positions
        #x_new = jnp.clip(x + dt * x_dot, -x_max_arena, x_max_arena)  # no wrap-around for now; may change it later
        #y_new = jnp.clip(y + dt * y_dot, -y_max_arena, y_max_arena)
        # wraparound
        x_new = jnp.mod(x + dt * x_dot + x_max_arena, 2 * x_max_arena) - x_max_arena
        y_new = jnp.mod(y + dt * y_dot + y_max_arena, 2 * y_max_arena) - y_max_arena
        ang_new = jnp.mod(ang + dt * ang_dot + jnp.pi, 2 * jnp.pi) - jnp.pi

        x_dot_new = speed * jnp.cos(ang) - dt * x_dot * damping
        y_dot_new = speed * jnp.sin(ang) - dt * y_dot * damping
        ang_dot_new = ang_speed - dt * ang_dot * damping

        # fixed metabolic cost
        metabolic_cost = BASIC_METABOLIC_COST_SHEEP
        energy_new = energy + energy_intake - metabolic_cost # energy_intake already includes loss to wolves

        new_energy_offer = energy_new * eat_rate

        agent_is_dead = energy_new[0] <= 0.0

        new_state_content = {"x": x_new, "y": y_new, "x_dot": x_dot_new, "y_dot": y_dot_new, "ang": ang_new, "ang_dot": ang_dot_new,
                             "energy": energy_new, "energy_offer": new_energy_offer}
        new_state = State(content=new_state_content)

        return jax.lax.cond(
            agent_is_dead,
            lambda _: agent.replace(state=new_state, active_state=0),  # mark as dead/inactive
            lambda _: agent.replace(state=new_state, key=key, age=agent.age + dt),
            None
        )
    def step_inactive_agent():
        return agent

    return jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)


In [151]:
BASIC_METABOLIC_COST_SHEEP = 0.35
result = step_agent(agent, step_input, step_params)

In [152]:
print(f"Initial state:")
print(f"  Position: ({agent.state.content['x'][0]:.2f}, {agent.state.content['y'][0]:.2f})")
print(f"  Velocity: ({agent.state.content['x_dot'][0]:.2f}, {agent.state.content['y_dot'][0]:.2f})")
print(f"  Energy: {agent.state.content['energy'][0]:.2f}")
print(f"  Age: {agent.age:.2f}")
print(f"  Active: {bool(agent.active_state)}")

print(f"\nAfter step:")
print(f"  Position: ({result.state.content['x'][0]:.2f}, {result.state.content['y'][0]:.2f})")
print(f"  Velocity: ({result.state.content['x_dot'][0]:.2f}, {result.state.content['y_dot'][0]:.2f})")
print(f"  Energy: {result.state.content['energy'][0]:.2f}")
print(f"  Age: {result.age:.2f}")
print(f"  Active: {bool(result.active_state)}")

Initial state:
  Position: (0.00, 0.00)
  Velocity: (1.00, 0.00)
  Energy: 10.00
  Age: 0.00
  Active: True

After step:
  Position: (1.00, 0.00)
  Velocity: (71.64, 0.00)
  Energy: 14.65
  Age: 1.00
  Active: True


In [153]:
expected_energy_offer = result.state.content['energy'][0] * agent_params.content['eat_rate']
print(expected_energy_offer)
print(result.state.content['energy_offer'][0])
assert jnp.isclose(result.state.content['energy_offer'][0], expected_energy_offer), "Energy offer should be energy * eat_rate"


2.93
2.93


In [154]:
def reset_agent(agent, reset_params):
    x_max = agent.params.content["x_max"]
    y_max = agent.params.content["y_max"]
    energy_begin_max = agent.params.content["energy_begin_max"]
    eat_rate = agent.params.content["eat_rate"]
    key = agent.key

    key, *subkeys = random.split(key, 5)
    x = random.uniform(subkeys[0], shape=(1,), minval=-x_max, maxval=x_max)
    y = random.uniform(subkeys[1], shape=(1,), minval=-y_max, maxval=y_max)
    ang = random.uniform(subkeys[2], shape=(1,), minval=-jnp.pi, maxval=jnp.pi)
    x_dot = jnp.zeros((1,), dtype=jnp.float32)
    y_dot = jnp.zeros((1,), dtype=jnp.float32)
    ang_dot = jnp.zeros((1,), dtype=jnp.float32)

    energy = random.uniform(subkeys[3], shape=(1,), minval=0.5 * energy_begin_max, maxval=energy_begin_max)

    state_content = {"x": x, "y": y, "ang": ang, "x_dot": x_dot, "y_dot": y_dot, "ang_dot": ang_dot,
                     "energy": energy}
    state = State(content=state_content)

    return agent.replace(state=state, age=0.0, active_state=1, key=key)

In [155]:
# Create an agent that has been running for a while
key = random.PRNGKey(100)
aged_state = State(content={
    "x": jnp.array([5.0]),
    "y": jnp.array([-3.0]),
    "ang": jnp.array([1.5]),
    "x_dot": jnp.array([2.0]),
    "y_dot": jnp.array([1.5]),
    "ang_dot": jnp.array([0.3]),
    "energy": jnp.array([15.0])  # Low energy
})

aged_params = Params(content={
    "x_max": 10.0,
    "y_max": 10.0,
    "energy_begin_max": 100.0,
    "eat_rate": 0.2,
    "radius": 2.0,
    "mass": 5.0
})

aged_agent = Sheep(
    id=1,
    state=aged_state,
    params=aged_params,
    active_state=1,
    agent_type=1,
    age=50.0,
    key=key,
    policy=None
)

In [156]:
reset_result = reset_agent(aged_agent, None)

print(f"\nAfter reset:")
print(f"  Position: ({reset_result.state.content['x'][0]:.2f}, {reset_result.state.content['y'][0]:.2f})")
print(f"  Velocity: ({reset_result.state.content['x_dot'][0]:.2f}, {reset_result.state.content['y_dot'][0]:.2f})")
print(f"  Angle: {reset_result.state.content['ang'][0]:.4f}")
print(f"  Energy: {reset_result.state.content['energy'][0]:.2f}")
print(f"  Age: {reset_result.age:.2f}")
print(f"  Active: {reset_result.active_state}")


After reset:
  Position: (3.51, 9.05)
  Velocity: (0.00, 0.00)
  Angle: 0.1900
  Energy: 91.57
  Age: 0.00
  Active: 1


In [157]:
dead_state = State(content={
    "x": jnp.array([-1.0]),  # Placeholder for dead
    "y": jnp.array([-1.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([-1.0])  # Dead
})

dead_agent = Sheep(
    id=2,
    state=dead_state,
    params=aged_params,
    active_state=0,  # Inactive
    agent_type=1,
    age=25.0,
    key=random.PRNGKey(200),
    policy=None
)

In [158]:
reset_dead = reset_agent(dead_agent, None)

print(f"\nAfter reset:")
print(f"  Position: ({reset_dead.state.content['x'][0]:.2f}, {reset_dead.state.content['y'][0]:.2f})")
print(f"  Velocity: ({reset_dead.state.content['x_dot'][0]:.2f}, {reset_dead.state.content['y_dot'][0]:.2f})")
print(f"  Energy: {reset_dead.state.content['energy'][0]:.2f}")
print(f"  Age: {reset_dead.age:.2f}")
print(f"  Active: {reset_dead.active_state}")


After reset:
  Position: (-9.92, 0.10)
  Velocity: (0.00, 0.00)
  Energy: 56.25
  Age: 0.00
  Active: 1


In [159]:
print(f"Before reset - params:")
print(f"  x_max: {aged_agent.params.content['x_max']}")
print(f"  y_max: {aged_agent.params.content['y_max']}")
print(f"  energy_begin_max: {aged_agent.params.content['energy_begin_max']}")
print(f"  eat_rate: {aged_agent.params.content['eat_rate']}")
print(f"  radius: {aged_agent.params.content['radius']}")
print(f"  mass: {aged_agent.params.content['mass']}")
print(f"  ID: {aged_agent.id}")
print(f"  Agent type: {aged_agent.agent_type}")

reset_params_test = reset_agent(aged_agent, None)

print(f"\nAfter reset - params:")
print(f"  x_max: {reset_params_test.params.content['x_max']}")
print(f"  y_max: {reset_params_test.params.content['y_max']}")
print(f"  energy_begin_max: {reset_params_test.params.content['energy_begin_max']}")
print(f"  eat_rate: {reset_params_test.params.content['eat_rate']}")
print(f"  radius: {reset_params_test.params.content['radius']}")
print(f"  mass: {reset_params_test.params.content['mass']}")
print(f"  ID: {reset_params_test.id}")
print(f"  Agent type: {reset_params_test.agent_type}")

# All params should be identical
assert reset_params_test.params.content['x_max'] == aged_agent.params.content['x_max']
assert reset_params_test.params.content['y_max'] == aged_agent.params.content['y_max']
assert reset_params_test.params.content['energy_begin_max'] == aged_agent.params.content['energy_begin_max']
assert reset_params_test.params.content['eat_rate'] == aged_agent.params.content['eat_rate']
assert reset_params_test.id == aged_agent.id
assert reset_params_test.agent_type == aged_agent.agent_type


Before reset - params:
  x_max: 10.0
  y_max: 10.0
  energy_begin_max: 100.0
  eat_rate: 0.2
  radius: 2.0
  mass: 5.0
  ID: 1
  Agent type: 1

After reset - params:
  x_max: 10.0
  y_max: 10.0
  energy_begin_max: 100.0
  eat_rate: 0.2
  radius: 2.0
  mass: 5.0
  ID: 1
  Agent type: 1


In [160]:
key_test_agent = Sheep(
    id=3,
    state=aged_state,
    params=aged_params,
    active_state=1,
    agent_type=1,
    age=10.0,
    key=random.PRNGKey(300),
    policy=None
)

print(f"Original key: {key_test_agent.key}")

reset_key_test = reset_agent(key_test_agent, None)

print(f"Reset key: {reset_key_test.key}")

# Key should be different after reset
key_changed = not jnp.array_equal(reset_key_test.key, key_test_agent.key)
print(f"\nKey changed: {key_changed}")
assert key_changed, "Random key should be updated"

Original key: [  0 300]
Reset key: [ 643892670 3797580258]

Key changed: True


In [202]:
# Test 3
def calculate_sheep_energy_intake(sheep: Sheep):
    def sheep_local_density(one_sheep, all_sheep):
        """Calculate how many other sheep are within this sheep's energy radius"""
        xs_sheep = all_sheep.state.content["x"].reshape(-1)
        ys_sheep = all_sheep.state.content["y"].reshape(-1)
        active_sheep = all_sheep.active_state.astype(bool)

        x_sheep = one_sheep.state.content["x"]
        y_sheep = one_sheep.state.content["y"]

        # calculate distance to all other sheep
        distances = jnp.linalg.norm(jnp.stack((xs_sheep - x_sheep, ys_sheep - y_sheep), axis=1), axis=1).reshape(-1)

        energy_radius = SHEEP_RADIUS* 3.0

        cond = jnp.logical_and(distances <= energy_radius, active_sheep)
        is_near = jnp.where(cond, 1.0, 0.0)
        num_sheep_in_radius = jnp.sum(is_near)

        return num_sheep_in_radius, is_near

    active_mask = sheep.active_state.astype(bool)

    num_sheep_in_radius, is_near_matrix = jax.vmap(sheep_local_density, in_axes=(0, None))(sheep, sheep)

    energy_share = jnp.divide(BASE_ENERGY_RATE, jnp.maximum(num_sheep_in_radius, 1.0))
    energy_intake = energy_share * active_mask

    return energy_intake

In [208]:
keys = random.split(random.PRNGKey(999), 4)

# Sheep 1: at origin
sheep1_params = Params(content={
    "x_max": 100.0,
    "y_max": 100.0,
    "energy_begin_max": 100.0,
    "eat_rate": 0.2,
    "radius": SHEEP_RADIUS,
    "mass_begin": 5.0
})

sheep1 = Sheep.create_agent(1, sheep1_params, 0, 1, keys[0])
sheep1_state = State(content={
    "x": jnp.array([0.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([50.0])
})
sheep1 = sheep1.replace(state=sheep1_state)

# Sheep 2: close to sheep 1 (distance = 10, within radius of 15)
sheep2 = Sheep.create_agent(1, sheep1_params, 1, 1, keys[1])
sheep2_state = State(content={
    "x": jnp.array([10.0]),
    "y": jnp.array([2.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([60.0])
})
sheep2 = sheep2.replace(state=sheep2_state)

# Sheep 3: far from others (distance = 100, outside radius)
sheep3 = Sheep.create_agent(1, sheep1_params, 2, 1, keys[2])
sheep3_state = State(content={
    "x": jnp.array([100.0]),
    "y": jnp.array([5.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([70.0])
})
sheep3 = sheep3.replace(state=sheep3_state)

sheep4 = Sheep.create_agent(1, sheep1_params, 3, 1, keys[3])
sheep4_state = State(content={
    "x": jnp.array([15.0]),
    "y": jnp.array([15.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([80.0])
})
sheep4 = sheep4.replace(state=sheep4_state)

In [209]:
sheep_set = jax.tree.map(lambda *xs: jnp.stack(xs), sheep1, sheep2, sheep3, sheep4)

In [210]:
print("Created 4 test sheep:")
print(f"Sheep 1: position ({sheep1.state.content['x'][0]:.1f}, {sheep1.state.content['y'][0]:.1f}), energy={sheep1.state.content['energy'][0]:.1f}, active={sheep1.active_state}")
print(f"Sheep 2: position ({sheep2.state.content['x'][0]:.1f}, {sheep2.state.content['y'][0]:.1f}), energy={sheep2.state.content['energy'][0]:.1f}, active={sheep2.active_state}")
print(f"Sheep 3: position ({sheep3.state.content['x'][0]:.1f}, {sheep3.state.content['y'][0]:.1f}), energy={sheep3.state.content['energy'][0]:.1f}, active={sheep3.active_state}")
print(f"Sheep 4: position ({sheep4.state.content['x'][0]:.1f}, {sheep4.state.content['y'][0]:.1f}), energy={sheep4.state.content['energy'][0]:.1f}, active={sheep4.active_state}")

print(f"\nEnergy radius: {SHEEP_RADIUS * 3.0}")

Created 4 test sheep:
Sheep 1: position (0.0, 0.0), energy=50.0, active=1
Sheep 2: position (10.0, 2.0), energy=60.0, active=1
Sheep 3: position (100.0, 5.0), energy=70.0, active=1
Sheep 4: position (15.0, 15.0), energy=80.0, active=1

Energy radius: 15.0
Distance sheep1-sheep2: 10.0 (within radius - should share)
Distance to sheep3: 100.0 (outside radius - alone)


In [211]:
BASE_ENERGY_RATE = 0.6
energy_intake = calculate_sheep_energy_intake(sheep_set)

In [212]:
print(f"Sheep 1: {energy_intake[0]:.4f} (expected: {BASE_ENERGY_RATE/2:.4f} - shares with sheep 1 and 2)")
print(f"Sheep 2: {energy_intake[1]:.4f} (expected: {BASE_ENERGY_RATE/2:.4f} - shares with sheep 1 and 2)")
print(f"Sheep 3: {energy_intake[2]:.4f} (expected: {BASE_ENERGY_RATE:.4f} - alone)")
print(f"Sheep 4: {energy_intake[3]:.4f} (expected: {BASE_ENERGY_RATE/3:.4f} - shares with 1 and 2)")

Sheep 1: 0.3000 (expected: 0.3000 - shares with sheep 1 and 2)
Sheep 2: 0.2000 (expected: 0.3000 - shares with sheep 1 and 2)
Sheep 3: 0.6000 (expected: 0.6000 - alone)
Sheep 4: 0.3000 (expected: 0.2000 - shares with 1 and 2)


In [279]:
def wolves_sheep_interactions(sheep: Sheep, wolves: Wolf):
    def wolf_sheep_interaction(one_wolf, sheep):
        xs_sheep = sheep.state.content["x"]
        ys_sheep = sheep.state.content["y"]
        x_wolf = one_wolf.state.content["x"]
        y_wolf = one_wolf.state.content["y"]

        active_sheep = sheep.active_state

        wolf_radius = one_wolf.params.content["radius"]

        distances = jnp.linalg.norm(jnp.stack((xs_sheep - x_wolf, ys_sheep - y_wolf), axis=1), axis=1).reshape(-1)
        is_in_range = jnp.where(jnp.logical_and(distances <= wolf_radius, active_sheep), 1.0, 0.0) # only consider active sheep

        # find the closest sheep; wolf can only catch one sheep at a time
        distances_masked = jnp.where(is_in_range > 0, distances, jnp.inf)
        closest_sheep_idx = jnp.argmin(distances_masked)

        is_catching_sheep = jnp.zeros_like(is_in_range)
        is_catching_sheep = is_catching_sheep.at[closest_sheep_idx].set(
            jnp.where(distances_masked[closest_sheep_idx] < jnp.inf, 1.0, 0.0)
        )
        return is_catching_sheep

    is_catching_matrix = jax.vmap(wolf_sheep_interaction, in_axes=(0, None))(wolves, sheep) # shape (num_wolves, num_sheep)
    is_being_fed_on = jnp.any(is_catching_matrix, axis=0)  # shape (num_sheep,) - t/f if sheep is being fed on by any wolf

    #split energy among wolves if multiple wolves target same sheep
    num_wolves_at_sheep = jnp.maximum(jnp.sum(is_catching_matrix, axis=0), 1.0)
    energy_sharing_matrix = jnp.divide(is_catching_matrix, num_wolves_at_sheep)

    energy_offer_per_sheep = sheep.state.content["energy"].reshape(-1) * EAT_RATE_SHEEP

    # calculate energy intake for each wolf
    energy_intake_wolves = jnp.multiply(energy_sharing_matrix, energy_offer_per_sheep)
    energy_intake_wolves = jnp.sum(energy_intake_wolves, axis=1).reshape(-1)

    # calculate energy loss for each sheep
    energy_loss_sheep = jnp.where(is_being_fed_on, energy_offer_per_sheep, 0.0)

    return energy_loss_sheep, energy_intake_wolves

In [296]:
# wolf-sheep interaction
keys = random.split(random.PRNGKey(50), 4)

# Wolf 1: at origin
wolf1_params = Params(content={
    "x_max": 100.0,
    "y_max": 100.0,
    "energy_begin_max": 100.0,
    "radius": WOLF_RADIUS,
    "mass_begin": 5.0
})

wolf1 = Wolf.create_agent(2, wolf1_params, 1, 1, keys[0])
wolf1_state = State(content={
    "x": jnp.array([0.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([50.0])
})
wolf1 = wolf1.replace(state=wolf1_state)


sheep1_params = Params(content={
    "x_max": 100.0,
    "y_max": 100.0,
    "energy_begin_max": 100.0,
    "eat_rate": EAT_RATE_SHEEP,
    "radius": SHEEP_RADIUS,
    "mass_begin": 5.0
})
# Sheep 1: close to wolf 1 (distance = 5, within wolf radius of 5)
sheep5 = Sheep.create_agent(1, sheep1_params, 5, 1, keys[1])
sheep5_state = State(content={
    "x": jnp.array([5.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([60.0])
})
sheep5 = sheep5.replace(state=sheep5_state)

# Sheep 2: in wolf range but farther than sheep 1 (distance = 6, outside radius)
sheep6 = Sheep.create_agent(1, sheep1_params, 6, 1, keys[2])
sheep6_state = State(content={
    "x": jnp.array([4.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([50.0])
})
sheep6 = sheep6.replace(state=sheep6_state)

sheep7 = Sheep.create_agent(1, sheep1_params, 7, 1, keys[3])
sheep7_state = State(content={
    "x": jnp.array([13.0]),
    "y": jnp.array([0.0]),
    "ang": jnp.array([0.0]),
    "x_dot": jnp.array([0.0]),
    "y_dot": jnp.array([0.0]),
    "ang_dot": jnp.array([0.0]),
    "energy": jnp.array([20.0])
})
sheep7 = sheep7.replace(state=sheep7_state)

In [297]:
wolf_set = jax.tree.map(lambda *xs: jnp.stack(xs), wolf1)
sheep_set = jax.tree.map(lambda *xs: jnp.stack(xs), sheep5, sheep6, sheep7)

EAT_RATE_SHEEP = 0.3
WOLF_RADIUS = 5.0

In [298]:
energy_loss_sheep, energy_intake_wolves = wolves_sheep_interactions(sheep_set, wolf_set)

In [299]:
print(energy_loss_sheep)
print(energy_intake_wolves)

[ 0.       15.000001  0.      ]
[15.000001]


In [300]:
def create_world(params, key):
    sheep_params = params.content["sheep_params"]
    wolf_params = params.content["wolf_params"]

    num_sheep = sheep_params["num_sheep"]
    key, sheep_key = random.split(key, 2)

    x_max_array = jnp.tile(jnp.array([sheep_params["x_max"]]), (num_sheep,))
    y_max_array = jnp.tile(jnp.array([sheep_params["y_max"]]), (num_sheep,))
    energy_begin_max_array = jnp.tile(jnp.array([sheep_params["energy_begin_max"]]), (num_sheep,))
    eat_rate_array = jnp.tile(jnp.array([sheep_params["eat_rate"]]), (num_sheep,))
    radius_array = jnp.tile(jnp.array([sheep_params["radius"]]), (num_sheep,))
    mass_array = jnp.tile(jnp.array([sheep_params["mass_begin"]]), (num_sheep,))

    sheep_create_params = Params(content= {
        "x_max": x_max_array,
        "y_max": y_max_array,
        "energy_begin_max": energy_begin_max_array,
        "eat_rate": eat_rate_array,
        "radius": radius_array,
        "mass_begin": mass_array
    })

    sheep = create_agents(agent=Sheep, params=sheep_create_params, num_agents=num_sheep, num_active_agents=num_sheep,
                          agent_type=sheep_params["agent_type"], key=sheep_key)

    sheep_set = Set(num_agents=num_sheep, num_active_agents=num_sheep, agents=sheep, id=0, set_type=sheep_params["agent_type"],
                    params=None, state=None, policy=None, key=None)


    num_wolf = wolf_params["num_wolf"]
    key, wolf_key = random.split(key, 2)
    x_max_array = jnp.tile(jnp.array([wolf_params["x_max"]]), (num_wolf,))
    y_max_array = jnp.tile(jnp.array([wolf_params["y_max"]]), (num_wolf,))
    energy_begin_max_array = jnp.tile(jnp.array([wolf_params["energy_begin_max"]]), (num_wolf,))
    radius_array = jnp.tile(jnp.array([wolf_params["radius"]]), (num_wolf,))
    mass_array = jnp.tile(jnp.array([wolf_params["mass_begin"]]), (num_wolf,))

    wolf_create_params = Params(content= {"x_max": x_max_array,
                                          "y_max": y_max_array,
                                          "energy_begin_max": energy_begin_max_array,
                                          "radius": radius_array,
                                          "mass_begin": mass_array
    })

    wolves = create_agents(agent=Wolf, params=wolf_create_params, num_agents=num_wolf, num_active_agents=num_wolf,
                           agent_type=wolf_params["agent_type"], key=wolf_key)

    wolf_set = Set(num_agents=num_wolf, num_active_agents=num_wolf, agents=wolves, id=2, set_type=wolf_params["agent_type"],
                   params=None, state=None, policy=None, key=None)


    return PredatorPreyWorld(sheep_set=sheep_set, wolf_set=wolf_set)


In [304]:
test_params = Params(content={
    "sheep_params": {
        "x_max": 500.0,
        "y_max": 500.0,
        "energy_begin_max": 50.0,
        "mass_begin": 5.0,
        "eat_rate": 0.4,
        "radius": 5.0,
        "agent_type": 1,
        "num_sheep": 10  # Number of sheep to create
    },
    "wolf_params": {
        "x_max": 500.0,
        "y_max": 500.0,
        "energy_begin_max": 50.0,
        "mass_begin": 7.0,
        "radius": 7.0,
        "agent_type": 3,
        "num_wolf": 5  # Number of wolves to create
    }
})

# Input 2: random key
test_key = random.PRNGKey(41)

In [305]:
world = PredatorPreyWorld.create_world(test_params, test_key)


In [306]:
print(f"Number of sheep: {world.sheep_set.num_agents}")
print(f"Number of wolves: {world.wolf_set.num_agents}")
print(f"Number of active sheep: {world.sheep_set.num_active_agents}")
print(f"Number of active wolves: {world.wolf_set.num_active_agents}")

# Access individual agents
print(f"\nFirst sheep position: ({world.sheep_set.agents.state.content['x'][0][0]:.2f}, {world.sheep_set.agents.state.content['y'][0][0]:.2f})")
print(f"First sheep energy: {world.sheep_set.agents.state.content['energy'][0][0]:.2f}")
print(f"First wolf position: ({world.wolf_set.agents.state.content['x'][0][0]:.2f}, {world.wolf_set.agents.state.content['y'][0][0]:.2f})")

Number of sheep: 10
Number of wolves: 5
Number of active sheep: 10
Number of active wolves: 5

First sheep position: (446.09, -164.32)
First sheep energy: 42.14
First wolf position: (311.57, 341.78)


In [307]:
def step_world(pp_world, _t):
    sheep_set = pp_world.sheep_set
    wolf_set = pp_world.wolf_set

    energy_intake_from_environment = jit_calculate_sheep_energy_intake(sheep_set.agents)
    energy_loss_sheep, energy_intake_wolves = jit_wolves_sheep_interactions(sheep_set.agents, wolf_set.agents)


    sheep_step_input = Signal(content={"energy_intake": energy_intake_from_environment - energy_loss_sheep})
    sheep_step_params = Params(content={"dt": Dt,
                                        "damping": DAMPING,
                                        "metabolic_cost_speed": METABOLIC_COST_SPEED,
                                        "metabolic_cost_angular": METABOLIC_COST_ANGULAR,
                                        "x_max_arena": WORLD_SIZE_X,
                                        "y_max_arena": WORLD_SIZE_Y,
    })
    sheep_set = jit_step_agents(Sheep.step_agent, sheep_step_params, sheep_step_input, sheep_set)


    wolf_step_input = Signal(content={"energy_intake": energy_intake_wolves})
    wolf_step_params = Params(content={"dt": Dt,
                                       "damping": DAMPING,
                                       "metabolic_cost_speed": METABOLIC_COST_SPEED,
                                       "metabolic_cost_angular": METABOLIC_COST_ANGULAR,
                                       "x_max_arena": WORLD_SIZE_X,
                                       "y_max_arena": WORLD_SIZE_Y,
    })
    wolf_set = jit_step_agents(Wolf.step_agent, wolf_step_params, wolf_step_input, wolf_set)


    render_data = Signal(content={"sheep_xs": sheep_set.agents.state.content["x"].reshape(-1, 1),
                                  "sheep_ys": sheep_set.agents.state.content["y"].reshape(-1, 1),
                                  "sheep_angles": sheep_set.agents.state.content["ang"].reshape(-1, 1),
                                  "sheep_energy": sheep_set.agents.state.content["energy"].reshape(-1, 1),
                                  "wolf_xs": wolf_set.agents.state.content["x"].reshape(-1, 1),
                                  "wolf_ys": wolf_set.agents.state.content["y"].reshape(-1, 1),
                                  "wolf_angles": wolf_set.agents.state.content["ang"].reshape(-1, 1),
                                  "wolf_energy": wolf_set.agents.state.content["energy"].reshape(-1, 1)
    })

    return pp_world.replace(sheep_set=sheep_set, wolf_set=wolf_set), render_data


In [308]:
Dt = 0.1
DAMPING = 0.1
METABOLIC_COST_SPEED = 0.01
METABOLIC_COST_ANGULAR = 0.05
WORLD_SIZE_X = 10000.0
WORLD_SIZE_Y = 10000.0

In [309]:
test_params = Params(content={
    "sheep_params": {
        "x_max": 500.0,
        "y_max": 500.0,
        "energy_begin_max": 50.0,
        "mass_begin": 5.0,
        "eat_rate": 0.4,
        "radius": 5.0,
        "agent_type": 1,
        "num_sheep": 3  # Small number for testing
    },
    "wolf_params": {
        "x_max": 500.0,
        "y_max": 500.0,
        "energy_begin_max": 50.0,
        "mass_begin": 7.0,
        "radius": 7.0,
        "agent_type": 3,
        "num_wolf": 2  # Small number for testing
    }
})

test_key = random.PRNGKey(123)
test_world = PredatorPreyWorld.create_world(test_params, test_key)

In [310]:
timestep = 0

print(f"Number of sheep: {test_world.sheep_set.num_agents}")
print(f"Number of wolves: {test_world.wolf_set.num_agents}")

Number of sheep: 3
Number of wolves: 2


In [312]:
print(f"\nSheep energies:")
for i in range(test_world.sheep_set.num_agents):
    energy = test_world.sheep_set.agents.state.content['energy'][i][0]
    x = test_world.sheep_set.agents.state.content['x'][i][0]
    y = test_world.sheep_set.agents.state.content['y'][i][0]
    print(f"  Sheep {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f}")

print(f"\nWolf energies:")
for i in range(test_world.wolf_set.num_agents):
    energy = test_world.wolf_set.agents.state.content['energy'][i][0]
    x = test_world.wolf_set.agents.state.content['x'][i][0]
    y = test_world.wolf_set.agents.state.content['y'][i][0]
    print(f"  Wolf {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f}")

# Call the function
updated_world, render_data = step_world(test_world, timestep)


Sheep energies:
  Sheep 0: position=(199.68, 475.66), energy=43.36
  Sheep 1: position=(171.15, 469.74), energy=46.19
  Sheep 2: position=(-492.80, -223.64), energy=37.85

Wolf energies:
  Wolf 0: position=(273.37, 82.20), energy=25.67
  Wolf 1: position=(479.30, 305.46), energy=43.72


In [313]:
print(f"Sheep energies (after step):")
for i in range(updated_world.sheep_set.num_agents):
    energy = updated_world.sheep_set.agents.state.content['energy'][i][0]
    x = updated_world.sheep_set.agents.state.content['x'][i][0]
    y = updated_world.sheep_set.agents.state.content['y'][i][0]
    old_energy = test_world.sheep_set.agents.state.content['energy'][i][0]
    energy_change = energy - old_energy
    print(f"  Sheep {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f} (change: {energy_change:+.4f})")

print(f"\nWolf energies (after step):")
for i in range(updated_world.wolf_set.num_agents):
    energy = updated_world.wolf_set.agents.state.content['energy'][i][0]
    x = updated_world.wolf_set.agents.state.content['x'][i][0]
    y = updated_world.wolf_set.agents.state.content['y'][i][0]
    old_energy = test_world.wolf_set.agents.state.content['energy'][i][0]
    energy_change = energy - old_energy
    print(f"  Wolf {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f} (change: {energy_change:+.4f})")


Sheep energies (after step):
  Sheep 0: position=(199.68, 475.66), energy=43.64 (change: +0.2800)
  Sheep 1: position=(171.15, 469.74), energy=46.47 (change: +0.2800)
  Sheep 2: position=(-492.80, -223.64), energy=38.13 (change: +0.2800)

Wolf energies (after step):
  Wolf 0: position=(273.38, 82.20), energy=25.63 (change: -0.0400)
  Wolf 1: position=(479.30, 305.46), energy=43.68 (change: -0.0400)


In [314]:
positions_changed = not jnp.allclose(
    updated_world.sheep_set.agents.state.content['x'],
    test_world.sheep_set.agents.state.content['x']
)
print(f"✓ Sheep positions changed: {positions_changed}")

✓ Sheep positions changed: False


In [316]:
# Instead of jnp.allclose(), check the actual differences:
position_diff_x = updated_world.sheep_set.agents.state.content['x'] - test_world.sheep_set.agents.state.content['x']
position_diff_y = updated_world.sheep_set.agents.state.content['y'] - test_world.sheep_set.agents.state.content['y']

print("\nPosition changes:")
for i in range(test_world.sheep_set.num_agents):
    dx = position_diff_x[i][0]
    dy = position_diff_y[i][0]
    distance_moved = jnp.sqrt(dx**2 + dy**2)
    print(f"  Sheep {i}: dx={dx:.6f}, dy={dy:.6f}, distance={distance_moved:.6f}")


Position changes:
  Sheep 0: dx=0.000198, dy=-0.000183, distance=0.000270
  Sheep 1: dx=-0.000412, dy=-0.000092, distance=0.000422
  Sheep 2: dx=0.000031, dy=0.000244, distance=0.000246


In [318]:
new_world, render_data = step_world(updated_world, timestep)

In [320]:
print(f"Sheep energies (after step):")
for i in range(new_world.sheep_set.num_agents):
    energy = new_world.sheep_set.agents.state.content['energy'][i][0]
    x = new_world.sheep_set.agents.state.content['x'][i][0]
    y = new_world.sheep_set.agents.state.content['y'][i][0]
    old_energy = updated_world.sheep_set.agents.state.content['energy'][i][0]
    energy_change = energy - old_energy
    print(f"  Sheep {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f} (change: {energy_change:+.4f})")

print(f"\nWolf energies (after step):")
for i in range(new_world.wolf_set.num_agents):
    energy = new_world.wolf_set.agents.state.content['energy'][i][0]
    x = new_world.wolf_set.agents.state.content['x'][i][0]
    y = new_world.wolf_set.agents.state.content['y'][i][0]
    old_energy = updated_world.wolf_set.agents.state.content['energy'][i][0]
    energy_change = energy - old_energy
    print(f"  Wolf {i}: position=({x:.2f}, {y:.2f}), energy={energy:.2f} (change: {energy_change:+.4f})")

Sheep energies (after step):
  Sheep 0: position=(199.84, 475.44), energy=43.92 (change: +0.2800)
  Sheep 1: position=(176.04, 470.54), energy=46.75 (change: +0.2800)
  Sheep 2: position=(-492.43, -224.71), energy=38.41 (change: +0.2800)

Wolf energies (after step):
  Wolf 0: position=(282.16, 72.52), energy=25.59 (change: -0.0400)
  Wolf 1: position=(478.37, 301.26), energy=43.64 (change: -0.0400)


In [321]:
# Test 3: Render data has correct shape
expected_sheep_shape = (test_world.sheep_set.num_agents, 1)
expected_wolf_shape = (test_world.wolf_set.num_agents, 1)
print(f"✓ Render data sheep shape correct: {render_data.content['sheep_xs'].shape == expected_sheep_shape}")
print(f"✓ Render data wolf shape correct: {render_data.content['wolf_xs'].shape == expected_wolf_shape}")

# Test 4: World structure preserved
print(f"✓ Same number of sheep: {updated_world.sheep_set.num_agents == test_world.sheep_set.num_agents}")
print(f"✓ Same number of wolves: {updated_world.wolf_set.num_agents == test_world.wolf_set.num_agents}")

✓ Render data sheep shape correct: True
✓ Render data wolf shape correct: True
✓ Same number of sheep: True
✓ Same number of wolves: True
