In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ex17_1_lib import DNLDS, EKF, SUT, SPKF 

In [None]:
# 1D Rocket Model
# x = [s, v, m]
n = 3
f = lambda x, t, u : np.array([ x[1], (u[0] - 0.2*x[1]**2)*x[2], -0.01**2*u[0] ]) # continuous time model
h = lambda x: np.array([x[0]])
dfdx = lambda x, u: np.array([ [ 0 , 1, 0 ], 
                               [ 0, -0.2*2*x[1]*x[2], u[0] - 0.2*x[1]**2 ],
                               [ 0, 0, 0 ] ]) # Jacobian of f1
dhdx = lambda x: np.array([[1, 0, 0]]) # Jacobian of h1
G = np.eye(n)
Q = np.diag([0.1**2, 0.001**2, 0**2])
R = np.diag([1**2])
x0 = np.array([10.0,0.0,1.0])
x0_est = np.array([10,0,0.56])
P0 = np.diag([10**-8,10**-8,1**2])
u_scale = 15

In [None]:
# Init Dynamical System
nls = DNLDS(f, h, G, Q, R, x0)

# Init Kalman Filter 
Q = Q + np.diag([0.05**2, 0.0005**2, 10**-6])

# EKF
filter = EKF(f, h, dfdx, dhdx, G, Q, R, x0, P0)

# OR 

# Sigma Point Kalman Filter 
#alpha = 0.0001
#kappa = 0.0
#beta = 2.0
#sut = SUT(alpha, beta, kappa, n)
#filter =  SPKF(f,h,G,Q,R,x0_est,P0,sut,variant=1) # 0 - normal UKF, 1 - IUKF, 2 - UKFz

In [None]:
# Logging 
x_true = []; x_true.append(nls.x)
x_est = []; x_est.append(filter.x)

T = 15*60 # in seconds 
dt = 0.1 # seconds 

for i in range(int(T/dt)):
    if (nls.x[0]<=0):
        # stop if we are crashing, we did not model ground level
        print("Crashed at time: ", i*dt, " seconds.")
        break
    u = np.random.rand(1)*u_scale
    nls.step(u,dt,1)
    filter.predict(u,dt,1)

    meas = nls.output()
    filter.update(meas,1) # 0 - simple covariance update, 1 - Joseph covariance update
    
    # Logging 
    x_true.append(nls.x)
    x_est.append(filter.x)

x_true = np.array(x_true)
x_est = np.array(x_est)

In [None]:
# Plot the logs
for i in range(nls.x.shape[0]):
    _, ax = plt.subplots(1)
    plt.style.use('seaborn-whitegrid')
    ax.plot(x_true[:,i],color='blue', linestyle='solid', marker='o',
        markerfacecolor='blue', markersize=4, label='True State')
    ax.plot(x_est[:,i],color='orange', linestyle='solid', marker='o',
        markerfacecolor='orange', markersize=4, label='Est State')
    ax.set_title('State '+str(i))
    plt.legend()

print("Mass:\n", x_true[:,2])