# Beam Bending

A 2D cantilever beam bending problem: finding the equilibrium shape of a beam fixed at one end with a point load applied at the tip.

Features used:
- {class}`~jaxls.Var` subclassing for custom variable types
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for defining elastic energy costs
- Equality constraints (`constraint_eq_zero`) for fixed-end boundary conditions
- Batched cost construction for efficiency

We model a steel ruler (1 m × 30 mm × 3 mm) clamped at one end with a small weight hanging from the tip.

In [1]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import jax
import jax.numpy as jnp
import jaxls

## Variables and costs

We model the beam as a series of nodes along its length. Each node has a 2D position that we want to solve for. The beam is discretized into segments, and we minimize the elastic bending energy while constraining the fixed end.

In [3]:
class NodeVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """A 2D node position variable."""

Cost functions define the physics of the beam based on Euler-Bernoulli beam theory:

1. Bending energy: $E_b = \frac{1}{2} EI \int \kappa^2 \, ds$ where $\kappa$ is curvature
2. Stretching energy: $E_s = \frac{1}{2} EA \int \epsilon^2 \, ds$ where $\epsilon$ is axial strain  
3. Tip displacement constraint: We constrain the tip to the analytical deflection $\delta = \frac{FL^3}{3EI}$
4. Fixed-end constraints: Position and slope at the clamped end

In [4]:
@jaxls.Cost.factory
def bending_cost(
    vals: jaxls.VarValues,
    node_a: NodeVar,
    node_b: NodeVar,
    node_c: NodeVar,
    EI: float,
    ds: float,
) -> jax.Array:
    """Discrete bending energy: EI * kappa^2 * ds.

    Curvature kappa is approximated by the angle change between segments
    divided by segment length.
    """
    p_a = vals[node_a]
    p_b = vals[node_b]
    p_c = vals[node_c]

    # Vectors along segments
    v1 = p_b - p_a
    v2 = p_c - p_b

    # Angle between segments (sin approximation for small angles)
    cross = v1[0] * v2[1] - v1[1] * v2[0]
    len1 = jnp.sqrt(jnp.sum(v1**2) + 1e-12)
    len2 = jnp.sqrt(jnp.sum(v2**2) + 1e-12)
    sin_theta = cross / (len1 * len2)

    # Curvature ≈ theta / ds
    kappa = sin_theta / ds

    # Bending energy contribution: sqrt(EI * ds) * kappa (squared by solver)
    return jnp.sqrt(EI * ds) * kappa


@jaxls.Cost.factory
def stretching_cost(
    vals: jaxls.VarValues,
    node_a: NodeVar,
    node_b: NodeVar,
    rest_length: float,
    EA: float,
    ds: float,
) -> jax.Array:
    """Discrete stretching energy: EA * strain^2 * ds."""
    diff = vals[node_b] - vals[node_a]
    length = jnp.sqrt(jnp.sum(diff**2) + 1e-12)
    strain = (length - rest_length) / rest_length
    return jnp.sqrt(EA * ds) * strain


@jaxls.Cost.factory(kind="constraint_eq_zero")
def fixed_position_constraint(
    vals: jaxls.VarValues,
    node: NodeVar,
    target: jax.Array,
) -> jax.Array:
    """Pin a node to a fixed position."""
    return vals[node] - target


@jaxls.Cost.factory(kind="constraint_eq_zero")
def fixed_angle_constraint(
    vals: jaxls.VarValues,
    node_a: NodeVar,
    node_b: NodeVar,
    target_direction: jax.Array,
) -> jax.Array:
    """Constrain the angle of a segment to match a target direction."""
    segment = vals[node_b] - vals[node_a]
    seg_norm = segment / (jnp.sqrt(jnp.sum(segment**2)) + 1e-12)
    cross = seg_norm[0] * target_direction[1] - seg_norm[1] * target_direction[0]
    return cross


@jaxls.Cost.factory(kind="constraint_eq_zero")
def tip_displacement_constraint(
    vals: jaxls.VarValues,
    node: NodeVar,
    target_y: float,
) -> jax.Array:
    """Constrain the tip to a specific vertical displacement."""
    return vals[node][1] - target_y

## Physical parameters

We model a thin steel ruler:
- Length: 1 m
- Cross-section: 30 mm wide x 3 mm thick
- Material: Steel (E = 200 GPa)
- Tip load: 3 N (approximately 300 gram weight)

The analytical tip deflection for a cantilever beam is $\delta = \frac{FL^3}{3EI}$.

In [18]:
# Geometry
num_nodes = 21
beam_length = 1.0  # [m]
segment_length = beam_length / (num_nodes - 1)

# Material properties: steel
E = 200e9  # Young's modulus [Pa]
width = 0.030  # [m] = 30 mm
thickness = 0.003  # [m] = 3 mm
A = width * thickness  # Cross-sectional area [m^2]
I = width * thickness**3 / 12  # Second moment of area [m^4]

# Stiffnesses
EI = E * I  # Bending stiffness [N·m^2]
EA = E * A  # Axial stiffness [N]

# Applied load
tip_force = 3.0  # [N]

# Analytical solution for comparison
delta_analytical = tip_force * beam_length**3 / (3 * EI)

print(f"Beam: L = {beam_length} m, {width * 1000:.0f} mm × {thickness * 1000:.0f} mm")
print(f"Bending stiffness EI = {EI:.2f} N·m²")
print(f"Tip load F = {tip_force} N")
print(f"Analytical tip deflection: {delta_analytical * 1000:.2f} mm")

# Initial configuration: straight horizontal beam
initial_positions = jnp.array([[i * segment_length, 0.0] for i in range(num_nodes)])

Beam: L = 1.0 m, 30 mm × 3 mm
Bending stiffness EI = 13.50 N·m²
Tip load F = 3.0 N
Analytical tip deflection: 74.07 mm


## Problem construction

We use batched construction for efficiency: creating arrays of indices and passing them to cost factories in a single call.

In [19]:
# Create batched node variables
node_vars = NodeVar(id=jnp.arange(num_nodes))

# Indices for bending costs (triplets of consecutive nodes)
bending_a = jnp.arange(num_nodes - 2)
bending_b = jnp.arange(1, num_nodes - 1)
bending_c = jnp.arange(2, num_nodes)

# Indices for stretching costs (pairs of consecutive nodes)
stretch_a = jnp.arange(num_nodes - 1)
stretch_b = jnp.arange(1, num_nodes)

# Build costs using batched construction
costs: list[jaxls.Cost] = [
    # Bending energy at interior nodes
    bending_cost(
        NodeVar(id=bending_a),
        NodeVar(id=bending_b),
        NodeVar(id=bending_c),
        EI,
        segment_length,
    ),
    # Stretching energy for all segments
    stretching_cost(
        NodeVar(id=stretch_a),
        NodeVar(id=stretch_b),
        segment_length,
        EA,
        segment_length,
    ),
    # Fixed position at clamped end
    fixed_position_constraint(NodeVar(id=0), initial_positions[0]),
    # Fixed angle: first segment must be horizontal
    fixed_angle_constraint(
        NodeVar(id=0),
        NodeVar(id=1),
        jnp.array([1.0, 0.0]),
    ),
    # Tip displacement from analytical solution
    tip_displacement_constraint(NodeVar(id=num_nodes - 1), -delta_analytical),
]

print(f"Created {len(costs)} cost objects")

Created 5 cost objects


## Solving

In [20]:
# Create initial values
initial_vals = jaxls.VarValues.make([node_vars.with_value(initial_positions)])

# Build and solve
problem = jaxls.LeastSquaresProblem(costs, [node_vars]).analyze()
solution = problem.solve(initial_vals)

[1mINFO    [0m | Building optimization problem with 42 terms and 21 variables: 39 costs, 3 eq_zero, 0 leq_zero, 0 geq_zero
[1mINFO    [0m | Vectorizing group with 19 costs, 3 variables each: bending_cost
[1mINFO    [0m | Vectorizing group with 20 costs, 2 variables each: stretching_cost
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_fixed_position_constraint
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_tip_displacement_constraint
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 2 variables each: augmented_fixed_angle_constraint
[1mINFO    [0m | Augmented Lagrangian: initial snorm=7.4074e-02, csupn=7.4074e-02, max_rho=1.0000e+01, constraint_dim=4
[1mINFO    [0m |  step #1: cost=0.0549 lambd=0.0005 inexact_tol=1.0e-02
[1mINFO    [0m |      - bending_cost(19): 0.00000 (avg 0.00000)
[1mINFO    [0m

## Visualization

In [21]:
import plotly.graph_objects as go
from IPython.display import HTML

# Extract positions
initial_pos = initial_vals[node_vars]
final_pos = solution[node_vars]

# Compute tip deflection
tip_deflection = float(final_pos[-1, 1])
print(f"Tip deflection: {tip_deflection * 1000:.2f} mm")

# Create visualization (convert to mm)
fig = go.Figure()

# Initial (straight) beam
fig.add_trace(
    go.Scatter(
        x=[float(p) * 1000 for p in initial_pos[:, 0]],
        y=[float(p) * 1000 for p in initial_pos[:, 1]],
        mode="lines+markers",
        name="Initial (straight)",
        line=dict(color="lightgray", width=3, dash="dash"),
        marker=dict(size=5, color="lightgray"),
    )
)

# Deflected beam
fig.add_trace(
    go.Scatter(
        x=[float(p) * 1000 for p in final_pos[:, 0]],
        y=[float(p) * 1000 for p in final_pos[:, 1]],
        mode="lines+markers",
        name="Deflected",
        line=dict(color="steelblue", width=4),
        marker=dict(size=6, color="steelblue"),
    )
)

# Fixed end indicator
fig.add_trace(
    go.Scatter(
        x=[0],
        y=[0],
        mode="markers",
        name="Fixed end",
        marker=dict(size=15, color="crimson", symbol="square"),
    )
)

# Tip load indicator (arrow)
fig.add_annotation(
    x=float(final_pos[-1, 0]) * 1000,
    y=float(final_pos[-1, 1]) * 1000,
    ax=0,
    ay=-40,
    xref="x",
    yref="y",
    axref="pixel",
    ayref="pixel",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowwidth=3,
    arrowcolor="orange",
)

fig.update_layout(
    title=f"Cantilever Beam: {tip_force} N Tip Load → {abs(tip_deflection) * 1000:.1f} mm Deflection",
    xaxis=dict(title="x [mm]"),
    yaxis=dict(title="y [mm]", scaleanchor="x", scaleratio=1),
    height=400,
    showlegend=True,
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
    margin=dict(t=80, b=50, l=50, r=50),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Tip deflection: -74.63 mm


## Varying load animation

We can visualize how the beam responds to increasing load by solving for a range of tip forces. Using `jax.lax.scan`, we solve sequentially while using each solution as the initial guess for the next (warm-starting), which helps convergence for larger deflections:

In [23]:
# Solve for a range of tip forces using scan with warm-starting.
n_frames = 15
tip_forces = jnp.linspace(0.0, tip_force, n_frames)
deltas = tip_forces * beam_length**3 / (3 * EI)


def solve_for_delta(
    current_vals: jaxls.VarValues, delta: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
    """Solve beam problem for a given tip displacement, using current_vals as warm start."""
    costs_f: list[jaxls.Cost] = [
        bending_cost(
            NodeVar(id=bending_a),
            NodeVar(id=bending_b),
            NodeVar(id=bending_c),
            EI,
            segment_length,
        ),
        stretching_cost(
            NodeVar(id=stretch_a),
            NodeVar(id=stretch_b),
            segment_length,
            EA,
            segment_length,
        ),
        fixed_position_constraint(NodeVar(id=0), initial_positions[0]),
        fixed_angle_constraint(NodeVar(id=0), NodeVar(id=1), jnp.array([1.0, 0.0])),
        tip_displacement_constraint(NodeVar(id=num_nodes - 1), -delta),
    ]
    problem_f = jaxls.LeastSquaresProblem(costs_f, [node_vars]).analyze()
    solution = problem_f.solve(current_vals, verbose=False)
    return solution, solution[node_vars]


# Solve all frames sequentially with warm-starting
_, frame_solutions = jax.lax.scan(solve_for_delta, initial_vals, deltas)
print(f"Generated {len(frame_solutions)} frames with shape {frame_solutions.shape}")

[1mINFO    [0m | Building optimization problem with 42 terms and 21 variables: 39 costs, 3 eq_zero, 0 leq_zero, 0 geq_zero
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 2 variables each: augmented_fixed_angle_constraint
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_tip_displacement_constraint
[1mINFO    [0m | Vectorizing group with 19 costs, 3 variables each: bending_cost
[1mINFO    [0m | Vectorizing group with 20 costs, 2 variables each: stretching_cost
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_fixed_position_constraint
Generated 15 frames with shape (15, 21, 2)


In [24]:
# Build animation frames (frame_solutions is now shape (n_frames, num_nodes, 2)).
frames = []
for i in range(n_frames):
    pos = frame_solutions[i]
    force_val = float(tip_forces[i])
    deflection = float(pos[-1, 1]) * 1000
    frames.append(
        go.Frame(
            data=[
                go.Scatter(
                    x=[float(p) * 1000 for p in pos[:, 0]],
                    y=[float(p) * 1000 for p in pos[:, 1]],
                    mode="lines+markers",
                    line=dict(color="steelblue", width=4),
                    marker=dict(size=6, color="steelblue"),
                ),
                go.Scatter(
                    x=[0],
                    y=[0],
                    mode="markers",
                    marker=dict(size=15, color="crimson", symbol="square"),
                ),
            ],
            name=str(i),
            layout=go.Layout(
                title=f"Tip Load: {force_val:.1f} N → Deflection: {abs(deflection):.1f} mm"
            ),
        )
    )

# Initial frame.
init_pos = frame_solutions[0]
fig_anim = go.Figure(
    data=[
        go.Scatter(
            x=[float(p) * 1000 for p in init_pos[:, 0]],
            y=[float(p) * 1000 for p in init_pos[:, 1]],
            mode="lines+markers",
            line=dict(color="steelblue", width=4),
            marker=dict(size=6, color="steelblue"),
            name="Beam",
        ),
        go.Scatter(
            x=[0],
            y=[0],
            mode="markers",
            marker=dict(size=15, color="crimson", symbol="square"),
            name="Fixed end",
        ),
    ],
    frames=frames,
    layout=go.Layout(
        title="Tip Load: 0.0 N → Deflection: 0.0 mm",
        xaxis=dict(title="x [mm]", range=[-50, 1050]),
        yaxis=dict(title="y [mm]", range=[-100, 20], scaleanchor="x", scaleratio=1),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=1.15,
                x=0.5,
                xanchor="center",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=150, redraw=True),
                                fromcurrent=True,
                                transition=dict(duration=50),
                            ),
                        ],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[
                            [None],
                            dict(
                                frame=dict(duration=0, redraw=False), mode="immediate"
                            ),
                        ],
                    ),
                ],
            )
        ],
        sliders=[
            dict(
                active=0,
                yanchor="top",
                xanchor="left",
                currentvalue=dict(
                    prefix="Force: ", suffix=" N", visible=True, xanchor="center"
                ),
                pad=dict(b=10, t=50),
                steps=[
                    dict(
                        args=[
                            [str(i)],
                            dict(
                                frame=dict(duration=0, redraw=True),
                                mode="immediate",
                                transition=dict(duration=0),
                            ),
                        ],
                        label=f"{float(tip_forces[i]):.1f}",
                        method="animate",
                    )
                    for i in range(n_frames)
                ],
                x=0.1,
                y=0,
                len=0.8,
            )
        ],
        height=400,
        showlegend=False,
        margin=dict(t=80, b=80),
    ),
)
HTML(fig_anim.to_html(full_html=False, include_plotlyjs="cdn", auto_play=False))

The animation shows how the beam deflects as the tip load increases from 0 to 3 N. The deflection is proportional to the applied force, consistent with the linear Euler-Bernoulli beam theory for small deflections.

Key observations:
- Fixed end (red square): Position and slope are constrained
- Smooth deflection curve: The bending stiffness EI resists curvature
- Linear response: Tip deflection $\delta = \frac{FL^3}{3EI}$ is proportional to load

For more details on the solver, see {class}`jaxls.LeastSquaresProblem`.