# SE(3) pose graph

In this notebook, we solve a 3D pose graph optimization problem: estimating robot poses
from noisy relative measurements in full 6-DOF.

Extending pose graph optimization to 3D enables SLAM for drones, underwater vehicles, and handheld mapping devices. The core idea remains the same as 2D: relative motion measurements between poses accumulate error, but loop closures that recognize revisited locations provide global constraints.

This example uses the sphere2500 dataset, where poses are arranged on a sphere surface.

Features used:
- {class}`~jaxls.SE3Var` for SE(3) robot poses
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` with batched edge construction
- g2o dataset with ~2500 poses and loop closures

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [None]:
import pathlib

import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np

## Loading the dataset

Parse the g2o file to extract poses and edges. The sphere2500 dataset has poses arranged on a sphere surface with loop closures connecting nearby poses:

In [None]:
@jax.jit
def parse_precision_matrix(components: jax.Array) -> jax.Array:
    """Convert upper triangular components to sqrt precision matrix.

    Args:
        components: Upper triangular components of the precision matrix (21,)

    Returns:
        Upper Cholesky factor of the precision matrix (6, 6)
    """
    precision = jnp.zeros((6, 6))
    triu_indices = jnp.triu_indices(6)
    precision = precision.at[triu_indices].set(components)
    precision = precision + precision.T - jnp.diag(jnp.diag(precision))
    return jnp.linalg.cholesky(precision).T


def parse_g2o_se3(path: pathlib.Path) -> dict:
    """Parse a 3D g2o file (VERTEX_SE3:QUAT and EDGE_SE3:QUAT format).

    Args:
        path: Path to the g2o file

    Returns:
        Dictionary with 'poses' (N, 7) array and 'edges' list of tuples
    """
    with open(path) as f:
        lines = f.readlines()

    poses = []  # (x, y, z, qx, qy, qz, qw)
    edges = []  # (i, j, x, y, z, qx, qy, qz, qw, precision_components)

    for line in lines:
        parts = line.strip().split()
        if not parts:
            continue

        if parts[0] == "VERTEX_SE3:QUAT":
            # Format: VERTEX_SE3:QUAT id x y z qx qy qz qw.
            _, idx, x, y, z, qx, qy, qz, qw = parts
            poses.append(
                (
                    float(x),
                    float(y),
                    float(z),
                    float(qx),
                    float(qy),
                    float(qz),
                    float(qw),
                )
            )

        elif parts[0] == "EDGE_SE3:QUAT":
            # Format: EDGE_SE3:QUAT i j x y z qx qy qz qw info_upper_tri(21)
            _, i, j = parts[:3]
            numerical = list(map(float, parts[3:]))
            x, y, z = numerical[0:3]
            qx, qy, qz, qw = numerical[3:7]
            precision_comps = np.array(numerical[7:])
            edges.append((int(i), int(j), x, y, z, qx, qy, qz, qw, precision_comps))

    return {"poses": np.array(poses), "edges": edges}

In [None]:
# Load the sphere2500 dataset.
g2o_path = pathlib.Path("./data/sphere2500.g2o")
data = parse_g2o_se3(g2o_path)

num_poses = len(data["poses"])
num_edges = len(data["edges"])

# Count odometry vs loop closure edges.
odometry_edges = [(i, j) for i, j, *_ in data["edges"] if j == i + 1]
loop_closure_edges = [(i, j) for i, j, *_ in data["edges"] if j != i + 1]

print(f"Poses: {num_poses}")
print(
    f"Edges: {num_edges} ({len(odometry_edges)} odometry, {len(loop_closure_edges)} loop closures)"
)

## Variables and costs

Use {class}`~jaxls.SE3Var` for poses on SE(3). Create batched costs for efficient optimization:

In [None]:
# Create batched pose variables.
pose_vars = jaxls.SE3Var(id=jnp.arange(num_poses))


@jaxls.Cost.factory
def between_cost(
    vals: jaxls.VarValues,
    var_a: jaxls.SE3Var,
    var_b: jaxls.SE3Var,
    measured: jaxlie.SE3,
    sqrt_precision: jax.Array,
) -> jax.Array:
    """Cost for relative pose measurement between two poses."""
    T_a = vals[var_a]
    T_b = vals[var_b]
    # Error: measured^{-1} @ (T_a^{-1} @ T_b)
    residual = (measured.inverse() @ (T_a.inverse() @ T_b)).log()
    return sqrt_precision @ residual


@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE3Var,
    target: jaxlie.SE3,
) -> jax.Array:
    """Anchor the first pose to prevent gauge freedom."""
    return (vals[var].inverse() @ target).log()

In [None]:
# Build edge arrays for batched cost construction.
edge_i = jnp.array([e[0] for e in data["edges"]])
edge_j = jnp.array([e[1] for e in data["edges"]])

# Measured relative poses (g2o uses xyzw quaternion order)
measured_poses = jaxlie.SE3.from_rotation_and_translation(
    rotation=jaxlie.SO3.from_quaternion_xyzw(
        jnp.array([[e[5], e[6], e[7], e[8]] for e in data["edges"]])
    ),
    translation=jnp.array([[e[2], e[3], e[4]] for e in data["edges"]]),
)

# Sqrt precision matrices.
precision_comps = jnp.array([e[9] for e in data["edges"]])
sqrt_precisions = jax.vmap(parse_precision_matrix)(precision_comps)

print(f"Batched edge arrays: {edge_i.shape[0]} edges")

## Solving

In [None]:
# Initial poses from g2o file (g2o uses xyzw quaternion order)
initial_poses = jaxlie.SE3.from_rotation_and_translation(
    rotation=jaxlie.SO3.from_quaternion_xyzw(
        jnp.array(data["poses"][:, 3:7])  # qx, qy, qz, qw
    ),
    translation=jnp.array(data["poses"][:, 0:3]),  # x, y, z
)

# Create costs using batched construction.
costs: list[jaxls.Cost] = [
    # All between factors in one batched call.
    between_cost(
        jaxls.SE3Var(id=edge_i),
        jaxls.SE3Var(id=edge_j),
        measured_poses,
        sqrt_precisions,
    ),
    # Anchor first pose.
    anchor_cost(
        jaxls.SE3Var(id=0),
        jaxlie.SE3(wxyz_xyz=initial_poses.wxyz_xyz[0]),
    ),
]

initial_vals = jaxls.VarValues.make([pose_vars.with_value(initial_poses)])

# Build and analyze problem.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars]).analyze()

In [None]:
# Solve with Gauss-Newton.
solution = problem.solve(initial_vals, trust_region=None)

## Visualization

Compare the initial trajectory with the optimized result. The poses form a sphere surface with loop closures connecting nearby points:

In [None]:
initial_xyz = np.array(initial_vals[pose_vars].translation())
optimized_xyz = np.array(solution[pose_vars].translation())

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def get_trajectory_trace(
    positions: np.ndarray,
    name: str,
    color: str,
) -> go.Scatter3d:
    """Create 3D trajectory trace.

    Args:
        positions: Position array (N, 3)
        name: Trace name for legend
        color: Marker color

    Returns:
        Plotly Scatter3d trace
    """
    return go.Scatter3d(
        x=positions[:, 0],
        y=positions[:, 1],
        z=positions[:, 2],
        mode="markers",
        marker=dict(color=color, size=2),
        name=name,
        hovertemplate="(%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>",
    )


def get_loop_closure_traces(
    positions: np.ndarray,
    edges: list[tuple[int, int]],
    color: str,
    max_edges: int = 300,
) -> list[go.Scatter3d]:
    """Create loop closure edge traces (subsample for performance).

    Args:
        positions: Position array (N, 3)
        edges: List of (i, j) index pairs for loop closure edges
        color: Line color
        max_edges: Maximum number of edges to display

    Returns:
        List containing a single Plotly Scatter3d trace for all edges
    """
    # Subsample if too many.
    step = max(1, len(edges) // max_edges)
    sampled = edges[::step]

    x_coords = []
    y_coords = []
    z_coords = []
    for i, j in sampled:
        x_coords.extend([positions[i, 0], positions[j, 0], None])
        y_coords.extend([positions[i, 1], positions[j, 1], None])
        z_coords.extend([positions[i, 2], positions[j, 2], None])

    return [
        go.Scatter3d(
            x=x_coords,
            y=y_coords,
            z=z_coords,
            mode="lines",
            line=dict(color=color, width=1),
            opacity=0.3,
            name="Loop closures",
            hoverinfo="skip",
        )
    ]

In [None]:
fig = make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
    subplot_titles=("Initial", "Optimized"),
    horizontal_spacing=0.02,
)

# Initial trajectory with loop closures.
for trace in get_loop_closure_traces(initial_xyz, loop_closure_edges, "tomato"):
    trace.showlegend = False
    fig.add_trace(trace, row=1, col=1)
trace_init = get_trajectory_trace(initial_xyz, "Trajectory", "steelblue")
trace_init.showlegend = False
fig.add_trace(trace_init, row=1, col=1)

# Optimized trajectory.
for trace in get_loop_closure_traces(optimized_xyz, loop_closure_edges, "tomato"):
    trace.showlegend = False
    fig.add_trace(trace, row=1, col=2)
trace_opt = get_trajectory_trace(optimized_xyz, "Trajectory", "forestgreen")
trace_opt.showlegend = False
fig.add_trace(trace_opt, row=1, col=2)

# Compute bounds for consistent views.
all_positions = np.concatenate([initial_xyz, optimized_xyz])
center = all_positions.mean(axis=0)
max_range = np.abs(all_positions - center).max() * 1.1

camera = dict(
    eye=dict(x=1.5, y=1.5, z=1.0),
    up=dict(x=0, y=0, z=1),
)

scene_config = dict(
    aspectmode="cube",
    xaxis=dict(
        range=[center[0] - max_range, center[0] + max_range], showbackground=False
    ),
    yaxis=dict(
        range=[center[1] - max_range, center[1] + max_range], showbackground=False
    ),
    zaxis=dict(
        range=[center[2] - max_range, center[2] + max_range], showbackground=False
    ),
    camera=camera,
)

fig.update_layout(
    scene=scene_config,
    scene2=scene_config,
    height=500,
    showlegend=False,
    margin=dict(t=40, b=20, l=20, r=20),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization refines the noisy initial pose estimates using loop closure constraints. The resulting trajectory forms a clean sphere surface, demonstrating SE(3) pose graph optimization in 3D.

For more on Lie group variables, see {class}`jaxls.SE3Var` and {class}`jaxls.SE2Var`.