In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Build a dataset

In [None]:
def draw_states_and_observations(rng_seed: int = 123):
    """Truck example from Wikipedia"""
    
    # Parameters
    sigma_a = 0.01  # Stddev of acceleration
    sigma_obs = 0.1  # Stddev of position observation
    n_steps = 1000
    
    # RNG setup
    rng = np.random.RandomState(rng_seed)
    
    def get_noisy_obs(state):
        v = rng.normal(loc=0.0, scale=sigma_obs, size=())
        obs = state[0] + v
        return obs
    
    # Initial state
    state_history = [np.array([0.0, 0.0])]
    obs_history = [get_noisy_obs(state_history[0])]
    
    for k in range(n_steps):
        old_state = state_history[-1]
        w = rng.normal(loc=0.0, scale=sigma_a, size=()) * np.array([0.5, 1.0])
        new_state = np.array([[1, 1], [0, 0.9]]) @ old_state + w
        state_history.append(new_state)
        obs_history.append(get_noisy_obs(new_state))
    
    return state_history, obs_history

In [None]:
state_history, obs_history = draw_states_and_observations(rng_seed=123)

In [None]:
plt.plot([x for x, _ in state_history])
plt.plot(obs_history)

# Implement the Kalman Filter

In [None]:
class KalmanFilter:
    
    def __init__(
        self,
        initial_state_mean: np.array,
        initial_covariance: np.array,
        state_transition_model: np.array,
        process_noise_covariance: np.array,
        observation_model: np.array,
        observation_noise_covariance: np.array
    ):
        # Use Wikipedia notation internally
        self._F = state_transition_model
        self._H = observation_model
        self._Q = process_noise_covariance
        self._R = observation_noise_covariance
        
        self._x = initial_state_mean
        self._P = initial_state_covariance
        
        # FIXME - get shapes and raise errors if shapes are not as expected
    
    def predict() -> np.array:
        """Run a single timestep of the model and return the a priori estimates"""
        # FIXME - implement this
    
    def update(observation: np.array) -> np.array:
        """Incorporate the observations and return the a posteriori estimates"""
        # FIXME - implement this