# SMPL-H shape fitting

Optimizing SMPL-H body shape parameters to achieve a target height.

**Inputs:** Target height, initial shape parameters (zeros)  
**Outputs:** Shape (beta) parameters that produce a body with the desired height

Features used:
- {class}`~jaxls.Var` for shape parameters
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` with constraint for height
- Augmented Lagrangian solver for constrained optimization

Note: the SMPL-H implementation here is minimal. For full-featured SMPL models in `jaxls`, see [egoallo](https://github.com/brentyi/egoallo) or [VideoMimic](https://github.com/hongsukchoi/videomimic).

In [1]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import io
import pathlib
import urllib.request
import zipfile

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
import jaxls
import numpy as np
from jax import Array

## Download SMPL-H model

The SMPL-H model represents human body shape using a low-dimensional parameterization. Shape variations are controlled by beta parameters that deform a template mesh.

In [None]:
# Download SMPL-H model if not already present.
smplh_path = pathlib.Path("/tmp/SMPLH_NEUTRAL.npz")

if not smplh_path.exists():
    print("Downloading SMPL-H model...")
    url = "https://brentyi.github.io/viser-example-assets/SMPLH_NEUTRAL.zip"
    with urllib.request.urlopen(url) as response:
        zip_data = io.BytesIO(response.read())
    with zipfile.ZipFile(zip_data) as zf:
        zf.extractall("/tmp")
    print(f"Downloaded to {smplh_path}")
else:
    print(f"Using cached model at {smplh_path}")

## SMPL-H model implementation

A minimal implementation of the SMPL-H body model. Shape is controlled by beta parameters, which are PCA coefficients that linearly combine learned shape basis vectors to deform the template mesh.

In [4]:
@jdc.pytree_dataclass
class SmplhModel:
    """SMPL-H human body model."""

    faces: Array
    """Vertex indices for mesh faces, shape (faces, 3)."""
    v_template: Array
    """Template mesh vertices, shape (verts, 3)."""
    shapedirs: Array
    """Shape blend shape bases, shape (verts, 3, n_betas)."""

    @staticmethod
    def load(npz_path: pathlib.Path) -> "SmplhModel":
        """Load model from .npz file."""
        params = np.load(npz_path, allow_pickle=True)
        return SmplhModel(
            faces=jnp.array(params["f"].astype(np.int32)),
            v_template=jnp.array(params["v_template"].astype(np.float32)),
            shapedirs=jnp.array(params["shapedirs"].astype(np.float32)),
        )

    def get_vertices(self, betas: Array) -> Array:
        """Compute mesh vertices for given shape parameters."""
        num_betas = betas.shape[0]
        # Apply shape blend shapes: v = v_template + shapedirs @ betas.
        return self.v_template + jnp.einsum(
            "vxb,b->vx", self.shapedirs[:, :, :num_betas], betas
        )

    def get_height(self, betas: Array) -> Array:
        """Compute body height from min to max vertex z-coordinate."""
        verts = self.get_vertices(betas)
        # Height is the range of the y-coordinate (SMPL uses y-up).
        return jnp.max(verts[:, 1]) - jnp.min(verts[:, 1])

In [5]:
# Load the model.
model = SmplhModel.load(smplh_path)

# Check the template (zero-beta) height.
template_height = float(model.get_height(jnp.zeros(16)))
print(
    f"Template mesh: {model.v_template.shape[0]} vertices, {model.faces.shape[0]} faces"
)
print(f"Template height (beta=0): {template_height:.3f} m")

Template mesh: 6890 vertices, 13776 faces
Template height (beta=0): 1.717 m


## Problem setup

We optimize the first 10 beta parameters to achieve a target height of 2.0 meters (tall), while regularizing betas toward zero to maintain a natural body shape.

In [6]:
# Target height in meters.
TARGET_HEIGHT = 2.0
NUM_BETAS = 10


# Variable for shape parameters.
class BetaVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(NUM_BETAS)):
    """SMPL-H beta (shape) parameters."""


beta_var = BetaVar(id=0)


@jaxls.Cost.factory(kind="constraint_eq_zero")
def height_constraint(
    vals: jaxls.VarValues,
    var: BetaVar,
    model: SmplhModel,
    target_height: float,
) -> jax.Array:
    """Constrain body height to target value."""
    betas = vals[var]
    current_height = model.get_height(betas)
    # jaxls accepts scalar residuals.
    return current_height - target_height


@jaxls.Cost.factory
def beta_regularization(
    vals: jaxls.VarValues,
    var: BetaVar,
    weight: float,
) -> jax.Array:
    """Regularize betas toward zero for natural shapes."""
    return weight * vals[var]

## Solving

When constraints are present, jaxls automatically uses an Augmented Lagrangian method. The solver iteratively adjusts Lagrange multipliers and penalty parameters to satisfy the constraint.

In [None]:
# Build the optimization problem.
costs: list[jaxls.Cost] = [
    height_constraint(beta_var, model, TARGET_HEIGHT),
    beta_regularization(beta_var, weight=0.5),
]

# Initial values: zeros.
initial_betas = jnp.zeros(NUM_BETAS)
initial_vals = jaxls.VarValues.make([beta_var.with_value(initial_betas)])

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [beta_var])

# Visualize the problem structure structure.
problem.show()

In [None]:
# Analyze the problem and print info.
problem = problem.analyze()

print(f"Initial height: {model.get_height(initial_betas):.3f} m")
print(f"Target height: {TARGET_HEIGHT:.3f} m")

In [8]:
# Solve. Augmented Lagrangian is used automatically for constrained problems.
solution = problem.solve(
    initial_vals,
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)

optimized_betas = solution[beta_var]
final_height = model.get_height(optimized_betas)

print(f"\nOptimized height: {float(final_height):.4f} m")
print(f"Height error: {abs(float(final_height) - TARGET_HEIGHT) * 100:.2f} cm")
print(f"Beta norm: {float(jnp.linalg.norm(optimized_betas)):.3f}")

[1mINFO    [0m | Augmented Lagrangian: initial snorm=2.8264e-01, csupn=2.8264e-01, max_rho=1.0000e+01, constraint_dim=1
[1mINFO    [0m |  step #0: cost=0.0000 lambd=0.0005
[1mINFO    [0m |      - beta_regularization(1): 0.00000 (avg 0.00000)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.79885 (avg 0.79885)
[1mINFO    [0m |      accepted=True ATb_norm=4.61e-01 cost_prev=0.7988 cost_new=0.5705
[1mINFO    [0m |  step #1: cost=0.1629 lambd=0.0003
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.40754 (avg 0.40754)
[1mINFO    [0m |  step #2: cost=0.1629 lambd=0.0005
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.40754 (avg 0.40754)
[1mINFO    [0m |  step #3: cost=0.1629 lambd=0.0010
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_co

## Visualization

Compare the template mesh (beta=0) with the optimized shape side by side.

In [9]:
import plotly.graph_objects as go
from IPython.display import HTML


def create_mesh_trace(
    vertices: np.ndarray,
    faces: np.ndarray,
    color: str,
    name: str,
    x_offset: float = 0.0,
    z_offset: float = 0.0,
) -> go.Mesh3d:
    """Create a 3D mesh trace for plotly."""
    return go.Mesh3d(
        x=vertices[:, 0] + x_offset,
        y=vertices[:, 2],  # Swap y/z for better viewing angle.
        z=vertices[:, 1] + z_offset,
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        color=color,
        opacity=1.0,
        name=name,
        flatshading=False,
        lighting=dict(
            ambient=0.4,
            diffuse=0.8,
            specular=0.3,
            roughness=0.5,
            fresnel=0.2,
        ),
        lightposition=dict(x=100, y=200, z=300),
        showlegend=True,
    )


def create_height_marker(
    vertices: np.ndarray,
    color: str,
    x_offset: float = 0.0,
    z_offset: float = 0.0,
) -> list[go.Scatter3d]:
    """Create height indicator lines and markers."""
    # Use y-coordinate for height (SMPL uses y-up).
    y_min, y_max = vertices[:, 1].min(), vertices[:, 1].max()
    # Place marker behind the body.
    y_pos = vertices[:, 2].min() - 0.08

    traces = []

    # Vertical line showing height.
    traces.append(
        go.Scatter3d(
            x=[x_offset, x_offset],
            y=[y_pos, y_pos],
            z=[y_min + z_offset, y_max + z_offset],
            mode="lines",
            line=dict(color=color, width=4),
            showlegend=False,
        )
    )

    # Top and bottom markers.
    traces.append(
        go.Scatter3d(
            x=[x_offset, x_offset],
            y=[y_pos, y_pos],
            z=[y_min + z_offset, y_max + z_offset],
            mode="markers",
            marker=dict(size=5, color=color),
            showlegend=False,
        )
    )

    return traces


def create_ground_plane(
    x_range: tuple[float, float],
    y_range: tuple[float, float],
    z_level: float = 0.0,
    color: str = "lightgray",
) -> go.Mesh3d:
    """Create a ground plane mesh."""
    x0, x1 = x_range
    y0, y1 = y_range
    return go.Mesh3d(
        x=[x0, x1, x1, x0],
        y=[y0, y0, y1, y1],
        z=[z_level, z_level, z_level, z_level],
        i=[0, 0],
        j=[1, 2],
        k=[2, 3],
        color=color,
        opacity=0.5,
        showlegend=False,
        hoverinfo="skip",
    )

In [None]:
# Get vertices for both configurations.
initial_verts = np.array(model.get_vertices(initial_betas))
optimized_verts = np.array(model.get_vertices(optimized_betas))
faces = np.array(model.faces)

# Compute heights for legend labels.
initial_height = float(model.get_height(initial_betas))
optimized_height = float(final_height)

# Compute z offsets to align feet at ground level (z=0).
initial_z_offset = -initial_verts[:, 1].min()
optimized_z_offset = -optimized_verts[:, 1].min()

# Offset for side-by-side placement.
x_offset = 1.5

fig = go.Figure()

# Add ground plane.
fig.add_trace(
    create_ground_plane(
        x_range=(-1.5, 1.5),
        y_range=(-0.5, 0.5),
        z_level=0.0,
    )
)

# Add initial mesh (left) with height marker.
fig.add_trace(
    create_mesh_trace(
        initial_verts,
        faces,
        "steelblue",
        f"Initial (height={initial_height:.2f}m)",
        x_offset=x_offset / 2,
        z_offset=initial_z_offset,
    )
)
for trace in create_height_marker(
    initial_verts, "tomato", x_offset=x_offset / 2, z_offset=initial_z_offset
):
    fig.add_trace(trace)

# Add optimized mesh (right) with height marker.
fig.add_trace(
    create_mesh_trace(
        optimized_verts,
        faces,
        "forestgreen",
        f"Optimized (height={optimized_height:.2f}m)",
        x_offset=-x_offset / 2,
        z_offset=optimized_z_offset,
    )
)
for trace in create_height_marker(
    optimized_verts, "tomato", x_offset=-x_offset / 2, z_offset=optimized_z_offset
):
    fig.add_trace(trace)

# Camera and layout settings.
fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
        aspectmode="data",
        camera=dict(
            eye=dict(x=2.0, y=0.5, z=0.3),
            center=dict(x=0, y=0, z=0.4),
        ),
    ),
    height=500,
    margin=dict(t=20, b=20, l=20, r=20),
    legend=dict(
        yanchor="top",
        y=0.95,
        xanchor="left",
        x=0.02,
        bgcolor="rgba(255,255,255,0.8)",
    ),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization finds shape parameters that satisfy the height constraint while keeping the body shape natural (small beta norm). The regularization prevents extreme deformations that could produce unrealistic body shapes.

For more on constrained optimization, see {doc}`/guide/advanced/constraints`.