In [1]:
import jax
from jax import jit, lax

from jax_md import space, partition

from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.simulator.behaviors import behavior_bank


box_size = 100.
neighbor_radius = 100.

dynamics_fn = dynamics_rigid
displacement, shift = space.periodic(side=box_size)
init_fn, step_fn = dynamics_fn(displacement, shift, map_dim=2, dt=0.1, behavior_bank=behavior_bank)

neighbor_fn = partition.neighbor_list(displacement, box_size,
                                      r_cutoff=neighbor_radius,
                                      dr_threshold=10.,
                                      capacity_multiplier=1.5,
                                      format=partition.Sparse)

@jit
def update_fn(_, state_and_neighbors):
    state, neighs = state_and_neighbors
    neighs = neighs.update(state.position.center)
    return (step_fn(state=state, neighbor=neighs),
            neighs)



{'FEAR': 0, 'AGGRESSION': 1, 'LOVE': 2, 'SHY': 3, 'manual': 4, 'noop': 5}


In [2]:
import jax.numpy as jnp
from jax_md.rigid_body import RigidBody

# initial state, considering 2 agents
init_state_kwargs = {'idx': jnp.array([0, 1]),  # indexes of the agent
                     'position': RigidBody(center=jnp.array([[75.83518, 94.59901], [71.06458, 86.82656]]),  # initial positions
                                           orientation=jnp.array([4.8367124, 4.5396805])),  # initial orientations
                     'mass': RigidBody(center=jnp.array([1., 1.]),  # mass
                                       orientation=jnp.array([0.125, 0.125])),  # moment of inertia
                     'prox': jnp.array([[0., 0.], [0., 0.]]),  # initial proximeter values
                     'motor': jnp.array([[0., 0.], [0., 0.]]),  # initial motor values
                     'behavior': jnp.array([1, 1]),  # initial behavior. For the available behaviors (and where to implement your owns): from vivarium.simulator.behaviors import behavior_name_map
                     'wheel_diameter': jnp.array([2., 2.]),  # agent wheel diameter
                     'base_length': jnp.array([5., 5.]),  # agent base diameter
                     'speed_mul': jnp.array([0.1, 0.1]),  # unused for now
                     'theta_mul': jnp.array([0.1, 0.1]),  # unused for now
                     'proxs_dist_max': jnp.array([100., 100.]), # max sensing distance of proximeters
                     'proxs_cos_min': jnp.array([0., 0.]),  # max angle cosinus od proximeter (0.0 means it can only sense object between + or - pi/2)
                     'color': jnp.array([[0., 0., 1.], [0., 0., 1.]]),  #  RGB color (only used for rendering for now)
                     'entity_type': jnp.array([0, 0])   # to distinguish different categories of entities, unused for now (they are all agents)
                    }

key = jax.random.PRNGKey(0)
state = init_fn(key, **init_state_kwargs)

neighbors = neighbor_fn.allocate(state.position.center)

In [3]:
print(state)

NVEState(position=RigidBody(center=Array([[75.83518, 94.59901],
       [71.06458, 86.82656]], dtype=float32), orientation=Array([4.8367124, 4.5396805], dtype=float32)), momentum=RigidBody(center=Array([[0., 0.],
       [0., 0.]], dtype=float32), orientation=Array([ 0., -0.], dtype=float32)), force=RigidBody(center=Array([[0., 0.],
       [0., 0.]], dtype=float32), orientation=Array([0., 0.], dtype=float32)), mass=RigidBody(center=Array([[1.],
       [1.]], dtype=float32), orientation=Array([0.125, 0.125], dtype=float32)), idx=Array([0, 1], dtype=int32), prox=Array([[0., 0.],
       [0., 0.]], dtype=float32), motor=Array([[0., 0.],
       [0., 0.]], dtype=float32), behavior=Array([1, 1], dtype=int32), wheel_diameter=Array([2., 2.], dtype=float32), base_length=Array([5., 5.], dtype=float32), speed_mul=Array([0.1, 0.1], dtype=float32), theta_mul=Array([0.1, 0.1], dtype=float32), proxs_dist_max=Array([100., 100.], dtype=float32), proxs_cos_min=Array([0., 0.], dtype=float32), color=Array([[

In [4]:
step_per_lax = 100

for _ in range(1000):
    new_state = state
    new_state, neighbors = lax.fori_loop(0, step_per_lax, update_fn, (new_state, neighbors))

    # If the neighbor list can't fit in the allocation, rebuild it but bigger.
    if neighbors.did_buffer_overflow:
        print('REBUILDING')
        neighbors = neighbor_fn.allocate(new_state.position.center)
        new_state, neighbors = lax.fori_loop(0, step_per_lax, update_fn, (state, neighbors))

        assert not neighbors.did_buffer_overflow
        
    state = new_state
    neighbors = neighbors

In [5]:
print(state)

NVEState(position=RigidBody(center=Array([[47.09005 , 84.10769 ],
       [49.025436, 81.99092 ]], dtype=float32), orientation=Array([5.4292183, 8.622202 ], dtype=float32)), momentum=RigidBody(center=Array([[-0.00117491, -0.00106411],
       [-0.00038874, -0.00035334]], dtype=float32), orientation=Array([ 0.00193335, -0.00053703], dtype=float32)), force=RigidBody(center=Array([[-0.00080221, -0.00072303],
       [-0.00096553, -0.00089671]], dtype=float32), orientation=Array([ 0.00180245, -0.00191506], dtype=float32)), mass=RigidBody(center=Array([[1.],
       [1.]], dtype=float32), orientation=Array([0.125, 0.125], dtype=float32)), idx=Array([0, 1], dtype=int32), prox=Array([[0.9713183, 0.       ],
       [0.       , 0.9713183]], dtype=float32), motor=Array([[0.       , 0.9713183],
       [0.9713183, 0.       ]], dtype=float32), behavior=Array([1, 1], dtype=int32), wheel_diameter=Array([2., 2.], dtype=float32), base_length=Array([5., 5.], dtype=float32), speed_mul=Array([0.1, 0.1], dtype

In [77]:
from vivarium.simulator.behaviors import behavior_name_map
behavior_name_map

{'FEAR': 0, 'AGGRESSION': 1, 'LOVE': 2, 'SHY': 3, 'manual': 4, 'noop': 5}