# Practical 2

Within this practical we consider data assimilation for a a more complicated dynamical system, namely a double pendulum. The state space for this system is four dimensional, and with this increase the grid-based methods 
used in the first practical become impractical; to perform a single push-forward of a PDF using a resolution of 300 along each axis would  now require $300^4= 8100000000$ integrations of the equations of motion, 
while each `ProbabilityGrid` would need about 65Gb of memory to store the PDF values. To make progress, we instead need to consider more efficient approximate methods, with the focus here being on the Ensemble Kalman filter. 

A further aspect of the double pendulum is that its dynamics is chaotic, with this property have a profound effect on the accuracy of the forecasts that can be generated. 

In [None]:
# Import the necessary libraries for this notebook, 
# installing pygeoinf if required. 
try:
    from pygeoinf import data_assimilation as da
except ImportError: 
    %pip install pygeoinf --quiet
    from pygeoinf import data_assimilation as da

import numpy as np
import matplotlib.pyplot as plt
from pygeoinf.data_assimilation.pendulum import double

plt.rcParams['animation.embed_limit'] = 100.0

## Part 1 - Chaotic dynamics

We start by setting up the double pendulum system and animating its dynamics for a given initial condition. The default value starts the system close to its stable equilibrium and here you should see a rather simple result. You can then experiment with a few larger initial perturbations to see what happens. 

In [None]:
# =============================================================================
# 1. PHYSICAL SETUP
# =============================================================================
# We define a standard double pendulum with equal lengths and masses. You can 
# later change these parameters if you wish. 
params = {"L1": 1.0, "L2": 1.0, "m1": 1.0, "m2": 1.0, "g": 1.0}

# Unpack for passing to the solver later
physics_args = (params["L1"], params["L2"], params["m1"], params["m2"], params["g"])


# =============================================================================
# 2. INITIAL CONDITIONS (STUDENT EXPERIMENT)
# =============================================================================
# The state vector is 4D: [theta1, theta2, p1, p2]
# Both angles measures from the downwards vertical direction. 
# p: Generalized momentum (related to angular velocity)

theta1_init = np.deg2rad(120)
theta2_init = np.deg2rad(120)

p1_init = 0.0
p2_init = 0.0

# Create the state vector y0
y0 = np.array([theta1_init, theta2_init, p1_init, p2_init])

print(f"Initial State: {y0}")
print("Integration started...")


# =============================================================================
# 3. SOLVE DYNAMICS
# =============================================================================
# We integrate the system for 20 seconds.

fps = 20
t0 = 0
t1 = 30
t_points = np.linspace(t0, t1, round(fps*(t1-t0))) 

# Solve using the library function.
# double.eom is the Equation of Motion for the system
solution = da.solve_trajectory(
    eom_func=double.eom,
    y0=y0,
    t_points=t_points,
    args=physics_args,
    rtol=1e-10,  # High precision required for chaotic systems
    atol=1e-12,
)

print("Integration complete.")


# =============================================================================
# 4. VISUALISATION
# =============================================================================
# We use the built-in animator for the double pendulum.
# It includes a fading trail to help visualise the complexity of the path.

print("Generating Animation...")

anim = double.animate_pendulum(
    t_points=t_points,
    solution=solution,
    L1=params["L1"],
    L2=params["L2"],
    trail_len=50,  # Length of the "tail" following the bob
)
plt.close()

# In a Jupyter Notebook, we would use:
da.display_animation_html(anim)




You should now have seen that for sufficiently large initial perturbation the dynamics of the double pendulum can be very complicated. We can get a further sense of this property by looking at the evolution of the system in phase space. To do so, we draw samples from a probability distribution for the initial state that is tightly peaked and then animate the evolution of this ensemble of states forward in time. To visualise the behaviour, we cannot plot the four dimensional phase space directly, but we make separate angle-momentum plots for the two bobs and show them side by side. 

In [None]:
# 1. Define the Initial Distribution as a tightly peaked Gaussian
mean_state = np.array([np.deg2rad(120), np.deg2rad(120), 0.0, 0.0])
sigma = 0.01
covariance = np.diag([sigma**2, sigma**2, sigma**2, sigma**2]) 

# 2. Sample the Ensemble
n_ensemble = 100
rng = np.random.default_rng(seed=42)
initial_ensemble = rng.multivariate_normal(mean_state, covariance, size=n_ensemble)

# 3. Propagate Forward
t0 = 0.0
t1 = 30.0
dt = 0.05
t_eval = np.arange(t0, t1, dt)

print(f"Propagating {n_ensemble} particles... this may take a moment.")
ensemble_trajectories = da.solve_ensemble(
    eom_func=double.physics.eom,
    initial_conditions=initial_ensemble,
    t_points=t_eval,
    args=physics_args
)

# 4. Animate in Phase Space using the library function
anim = double.animate_ensemble_phase_space(
    t_points=t_eval,
    ensemble_trajectories=ensemble_trajectories
)
plt.close()

da.display_animation_html(anim)

The results of this section show in a qualitative sense that the dynamics for the double pendulum can be chaotic. Within this course we have not discussed this idea in detail, but a key qualitative aspect is that the systems evolution is very sensitive to initial conditions, with this being seen clearly by looking at the behaviour of an initially tightly packed ensemble of states. Clearly, this behaviour places limits on our ability to accurate forecast the future state of such a system.  

## Part 2 - The Ensemble Kalman filter

We have seen in Lecture 2 that the ensemble Kalman filter provides a workable means for applying data assimilation to non-linear dynamical systems with high dimensional state spaces. The aim of this section is for you to apply this method to a synthetic data set for the double pendulum system. 

The code block below generates synthetic data for you to work with. In detail, at each observation time you are provided with a noisy observations of the angle for each pendulum bob. The ultimate aim is to produce a forecast of the pendulum's state at a time lying beyond the final set of observations. An issue you will want to consider is the time duration over which your forecast remains sufficiently accurate to be useful. 

In [None]:
# 1. Physical Parameters
# We use the same standard double pendulum setup
params = {
    'L1': 1.0, 'L2': 1.0, 
    'm1': 1.0, 'm2': 1.0, 
    'g': 1.0
}

# 2. "True" Initial Condition
# We stick with the chaotic regime (Horizontal start)
# [theta1, theta2, p1, p2]
true_y0 = np.array([np.deg2rad(120), np.deg2rad(120), 0.0, 0.0])

# 3. Setup the Problem Manager
problem = da.BayesianAssimilationProblem(
    eom_func=double.eom,
    eom_args=(params['L1'], params['L2'], params['m1'], params['m2'], params['g'])
)

# 4. Define Observations (Multivariate)
# We observe BOTH angles (theta1, theta2) every 0.5 seconds.
# We stop observing at t=20.0s to allow for a forecast period afterwards.
n_obs = 5
dt_obs = 5

t_obs_points = np.linspace(
    dt_obs,
    dt_obs*n_obs, 
    n_obs
)

# Observation Noise (Standard Deviation)
obs_sigma = np.deg2rad(10)

# Covariance Matrix R (2x2)
# We assume independent noise for each angle
R = np.diag([obs_sigma**2, obs_sigma**2])

# Observation Operator H (2x4)
# Maps state [th1, th2, p1, p2] -> observations [th1, th2]
H = np.array([
    [1.0, 0.0, 0.0, 0.0],  # Selects theta1
    [0.0, 1.0, 0.0, 0.0]   # Selects theta2
])

print(f"Generating synthetic data with {len(t_obs_points)} observation times...")

# Register observations
for t in t_obs_points:
    problem.add_observation(time=t, covariance=R, operator=H)

# 5. Generate Truth (with Forecast Window)
# We simulate past the last observation (up to 15s) to test forecasting.
truth_data = problem.generate_synthetic_data(
    true_initial_condition=true_y0,
    dt_render=0.01,    
)

# Extract data for plotting
t_true = truth_data['t_ground_truth']
y_true = truth_data['state_ground_truth']

# Extract observations
# Note: y_obs is now a vector of length 2 at each time step
obs_times = [t for t, _ in problem.observations]
obs_vals = np.array([m.y_obs for _, m in problem.observations]) # Shape (N_obs, 2)


# =============================================================================
# VISUALISATION: Truth vs Observations
# =============================================================================
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

# Plot Theta 1
ax1.plot(t_true, y_true[0], 'k-', lw=1.5, label='True Trajectory')
ax1.errorbar(
    obs_times, obs_vals[:, 0], 
    yerr=obs_sigma, fmt='rx', capsize=4, label='Observations'
)
ax1.set_ylabel(r'$\theta_1$')
ax1.set_title('Double Pendulum: Tracking & Forecasting Problem')
ax1.legend(loc='upper left')
ax1.grid(True, alpha=0.3)

# Plot Theta 2
ax2.plot(t_true, y_true[1], 'k-', lw=1.5, label='True Trajectory')
ax2.errorbar(
    obs_times, obs_vals[:, 1], 
    yerr=obs_sigma, fmt='rx', capsize=4, label='Observations'
)
ax2.set_ylabel(r'$\theta_2$')
ax2.set_xlabel('Time [s]')
ax2.grid(True, alpha=0.3)


plt.tight_layout()
plt.show()

The next code block provides the basic structure for your assimilation loop, while a range of hints are included to highlight useful methods to apply within your implementation. Once you have implemented the loop, your forecast for the the state will be visualised and compared to the true trajectory.  

In [None]:
# =============================================================================
# PART 2: THE ENSEMBLE KALMAN FILTER (Student Task)
# =============================================================================

# 1. Configuration & Extended Truth
# ---------------------------------
n_ensemble = 200
forecast_duration = 10.0  # Seconds to forecast AFTER last observation

# Determine timings based on the observations we generated earlier
t_last_obs = t_obs_points[-1]
t_final = t_last_obs + forecast_duration

# GENERATE EXTENDED TRUTH
# We re-run the true dynamics for the full duration (0 to t_final) so we have
# a "Ground Truth" to compare our forecast against.
print(f"Simulating 'Truth' up to t={t_final}s...")
t_true_full = np.arange(0, t_final, 0.05)
y_true_full = da.solve_trajectory(
    double.physics.eom,
    true_y0,
    t_true_full,
    args=physics_args
)

# 2. Initial Ensemble (The Prior)
# -------------------------------
print(f"Initializing {n_ensemble} ensemble members...")
# We start with a guess centered on the truth but with some spread
prior_mean = true_y0
prior_cov = np.diag([0.05**2, 0.05**2, 0.01**2, 0.01**2]) 

rng = np.random.default_rng(seed=42)
current_ensemble = rng.multivariate_normal(prior_mean, prior_cov, size=n_ensemble)


# 3. The Assimilation Loop (TODO)
# -------------------------------
t_current = 0.0
print("Starting EnKF Assimilation...")

# Loop over the observations we generated in Part 1
for t_obs, y_obs_k in zip(obs_times, obs_vals):
    
    dt_step = t_obs - t_current
    
    # =========================================================================
    # TODO STEP A: FORECAST
    # =========================================================================
    # Evolve the entire ensemble from t_current to t_obs using non-linear physics.
    # Hint: Use da.solve_ensemble(). It returns shape (N_ens, State_Dim, Time_Steps).
    # You only need the state at the final time step (index -1).
    
    forecast_ensemble = None 
    
    # --- SOLUTION ---
    forecast_traj = da.solve_ensemble(
         double.physics.eom,
         current_ensemble,
         np.array([0, dt_step]),
         args=physics_args
     )
    forecast_ensemble = forecast_traj[:, :, -1]
    # ----------------
    
    
    # =========================================================================
    # TODO STEP B: ANALYSIS
    # =========================================================================
    # Update the ensemble using the observation y_obs_k.
    
    analysis_ensemble = None
    
    if forecast_ensemble is not None:
        # 1. Perturb Observations
        # Create perturbed observations: y_p = y_obs + noise, where noise ~ N(0, R)
        # Hint: rng.multivariate_normal(np.zeros(2), R, size=n_ensemble)
        
        # obs_noise = ...
        # perturbed_obs = ...
        
        # 2. Compute Sample Covariance (Q) of the forecast
        # Hint: np.cov(forecast_ensemble, rowvar=False)
        
        # Q_sample = ...
        
        # 3. Compute Kalman Gain (K)
        # Hint: S = H @ Q @ H.T + R
        # Hint: K = Q @ H.T @ np.linalg.inv(S)
        
        # K_gain = ...
        
        # 4. Update Ensemble Members
        # Hint: Loop over i in range(n_ensemble)
        # x_new = x_old + K * (y_perturbed - H * x_old)
        
        # --- SOLUTION ---
        obs_noise = rng.multivariate_normal(np.zeros(2), R, size=n_ensemble)
        perturbed_obs = y_obs_k + obs_noise 
        #
        Q_sample = np.cov(forecast_ensemble, rowvar=False)
        #
        S = H @ Q_sample @ H.T + R
        K_gain = Q_sample @ H.T @ np.linalg.inv(S)
        #
        analysis_ensemble = np.zeros_like(forecast_ensemble)
        for i in range(n_ensemble):
             innovation = perturbed_obs[i] - (H @ forecast_ensemble[i])
             analysis_ensemble[i] = forecast_ensemble[i] + K_gain @ innovation
        # ----------------
        pass 

    # Safety Fallback if todo not implemented
    if analysis_ensemble is None: analysis_ensemble = current_ensemble

    # Store State
    current_ensemble = analysis_ensemble
    t_current = t_obs
    
    print(f"Assimilated t={t_current:.1f}s")


# 4. Forecast Beyond Data (TODO)
# ------------------------------
# Propagate the final analysis ensemble to 't_final'.

print("Running Extended Forecast...")

t_forecast = np.arange(t_current, t_final, 0.05)
final_forecast_traj = None

# --- SOLUTION ---
if len(t_forecast) > 0:
     final_forecast_traj = da.solve_ensemble(
         double.physics.eom,
         current_ensemble,
         t_forecast,
         args=physics_args
     )
# ----------------


# 5. Visualisation (Full State: Angles & Momenta)
# -----------------------------------------------
# We compare the Forecast (Mean + Uncertainty) against the Extended Truth.

if final_forecast_traj is not None:
    
    # Calculate Forecast Statistics (Mean & Std Dev)
    ens_mean = np.mean(final_forecast_traj, axis=0) 
    ens_std  = np.std(final_forecast_traj, axis=0)

    # Create 2x2 grid for Angles and Momenta
    fig, axes = plt.subplots(2, 2, figsize=(14, 10), sharex=True)
    
    plot_config = [
        (0, 0, r"$\theta_1$ [rad]", "Angle 1"),
        (0, 1, r"$\theta_2$ [rad]", "Angle 2"),
        (1, 0, r"$p_1$", "Momentum 1"),
        (1, 1, r"$p_2$", "Momentum 2")
    ]
    
    for row, col, ylabel, title in plot_config:
        state_idx = row * 2 + col
        ax = axes[row, col]
        
        # 1. Plot WHOLE True Path (Context + Future)
        ax.plot(t_true_full, y_true_full[state_idx], 'k-', lw=1.5, label='True Path')
        
        # 2. Plot Observations (Only for Angles, before forecast)
        if row == 0: 
            ax.errorbar(
                obs_times, obs_vals[:, col], 
                yerr=obs_sigma, 
                fmt='rx', capsize=3, label='Observations'
            )
        
        # 3. Plot Forecast Mean (In forecast range)
        ax.plot(t_forecast, ens_mean[state_idx], 'g--', lw=2, label='EnKF Forecast')
        
        # 4. Plot Forecast Uncertainty (Shaded)
        ax.fill_between(
            t_forecast, 
            ens_mean[state_idx] - 2*ens_std[state_idx], 
            ens_mean[state_idx] + 2*ens_std[state_idx], 
            color='green', alpha=0.2, label='Forecast Uncertainty ($2\sigma$)'
        )
        
        # Vertical line dividing Assimilation vs Forecast
        ax.axvline(t_last_obs, color='k', linestyle=':', lw=1.5)
        
        ax.set_ylabel(ylabel)
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        
        # Legend only on first plot
        if row == 0 and col == 0:
            ax.legend(loc='lower left', fontsize='small')

    axes[1, 0].set_xlabel("Time [s]")
    axes[1, 1].set_xlabel("Time [s]")
    
    plt.suptitle("EnKF Performance: Truth vs Forecast", y=1.02, fontsize=14)
    plt.tight_layout()
    plt.show()