# MRP Extended Kalman Filter Demo

This demo introduces a JAXitude workflow estimating the attitude evolution of a tumbling spacecraft using simulated data using modified Rodrigues parameters (MRPs).  We'll make use of four JAXitude tools for this exercise:
- `jaxitude.quaternions.Quaternion()` to transform the simulated quaternion data to principal axes and angles,
- `jaxitude.operations.noise.Heading` to appropriately simulate measurement error from the simulated data,
- `jaxitude.rodrigues.MRP()` to convert simulated attitude measurements to MRPs,
- `jaxitude.determination.eki.MRPEKF` to build the MRP extended Kalman filter algorithm.

As always with JAXitude, any vector data must be transformed to column vectors (I.e. `jnp.array([1., 0., 0.])` $\rightarrow$ `jnp.array([1., 0., 0.]).reshape((3, 1))`).

## Step Zero: Load Simulation Data

We need to import the training data provided under the local directory `data/`.  We'll be using the `tumbling1/` data set for this demo.

In [None]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.getcwd()).parent))
from typing import Tuple

from numpy import genfromtxt
import jax.numpy as jnp
from jax import config
config.update('jax_enable_x64', True)  # Default to double float precision.

data_path = Path('data/') / 'tumbling1'

# Load observed gyroscope data.
with open(data_path / 'gyro-0.csv', 'r') as f:
    w_obs_data = jnp.asarray(genfromtxt(f, delimiter=',', skip_header=1))

# Load true gyroscope data.
with open(data_path / 'ref_gyro.csv', 'r') as f:
    w_true_data = jnp.asarray(genfromtxt(f, delimiter=',', skip_header=1))

# Load true attitude data (parameterized with quaternions).
with open(data_path / 'ref_att_quat.csv', 'r') as f:
    b_true_data = jnp.asarray(genfromtxt(f, delimiter=',', skip_header=1))

# load simulation time data.
with open(data_path / 'time.csv', 'r') as f:
    t_data = jnp.asarray(genfromtxt(f, delimiter=',', skip_header=1))

# Remember: all vector data must be transformed to column vectors to utilize
# JAXitude calculations!
n_steps = jnp.shape(t_data)[0]
w_obs_data = w_obs_data.reshape((n_steps, 3, 1))
w_true_data = w_true_data.reshape((n_steps, 3, 1))
b_true_data = b_true_data.reshape((n_steps, 4, 1))

## Step One: Simulate Attitude Measurements

Attitude is always an orientation with respect to another frame.  For the 'tumbling1' simulated data, the craft starts at longitude and latitude $(\lambda=32\deg, \phi=120\deg)$ with an altitude of $h=1000$ m. Fortunately, we are only interesting in tracking the craft from its starting position, which we set as a zero rotation.  This corresponds to an initial quaternion value of $\mathbf{b}_0=[1, 0, 0, 0]^T$.  Our attitude measurements, or heading vectors, will be with respect to $\mathbf{b}_0$.

What these heading vectors mean needs a bit more explanation.  Let's assume the heading measurement device is placed along the $x$-axis (facing east, in this case).  The measured heading vector at any later time $t$ during the craft's tumble will be a vector $\mathbf{u}(t) = [\mathbf{R}(\mathbf{b}(t))]\mathbf{u}_0$, where $\mathbf{u}_0=[1,0,0]^T$.  Note that $[\mathbf{R}(\mathbf{b}_0)]\mathbf{u}_0 = [\mathbf{I}_{(3\times3)}]\mathbf{u}_0 = \mathbf{u}_0$.

Adding noise to attitude or heading data is not trivial since such objects are not closed under vector operations.  Practically speaking, that means that if $\mathbf{u}$ is a heading vector, $\mathbf{u} + \delta\mathbf{u}$ is not guaranteed to be a valid heading vector.  That's because heading vectors are unit vectors (so $|\mathbf{u}|=1$), but $|\mathbf{u}| \neq |\mathbf{u} + \delta\mathbf{u}|$ for most error vectors $\delta\mathbf{u}$!

JAXitude resolves this by instead generating a random rotation matrix $[\mathbf{R}(\delta\phi)]$, where $\delta\phi \sim \mathcal{N}(0, \sigma_{\delta\phi}$), and then apply that rotation matrix to the 'true' input unit vector $\mathbf{u}$ to get the 'measured' output unit vector $\mathbf{u}^* = [\mathbf{R}(\delta\phi)]\mathbf{u}$.

Note that this random rotation takes place along an axis that is perpendicular to the 'true' unit vector $\mathbf{u}$.  Any orthogonal rotation axis within this perpendicular plane is valid and is picked from a uniform distribution.

One may ask "why don't we just directly inject noise into the simulated quaternion set directly?"  That's a fair question and you can also corrupt quaternion data using `jaxitude.operations.noise.Quaternion` functionality, but here we opt simualte heading measurements from a ficticious heading measurement device!

In [None]:
from jax import vmap
from jax.random import PRNGKey, split

# import importlib
# from jaxitude import base
# importlib.reload(base)

from jaxitude.base import find_DCM
from jaxitude import quaternions
from jaxitude.quaternions import Quaternion
from jaxitude.rodrigues import MRP
from jaxitude.operations.noise import Heading

# We need a random key to generate noise.
key = PRNGKey(1)

# Initial heading vector.
u_0 = jnp.array(
    [[0.],
     [1.],
     [0.]]
)

# Heading rotation angle standard deviation will be set to five degrees.
sigma_phi = 5. * jnp.pi / 180.  # Convert to radians for JAXitude!


# For this helper function, we'll define it for a single key, b argument pair
# and then vectorize with vmap.
@vmap
def convert_corrupt(key: int, b: jnp.ndarray) -> Tuple[jnp.ndarray]:
    """ This helper function takes the quaternion set b, makes sure it is
        normalized, converts it to a rotation matrix to get a heading vector,
        corrupts the heading vector, and returns the simulated heading
        measurement.
    """
    # We get the rotation matrix R(b) by calling the instantiated Quaternion
    # object.  Note that we also normalized the quaternion set.
    R = Quaternion(b / jnp.linalg.norm(b))()

    # Get the 'true' heading measurement.
    u_true = R @ u_0

    # Corrupt u_true with a normalized u_true.
    _, subkey = split(key)
    return u_true / jnp.linalg.norm(u_true), Heading.addnoise(
        subkey,
        u_true / jnp.linalg.norm(u_true),
        sigma_phi
    )

# Let's also save the 'true' headings for comparison later.
key, subkey = split(key)
heading_true_data, heading_obs_data = convert_corrupt(
    split(subkey, n_steps),
    b_true_data
)


## Step Two: MRP Conversion

Since we are going to feed all this data into an MRP EKF algorithm, we next want to convert the heading measurements to MRPs.  Again, we'll make use of JAXitude to simplify this workflow.

In [None]:
from jaxitude.base import DCM

# We also want to get the measurement quaternions form the measured headings. 
# Let's use the utility function 'find_DCM' from 'jaxitude.base' and vectorize
# with vmap.
@vmap
def mrp_from_heading(u: jnp.ndarray) -> jnp.ndarray:
    """ Helper function to get measured quaternion from heading measurement.
    """
    # find_DCM calculates the rotation matrix that will rotate the first vector
    # onto the second. We can then construct a DCM object and calculate the MRP
    # s set from there
    return DCM(find_DCM(u_0, u)).get_s()

mrp_obs_data = mrp_from_heading(heading_obs_data)
mrp_true_data = mrp_from_heading(heading_true_data)

In [None]:
x = heading_obs_data[101]
print(jnp.linalg.norm(x))
R = DCM(find_DCM(u_0, x))()
R @ R.T

## Step Three: MRP EKF Algorithm



In [None]:
from scipy.linalg import block_diag
from jax.random import normal

import importlib
from jaxitude.estimation import ekf
from jaxitude.operations import linearization 
importlib.reload(linearization)
importlib.reload(ekf)
from jaxitude.estimation.ekf import MRPEKF


# The bias should be around three degrees for all gyroscope axes.
b = jnp.full((3, 1), 3. * jnp.pi / 180.)

# Now we define the state vector s.
x = jnp.vstack([mrp_obs_data[0], b])

# Initializing the state covariance matrix P is more difficult.  We will start
# with a diagonal matrix with with sigma_s = 0.01 and sigma_b = 0.05.
sigmap_w = 1e-3
sigmap_b = 0.05
P = block_diag(
    jnp.eye(3) * sigmap_w,
    jnp.eye(3) * sigmap_b
)

# The noise vector eta is also needed.
key, subkey1, subkey2 = split(key, 3)
eta_w = normal(subkey1, (3, 1)) * sigmap_w
eta_b = normal(subkey1, (3, 1)) * sigmap_b
eta = jnp.vstack([eta_w, eta_b])

# The process noise covariance will simply be the initial state covariance
# estimate.
Q = P.copy()

# The measurement noise is known from our simulated measurements and from
# comparing the gyroscope data.
sigma_w = 6e-4
R_w = jnp.eye(3) * sigma_w**2.
R_s = jnp.eye(3) * sigma_phi**2.

# Finally, the time steps are calculated from the data.
dt = t_data[1] - t_data[0]

x_list = []
eta_list = []
P_list = []

for i in range(100):
    x_list.append(x)
    eta_list.append(eta)
    P_list.append(P)

    x, eta, P = MRPEKF.filter_step(
        x, eta, P,
        w_obs_data[i] * jnp.pi / 180., mrp_obs_data[i],
        R_w, R_s, Q, dt
    )

x_arr = jnp.array(x_list)
eta_arr = jnp.array(eta_list)
P_arr = jnp.array(P_list)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(t_data[:100], mrp_obs_data[:100, 1, 0], alpha=0.5)
plt.plot(t_data[:100], x_arr[:100, 1, 0], alpha=0.5)
plt.ylim([-0.2, 0.2])

In [None]:
plt.plot(t_data, mrp_true_data[:, 2, 0], alpha=0.5)
