In [None]:
import casadi as ca
import numpy as np

def integrator(f: ca.Function, x: ca.SX, u: ca.SX, dt: ca.SX):
    k1 = f(x, u)
    k2 = f(x + 0.5 * dt * k1, u)
    k3 = f(x + 0.5 * dt * k2, u)
    k4 = f(x + dt * k3, u)
    return x + (dt/ 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)


# symbolic dynamics model for estimation and control
q_BI = ca.SX.sym('q_BI', 4)
omega = ca.SX.sym('omega', 3)
omega_w = ca.SX.sym("omega_w", 3)
x = ca.vertcat(q_BI, omega, omega_w)
u = ca.SX.sym('inputs', 6)
u_rw = u[:3]
u_mag = u[3:]

J_B = np.eye(3,3)
J_w = np.ones(1)
A_hat = np.eye(3,3)
dt = 0.01

                    
qx, qy, qz, qw = q_BI[0], q_BI[1], q_BI[2], q_BI[3]
wx, wy, wz = omega[0], omega[1], omega[2]
q_dot = 0.5 * ca.vertcat(
    -qx*wx - qy*wy - qz*wz,
    qw*wx + qy*wz - qz*wy,
    qw*wy - qx*wz + qz*wx,
    qw*wz + qx*wy - qy*wx
)

h_int = A_hat @ (J_w * (omega_w + A_hat @ omega)) 
tau_rw = A_hat @ u_rw # electrical dynamics not included
tau_mag = u_mag # TODO: implement

cross_term = ca.cross(omega, J_B @ omega + h_int)
total_torque = tau_mag - tau_rw - cross_term
omega_dot = ca.solve(J_B, total_torque)
omega_w_dot = u_rw / J_w - A_hat @ omega_dot

dx = ca.vertcat(q_dot, omega_dot, omega_w_dot)

f = ca.Function("system_dynamics", [x, u], [dx])

x_next = integrator(f, x, u, dt)

F = ca.Function("F", [x, u, dt], [x_next], ['x','u','dt'], ['x_next'])

In [None]:
from filterpy.kalman import ExtendedKalmanFilter
import casadi as ca

# TODO: implement extended kalman filter wrapper
class IMUEKF(ExtendedKalmanFilter):
    def __init__(self, dim_x, dim_z, f: ca.Function, g: ca.Function):
        super().__init__(dim_x, dim_z)

        self.f = f # process model

        dt = ca.SX.sym("dt")
        self.F = integrator(self.f, self.x, u, dt) # function returns next state
        
        self.g = g # measurement model

        self.g_jac = g.jacobian()

    def predict_x(self, u, dt):
        self.x = self.F(self.x, u, dt)

    def update(self, z):

        super().update(z, self.g_jac, self.g)
