## Tips to keep in mind

- Prefer Struct-of-Arrays (SoA) over Array-of-Structs (AoS). Keep each attribute in its own contiguous NumPy array: pos, vel, mass, … not a list of Particle objects. This maximizes cache locality, vectorization, and Numba/JAX/CuPy performance.

- Avoid Python objects in inner loops. Pure float32/float64 and int32 NumPy arrays; no lists of dicts, no per-particle dataclasses.

- Dense, fixed-shape arrays. Ragged arrays and variable-length lists hurt vectorization and JIT.

- Separate dynamic from static. Put evolving state (pos/vel) in one block, constants (mass/radius/type) in another so you can stream the hot data every step.

- IDs vs indices. Use stable id (int) for external references, but store state in dense arrays addressed by dense index. Keep an id↔index mapping (“sparse set”) to enable O(1) add/remove.

In [1]:
from dataclasses import dataclass
import numpy as np
import scipy as sp
import numba as nb

In [None]:
@dataclass
class ParticleSet:
    # --- Fixed-size storage ---
    capacity: int
    dimensions: int = 3
    dtype: np.dtype = np.float32

    # --- arrays (allocated in __post_init__) ---
    positions:  np.ndarray = None   # (capacity, dimensions)
    velocities: np.ndarray = None   # (capacity, dimensions)
    masses:     np.ndarray = None   # (capacity,)
    radii:      np.ndarray = None   # (capacity,)

    # --- bookkeeping ---
    alive_mask:   np.ndarray = None # (capacity,), True for active rows in [0:size)
    particle_ids: np.ndarray = None # (capacity,), int64 stable handles
    id_to_index:  dict = None       # maps particle_id -> dense index

    # --- active count (all active particles are stored in rows [0:size)) ---
    size: int = 0
    next_id: int = 0                # monotonically increasing ID source

    def __post_init__(self):
        N, D = self.capacity, self.dimensions
        self.positions  = np.zeros((N, D), dtype=self.dtype)
        self.velocities = np.zeros((N, D), dtype=self.dtype)
        self.masses     = np.ones(N, dtype=self.dtype)
        self.radii      = np.zeros(N, dtype=self.dtype)

        self.alive_mask   = np.zeros(N, dtype=bool)
        self.particle_ids = np.empty(N, dtype=np.int64)
        self.id_to_index  = {}

    # --- add / spawn (batch) ---
    def add(self, positions, velocities, masses, radii):
        """
        Add k particles in one call.
        Returns: np.ndarray of particle_ids with shape (k,)
        """
        k = positions.shape[0]
        assert positions.shape  == (k, self.dimensions)
        assert velocities.shape == (k, self.dimensions)
        assert masses.shape     == (k,)
        assert radii.shape      == (k,)
        if self.size + k > self.capacity:
            raise ValueError("Not enough capacity to add particles")

        start = self.size
        end   = start + k

        self.positions[start:end]  = positions
        self.velocities[start:end] = velocities
        self.masses[start:end]     = masses
        self.radii[start:end]      = radii
        self.alive_mask[start:end] = True
 
        # --- assign stable, unique IDs ---
        new_ids = np.arange(self.next_id, self.next_id + k, dtype=np.int64)
        self.particle_ids[start:end] = new_ids
        for row, pid in enumerate(new_ids, start=start):
            self.id_to_index[int(pid)] = row
        self.next_id += k

        self.size = end
        return new_ids

    # --- remove (swap-remove keeps front dense) ---
    def remove_by_index(self, index: int):
        """Remove particle at dense index `index` by swapping with the last active row."""
        last_active = self.size - 1
        if index < 0 or index > last_active:
            raise IndexError("Index out of range")

        removed_id = int(self.particle_ids[index])

        if index != last_active:
            # --- move last active particle into the hole ---
            for arr in (self.positions, self.velocities, self.masses, self.radii,
                        self.particle_ids, self.alive_mask):
                arr[index], arr[last_active] = arr[last_active].copy(), arr[index].copy()
            # --- fix mapping for the moved particle ---
            moved_id = int(self.particle_ids[index])
            self.id_to_index[moved_id] = index

        # --- mark tail as inactive and drop mapping for removed particle ---
        self.alive_mask[last_active] = False
        self.id_to_index.pop(removed_id, None)
        self.size -= 1

    def remove_by_id(self, particle_id: int):
        """Remove a particle using its stable ID."""
        index = self.id_to_index.get(int(particle_id))
        if index is None:
            raise KeyError("Unknown particle_id")
        self.remove_by_index(index)
    
    # --- convenience views ---
    def active_slice(self):
        """Slice covering all active rows (0..size-1)."""
        return slice(0, self.size)

    def active_arrays(self):
        """Views for kernels: only the active prefix."""
        s = self.active_slice()
        return (self.positions[s], self.velocities[s], self.masses[s], self.radii[s])

In [None]:
# --- example usage ---
test_set = ParticleSet(capacity=100_000, dimensions=2)

# --- add 10 particles --- 
num_particles = 10
ids = test_set.add(positions=np.random.randn(num_particles,2).astype(np.float32),
                   velocities=np.zeros((num_particles,2), np.float32),
                   masses=np.ones(num_particles, np.float32),
                   radii=0.05*np.ones(num_particles, np.float32))
 
# --- remove by id --- 
test_set.remove_by_id(int(ids[3]))

# --- pass active views to a kernel --- 
positions, velocities, masses, radii = test_set.active_arrays()