# Cloth Simulation

Simulating a cloth being lifted from a table by its corners.

Features used:
- {class}`~jaxls.Var` subclassing for custom 3D point variables
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for potential energy terms
- Equality constraints (`constraint_eq_zero`) for pinning lifted corners
- Inequality constraints (`constraint_geq_zero`) for the table surface
- Multiple spring types: structural, shear, and bend springs
- Batched cost construction for efficient problem setup

In [1]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import jax
import jax.numpy as jnp
import jaxls

## Variables and costs

We define a 3D point variable and cost functions representing potential energy terms. Since jaxls minimizes the sum of squared costs, we design each cost so that its square equals the corresponding energy contribution:

In [3]:
class Point3Var(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
    """A 3D point variable."""

In [4]:
@jaxls.Cost.factory
def spring_cost(
    vals: jaxls.VarValues,
    var_a: Point3Var,
    var_b: Point3Var,
    rest_length: jax.Array,
    stiffness: float,
) -> jax.Array:
    """Elastic potential energy for a Hookean spring.

    Returns sqrt(k) * (L - L₀) so squared cost ∝ k(L - L₀)².
    """
    diff = vals[var_a] - vals[var_b]
    length = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    return (length - rest_length) * jnp.sqrt(stiffness)


@jaxls.Cost.factory
def gravity_cost(
    vals: jaxls.VarValues,
    var: Point3Var,
    mass: float,
    g: float = 9.81,
) -> jax.Array:
    """Gravitational potential energy mgh.

    Returns sqrt(2mgh) so squared cost = 2mgh, giving constant
    downward force mg independent of height.
    """
    z = vals[var][2]
    # softplus for numerical stability near z=0
    return jnp.sqrt(2.0 * mass * g * jax.nn.softplus(z) + 1e-8)


@jaxls.Cost.factory(kind="constraint_geq_zero")
def table_constraint(
    vals: jaxls.VarValues,
    var: Point3Var,
) -> jax.Array:
    """Inequality constraint: point must stay above the table (z >= 0)."""
    return vals[var][2]


@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_constraint(
    vals: jaxls.VarValues,
    var: Point3Var,
    target: jax.Array,
) -> jax.Array:
    """Pin a point to a target position (hard constraint)."""
    return vals[var] - target

## Grid setup

Create a 12x12 grid of particles starting flat on the table (z=0). The two corners will be lifted up and anchored at a fixed height. We connect particles with three types of springs:

- Structural springs: Connect adjacent neighbors (horizontal and vertical)
- Shear springs: Connect diagonal neighbors (resist shearing)
- Bend springs: Connect skip-one neighbors (resist bending)

In [5]:
# Grid dimensions
cols, rows = 12, 12
num_points = cols * rows
spacing = 0.5
lift_height = 1.5  # Height to lift the corners


def idx(row: int, col: int) -> int:
    """Convert (row, col) to flat index.

    Args:
        row: Row index
        col: Column index

    Returns:
        Flat index into the points array
    """
    return row * cols + col


# Initial positions (regular grid flat on the table at z=0)
initial_positions = jnp.array(
    [[c * spacing, r * spacing, 0.0] for r in range(rows) for c in range(cols)]
)

print(f"Grid: {cols}x{rows} = {num_points} points")
print(f"Spacing: {spacing} units")
print(f"Lift height: {lift_height} units")

Grid: 12x12 = 144 points
Spacing: 0.5 units
Lift height: 1.5 units


In [6]:
# Create all point variables at once with batched IDs
all_point_vars = Point3Var(id=jnp.arange(num_points))

# Anchor indices: two opposite corners (diagonal lift)
anchor_indices = jnp.array([idx(0, 0), idx(rows - 1, cols - 1)])

# Lifted anchor positions (raise corners to lift_height)
anchor_positions = initial_positions[anchor_indices].at[:, 2].set(lift_height)

# Free point indices: all except anchored corners
all_indices = set(range(num_points))
anchored_set = set(anchor_indices.tolist())
free_indices = jnp.array(sorted(all_indices - anchored_set))

print(
    f"Anchored points: {len(anchor_indices)} (opposite corners, lifted to z={lift_height})"
)
print(f"Free points: {len(free_indices)}")

Anchored points: 2 (opposite corners, lifted to z=1.5)
Free points: 142


In [7]:
# Build spring connectivity arrays

# Structural springs (adjacent neighbors)
struct_a, struct_b = [], []
# Horizontal springs
for r in range(rows):
    for c in range(cols - 1):
        struct_a.append(idx(r, c))
        struct_b.append(idx(r, c + 1))
# Vertical springs
for r in range(rows - 1):
    for c in range(cols):
        struct_a.append(idx(r, c))
        struct_b.append(idx(r + 1, c))

struct_a = jnp.array(struct_a)
struct_b = jnp.array(struct_b)
struct_rest_length = spacing

# Shear springs (diagonal neighbors)
shear_a, shear_b = [], []
for r in range(rows - 1):
    for c in range(cols - 1):
        # Diagonal down-right
        shear_a.append(idx(r, c))
        shear_b.append(idx(r + 1, c + 1))
        # Diagonal down-left
        shear_a.append(idx(r, c + 1))
        shear_b.append(idx(r + 1, c))

shear_a = jnp.array(shear_a)
shear_b = jnp.array(shear_b)
shear_rest_length = spacing * jnp.sqrt(2)

# Bend springs (skip-one neighbors for stiffness)
bend_a, bend_b = [], []
# Horizontal bend
for r in range(rows):
    for c in range(cols - 2):
        bend_a.append(idx(r, c))
        bend_b.append(idx(r, c + 2))
# Vertical bend
for r in range(rows - 2):
    for c in range(cols):
        bend_a.append(idx(r, c))
        bend_b.append(idx(r + 2, c))

bend_a = jnp.array(bend_a)
bend_b = jnp.array(bend_b)
bend_rest_length = spacing * 2

print(f"Structural springs: {len(struct_a)}")
print(f"Shear springs: {len(shear_a)}")
print(f"Bend springs: {len(bend_a)}")
print(f"Total springs: {len(struct_a) + len(shear_a) + len(bend_a)}")

Structural springs: 264
Shear springs: 242
Bend springs: 240
Total springs: 746


## Problem construction

Build the optimization problem with spring and gravity costs, plus constraints for the anchored corners and table surface:

In [8]:
# Spring stiffness parameters (N/m)
structural_stiffness = 50.0
shear_stiffness = 20.0
bend_stiffness = 10.0

# Physics parameters
mass_per_point = 0.01  # kg
g = 9.81  # m/s²

# Create all costs using batched construction
costs: list[jaxls.Cost] = [
    # Anchor constraints (corners lifted)
    anchor_constraint(Point3Var(id=anchor_indices), anchor_positions),
    # Table constraint (z >= 0 for free points)
    table_constraint(Point3Var(id=free_indices)),
    # Structural springs
    spring_cost(
        Point3Var(id=struct_a),
        Point3Var(id=struct_b),
        struct_rest_length,
        structural_stiffness,
    ),
    # Shear springs
    spring_cost(
        Point3Var(id=shear_a),
        Point3Var(id=shear_b),
        shear_rest_length,
        shear_stiffness,
    ),
    # Bend springs
    spring_cost(
        Point3Var(id=bend_a),
        Point3Var(id=bend_b),
        bend_rest_length,
        bend_stiffness,
    ),
    # Gravity on free points
    gravity_cost(Point3Var(id=free_indices), mass_per_point, g),
]

print(f"Created {len(costs)} batched cost objects")

Created 6 batched cost objects


## Solving

In [9]:
# Create initial values
initial_vals = jaxls.VarValues.make([all_point_vars.with_value(initial_positions)])

# Build and solve
problem = jaxls.LeastSquaresProblem(costs, [all_point_vars]).analyze()
solution = problem.solve(initial_vals)

[1mINFO    [0m | Building optimization problem with 1032 terms and 144 variables: 888 costs, 2 eq_zero, 0 leq_zero, 142 geq_zero
[1mINFO    [0m | Vectorizing constraint group with 2 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
[1mINFO    [0m | Vectorizing constraint group with 142 constraints (constraint_geq_zero), 1 variables each: augmented_table_constraint
[1mINFO    [0m | Vectorizing group with 746 costs, 2 variables each: spring_cost
[1mINFO    [0m | Vectorizing group with 142 costs, 1 variables each: gravity_cost
[1mINFO    [0m | Augmented Lagrangian: initial snorm=1.5000e+00, csupn=1.5000e+00, max_rho=2.3811e+02, constraint_dim=148
[1mINFO    [0m |  step #1: cost=971.7655 lambd=0.0005 inexact_tol=1.0e-02
[1mINFO    [0m |      - augmented_anchor_constraint(2): 952.45416 (avg 158.74237)
[1mINFO    [0m |      - augmented_table_constraint(142): 0.00000 (avg 0.00000)
[1mINFO    [0m |      - spring_cost(746): 0.00000 (avg 0.00000)

## Visualization

Compare the initial state (cloth flat on the table) with the optimized state (cloth lifted by its corners):

In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def add_cloth_to_scene(
    fig: go.Figure,
    positions: jax.Array,
    row: int,
    col: int,
    show_table: bool = True,
) -> None:
    """Add cloth mesh and table to a subplot scene.

    Args:
        fig: Plotly figure to add traces to
        positions: Cloth node positions (num_points, 3)
        row: Subplot row index (1-indexed)
        col: Subplot column index (1-indexed)
        show_table: Whether to show the table surface
    """
    # Build triangle mesh indices for the cloth surface
    triangles_i, triangles_j, triangles_k = [], [], []
    for r in range(rows - 1):
        for c in range(cols - 1):
            # Lower-left triangle
            triangles_i.append(idx(r, c))
            triangles_j.append(idx(r + 1, c))
            triangles_k.append(idx(r, c + 1))
            # Upper-right triangle
            triangles_i.append(idx(r + 1, c))
            triangles_j.append(idx(r + 1, c + 1))
            triangles_k.append(idx(r, c + 1))

    scene_name = "scene" if (row == 1 and col == 1) else f"scene{(row - 1) * 2 + col}"

    # Add table surface
    if show_table:
        table_size = (cols - 1) * spacing + 1.5
        table_offset = -0.5
        fig.add_trace(
            go.Mesh3d(
                x=[table_offset, table_offset, table_size, table_size],
                y=[table_offset, table_size, table_size, table_offset],
                z=[-0.01, -0.01, -0.01, -0.01],
                i=[0, 0],
                j=[1, 2],
                k=[2, 3],
                color="rgb(139, 90, 43)",
                opacity=0.8,
                name="Table",
                showlegend=False,
                hoverinfo="skip",
            ),
            row=row,
            col=col,
        )

    # Cloth surface mesh
    fig.add_trace(
        go.Mesh3d(
            x=positions[:, 0],
            y=positions[:, 1],
            z=positions[:, 2],
            i=triangles_i,
            j=triangles_j,
            k=triangles_k,
            intensity=positions[:, 2],
            colorscale="Blues_r",
            showscale=False,
            opacity=1.0,
            name="Cloth",
            flatshading=False,
            lighting=dict(ambient=0.6, diffuse=0.8, specular=0.2, roughness=0.5),
            lightposition=dict(x=5, y=-5, z=10),
            hoverinfo="skip",
        ),
        row=row,
        col=col,
    )

    # Anchor points
    anchor_pos = positions[anchor_indices]
    fig.add_trace(
        go.Scatter3d(
            x=anchor_pos[:, 0],
            y=anchor_pos[:, 1],
            z=anchor_pos[:, 2],
            mode="markers",
            marker=dict(size=6, color="crimson"),
            name="Anchors",
            showlegend=False,
            hovertemplate="Anchor<br>(%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>",
        ),
        row=row,
        col=col,
    )


# Create side-by-side subplots
fig = make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("Initial: Flat on Table", "Optimized: Lifted by Corners"),
    horizontal_spacing=0.02,
)

# Add initial state (flat on table)
add_cloth_to_scene(fig, initial_positions, row=1, col=1)

# Add optimized state (lifted)
add_cloth_to_scene(fig, solution[all_point_vars], row=1, col=2)

# Configure 3D scenes
axis_range = 4
center_x = (cols - 1) * spacing / 2
center_y = (rows - 1) * spacing / 2
center_z = 1.5

scene_config = dict(
    xaxis=dict(
        title="",
        range=[center_x - axis_range, center_x + axis_range],
        showbackground=False,
        showticklabels=False,
    ),
    yaxis=dict(
        title="",
        range=[center_y - axis_range, center_y + axis_range],
        showbackground=False,
        showticklabels=False,
    ),
    zaxis=dict(
        title="",
        range=[-1, axis_range + 1],
        showbackground=False,
        showticklabels=False,
    ),
    aspectmode="cube",
    camera=dict(eye=dict(x=1.3, y=-1.3, z=0.8)),
)

fig.update_layout(
    scene=scene_config,
    scene2=scene_config,
    height=450,
    margin=dict(t=40, b=20, l=20, r=20),
    showlegend=False,
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The solver finds the equilibrium shape by minimizing total potential energy:

- **Spring energy**: $\frac{1}{2}k(L - L_0)^2$ for each spring (Hooke's law)
- **Gravitational energy**: $mgh$ for each point mass

The table constraint ($z \geq 0$) prevents points from passing through the surface. Starting from a flat configuration, lifting two diagonal corners creates a natural draping effect.

The three spring types work together:

- Structural springs maintain the basic grid structure
- Shear springs prevent excessive diagonal stretching
- Bend springs add stiffness to resist folding

For more details on solver configuration, see {class}`jaxls.TrustRegionConfig` and {class}`jaxls.AugmentedLagrangianConfig`.