# Square-Root Unscented Kalman Filter Tutorial

This notebook walks through the SR-UKF Python bindings with a hands-on
pendulum tracking example.

## What is a Kalman Filter?

A Kalman filter **optimally fuses** two imperfect information sources:

1. A **model** predicting how the system evolves (but with drift)
2. **Sensor readings** (but with noise)

The result is an estimate better than either source alone.

## Why Unscented?

The classic Kalman filter only works for **linear** systems. For nonlinear
systems, the **Unscented Kalman Filter** (UKF) uses a clever trick:

> "It's easier to approximate a probability distribution than to
> approximate an arbitrary nonlinear function."

Instead of linearizing (like the EKF), the UKF picks **sigma points**
that capture the mean and covariance, propagates them through the
nonlinear function exactly, then reconstructs the statistics.

In [None]:
import numpy as np
import math
from srukf import UnscentedKalmanFilter

## Step 1: Define the System

We'll track a damped pendulum with state `[angle, angular_velocity]`.
We can only measure the angle (with noise).

In [None]:
G, L, B = 9.81, 1.0, 0.1  # gravity, length, damping

def pendulum_rk4(theta, omega, dt):
    """RK4 integration of the pendulum ODE."""
    def deriv(th, om):
        return om, -(G/L)*math.sin(th) - B*om
    k1t, k1o = deriv(theta, omega)
    k2t, k2o = deriv(theta + k1t*dt/2, omega + k1o*dt/2)
    k3t, k3o = deriv(theta + k2t*dt/2, omega + k2o*dt/2)
    k4t, k4o = deriv(theta + k3t*dt, omega + k3o*dt)
    return (
        theta + (dt/6)*(k1t + 2*k2t + 2*k3t + k4t),
        omega + (dt/6)*(k1o + 2*k2o + 2*k3o + k4o),
    )

def process_model(x, dt=0.01):
    th, om = pendulum_rk4(x[0], x[1], dt)
    return np.array([th, om])

def measurement_model(x):
    return np.array([x[0]])  # observe angle only

## Step 2: Create the Filter

In [None]:
ukf = UnscentedKalmanFilter(
    state_dim=2,
    meas_dim=1,
    process_noise_sqrt=0.01 * np.eye(2),
    meas_noise_sqrt=np.array([[0.1]]),
)

# Set initial conditions
ukf.x = np.array([math.pi/4 + 0.05, 0.05])  # slightly wrong guess
ukf.S = 0.2 * np.eye(2)  # moderate initial uncertainty

print(f"Initial state: {ukf.x}")
print(f"Initial P trace: {np.trace(ukf.P):.4f}")

## Step 3: Run the Filter

In [None]:
rng = np.random.default_rng(42)
dt = 0.01
n_steps = 1000

theta_true, omega_true = math.pi/4, 0.0
time_since_meas = 0.0

results = {'time': [], 'true': [], 'meas': [], 'est': [], 'uncertainty': []}

for step in range(n_steps):
    t = step * dt
    
    # True dynamics
    theta_true, omega_true = pendulum_rk4(
        theta_true + rng.normal(0, 0.01*dt),
        omega_true + rng.normal(0, 0.01*dt), dt)
    
    # Predict
    ukf.predict(process_model, dt=dt)
    
    # Measure at 20 Hz
    time_since_meas += dt
    if time_since_meas >= 0.05:
        z = np.array([theta_true + rng.normal(0, 0.1)])
        ukf.update(z, measurement_model)
        results['meas'].append((t, z[0]))
        time_since_meas = 0.0
    
    results['time'].append(t)
    results['true'].append(theta_true)
    results['est'].append(ukf.x[0])
    results['uncertainty'].append(ukf.S[0, 0])

print(f"Final state: {ukf.x}")
print(f"Final P trace: {np.trace(ukf.P):.6f}")

## Step 4: Visualize Results

In [None]:
try:
    import matplotlib.pyplot as plt
    
    time = np.array(results['time'])
    true = np.degrees(results['true'])
    est = np.degrees(results['est'])
    unc = np.degrees(results['uncertainty'])
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    
    ax1.plot(time, true, 'g-', alpha=0.8, label='True')
    if results['meas']:
        mt, mz = zip(*results['meas'])
        ax1.scatter(mt, np.degrees(mz), c='red', s=8, alpha=0.5, label='Measurements')
    ax1.plot(time, est, 'b-', lw=2, label='SR-UKF')
    ax1.fill_between(time, est-unc, est+unc, alpha=0.2, color='blue')
    ax1.set_ylabel('Angle (deg)')
    ax1.set_title('Pendulum Tracking')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(time, np.abs(np.array(results['true']) - np.array(results['est'])), 'r-')
    ax2.set_xlabel('Time (s)')
    ax2.set_ylabel('Error (rad)')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
except ImportError:
    print('Install matplotlib for plots: pip install matplotlib')

## Parameter Tuning Guide

| Parameter | Typical Range | Effect |
|-----------|--------------|--------|
| `process_noise_sqrt` | 0.001 - 1.0 | Larger = trust model less, adapt faster |
| `meas_noise_sqrt` | 0.01 - 10.0 | Larger = trust measurements less, smoother |
| `alpha` | 1e-3 - 1.0 | Sigma point spread; smaller = tighter around mean |
| `beta` | 2.0 | Optimal for Gaussian; rarely needs changing |
| `kappa` | 0 or 3-N | Secondary scaling; 0 is usually fine |