
# Deformable Particles

This guide covers the :py:class:`~jaxdem.bonded_forces.deformable_particle.DeformableParticleModel`,
a bonded-force model that turns a collection of point particles (vertices) into
elastic, deformable bodies.

A deformable particle is defined by its mesh: a set of vertices connected by
elements (triangles in 3D, segments in 2D).
The model computes elastic forces from the mesh geometry, penalising deviations
in element measure (area/length), body content (volume/area), bending angle,
edge length, and surface tension.

Let's explore how to create, configure, and extend deformable particles.


## Creating a Deformable Particle
A deformable particle is created with
:py:meth:`~jaxdem.bonded_forces.BondedForceModel.create`, specifying the
registered name ``"deformableparticlemodel"`` and the mesh topology.

Vertices define particle positions, and elements define connectivity.
If reference (stress-free) quantities are not provided, they are computed
automatically from ``vertices``. If all required reference quantities are
provided explicitly, the ``vertices`` argument is optional.

The connectivity arrays stored in the deformable particle contain
the particles' ``unique_id`` values from
:py:class:`~jaxdem.system.State`. Because of this, colliders that reorder
(sort) particle arrays remain compatible with deformable particles.

For the simulation to run correctly, all vertices referenced by the
deformable model must exist in the simulation ``state``.



In [None]:
import jax
import jax.numpy as jnp
import jaxdem as jdem

# A simple square boundary in 2D: 4 vertices connected by 4 segments.
vertices_2d = jnp.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]], dtype=float)
elements_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)
edges_2d = elements_2d  # In 2D the edges often coincide with elements.
adjacency_2d = jnp.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=int)

dp = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=edges_2d,
    element_adjacency=adjacency_2d,
    em=1.0,
    eb=0.5,
    el=0.3,
    gamma=0.1,
)
print("Created DP:", type(dp).__name__)

## Passing the Model to the System
There are two equivalent ways to attach a deformable particle model to a
:py:class:`~jaxdem.system.System`.

**Option 1** — pass the model object directly:



In [None]:
state = jdem.State.create(pos=vertices_2d)
system = jdem.System.create(state.shape, bonded_force_model=dp)

**Option 2** — pass the registered type name and keyword arguments.
:py:meth:`~jaxdem.system.System.create` will build the model internally:



In [None]:
system = jdem.System.create(
    state.shape,
    bonded_force_model_type="deformableparticlemodel",
    bonded_force_manager_kw=dict(
        vertices=vertices_2d,
        elements=elements_2d,
        edges=edges_2d,
        element_adjacency=adjacency_2d,
        em=1.0,
        eb=0.5,
        el=0.3,
        gamma=0.1,
    ),
)

In case both are passed, the model object takes precedence and the keyword arguments are ignored.



## Coefficient Broadcasting
Every coefficient can be passed as a **scalar** or as a full array.
Scalar values are automatically broadcast to the correct shape determined by
the corresponding geometric entity:

.. list-table::
   :header-rows: 1

   * - Coefficient
     - Target shape
     - Description
   * - ``em``
     - ``(M,)`` — per element
     - Measure (area/length) stiffness
   * - ``gamma``
     - ``(M,)`` — per element
     - Surface/line tension
   * - ``eb``
     - ``(A,)`` — per adjacency pair
     - Bending stiffness
   * - ``el``
     - ``(E,)`` — per edge
     - Edge length stiffness
   * - ``ec``
     - ``(K,)`` — per body
     - Content (volume/area) stiffness

``ec`` is special: it is a **per-body** coefficient, not per-element.
The ``elements_id`` array maps each element to its parent body, so the
model knows which ``ec`` value to read for each element's content
contribution.
When only one body is present and ``elements_id`` is not provided,
``ec`` must have shape ``(1,)``.



In [None]:
dp_scalar = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=edges_2d,
    em=2.0,  # broadcast to shape (4,)
    el=0.3,  # broadcast to shape (4,)
    gamma=0.1,  # broadcast to shape (4,)
)
print("em shape:", dp_scalar.em.shape)  # (4,)
print("el shape:", dp_scalar.el.shape)  # (4,)
print("gamma shape:", dp_scalar.gamma.shape)  # (4,)

## Lazy Array Creation
The constructor only allocates the arrays that are actually needed.
If a coefficient is ``None`` (i.e. not provided), the corresponding topology
and reference arrays are **not** stored, even if they were passed to the
constructor. This keeps the model lean when only a subset of energy terms is
active.



In [None]:
dp_edges_only = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=edges_2d,
    element_adjacency=adjacency_2d,
    el=0.3,  # Only edge springs are active.
)
print("elements stored?", dp_edges_only.elements is not None)  # False
print("edges stored?", dp_edges_only.edges is not None)  # True
print("adjacency stored?", dp_edges_only.element_adjacency is not None)  # False

Here, even though ``elements`` and ``element_adjacency`` were passed, they
are discarded because no coefficient that needs them (``em``, ``ec``,
``gamma``, ``eb``) was provided.



## 2D vs 3D Differences
JaxDEM supports both 2D and 3D deformable particles. The key differences
are:

.. list-table::
   :header-rows: 1

   * - Concept
     - 2D
     - 3D
   * - Elements
     - Segments ``(M, 2)``
     - Triangles ``(M, 3)``
   * - Measure
     - Segment length
     - Triangle area
   * - Content
     - Enclosed area
     - Enclosed volume
   * - Bending
     - Angle at shared vertex
     - Dihedral angle at shared edge
   * - ``element_adjacency_edges``
     - Not needed (automatically inferred)
     - Required ``(A, 2)`` — vertex IDs of the shared edge

The dimension is inferred from the vertices: ``vertices.shape[-1]``
determines whether the model operates in 2D or 3D, and this must be
consistent with ``elements.shape[-1]``.

In 3D, each adjacency pair shares an edge (two vertices), and
``element_adjacency_edges`` stores those vertex IDs. If not provided,
the constructor will automatically infer them from the element connectivity.
In 2D, adjacencies share a single vertex and the edges array is not needed.



In [None]:
# A minimal 3D example: a tetrahedron with 4 triangular faces.
vertices_3d = jnp.array(
    [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 1.0, 0.0], [0.5, 0.5, 1.0]],
    dtype=float,
)
elements_3d = jnp.array(
    [[0, 1, 2], [0, 1, 3], [1, 2, 3], [0, 2, 3]],
    dtype=int,
)
adjacency_3d = jnp.array(
    [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]],
    dtype=int,
)

dp_3d = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_3d,
    elements=elements_3d,
    element_adjacency=adjacency_3d,
    em=1.0,
    eb=0.5,
)
print("3D elements shape:", dp_3d.elements.shape)  # (4, 3)
print("3D adjacency_edges shape:", dp_3d.element_adjacency_edges.shape)  # (6, 2)

## The ``elements_id`` Field
When a single deformable particle model contains **multiple bodies**, the
``elements_id`` array identifies which body each element belongs to.
This is required for the content energy term (``ec``), which is a per-body
quantity: the partial content contributions of each element are summed
per-body using ``elements_id``.

``elements_id`` has shape ``(M,)`` and contains integer body indices.
For example, if you have two bodies with 3 and 2 elements respectively:

```python
elements_id = jnp.array([0, 0, 0, 1, 1])
ec = jnp.array([0.5, 0.8])  # one value per body
```
When ``elements_id`` is not provided and ``ec`` is used, all elements are
assumed to belong to a single body (body 0), and ``ec`` must have shape
``(1,)``.



In [None]:
dp_two_bodies = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    elements_id=jnp.array([0, 0, 1, 1]),
    ec=jnp.array([0.5, 0.8]),
)
print("ec shape:", dp_two_bodies.ec.shape)  # (2,)
print("elements_id:", dp_two_bodies.elements_id)

## Vertex-Vertex Interactions
By default, vertices that belong to the **same deformable particle** (i.e.
share the same ``bond_id`` in the :py:class:`~jaxdem.state.State`) do
**not** interact through the regular contact forces. This avoids
self-collision within a single body.

You can toggle this behaviour with the ``interact_same_bond_id``
flag on :py:class:`~jaxdem.system.System`:



In [None]:
state = jdem.State.create(pos=vertices_2d)

# Default: vertices within the same DP do NOT collide.
system_no_self = jdem.System.create(
    state.shape,
    bonded_force_model=dp,
    interact_same_bond_id=False,
)
print("Self-interaction:", bool(system_no_self.interact_same_bond_id))

# Enable self-collision within the same DP.
system_self = jdem.System.create(
    state.shape,
    bonded_force_model=dp,
    interact_same_bond_id=True,
)
print("Self-interaction:", bool(system_self.interact_same_bond_id))

## Adding and Merging Deformable Particles
Just like :py:class:`~jaxdem.state.State`, deformable particle models
support ``add`` and ``merge`` operations for building up complex
configurations from smaller pieces.

:py:meth:`~jaxdem.bonded_forces.deformable_particle.DeformableParticleModel.add`
creates a new body from raw arrays and merges it into an existing model.
It is equivalent to calling ``create`` followed by ``merge``.



In [None]:
dp_base = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=edges_2d,
    em=1.0,
    el=0.3,
)

# Add a second body.
new_verts = vertices_2d + jnp.array([2.0, 0.0])
dp_extended = dp_base.add(
    dp_base,
    vertices=new_verts,
    elements=elements_2d,
    edges=edges_2d,
    em=2.0,
    el=0.5,
)
print("Elements after add:", dp_extended.elements.shape)  # (8, 2)
print("Edges after add:", dp_extended.edges.shape)  # (8, 2)

:py:meth:`~jaxdem.bonded_forces.deformable_particle.DeformableParticleModel.merge`
concatenates two existing models. Vertex indices and body IDs are
automatically shifted so that references remain consistent. This means each
merged container represents new bodies with new element_ids. element_id is
automatically shifted to ensure uniqueness across the merged model.
When one side has a term and the other does not, missing coefficients
are padded with ``0`` and missing reference values are padded with ``1``.



In [None]:
dp_a = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    em=1.0,
    gamma=0.1,
)

dp_b = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    elements=elements_2d,
    edges=edges_2d,
    el=0.5,
)

dp_merged = dp_a.merge(dp_a, dp_b)
print("Merged em:", dp_merged.em)  # em padded with 0 for dp_b's elements
print("Merged el:", dp_merged.el)  # el padded with 0 for dp_a's edges

## Batched Simulations with ``vmap``
Deformable particles work seamlessly with :py:func:`jax.vmap` for running
many independent simulations in parallel. Each simulation gets its own
``State`` and ``System`` (including its own bonded model).



In [None]:
from typing import Tuple


def create_sim(_i: jax.Array) -> Tuple[jdem.State, jdem.System]:
    state = jdem.State.create(pos=vertices_2d)
    dp_model = jdem.BondedForceModel.create(
        "deformableparticlemodel",
        vertices=state.pos,
        elements=elements_2d,
        edges=edges_2d,
        element_adjacency=adjacency_2d,
        em=[1.0],
        eb=[0.5],
        el=0.3,
        gamma=0.1,
    )
    system = jdem.System.create(state.shape, bonded_force_model=dp_model)
    return state, system


# Build a batch of 8 independent simulations.
states, systems = jax.vmap(create_sim)(jnp.arange(8))
print("Batched pos shape (B, N, dim):", states.pos.shape)

# Advance all simulations by 5 steps in parallel.
states, systems = systems.step(states, systems, n=5)
print("After stepping:", states.pos.shape)

## A Note on Edges
The ``edges`` array and the ``el`` coefficient define **spring connections**
between vertex pairs. These are fully independent from the ``elements``
connectivity: you can define edge springs that do not correspond to any
mesh element. This makes it possible to model springs that are external to
the mesh geometry, such as cross-bracing springs or tethers between
non-adjacent vertices.

Because edges are independent, they can be used alone (without elements) or
in combination with any other energy term.



In [None]:
dp_extra_springs = jdem.BondedForceModel.create(
    "deformableparticlemodel",
    vertices=vertices_2d,
    edges=jnp.array([[0, 2], [1, 3]]),  # Diagonal springs (not mesh edges).
    el=0.5,
)
print("Diagonal springs — edges:", dp_extra_springs.edges)
print("Diagonal springs — elements stored?", dp_extra_springs.elements is not None)

## VTK Output
The :py:class:`~jaxdem.writers.VTKWriter` automatically detects when a
:py:class:`~jaxdem.bonded_forces.deformable_particle.DeformableParticleModel`
is attached to the system and writes additional VTK files:

* **deformable_elements** — the mesh elements (triangles/segments) with
  per-cell data: ``elements_id``, ``ec``, ``gamma``,
  ``initial_element_measures``, ``current_element_measures``,
  ``partial_content``, and ``element_normals``.
* **deformable_edges** — the edge springs with per-cell data:
  ``initial_edge_lengths``, ``current_edge_lengths``, and ``el``.
* **deformable_edge_adjacencies** — the adjacency pairs (hinge edges in 3D,
  hinge vertices in 2D) with per-cell data: ``initial_bendings``,
  ``current_bendings``, and ``eb``.

Only the writers whose corresponding energy terms are active are included.
For example, if ``el`` is ``None``, the edges writer is skipped entirely.
This output can be loaded in ParaView for visualisation and debugging.

