In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
import filterpy.kalman as fp

%matplotlib notebook
rc('animation', html='html5')

In [None]:
RUN_TIME = 20
TIME_STEP = 0.05

INIT_POS = (0.0,0.0)
INIT_VEL = (3.0,30.0)
INIT_ACC = (0.0,0.0)

G = -9.8 # Acceleration due to gravity (m/s^2)
POS_MEAS_STDDEV = 0.6
VEL_MEAS_STDDEV = 0.4

In [None]:
# Define a particle
class Particle:
    
    def __init__(self, pos=(0.0,0.0), vel=(0.0,0.0), acc=(0.0,0.0), mass=1.0):
        self.pos = np.array(pos)
        self.vel = np.array(vel)
        self.acc = np.array(acc)
        self.mass = mass
        self.age = 0.0
        
    def applyForce(self, force):
        self.acc += np.array(force)/self.mass
        
    def applyGravity(self):
        self.acc += np.array((0,G))
        
    def update(self, dt):
        self.vel += self.acc*dt
        self.pos += self.vel*dt
        self.age += dt
        
        # Reset the acceleration after each update
        self.acc = np.array((0.0,0.0))
        

In [None]:
# Create the particle
p1 = Particle(pos=INIT_POS, vel=INIT_VEL)

pos_actual = np.empty((0,2), float)
vel_actual = np.empty((0,2), float)
timestamp = np.empty((0,1), float)
t = 0

# Simulate the actual particle trajectory
while t < RUN_TIME:
    pos_actual = np.append(pos_actual, [p1.pos], axis=0)
    vel_actual = np.append(vel_actual, [p1.vel], axis=0)
    timestamp = np.append(timestamp, t)
    p1.applyGravity()
    p1.update(dt=TIME_STEP)
    
    # Stop trajectory if we go below ground
    if p1.pos[1] < 0:
        break
        
    t += TIME_STEP

In [None]:
# Add measurement noise
pos_meas = pos_actual + np.random.normal(loc=0.0, scale=POS_MEAS_STDDEV, size=pos_actual.shape)
vel_meas = vel_actual + np.random.normal(loc=0.0, scale=VEL_MEAS_STDDEV, size=vel_actual.shape)

In [None]:
# Instantiate a Kalman filter
f = fp.KalmanFilter(dim_x=4, dim_z=4, dim_u=1)

dt = TIME_STEP
g = G

# Control transition matrix
f.B = np.array([
                [0.],          # x_pos
                [0.5*dt**2],   # y_pos
                [0.],          # x_vel
                [dt]           # y_vel
               ])

# State matrix
f.x = np.array([[1.],    # x_pos
                [1.],    # y_pos
                [1.],    # x_vel
                [1.]])   # y_vel

# State transition matrix
f.F = np.array([[1.,0.,dt,0.],
                [0.,1.,0.,dt],
                [0.,0.,1.,0.],
                [0.,0.,0.,1.]
               ])

# Measurement function
f.H = np.array([
                [1.,0.,0.,0.],
                [0.,1.,0.,0.],
                [0.,0.,0.,0.],
                [0.,0.,0.,0.],
               ])

# Coveriance matrix
f.P *= 1000.

# Measurement noise
f.R = np.array([
                [5., 0., 0., 0.],
                [0., 5., 0., 0.],
                [0., 0., 5., 0.],
                [0., 0., 0., 5.]
               ])

# Process noise
from filterpy.common import Q_discrete_white_noise
f.Q = Q_discrete_white_noise(dim=4, dt=0.1, var=0.13)

pos_est = np.empty((0,2), float)
#print(pos_est.shape)
# Process
for z in pos_meas:
    z = np.append(z, [0, 0])
    #print(f"H={f.H.shape}, x={f.x.shape}, z={z.shape}, K={f.K.shape}, I_={f._I.shape}, P={f.P.shape}, R={f.R.shape}")
    f.predict(u=g)
    f.update(z) 
    pos_est = np.append(pos_est, [f.x.T[0,:2]], axis=0)
        

In [None]:
# Display the static results
fig=plt.figure(figsize=(15, 6))

plt.subplot(1,2,1)
plt.plot(*zip(*pos_actual), label="Actual")
plt.plot(*zip(*pos_meas), '.', label="Measured")
plt.plot(pos_est[:,0],pos_est[:,1], '--', label="Estimated")
plt.legend()
plt.xlabel("X Position (m)")
plt.ylabel("Y Position (m)")
plt.title("Position Tracking")
plt.grid()

In [None]:
# Visualize
fig, ax = plt.subplots(figsize=(10,6))

# Calculate axis limits
MARGIN = 5
x_max = np.max(pos_actual[:,0]) + MARGIN
x_min = np.min(pos_actual[:,0]) - MARGIN
y_max = np.max(pos_actual[:,1]) + MARGIN
y_min = np.min(pos_actual[:,1])

# Set axis limits
ax.set_xlim((x_min, x_max))
ax.set_ylim((y_min, y_max))

line_actual, = ax.plot([], [], lw=2, label="Actual")
line_meas, = ax.plot([], [], '.', lw=2, label="Measured")
line_est, = ax.plot([], [], '--', lw=2, label="Estimated")

def init():
    line_actual.set_data([], [])
    line_meas.set_data([], [])
    line_est.set_data([], [])
    return (line_actual, line_meas, line_est)

def animate(i):
    line_actual.set_data(pos_actual[:i,0], pos_actual[:i,1])
    #line_actual.set_label("Actual")
    line_meas.set_data(pos_meas[:i,0], pos_meas[:i,1])
    #line_meas.set_label("Measured")
    line_est.set_data(pos_est[:i,0], pos_est[:i,1])
    #line_est.set_label("Estimated")
    return (line_actual, line_meas, line_est) #+ [plt.legend()]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=int(np.max(timestamp)/TIME_STEP)+1, interval=TIME_STEP*1000, 
                               blit=True)

#HTML(anim.to_html5_video())
plt.legend()
plt.title("2D Projectile Tracking using a Kalman filter")
plt.xlabel("X Pos (m)")
plt.ylabel("Y Pos (m)")
HTML(anim.to_jshtml())

In [None]:
anim.save('MovWave.gif', writer="imagemagick")