# Particle Life [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/maxencefaldor/cax/blob/main/examples/30_particle_life.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install -U "cax[examples]"

## Import

In [None]:
import jax
import jax.numpy as jnp
import mediapy
from cax.models.particle_life import ParticleLife, State
from flax import nnx

## Configuration

In [2]:
seed = 0

num_dims = 2
num_particles = 2048
boundary = "CIRCULAR"

# Physics
num_steps = 1024
dt = 0.01
velocity_half_life = dt
friction_factor = jnp.power(0.5, dt / velocity_half_life)

# Attraction
num_classes = 6
r_max = 0.15
beta = 0.3
force_factor = 1.0

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Init state

In [3]:
def init_state(key):
	key_class, key_position = jax.random.split(key)

	# Class
	class_ = jax.random.choice(key_class, num_classes, (num_particles,))

	# Position
	position = jax.random.uniform(key_position, (num_particles, num_dims), minval=0.0, maxval=1.0)

	# Velocity
	velocity = jnp.zeros((num_particles, num_dims))

	return State(class_=class_, position=position, velocity=velocity)

## Model

In [4]:
# Attraction matrix
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (num_classes, num_classes), minval=-1.0, maxval=1.0)

ca = ParticleLife(
	num_classes=num_classes,
	rngs=rngs,
	r_max=r_max,
	beta=beta,
	boundary=boundary,
	dt=dt,
	velocity_half_life=velocity_half_life,
	force_factor=force_factor,
)

## Visualize

In [11]:
state = init_state(key)

states = ca(state, num_steps=num_steps, all_steps=True)

In [None]:
states = jax.tree.map(lambda x, xs: jnp.concatenate([x[None], xs]), state, states)
frames = jax.vmap(lambda state: ca.render(state, particle_radius=0.003))(states)

mediapy.show_video(frames, width=512, height=512, fps=int(1 / dt))