# 1D-ARC Neural Cellular Automata

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install -U "cax[examples]"

## Import

In [1]:
import json
import os

import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.update.residual_update import ResidualUpdate
from flax import nnx
from tqdm.auto import tqdm


jax.config.update("jax_enable_x64", True)
key = jax.random.key(0)

2025-02-24 02:10:17.361561: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
num_dims = 2
dt = 0.02
t_half = 0.04
friction_factor = jnp.power(0.5, dt / t_half)

num_particles = 1024
num_classes = 6
r_max = 0.1

# Attraction matrix
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (num_classes, num_classes), minval=-1., maxval=1.)

# Class
key, subkey = jax.random.split(key)
classes = jax.random.choice(subkey, num_classes, (num_particles,))

# Position
key, subkey = jax.random.split(key)
positions = jax.random.uniform(subkey, (num_particles, num_dims), minval=0., maxval=1.)

# Velocity
velocities = jnp.zeros((num_particles, num_dims))


attraction_factors = A[classes[:, None], classes[None, :]]

In [3]:
beta = 0.3
def get_forces(distances, attraction_factors):
    distances /= r_max
    return jnp.select(
        condlist=[distances <= beta, (distances > beta) & (distances <= 1)],
        choicelist=[distances / beta - 1, attraction_factors * (1 - jnp.abs(2 * distances - 1 - beta) / (1 - beta))],
        default=0.
    )

In [4]:
directions = positions[:, None, :] - positions[None, :, :]
distances = jnp.linalg.norm(directions, axis=-1)
directions_norm = jnp.where((distances == 0)[..., None], 0., directions / distances[..., None])

In [5]:
force_factor = 10.

@jax.jit
def get_accelerations(positions):
    # Calculate periodic distances in each dimension
    pos_diff = positions[None, :, :] - positions[:, None, :]
    pos_diff = jnp.where(pos_diff > 0.5, pos_diff - 1.0, pos_diff)
    pos_diff = jnp.where(pos_diff < -0.5, pos_diff + 1.0, pos_diff)

    # Calculate distances and normalized directions with periodic conditions
    distances = jnp.linalg.norm(pos_diff, axis=-1)
    directions_norm = jnp.where(jnp.eye(num_particles)[..., None], 0., pos_diff / distances[..., None])
    directions_norm = jnp.divide(pos_diff, distances[..., None], where=distances[..., None] != 0.)

    forces = get_forces(distances, attraction_factors)
    return force_factor * r_max * jnp.sum(forces[..., None] * directions_norm, axis=1)


In [6]:
# Class
key, subkey = jax.random.split(key)
classes = jax.random.choice(subkey, num_classes, (num_particles,))

# Position
key, subkey = jax.random.split(key)
positions = jax.random.uniform(subkey, (num_particles, num_dims), minval=0., maxval=1.)

# Velocity
velocities = jnp.zeros((num_particles, num_dims))

position_list = [positions]
for i in range(1024):
    accelerations = get_accelerations(positions)
    velocities = friction_factor * velocities + accelerations * dt
    positions += velocities * dt

    # Apply periodic boundary conditions
    positions = positions % 1.0

    position_list.append(positions)

In [None]:
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

# Increase animation embed limit to 50MB
plt.rcParams['animation.embed_limit'] = 50*1024*1024

# Sample random colors for each particle
colors = plt.cm.rainbow(np.linspace(0, 1, num_classes))
particle_colors = colors[classes]

# Create figure and axis
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_facecolor('black')
fig.patch.set_facecolor('black')

# Initialize scatter plot
scatter = ax.scatter(position_list[0][:, 0], position_list[0][:, 1], c=particle_colors, s=10)  # Reduced size with s=30

# Animation update function
def update(frame):
    scatter.set_offsets(position_list[frame])
    return scatter,

# Create animation
anim = animation.FuncAnimation(
    fig, update, frames=len(position_list),
    interval=50, blit=True
)

plt.tight_layout()

# Display animation
HTML(anim.to_jshtml())
