Tests a set of bare-bones, as-simple-as-possible ODEs and a ODE next-step prediction scheme that will be used as a measurement map (map from state vector to predicted state) in an unscented Kalman filter.

The goals of this notebook are:

- Describe the time-step equations
- Implement prediction of a state vector at a next time-step given the current state vector


In [20]:
import numpy as np
from scipy.spatial.transform import Rotation
from scipy.integrate import solve_ivp
import torch
import plotly.graph_objects as go
from ipywidgets import interact, IntSlider
from IPython.display import display

## Dynamics Implementation


In a Kalman filter, the value of a state vector $\hat x_{k+1|k}$ at $k + 1$ is predicted given all information known by a system up to time-step $k$. `rk45_step` uses scipy's adaptive RK45 solver and takes $f$ (specific force) and $\omega$ (angular velocity) as if derived from such a $\hat x_{k + 1 | k}$.

An intermediate objective of a Kalman filter is to predict a set of next sensor readings $\hat z_{k + 1 | k}$ as a function of $\hat x_{k + 1 | k}$. The error between the true readings $z_{k + 1}$ and this prediction is:

$$
z_{k+1} - \hat z_{k + 1 | k}
$$

and is called the innovation, used in downstream calculations.


In [21]:
def quat_multiply(q1, q2):
    w1, x1, y1, z1 = q1
    w2, x2, y2, z2 = q2

    r1 = Rotation.from_quat([w1, x1, y1, z1], scalar_first=True)
    r2 = Rotation.from_quat([w2, x2, y2, z2], scalar_first=True)

    r_result = r1 * r2
    quat_xyzw = r_result.as_quat(scalar_first=True)

    return np.array([quat_xyzw[0], quat_xyzw[1], quat_xyzw[2], quat_xyzw[3]])


def rotate_vector_by_quat(v, q):
    w, x, y, z = q

    rot = Rotation.from_quat([w, x, y, z], scalar_first=True)

    return rot.apply(v)


def quat_conjugate(q):
    w, x, y, z = q
    return np.array([w, -x, -y, -z])


def dynamics(t, state, f, g, omega):
    p = state[0:3]
    v = state[3:6]
    q = state[6:10]

    f_rotated = rotate_vector_by_quat(f, q)

    omega_quat = np.array([0.0, omega[0], omega[1], omega[2]])

    p_dot = v
    v_dot = f_rotated + g
    q_dot = 0.5 * quat_multiply(q, omega_quat)

    return np.concatenate([p_dot, v_dot, q_dot])


def rk45_step(state, dt, f, g, omega):
    t0 = 0.0
    t_end = t0 + dt

    sol = solve_ivp(
        fun=lambda t, y: dynamics(t, y, f, g, omega),
        t_span=(t0, t_end),
        y0=state,
        method="RK45",
        rtol=1e-9,
        atol=1e-12,
    )

    state_new = sol.y[:, -1]

    q_new = state_new[6:10]
    q_new = q_new / np.linalg.norm(q_new)
    state_new[6:10] = q_new

    return state_new


## Mock Trajectory Data

These functions produce mock trajectory data. `get_world_frame_truth` returns world-frame forces and angular velocities that can be consistently transformed to body frame using any orientation.


In [22]:
def trajectory_r(t):
    x = 5.0 * torch.sin(0.5 * t)
    y = 3.0 * torch.cos(0.5 * t)
    z = 0.5 * t
    return torch.stack([x, y, z])


def trajectory_theta(t):
    roll = 0.1 * torch.sin(0.3 * t)
    pitch = 0.2 * torch.cos(0.4 * t)
    yaw = 0.5 * t
    return torch.stack([roll, pitch, yaw])


def axis_angle_to_quat(axis_angle):
    angle = torch.norm(axis_angle)
    if angle < 1e-8:
        return np.array([1.0, 0.0, 0.0, 0.0])
    axis = axis_angle / angle
    half_angle = angle / 2.0
    w = torch.cos(half_angle)
    xyz = axis * torch.sin(half_angle)
    quat = torch.cat([w.unsqueeze(0), xyz])
    return quat.detach().numpy()


def euler_to_quat(euler):
    roll, pitch, yaw = euler[0].item(), euler[1].item(), euler[2].item()
    rot = Rotation.from_euler("xyz", [roll, pitch, yaw])
    quat_xyzw = rot.as_quat()
    return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])


def get_world_frame_truth(t_val, g=np.array([0.0, 0.0, -9.81])):
    """
    Get world-frame truth data: acceleration (minus gravity) and angular velocity.
    These can then be transformed to body frame using any orientation.
    """
    dt_small = 1e-6

    t = torch.tensor(t_val, dtype=torch.float64)
    t_forward = torch.tensor(t_val + dt_small, dtype=torch.float64)
    t_backward = torch.tensor(t_val - dt_small, dtype=torch.float64)

    r = trajectory_r(t).detach().numpy()
    r_forward = trajectory_r(t_forward).detach().numpy()
    r_backward = trajectory_r(t_backward).detach().numpy()

    v = (r_forward - r_backward) / (2 * dt_small)
    a = (r_forward - 2 * r + r_backward) / (dt_small**2)

    theta = trajectory_theta(t)
    q_true = euler_to_quat(theta.detach())

    theta_forward = trajectory_theta(t_forward)
    q_forward = euler_to_quat(theta_forward.detach())

    q_dot_numerical = (q_forward - q_true) / dt_small

    q_conj = quat_conjugate(q_true)
    omega_quat = 2.0 * quat_multiply(q_conj, q_dot_numerical)
    omega_body = omega_quat[1:4]

    f_world = a - g

    return f_world, omega_body, r, v, a, q_true


In [23]:
t_start = 0.0
t_end = 30.0
dt = 0.0025
g = np.array([0.0, 0.0, -9.81])

time_steps = np.arange(t_start, t_end, dt)
n_steps = len(time_steps)

true_positions = np.zeros((n_steps, 3))
true_velocities = np.zeros((n_steps, 3))
true_accelerations = np.zeros((n_steps, 3))
true_quaternions = np.zeros((n_steps, 4))

pred_positions = np.zeros((n_steps, 3))
pred_velocities = np.zeros((n_steps, 3))
pred_accelerations = np.zeros((n_steps, 3))
pred_quaternions = np.zeros((n_steps, 4))

world_forces = np.zeros((n_steps, 3))
body_omegas = np.zeros((n_steps, 3))

print("Precomputing trajectory data...")
for i, t in enumerate(time_steps):
    f_world, omega_body, r, v, a, q = get_world_frame_truth(t, g)

    true_positions[i] = r
    true_velocities[i] = v
    true_accelerations[i] = a
    true_quaternions[i] = q

    world_forces[i] = f_world
    body_omegas[i] = omega_body

    if i == 0:
        pred_positions[i] = r
        pred_velocities[i] = v
        pred_accelerations[i] = a
        pred_quaternions[i] = q
    else:
        state = np.concatenate(
            [
                pred_positions[i - 1],
                pred_velocities[i - 1],
                pred_quaternions[i - 1],
            ]
        )

        f_body_pred = rotate_vector_by_quat(
            world_forces[i - 1], quat_conjugate(pred_quaternions[i - 1])
        )

        state_new = rk45_step(state, dt, f_body_pred, g, body_omegas[i - 1])

        pred_positions[i] = state_new[0:3]
        pred_velocities[i] = state_new[3:6]
        pred_quaternions[i] = state_new[6:10]

        f_rotated = rotate_vector_by_quat(f_body_pred, pred_quaternions[i - 1])
        pred_accelerations[i] = f_rotated + g

print(f"Computed {n_steps} time steps from t={t_start} to t={t_end}")


Precomputing trajectory data...
Computed 12000 time steps from t=0.0 to t=30.0


In [24]:
def quat_to_euler(q):
    w, x, y, z = q
    rot = Rotation.from_quat([x, y, z, w])
    return rot.as_euler("xyz")


def plot_trajectory(step_idx):
    step_idx = int(step_idx)

    all_positions = np.vstack([true_positions, pred_positions])
    x_range = [all_positions[:, 0].min() - 0.5, all_positions[:, 0].max() + 0.5]
    y_range = [all_positions[:, 1].min() - 0.5, all_positions[:, 1].max() + 0.5]
    z_range = [all_positions[:, 2].min() - 0.5, all_positions[:, 2].max() + 0.5]

    fig = go.Figure()

    fig.add_trace(
        go.Scatter3d(
            x=true_positions[: step_idx + 1, 0],
            y=true_positions[: step_idx + 1, 1],
            z=true_positions[: step_idx + 1, 2],
            mode="lines",
            name="True Trajectory",
            line=dict(color="blue", width=4),
        )
    )

    fig.add_trace(
        go.Scatter3d(
            x=pred_positions[: step_idx + 1, 0],
            y=pred_positions[: step_idx + 1, 1],
            z=pred_positions[: step_idx + 1, 2],
            mode="lines",
            name="RK45 Predicted",
            line=dict(color="red", width=4, dash="dash"),
        )
    )

    fig.add_trace(
        go.Scatter3d(
            x=[true_positions[step_idx, 0]],
            y=[true_positions[step_idx, 1]],
            z=[true_positions[step_idx, 2]],
            mode="markers",
            name="True Current",
            marker=dict(size=8, color="blue", symbol="circle"),
        )
    )

    fig.add_trace(
        go.Scatter3d(
            x=[pred_positions[step_idx, 0]],
            y=[pred_positions[step_idx, 1]],
            z=[pred_positions[step_idx, 2]],
            mode="markers",
            name="Pred Current",
            marker=dict(size=8, color="red", symbol="x"),
        )
    )

    fig.update_layout(
        title=f"3D Trajectory (t={time_steps[step_idx]:.2f}s)",
        scene=dict(
            xaxis=dict(title="X", range=x_range),
            yaxis=dict(title="Y", range=y_range),
            zaxis=dict(title="Z", range=z_range),
            aspectmode="cube",
        ),
        width=900,
        height=700,
        showlegend=True,
    )

    fig.show()

    true_euler = quat_to_euler(true_quaternions[step_idx])
    pred_euler = quat_to_euler(pred_quaternions[step_idx])

    print(f"\nRotation Angles at t={time_steps[step_idx]:.2f}s:")
    print(
        f"  True  - Roll: {true_euler[0]:8.4f} rad, Pitch: {true_euler[1]:8.4f} rad, Yaw: {true_euler[2]:8.4f} rad"
    )
    print(
        f"  RK45  - Roll: {pred_euler[0]:8.4f} rad, Pitch: {pred_euler[1]:8.4f} rad, Yaw: {pred_euler[2]:8.4f} rad"
    )
    print(
        f"  Error - Roll: {abs(true_euler[0] - pred_euler[0]):8.4f} rad, Pitch: {abs(true_euler[1] - pred_euler[1]):8.4f} rad, Yaw: {abs(true_euler[2] - pred_euler[2]):8.4f} rad"
    )


interact(
    plot_trajectory,
    step_idx=IntSlider(
        min=0, max=n_steps - 1, step=1, value=0, description="Time Step"
    ),
)


interactive(children=(IntSlider(value=0, description='Time Step', max=11999), Output()), _dom_classes=('widgetâ€¦

<function __main__.plot_trajectory(step_idx)>

## Single Step Prediction


In [25]:
t_start = 1.0
dt = 0.01
g = np.array([0.0, 0.0, -9.81])

f_world_t0, omega_t0, r_t0, v_t0, a_t0, q_t0 = get_world_frame_truth(t_start, g)

state = np.concatenate([r_t0, v_t0, q_t0])

f_world_t1, omega_t1, r_t1, v_t1, a_t1, q_t1 = get_world_frame_truth(t_start + dt, g)

f_body_t0 = rotate_vector_by_quat(f_world_t0, quat_conjugate(q_t0))

state_new = rk45_step(state, dt, f_body_t0, g, omega_t0)

f_rotated = rotate_vector_by_quat(f_body_t0, q_t0)
a_pred = f_rotated + g

print(f"RK45 Integration from t={t_start} to t={t_start + dt}:")
print("\nPredicted state (from RK45):")
print(
    f"  p (position):      [{state_new[0]:.6f}, {state_new[1]:.6f}, {state_new[2]:.6f}]"
)
print(
    f"  v (velocity):      [{state_new[3]:.6f}, {state_new[4]:.6f}, {state_new[5]:.6f}]"
)
print(f"  a (acceleration):  [{a_pred[0]:.6f}, {a_pred[1]:.6f}, {a_pred[2]:.6f}]")
print(
    f"  q (quaternion):    [{state_new[6]:.6f}, {state_new[7]:.6f}, {state_new[8]:.6f}, {state_new[9]:.6f}]"
)

print("\nTrue state (from trajectory):")
print(f"  p (position):      [{r_t1[0]:.6f}, {r_t1[1]:.6f}, {r_t1[2]:.6f}]")
print(f"  v (velocity):      [{v_t1[0]:.6f}, {v_t1[1]:.6f}, {v_t1[2]:.6f}]")
print(f"  a (acceleration):  [{a_t1[0]:.6f}, {a_t1[1]:.6f}, {a_t1[2]:.6f}]")
print(
    f"  q (quaternion):    [{q_t1[0]:.6f}, {q_t1[1]:.6f}, {q_t1[2]:.6f}, {q_t1[3]:.6f}]"
)

error_p = np.linalg.norm(state_new[0:3] - r_t1)
error_v = np.linalg.norm(state_new[3:6] - v_t1)
print(f"\nErrors:")
print(f"  Position error:    {error_p:.6e}")
print(f"  Velocity error:    {error_v:.6e}")


RK45 Integration from t=1.0 to t=1.01:

Predicted state (from RK45):
  p (position):      [2.419037, 2.625523, 0.505000]
  v (velocity):      [2.187981, -0.725789, 0.499996]
  a (acceleration):  [-0.599520, -0.658140, -0.000056]
  q (quaternion):    [0.963820, -0.008608, 0.092471, 0.249852]

True state (from trajectory):
  p (position):      [2.419037, 2.625523, 0.505000]
  v (velocity):      [2.187936, -0.725711, 0.500000]
  a (acceleration):  [-0.604850, -0.656364, 0.000000]
  q (quaternion):    [0.964435, -0.008552, 0.092610, 0.247416]

Errors:
  Position error:    2.954290e-07
  Velocity error:    8.942186e-05
