# Boids [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/maxencefaldor/cax/blob/main/examples/31_boids.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install -U "cax[examples]"

## Import

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from cax.models.boids import BoidPolicy, Boids, State
from flax import nnx

## Configuration

In [3]:
seed = 0

num_dims = 2
num_boids = 256

# Physics
num_steps = 1024
dt = 0.01

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Init state

In [4]:
def init_state(key):
	key_position, key_velocity = jax.random.split(key)

	# Position
	position = jax.random.uniform(key_position, (num_boids, num_dims))

	# Velocity
	velocity = jax.random.uniform(key_velocity, (num_boids, num_dims))

	return State(position=position, velocity=velocity)

## Model

In [5]:
# Instantiate a boid policy with default parameters
boid_policy = BoidPolicy(rngs=rngs)

ca = Boids(
	boid_policy=boid_policy,
	dt=dt,
	velocity_half_life=jnp.inf,
	boundary="CIRCULAR",
)

## Visualize

In [6]:
key, subkey = jax.random.split(key)
state = init_state(subkey)

state = ca(state, num_steps=num_steps, all_steps=True)

In [None]:
# Use a single color for all boids
boid_color = "white"  # Single color for all boids

# Create figure and axis with no padding
fig = plt.figure(figsize=(6, 6), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_facecolor("black")
fig.patch.set_facecolor("black")
ax.axis("off")

# Initialize scatter plot once with the same color for all boids
scatter = ax.scatter(
	state.position[0, :, 0], state.position[0, :, 1], c=boid_color, s=5, marker="."
)

# List to store rendered images
images = []

# Render each frame
for i, current_position in enumerate(state.position):
	# Update scatter plot data
	scatter.set_offsets(current_position)

	# Force a full redraw and flush
	fig.canvas.draw()
	fig.canvas.flush_events()

	# Convert to image array directly
	img = np.array(fig.canvas.renderer.buffer_rgba())
	img_rgb = img[..., :3]  # Drop alpha channel
	images.append(img_rgb)

# Close the figure
plt.close(fig)

In [10]:
media.show_video(images, fps=int(1 / dt), width=512, height=512)

0
This browser does not support the video tag.


## Boid simulation with custom boid policy

In [110]:
class BoidPolicy(nnx.Module):
	"""Boid policy inspired by the neural network-based reference implementation."""

	def __init__(
		self,
		rngs: nnx.Rngs,
		num_neighbors: int = 16,  # Number of neighbors to consider
		perception: float = 0.1,  # Perception radius
		hidden_features: int = 8,  # Hidden layer size from reference
		acceleration_scale: float = 10.0,  # Scaling factor from reference
	):
		"""Initialize the boid policy."""
		self.rngs = rngs
		self.num_neighbors = num_neighbors
		self.perception = perception
		self.acceleration_scale = acceleration_scale

		# Define the neural network layers similar to BoidNetwork
		self.dense1 = nnx.Linear(4, hidden_features, rngs=rngs)
		self.dense2 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)
		self.dense3 = nnx.Linear(hidden_features, hidden_features, rngs=rngs)
		self.dense4 = nnx.Linear(hidden_features, 1, rngs=rngs)

	def _toroidal_vector(self, position_1: jax.Array, position_2: jax.Array) -> jax.Array:
		"""Calculate vector considering toroidal boundaries in [0, 1]^n."""
		pos_diff = position_2 - position_1
		pos_diff = jnp.where(pos_diff > 0.5, pos_diff - 1.0, pos_diff)
		pos_diff = jnp.where(pos_diff < -0.5, pos_diff + 1.0, pos_diff)
		return pos_diff

	def _get_transformation_mats(self, position: jax.Array, velocity: jax.Array):
		"""Compute global-to-local and local-to-global transformation matrices."""
		u, v = velocity / jnp.maximum(jnp.linalg.norm(velocity), 1e-8)  # Normalize velocity
		x, y = position

		# Global to local transformation (including translation)
		global2local = jnp.array([[u, v, -u * x - v * y], [-v, u, v * x - u * y], [0, 0, 1]])

		# Local to global transformation (including translation)
		local2global = jnp.array([[u, -v, x], [v, u, y], [0, 0, 1]])

		# Rotation-only matrices (for velocity)
		global2local_rot = jnp.array([[u, v, 0], [-v, u, 0], [0, 0, 1]])
		local2global_rot = jnp.array([[u, -v, 0], [v, u, 0], [0, 0, 1]])

		return global2local, local2global, global2local_rot, local2global_rot

	def __call__(self, state: State, boid_idx: int) -> jax.Array:
		"""Compute acceleration for a boid based on its neighbors.

		Args:
			state: State containing position and velocity of all boids.
			boid_idx: Index of the current boid.

		Returns:
			Acceleration vector for the boid.

		"""
		# Extract current boid's position and velocity
		xi = state.position[boid_idx]
		vi = state.velocity[boid_idx]

		# Compute distances to all other boids
		distances = jax.vmap(lambda pos: jnp.sum(self._toroidal_vector(xi, pos) ** 2))(
			state.position
		)

		# Find nearest neighbors
		idx_neighbor = jnp.argsort(distances)[1 : self.num_neighbors + 1]  # Exclude self
		xn = state.position[idx_neighbor]  # Neighbor positions
		vn = state.velocity[idx_neighbor]  # Neighbor velocities
		neighbor_distances = distances[idx_neighbor]

		# Create mask for neighbors within visual range
		mask = neighbor_distances < self.perception**2

		# Get transformation matrices
		g2l, l2g, g2lr, l2gr = self._get_transformation_mats(xi, vi)

		# Transform neighbor positions to local frame
		xn_hom = jnp.concatenate(
			[xn, jnp.ones((self.num_neighbors, 1))], axis=-1
		)  # Homogeneous coords
		xn_local = jax.vmap(lambda x: g2l @ x)(xn_hom[:, :, None])[:, :2, 0]  # num_neighbors, 2

		# Transform neighbor velocities to local frame (rotation only)
		vn_hom = jnp.concatenate([vn, jnp.ones((self.num_neighbors, 1))], axis=-1)
		vn_local = jax.vmap(lambda v: g2lr @ v)(vn_hom[:, :, None])[:, :2, 0]  # num_neighbors, 2

		# Prepare inputs for the neural network (scale positions as in reference)
		inputs = jnp.concatenate([50.0 * xn_local, vn_local], axis=-1)  # num_neighbors, 4

		# Neural network processing (similar to BoidNetwork)
		x = self.dense1(inputs)  # num_neighbors, hidden_features
		x = nnx.tanh(x)
		x = self.dense2(x)
		x = nnx.tanh(x)

		# Aggregate over neighbors with mask
		x = (x * mask[:, None]).sum(axis=0) / jnp.maximum(mask.sum(), 1e-8)  # hidden_features

		# Final layers
		x = self.dense3(x)
		x = nnx.tanh(x)
		x = self.dense4(x)
		x = nnx.tanh(x)  # Scalar output

		# Handle case with no neighbors
		dv_local = jax.lax.select(
			mask.sum() > 0,
			jnp.array([0.0, x[0]]),  # [x, y] in local frame
			jnp.zeros(2),
		)

		# Scale acceleration
		dv_local = dv_local * self.acceleration_scale

		# Transform back to global frame
		dv_hom = jnp.concatenate([dv_local, jnp.zeros(1)], axis=-1)  # 3D homogeneous
		dv_global = (l2gr @ dv_hom[:, None])[:2, 0]  # Back to 2D global coords

		return dv_global

In [111]:
# Instantiate a boid policy with default parameters
boid_policy = BoidPolicy(
	rngs=rngs,
	acceleration_scale=2.0,
)

ca = Boids(
	boid_policy=boid_policy,
	dt=dt,
	velocity_half_life=jnp.inf,
	boundary="CIRCULAR",
)

## Visualize

In [112]:
def init_state(key):
	key_position, key_velocity = jax.random.split(key)

	# Position
	position = jax.random.uniform(key_position, (num_boids, num_dims))

	# Velocity
	velocity = 0.1 * jax.random.uniform(key_velocity, (num_boids, num_dims))

	return State(position=position, velocity=velocity)

In [113]:
key, subkey = jax.random.split(key)
state = init_state(subkey)

state = ca(state, num_steps=num_steps, all_steps=True)

In [114]:
# Use a single color for all boids
boid_color = "white"  # Single color for all boids

# Create figure and axis with no padding
fig = plt.figure(figsize=(6, 6), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_facecolor("black")
fig.patch.set_facecolor("black")
ax.axis("off")

# Initialize scatter plot once with the same color for all boids
scatter = ax.scatter(
	state.position[0, :, 0], state.position[0, :, 1], c=boid_color, s=5, marker="."
)

# List to store rendered images
images = []

# Render each frame
for i, current_position in enumerate(state.position):
	# Update scatter plot data
	scatter.set_offsets(current_position)

	# Force a full redraw and flush
	fig.canvas.draw()
	fig.canvas.flush_events()

	# Convert to image array directly
	img = np.array(fig.canvas.renderer.buffer_rgba())
	img_rgb = img[..., :3]  # Drop alpha channel
	images.append(img_rgb)

# Close the figure
plt.close(fig)

In [115]:
media.show_video(images, fps=int(1 / dt), width=512, height=512)

0
This browser does not support the video tag.
