# MPC Spacecraft Rendezvous in LEO

This notebook demonstrates **Model Predictive Control (MPC)** for spacecraft
rendezvous in low Earth orbit. A deputy satellite approaches a chief at the
RTN (Radial-Transverse-Normal) frame origin from a fixed initial offset.

Key features:

- **Prediction model**: Linear Hill-Clohessy-Wiltshire (HCW) equations,
  discretised via the closed-form state transition matrix
- **Truth propagation**: Nonlinear two-body dynamics via astrojax
  `create_orbit_dynamics` with RK4 integration
- **Cost function**: L1 fuel cost (promotes sparse impulsive burns) +
  quadratic terminal penalty for soft docking constraint
- **Receding horizon**: At each step, solve the MPC, apply only the first
  control input, re-observe, and repeat

In [None]:
import cvxpy as cp
import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from astrojax import (
    R_EARTH,
    Epoch,
    mean_motion,
    set_dtype,
    state_koe_to_eci,
    zero_eop,
)
from astrojax.integrators import rk4_step
from astrojax.orbit_dynamics import ForceModelConfig, create_orbit_dynamics
from astrojax.relative_motion import (
    hcw_derivative,
    hcw_stm,
    rotation_rtn_to_eci,
    state_eci_to_rtn,
    state_rtn_to_eci,
)

set_dtype(jnp.float64)


## Configuration

All tunable parameters are collected here. The chief orbit is a 500 km
circular LEO with 45° inclination. The deputy starts 1 km along-track
behind the chief with small radial and cross-track offsets.

In [None]:
# --- Reproducibility ---
SEED = None                      # Master RNG seed (set None for random each run)
rng = np.random.default_rng(SEED)

# --- Chief orbit ---
SMA = R_EARTH + 500e3          # Semi-major axis [m]
ECC = 0.001                     # Eccentricity (near-circular)
INC = 45.0                      # Inclination [deg]
RAAN = 30.0                     # RAAN [deg]
AOP = 60.0                      # Argument of perigee [deg]
MA = 0.0                        # Mean anomaly [deg]

# --- Deputy initial offset in RTN [m, m/s] ---
# Sampled uniformly from the box [lo, hi] for each RTN axis
RTN_BOX = {
    "R": (-500.0, 500.0),       # Radial [m]
    "T": (-2000.0, -500.0),     # Along-track [m]
    "N": (-300.0, 300.0),       # Cross-track [m]
}
rtn_lo = np.array([RTN_BOX["R"][0], RTN_BOX["T"][0], RTN_BOX["N"][0]])
rtn_hi = np.array([RTN_BOX["R"][1], RTN_BOX["T"][1], RTN_BOX["N"][1]])
RTN_OFFSET = np.concatenate([rng.uniform(rtn_lo, rtn_hi), np.zeros(3)])

# --- Simulation ---
DT = 60.0                       # Integration timestep [s]
N_ORBITS = 2.0                  # Maximum simulation duration [orbits]
HORIZON_ORBITS = 0.25           # MPC prediction horizon [orbits]

# --- MPC ---
MAX_DV = 0.5                    # Maximum delta-v per axis per step [m/s]
TERMINAL_WEIGHT = 1e4           # Quadratic terminal penalty
FUEL_WEIGHT = 1.0               # L1 fuel penalty

# --- Derived ---
N_MEAN = float(mean_motion(SMA))
T_ORBIT = 2 * np.pi / N_MEAN   # Orbital period [s]
N_STEPS = int(np.round(N_ORBITS * T_ORBIT / DT))
HORIZON = int(np.round(HORIZON_ORBITS * T_ORBIT / DT))

print(f"Semi-major axis:  {SMA/1e3:.1f} km")
print(f"Mean motion:      {N_MEAN*1e3:.4f} mrad/s")
print(f"Orbital period:   {T_ORBIT/60:.1f} min")
print(f"Simulation:       {N_ORBITS} orbits = {N_STEPS} steps ({N_STEPS*DT/60:.0f} min)")
print(f"MPC horizon:      {HORIZON_ORBITS} orbits = {HORIZON} steps ({HORIZON*DT/60:.1f} min)")
print(f"RTN offset:       {RTN_OFFSET[:3]} m")

## Dynamics Setup

We build the initial ECI states for chief and deputy using astrojax
coordinate transformations, then construct the truth dynamics using
`create_orbit_dynamics` with a two-body force model. RK4 integration
with substeps provides accurate nonlinear propagation.

In [None]:
# Chief OE -> ECI
chief_oe = jnp.array([SMA, ECC, INC, RAAN, AOP, MA])
chief_eci = state_koe_to_eci(chief_oe, use_degrees=True)

# Deputy: chief + RTN offset -> ECI
rtn_offset = jnp.array(RTN_OFFSET)
deputy_eci = state_rtn_to_eci(chief_eci, rtn_offset)

# Verify the RTN state
true_rtn_0 = state_eci_to_rtn(chief_eci, deputy_eci)
print(f"Initial RTN position: {true_rtn_0[:3]} m")
print(f"Initial RTN velocity: {true_rtn_0[3:]} m/s")
print(f"Initial range:        {np.linalg.norm(np.array(true_rtn_0[:3])):.1f} m")

# Build truth dynamics (two-body)
epoch_0 = Epoch(2024, 6, 15, 12, 0, 0)
dynamics = create_orbit_dynamics(zero_eop(), epoch_0, ForceModelConfig.two_body())

# JIT-compile a single propagation step (10 RK4 substeps per DT for accuracy)
n_substeps = 10
sub_dt = DT / n_substeps

@jax.jit
def propagate_one_step(t, state):
    """Propagate a single ECI state forward by DT using RK4 substeps."""
    def body(carry, _):
        t_i, s = carry
        result = rk4_step(dynamics, t_i, s, sub_dt)
        return (t_i + sub_dt, result.state), None
    (t_out, state_out), _ = jax.lax.scan(body, (t, state), None, length=n_substeps)
    return t_out, state_out

# Verify a single zero-control step
t_test = 0.0
chief_test, deputy_test = chief_eci, deputy_eci
_, chief_test = propagate_one_step(t_test, chief_test)
_, deputy_test = propagate_one_step(t_test, deputy_test)
rtn_after = state_eci_to_rtn(chief_test, deputy_test)
print("\nAfter 1 zero-control step:")
print(f"  Range: {float(jnp.linalg.norm(rtn_after[:3])):.1f} m")


## HCW State Transition Matrix

The MPC prediction model uses the **closed-form** Hill-Clohessy-Wiltshire
state transition matrix (Vallado, Sec. 6.8). For a circular chief orbit
with mean motion $n$ and timestep $\Delta t$, the 6x6 STM $\Phi(\Delta t)$
maps the RTN state forward in time analytically.

An impulsive $\Delta\mathbf{v}$ applied at time $\tau_k$ is modelled by
adding it to the velocity components before propagation:

$$\mathbf{x}(t) = \Phi(t, t_0)\,\mathbf{x}_0 + \sum_i \Phi(t, \tau_i)\,B\,\Delta\mathbf{v}_i$$

where $B$ maps a 3-vector $\Delta\mathbf{v}$ into the 6D state (velocity rows only).

In [None]:
# Built-in hcw_stm signature: hcw_stm(t, n) -> 6x6 JAX array
A_d = np.array(hcw_stm(DT, N_MEAN))

# Input matrix: DeltaV in RTN maps to velocity components
B_d = np.zeros((6, 3))
B_d[3:6, :] = np.eye(3)

# print("HCW STM (Phi):")
# print(np.array2string(A_d, precision=4, suppress_small=True))

# --- Verification: compare STM vs RK4-integrated HCW ---
x0_test = jnp.array([200.0, -1000.0, 100.0, 0.0, 0.0, 0.0])

# STM prediction
x_stm = A_d @ np.array(x0_test)

# RK4 integration (10 substeps for accuracy)
sub_dt_test = DT / 10.0
x_rk4 = x0_test
for _ in range(10):
    x_rk4 = rk4_step(lambda _t, s: hcw_derivative(s, N_MEAN), 0.0, x_rk4, sub_dt_test).state

diff = np.abs(np.array(x_stm) - np.array(x_rk4))
print("\nSTM vs RK4 difference:")
print(f"  Position: {diff[:3]} m")
print(f"  Velocity: {diff[3:]} m/s")
print(f"  Max pos error: {diff[:3].max():.2e} m")

## MPC Formulation

The MPC solves a convex optimisation problem at each step:

$$\min_{\mathbf{u}} \; w_f \sum_{k=0}^{H-1} \|\mathbf{u}_k\|_1 + w_t \|\mathbf{x}_H\|_2^2$$

subject to:
- HCW dynamics: $\mathbf{x}_{k+1} = \Phi\,(\mathbf{x}_k + B\,\mathbf{u}_k)$
- Box constraints: $\|\mathbf{u}_k\|_\infty \leq \Delta v_{\max}$

The impulse $\mathbf{u}_k$ is added to the state at time $t_k$ and then
the combined state is propagated by the STM $\Phi$, matching
$\Phi(t_{k+1}, t_k)(x_k + B\,\Delta v_k)$.

The **L1 fuel cost** promotes sparse control (fewer, larger burns), which
is physically realistic for impulsive thrusters. The **quadratic terminal
penalty** acts as a soft docking constraint — driving the terminal state
toward zero without hard equality constraints that could cause infeasibility.

In [None]:
def solve_mpc(
    x0: np.ndarray,
    A_d: np.ndarray,
    B_d: np.ndarray,
    horizon: int,
    max_dv: float,
    fuel_weight: float,
    terminal_weight: float,
) -> np.ndarray:
    """Solve the MPC optimisation for one step.

    Args:
        x0: Current 6D RTN state [m, m/s].
        A_d: 6x6 HCW state transition matrix.
        B_d: 6x3 control input matrix.
        horizon: Prediction horizon (number of steps).
        max_dv: Maximum delta-v per axis [m/s].
        fuel_weight: Weight on L1 fuel cost.
        terminal_weight: Weight on quadratic terminal penalty.

    Returns:
        First optimal control input [dv_R, dv_T, dv_N] in m/s.
    """
    H = horizon
    x = cp.Variable((H + 1, 6))
    u = cp.Variable((H, 3))

    cost = 0.0
    constraints = [x[0] == x0]

    for k in range(H):
        # Dynamics: apply impulse then propagate via STM
        constraints.append(x[k + 1] == A_d @ (x[k] + B_d @ u[k]))
        # Box constraint on control
        constraints.append(cp.norm_inf(u[k]) <= max_dv)
        # L1 fuel cost
        cost += fuel_weight * cp.norm1(u[k])

    # Quadratic terminal penalty
    cost += terminal_weight * cp.sum_squares(x[H])

    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve(solver=cp.CLARABEL, verbose=False)

    if prob.status not in ("optimal", "optimal_inaccurate"):
        print(f"  Warning: solver status = {prob.status}")
        return np.zeros(3)

    return np.array(u.value[0])


# Test the MPC with the initial state
x0_np = np.array(true_rtn_0)
u_test = solve_mpc(x0_np, A_d, B_d, HORIZON, MAX_DV, FUEL_WEIGHT, TERMINAL_WEIGHT)
print(f"Test MPC first control: {u_test} m/s")
print(f"Test MPC |u|_1: {np.sum(np.abs(u_test)):.4f} m/s")

## MPC Simulation Loop

At each timestep we:
1. Compute the relative RTN state from ECI positions
2. Solve the MPC to get the optimal control
3. Apply the first delta-V impulse (RTN -> ECI) to the deputy
4. Propagate both chief and deputy with nonlinear two-body dynamics
5. Record the trajectory and control history

The MPC uses the linear HCW model for prediction, but the actual
propagation is nonlinear two-body dynamics — a model mismatch that
receding horizon control naturally handles through re-planning.

In [None]:
# Storage
rtn_history = np.zeros((N_STEPS + 1, 6))
dv_history = np.zeros((N_STEPS, 3))
rtn_history[0] = np.array(true_rtn_0)

# Reset ECI states for simulation
chief_sim = chief_eci
deputy_sim = deputy_eci
t = 0.0

for k in range(N_STEPS):
    # Current true RTN state
    x_rtn = np.array(state_eci_to_rtn(chief_sim, deputy_sim))

    # Solve MPC
    u_opt = solve_mpc(x_rtn, A_d, B_d, HORIZON, MAX_DV, FUEL_WEIGHT, TERMINAL_WEIGHT)

    # Apply delta-V in RTN -> convert to ECI
    R = rotation_rtn_to_eci(chief_sim)
    dv_eci = R @ jnp.array(u_opt)
    deputy_sim = deputy_sim.at[3:6].add(dv_eci)

    # Propagate both from the same time
    t_k = t
    t, chief_sim = propagate_one_step(t_k, chief_sim)
    _, deputy_sim = propagate_one_step(t_k, deputy_sim)

    rtn_history[k + 1] = np.array(state_eci_to_rtn(chief_sim, deputy_sim))
    dv_history[k] = u_opt

    if (k + 1) % 20 == 0 or k == 0:
        rng_val = np.linalg.norm(rtn_history[k + 1, :3])
        dv_k = np.sum(np.abs(u_opt))
        print(f"Step {k+1:4d}/{N_STEPS}: range = {rng_val:8.1f} m, |dv| = {dv_k:.4f} m/s")

# Summary
total_dv_l1 = np.sum(np.abs(dv_history))
total_dv_l2 = np.sum(np.sqrt(np.sum(dv_history**2, axis=1)))
final_range = np.linalg.norm(rtn_history[-1, :3])
final_vel = np.linalg.norm(rtn_history[-1, 3:])

print("\n--- Summary ---")
print(f"Total delta-v (L1): {total_dv_l1:.4f} m/s")
print(f"Total delta-v (L2): {total_dv_l2:.4f} m/s")
print(f"Final range:        {final_range:.2f} m")
print(f"Final velocity:     {final_vel:.4f} m/s")

## 3D RTN Trajectory

In [None]:
T = rtn_history[:, 1]  # Along-track
R = rtn_history[:, 0]  # Radial
N = rtn_history[:, 2]  # Cross-track

fig = go.Figure()

# Trajectory
fig.add_trace(go.Scatter3d(
    x=T, y=N, z=R, mode="lines",
    line=dict(color="royalblue", width=3),
    name="Deputy trajectory",
))

# Start marker
fig.add_trace(go.Scatter3d(
    x=[T[0]], y=[N[0]], z=[R[0]], mode="markers",
    marker=dict(color="green", size=6, symbol="circle"),
    name="Start",
))

# Chief at origin (red diamond)
fig.add_trace(go.Scatter3d(
    x=[0], y=[0], z=[0], mode="markers",
    marker=dict(color="red", size=8, symbol="diamond"),
    name="Chief",
))

fig.update_layout(
    title="MPC Rendezvous Trajectory (RTN)",
    scene=dict(
        xaxis_title="Along-track T [m]",
        yaxis_title="Cross-track N [m]",
        zaxis_title="Radial R [m]",
        aspectmode="data",
    ),
    width=750, height=650,
    legend=dict(font=dict(size=11)),
)
fig.show()

In [None]:
T = rtn_history[:, 1]
R = rtn_history[:, 0]
N = rtn_history[:, 2]

fig = go.Figure()

# Trajectory
fig.add_trace(go.Scatter3d(
    x=T, y=N, z=R, mode="lines",
    line=dict(color="royalblue", width=3),
    name="Deputy trajectory",
))

# Start marker
fig.add_trace(go.Scatter3d(
    x=[T[0]], y=[N[0]], z=[R[0]], mode="markers",
    marker=dict(color="green", size=6, symbol="circle"),
    name="Start",
))

# Chief at origin (red diamond)
fig.add_trace(go.Scatter3d(
    x=[0], y=[0], z=[0], mode="markers",
    marker=dict(color="red", size=8, symbol="diamond"),
    name="Chief",
))

# Thrust lines — magnitude-scaled, coloured by dominant RTN component
dv_norms = np.linalg.norm(dv_history, axis=1)
active_idx = np.where(dv_norms > 1e-6)[0]
dv_active = dv_history[active_idx]
norms_active = dv_norms[active_idx]

pos_extent = max(np.ptp(T), np.ptp(R), np.ptp(N), 1.0)
max_line = 0.12 * pos_extent
scale = max_line / norms_active.max()

# Dominant axis: 0=R, 1=T, 2=N
dominant = np.argmax(np.abs(dv_active), axis=1)
color_map = {0: "crimson", 1: "darkorange", 2: "mediumseagreen"}
label_map = {0: "dv (R-dom)", 1: "dv (T-dom)", 2: "dv (N-dom)"}

# Build one trace per dominant-component group (for legend)
for axis_idx in range(3):
    mask = dominant == axis_idx
    if not mask.any():
        continue
    xs, ys, zs = [], [], []
    for i in np.where(mask)[0]:
        k = active_idx[i]
        dv = dv_active[i] * scale  # [R, T, N]
        xs += [T[k], T[k] + dv[1], None]
        ys += [N[k], N[k] + dv[2], None]
        zs += [R[k], R[k] + dv[0], None]
    fig.add_trace(go.Scatter3d(
        x=xs, y=ys, z=zs, mode="lines",
        line=dict(color=color_map[axis_idx], width=4),
        name=label_map[axis_idx],
    ))

fig.update_layout(
    title="MPC Rendezvous Trajectory with Thrust Vectors",
    scene=dict(
        xaxis_title="Along-track T [m]",
        yaxis_title="Cross-track N [m]",
        zaxis_title="Radial R [m]",
        aspectmode="data",
    ),
    width=750, height=650,
    legend=dict(font=dict(size=11)),
)
fig.show()

## Time Series

In [None]:
t_min = np.arange(N_STEPS + 1) * DT / 60.0
t_dv = np.arange(N_STEPS) * DT / 60.0

fig = make_subplots(
    rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.04,
    subplot_titles=("Approach Profile", "RTN Position", "RTN Velocity", "Delta-v"),
)

# --- Range ---
rng_vals = np.linalg.norm(rtn_history[:, :3], axis=1)
fig.add_trace(go.Scatter(x=t_min, y=rng_vals, mode="lines", name="Range",
                         line=dict(width=1.5)), row=1, col=1)

# --- RTN position ---
for i, lbl in enumerate(["R", "T", "N"]):
    fig.add_trace(go.Scatter(x=t_min, y=rtn_history[:, i], mode="lines",
                             name=lbl, line=dict(width=1.5)), row=2, col=1)

# --- RTN velocity ---
for i, lbl in enumerate(["dR/dt", "dT/dt", "dN/dt"]):
    fig.add_trace(go.Scatter(x=t_min, y=rtn_history[:, 3 + i], mode="lines",
                             name=lbl, line=dict(width=1.5)), row=3, col=1)

# --- Delta-v stems (vertical lines from zero) ---
colors_dv = ["steelblue", "darkorange", "seagreen"]
labels_dv = ["dvR", "dvT", "dvN"]
for i in range(3):
    # Build stem segments: vertical line per timestep with None separators
    xs, ys = [], []
    for k in range(N_STEPS):
        if abs(dv_history[k, i]) > 1e-8:
            xs += [t_dv[k], t_dv[k], None]
            ys += [0.0, dv_history[k, i], None]
    fig.add_trace(go.Scatter(x=xs, y=ys, mode="lines",
                             name=labels_dv[i],
                             line=dict(color=colors_dv[i], width=1.5)), row=4, col=1)

# Axis labels
fig.update_yaxes(title_text="Range [m]", row=1, col=1)
fig.update_yaxes(title_text="Position [m]", row=2, col=1)
fig.update_yaxes(title_text="Velocity [m/s]", row=3, col=1)
fig.update_yaxes(title_text="\u0394v [m/s]", row=4, col=1)
fig.update_xaxes(title_text="Time [min]", row=4, col=1)

fig.update_layout(height=900, width=750, showlegend=True,
                  legend=dict(font=dict(size=10)))
fig.show()

In [None]:
# --- Summary statistics ---
initial_range = np.linalg.norm(rtn_history[0, :3])
final_range = np.linalg.norm(rtn_history[-1, :3])
total_dv_l1 = np.sum(np.abs(dv_history))
total_dv_l2 = np.sum(np.sqrt(np.sum(dv_history**2, axis=1)))
active_burns = np.sum(np.any(np.abs(dv_history) > 1e-6, axis=1))

print("=" * 45)
print("  MPC Rendezvous Summary")
print("=" * 45)
print(f"  Initial range:       {initial_range:10.2f} m")
print(f"  Final range:         {final_range:10.2f} m")
print(f"  Final velocity:      {np.linalg.norm(rtn_history[-1, 3:]):10.4f} m/s")
print(f"  Total delta-v (L1):  {total_dv_l1:10.4f} m/s")
print(f"  Total delta-v (L2):  {total_dv_l2:10.4f} m/s")
print(f"  Active burn steps:   {active_burns:10d} / {N_STEPS}")
print(f"  Sim duration:        {N_STEPS*DT/60:10.0f} min")
print("=" * 45)