# Log Plotting with Forward Horizon and GP Visualization
This notebook combines:
- Forward horizon plotting from log files
- Gaussian Process visualization for wind estimation over time

In [259]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle
from IPython.display import HTML
import sys
from matplotlib.ticker import MultipleLocator


# Add path to GP modules
sys.path.insert(0, '/home/egmc/ws_px4_rta_mm_gpr/src/px4_rta_mm_gpr/px4_rta_mm_gpr/jax_mm_rta')
import jax.numpy as jnp
from TVGPR import TVGPR  # Time-Varying Gaussian Process Regression

In [260]:
# Configure matplotlib for animation in Jupyter
%matplotlib notebook

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica", "Arial"],
    "font.size": 14,
    "axes.labelsize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12
})

In [261]:
# Load the log file
log_file = 'log_files/12Y10Z__WIND_EST_ON__WIND_LIN_ON.log'
data = pd.read_csv(log_file)

print("Columns:", data.columns.tolist())
print(f"Total rows: {len(data)}")
print(f"\nFirst few rows:")
print(data.head())

Columns: ['time', 'x', 'y', 'z', 'yaw', 'ctrl_comp_time', 'rollout_comptime', 'y_ref', 'z_ref', 'yaw_ref', 'throttle', 'roll_rate', 'pitch_rate', 'yaw_rate', 'save_tube_pyH', 'save_tube_pyL', 'save_tube_pzH', 'save_tube_pzL', 'wy', 'wz']
Total rows: 14304

First few rows:
        time         x         y          z       yaw  ctrl_comp_time  \
0  15.003586 -0.011672  4.114041 -12.502148  0.000022        0.157984   
1  15.578735 -0.014770  4.076770 -12.528515  0.002012        0.168496   
2  15.672786 -0.015628  4.061078 -12.533593  0.002420        0.002579   
3  15.680060 -0.015628  4.061078 -12.533593  0.002420        0.002847   
4  15.689127 -0.015439  4.061344 -12.533191  0.002416        0.004645   

   rollout_comptime     y_ref      z_ref   yaw_ref  throttle  roll_rate  \
0          0.099910  4.102231 -12.476642 -0.355815  0.592726  -1.000000   
1          0.073061  4.097934 -12.466997 -0.375815  0.592204  -0.935647   
2          0.072788  4.092911 -12.455792 -0.395815  0.598690  -

In [262]:
# Number of forward-looking predictions per time step
num_horizon = 12

# Calculate maximum time step we can animate
max_timestep = len(data) // num_horizon
print(f"Total rows in data: {len(data)}")
print(f"Maximum animatable timestep: {max_timestep} (can animate n=0 to n={max_timestep-1})")

# Extract raw data - ALL rows contain actual state data
time_all = data['time'].values
y_all = data['y'].values  
z_all = -data['z'].values  # FLIP Z to make it positive up
yaw_all = data['yaw'].values

# Extract wind data
wy_all = data['wy'].values
wz_all = data['wz'].values

# Reference and tube data - organized by timestep (12 rows per timestep)
y_ref_all = data['y_ref'].values
z_ref_all = -data['z_ref'].values  # FLIP Z

# Tube bounds - FLIP Z and swap high/low
pyH_all = data['save_tube_pyH'].values
pyL_all = data['save_tube_pyL'].values  
pzH_all = -data['save_tube_pzL'].values  # FLIP and SWAP
pzL_all = -data['save_tube_pzH'].values  # FLIP and SWAP

Total rows in data: 14304
Maximum animatable timestep: 1192 (can animate n=0 to n=1191)


In [263]:
# Helper function to draw a planar quadrotor
def quadplot(ypos, zpos, rot, scale):
    """Draw a 2D quadrotor shape at given position and orientation"""
    ts = -0.5 * np.arange(np.pi, 3/2 * np.pi, 0.2) + 0.3
    xtemp = np.cos(ts)
    ytemp = np.sin(ts) * np.cos(ts)

    xs = np.hstack((0.4*xtemp - 1, -1, -1, 1, 1, 0.4*xtemp + 1))
    ys = np.hstack((0.3*ytemp + 0.4, 0.4, 0, 0, 0.4, 0.3*ytemp + 0.4))
    xs = scale * xs
    ys = scale * ys
    
    rotmat = np.array([[np.cos(rot), -np.sin(rot)], 
                       [np.sin(rot), np.cos(rot)]])
    newpos = rotmat @ np.vstack((xs, ys))
    xs = newpos[0, :] + ypos
    ys = newpos[1, :] + zpos
    
    return xs, ys


# Build Time-Varying Gaussian Process models for wind forces
# wy as a function of (time, height): input=[time, z], output=wy  
# wz as a function of (time, y position): input=[time, y], output=wz

# Number of initial observations to seed the TVGPR
n_init_obs = 20
init_indices = np.arange(0, min(n_init_obs, max_timestep))

# Filter out NaN values for initial observations
valid_mask_wy = ~(np.isnan(time_all[init_indices]) | np.isnan(z_all[init_indices]) | np.isnan(wy_all[init_indices]))
valid_mask_wz = ~(np.isnan(time_all[init_indices]) | np.isnan(y_all[init_indices]) | np.isnan(wz_all[init_indices]))

# Build initial observation arrays for TVGPR: [time, position, output]
# For wy(t, z): time and z are inputs, wy is output
init_obs_wy = np.column_stack((time_all[init_indices][valid_mask_wy],
                               z_all[init_indices][valid_mask_wy], 
                               wy_all[init_indices][valid_mask_wy]))

# For wz(t, y): time and y are inputs, wz is output  
init_obs_wz = np.column_stack((time_all[init_indices][valid_mask_wz],
                               y_all[init_indices][valid_mask_wz],
                               wz_all[init_indices][valid_mask_wz]))

print(f"Building TVGPR for wy(t,z) with {len(init_obs_wy)} initial observations")
print(f"Building TVGPR for wz(t,y) with {len(init_obs_wz)} initial observations")

# Create initial TVGPR models (will be updated during animation)
tvgp_wy_init = TVGPR(jnp.array(init_obs_wy), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
tvgp_wz_init = TVGPR(jnp.array(init_obs_wz), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)

print("Time-Varying GP models created successfully!")

Building TVGPR for wy(t,z) with 20 initial observations
Building TVGPR for wz(t,y) with 20 initial observations
Time-Varying GP models created successfully!


In [None]:
# ====== CONTROL PARAMETERS ======
SHOW_WIND_Y = True   # Show gray arrows for wind in Y direction
SHOW_WIND_Z = True   # Show orange arrows for wind in Z direction
WIND_Y_SCALE = 15    # Scale for Y-direction arrows (smaller = larger arrows)
WIND_Z_SCALE = 20    # Scale for Z-direction arrows (smaller = larger arrows)
# ================================

# Set up the figure with 3 subplots
fig = plt.figure(figsize=(8, 4))
ax = plt.subplot(1, 3, 1)  # Main trajectory plot
ax2 = plt.subplot(1, 3, 2)  # GP for wy
ax3 = plt.subplot(1, 3, 3)  # GP for wz

FONTSIZE = 12
LINEWIDTH = 1.0
QUADSIZE = 0.2
OBSSIZE = 40

# Set axis limits based on data (handle NaNs)
y_valid = y_all[~np.isnan(y_all)]
z_valid = z_all[~np.isnan(z_all)]
y_margin = (np.max(y_valid) - np.min(y_valid)) * 0.15
z_margin = (np.max(z_valid) - np.min(z_valid)) * 0.15
y_min, y_max = np.min(y_valid) - y_margin, np.max(y_valid) + y_margin
z_min, z_max = np.min(z_valid) - z_margin, np.max(z_valid) + z_margin

# ===== MAIN PLOT (ax) =====
# Create wind arrow grid
n_arrows_y = 8
n_arrows_z = 10
arrow_y_coords = np.linspace(y_min, y_max, n_arrows_y)
arrow_z_coords = np.linspace(z_min, z_max, n_arrows_z)
arrow_grid_y, arrow_grid_z = np.meshgrid(arrow_y_coords, arrow_z_coords)
arrow_points = np.column_stack((arrow_grid_y.ravel(), arrow_grid_z.ravel()))

# Calculate initial wind directions
t_init = time_all[0]
arrow_dirs_y = np.zeros_like(arrow_points)
arrow_dirs_z = np.zeros_like(arrow_points)

for i, (y_pt, z_pt) in enumerate(arrow_points):
    if SHOW_WIND_Y:
        wy_pred = float(tvgp_wy_init.mean(jnp.array([t_init, z_pt]))[0][0])
        arrow_dirs_y[i, :] = [wy_pred, 0.0]
    
    if SHOW_WIND_Z:
        wz_pred = float(tvgp_wz_init.mean(jnp.array([t_init, y_pt]))[0][0])
        arrow_dirs_z[i, :] = [0.0, wz_pred]

# Plot wind arrows
if SHOW_WIND_Y:
    wind_arrows_y = ax.quiver(arrow_points[:, 0], arrow_points[:, 1], 
                              arrow_dirs_y[:, 0], arrow_dirs_y[:, 1],
                              color='tab:gray', alpha=0.5, label='Wind Y', scale=WIND_Y_SCALE)
else:
    wind_arrows_y = None

if SHOW_WIND_Z:
    wind_arrows_z = ax.quiver(arrow_points[:, 0], arrow_points[:, 1],
                              arrow_dirs_z[:, 0], arrow_dirs_z[:, 1], 
                              color='tab:orange', alpha=0.5, label='Wind Z', scale=WIND_Z_SCALE)
else:
    wind_arrows_z = None

# Initialize plot elements
actual_line, = ax.plot([], [], label='Actual Trajectory', 
                       color='blue', linewidth=LINEWIDTH)
reference_horizon, = ax.plot([], [], label=f'Reference Horizon ({num_horizon}-step)', 
                             linestyle='dashed', color='black', 
                             linewidth=LINEWIDTH, alpha=0.8, 
                             marker='o', markersize=3)

# Reachable tube rectangles
tube_rects = []
for i in range(num_horizon):
    rect = Rectangle((0, 0), 1, 1, linewidth=0.8, 
                    edgecolor='red', facecolor='red', alpha=0.15)
    ax.add_patch(rect)
    tube_rects.append(rect)
    if i == 0:
        rect.set_label('Reachable Tube Horizon')

# Quadrotor representation
quad_line, = ax.plot([], [], color='black', linewidth=LINEWIDTH)
current_pos, = ax.plot([], [], 'ro', markersize=5, label='Current Position')

ax.set_xlabel('y', fontsize=FONTSIZE)
ax.set_ylabel('z', fontsize=FONTSIZE)
title_text = ax.set_title(f't = {t_init-t_init:.2f}', fontsize=FONTSIZE)
# ax.legend(loc='upper right', fontsize=FONTSIZE - 4)
ax.grid(True, alpha=0.3)
# ax.set_aspect('equal', adjustable='box')
ax.set_xlim(y_min, y_max)
# ax.set_ylim(z_min, z_max)
ax.set_xlim(-1, 4.5)
ax.xaxis.set_major_locator(MultipleLocator(2))

time_text = ax.text(0.02, 0.02, '', transform=ax.transAxes, 
                    fontsize=FONTSIZE-2, verticalalignment='bottom',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
step_text = ax.text(0.98, 0.02, '', transform=ax.transAxes,
                   fontsize=FONTSIZE-4, verticalalignment='bottom',
                   horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

# ===== GP PLOT FOR WY (ax2) =====
gp_query_points_z = np.linspace(z_min - 1, z_max + 1, 100)
gp_query_wy = np.column_stack((np.full_like(gp_query_points_z, t_init), gp_query_points_z))

mean_pts_wy = jnp.array([tvgp_wy_init.mean(jnp.array(gpqp)) for gpqp in gp_query_wy]).reshape(-1)
sig_pts_wy = 3 * jnp.array([jnp.sqrt(tvgp_wy_init.variance(jnp.array(gpqp))) for gpqp in gp_query_wy]).reshape(-1)

observations_wy = ax2.scatter(init_obs_wy[:, 1], init_obs_wy[:, 2], 
                              label='Observations', c=init_obs_wy[:, 0], 
                              cmap='viridis', s=OBSSIZE, edgecolors='black', zorder=10)
mean_line_wy, = ax2.plot(gp_query_points_z, mean_pts_wy, 
                         label='GP Mean', color='tab:blue', linewidth=LINEWIDTH)
upper_conf_wy, = ax2.plot(gp_query_points_z, mean_pts_wy + sig_pts_wy, 
                          label='3σ Bounds', color='lightblue', linewidth=LINEWIDTH)
lower_conf_wy, = ax2.plot(gp_query_points_z, mean_pts_wy - sig_pts_wy, 
                          color='lightblue', linewidth=LINEWIDTH)

ax2.set_xlabel('z', fontsize=FONTSIZE)
ax2.set_ylabel('Wind Force ($w_y$)', fontsize=FONTSIZE)
ax2.set_title('Disturbance Behavior (y)', fontsize=FONTSIZE)
# ax2.legend(loc='upper right', fontsize=FONTSIZE - 4)
ax2.grid(True, alpha=0.3)
# ax2.set_xlim(z_min - 1, z_max + 1)
ax2.set_xlim(0, 14)
ax2.xaxis.set_major_locator(MultipleLocator(2))

# ===== GP PLOT FOR WZ (ax3) =====
gp_query_points_y = np.linspace(y_min - 1, y_max + 1, 100)
gp_query_wz = np.column_stack((np.full_like(gp_query_points_y, t_init), gp_query_points_y))

mean_pts_wz = jnp.array([tvgp_wz_init.mean(jnp.array(gpqp)) for gpqp in gp_query_wz]).reshape(-1)
sig_pts_wz = 3 * jnp.array([jnp.sqrt(tvgp_wz_init.variance(jnp.array(gpqp))) for gpqp in gp_query_wz]).reshape(-1)

observations_wz = ax3.scatter(init_obs_wz[:, 1], init_obs_wz[:, 2], 
                              label='Observations', c=init_obs_wz[:, 0], 
                              cmap='viridis', s=OBSSIZE, edgecolors='black', zorder=10)
mean_line_wz, = ax3.plot(gp_query_points_y, mean_pts_wz, 
                         label='GP Mean', color='tab:blue', linewidth=LINEWIDTH)
upper_conf_wz, = ax3.plot(gp_query_points_y, mean_pts_wz + sig_pts_wz, 
                          label='3σ Bounds', color='lightblue', linewidth=LINEWIDTH)
lower_conf_wz, = ax3.plot(gp_query_points_y, mean_pts_wz - sig_pts_wz, 
                          color='lightblue', linewidth=LINEWIDTH)

ax3.set_xlabel('y', fontsize=FONTSIZE)
ax3.set_ylabel('Wind Force ($w_z$)', fontsize=FONTSIZE)
ax3.set_title('Disturbance Behavior (z)', fontsize=FONTSIZE)
# ax3.legend(loc='upper right', fontsize=FONTSIZE - 4)
ax3.grid(True, alpha=0.3)
# ax3.set_xlim(y_min - 1, y_max + 1)
ax3.set_xlim(-2, 6)
ax3.xaxis.set_major_locator(MultipleLocator(2))
plt.tight_layout()

print(f"Figure initialized with 3 subplots")

<IPython.core.display.Javascript object>

In [None]:
# Animation update function with GP visualization
obs_window_size = 50  # Number of recent observations to keep in the GP

def update(n):
    """Update animation at time step n"""
    global tvgp_wy_init, tvgp_wz_init
    
    # Build return list
    return_list = [actual_line, reference_horizon, quad_line, current_pos, 
                   time_text, step_text, title_text, mean_line_wy, upper_conf_wy, lower_conf_wy,
                   observations_wy, mean_line_wz, upper_conf_wz, lower_conf_wz, observations_wz]
    if SHOW_WIND_Y and wind_arrows_y is not None:
        return_list.append(wind_arrows_y)
    if SHOW_WIND_Z and wind_arrows_z is not None:
        return_list.append(wind_arrows_z)
    return_list.extend(tube_rects)
    
    # Skip if beyond valid range
    if n >= max_timestep or np.isnan(time_all[n]):
        return return_list
    
    # 1. Update actual trajectory
    actual_line.set_data(y_all[:n+1], z_all[:n+1])
    
    # 2. Get current position and time
    t_curr = time_all[n]
    y_curr = y_all[n]
    z_curr = z_all[n]
    yaw_curr = yaw_all[n]
    
    # 3. Update TVGPR with observations collected so far (sliding window)
    start_idx = max(0, n - obs_window_size)
    obs_indices = np.arange(start_idx, n + 1)
    
    # Filter out NaN values
    valid_mask_wy = ~(np.isnan(time_all[obs_indices]) | np.isnan(z_all[obs_indices]) | np.isnan(wy_all[obs_indices]))
    valid_mask_wz = ~(np.isnan(time_all[obs_indices]) | np.isnan(y_all[obs_indices]) | np.isnan(wz_all[obs_indices]))
    
    if SHOW_WIND_Y and np.sum(valid_mask_wy) > 5:
        current_obs_wy = np.column_stack((time_all[obs_indices][valid_mask_wy],
                                         z_all[obs_indices][valid_mask_wy],
                                         wy_all[obs_indices][valid_mask_wy]))
        tvgp_wy_init = TVGPR(jnp.array(current_obs_wy), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
    
    if SHOW_WIND_Z and np.sum(valid_mask_wz) > 5:
        current_obs_wz = np.column_stack((time_all[obs_indices][valid_mask_wz],
                                         y_all[obs_indices][valid_mask_wz],
                                         wz_all[obs_indices][valid_mask_wz]))
        tvgp_wz_init = TVGPR(jnp.array(current_obs_wz), sigma_f=5.0, l=2.0, sigma_n=0.01, epsilon=0.25)
    
    # 4. Update wind arrows on main plot
    if SHOW_WIND_Y and wind_arrows_y is not None:
        arrow_dirs_y_new = np.zeros_like(arrow_points)
        for i, (y_pt, z_pt) in enumerate(arrow_points):
            wy_pred = float(tvgp_wy_init.mean(jnp.array([t_curr, z_pt]))[0][0])
            arrow_dirs_y_new[i, :] = [wy_pred, 0.0]
        wind_arrows_y.set_UVC(arrow_dirs_y_new[:, 0], arrow_dirs_y_new[:, 1])
    
    if SHOW_WIND_Z and wind_arrows_z is not None:
        arrow_dirs_z_new = np.zeros_like(arrow_points)
        for i, (y_pt, z_pt) in enumerate(arrow_points):
            wz_pred = float(tvgp_wz_init.mean(jnp.array([t_curr, y_pt]))[0][0])
            arrow_dirs_z_new[i, :] = [0.0, wz_pred]
        wind_arrows_z.set_UVC(arrow_dirs_z_new[:, 0], arrow_dirs_z_new[:, 1])
    
    # 5. Update reference horizon and tubes
    ref_start = num_horizon * n
    ref_end = ref_start + num_horizon
    
    if ref_end <= len(data):
        y_refs = y_ref_all[ref_start:ref_end]
        z_refs = z_ref_all[ref_start:ref_end]
        pyH_refs = pyH_all[ref_start:ref_end]
        pyL_refs = pyL_all[ref_start:ref_end]
        pzH_refs = pzH_all[ref_start:ref_end]
        pzL_refs = pzL_all[ref_start:ref_end]
        
        reference_horizon.set_data(y_refs, z_refs)
        
        for i in range(num_horizon):
            if not (np.isnan(pyH_refs[i]) or np.isnan(pyL_refs[i]) or 
                   np.isnan(pzH_refs[i]) or np.isnan(pzL_refs[i])):
                tube_width = pyH_refs[i] - pyL_refs[i]
                tube_height = pzH_refs[i] - pzL_refs[i]
                
                if tube_width > 0 and tube_height > 0 and tube_width < 50 and tube_height < 50:
                    tube_rects[i].set_xy((pyL_refs[i], pzL_refs[i]))
                    tube_rects[i].set_width(tube_width)
                    tube_rects[i].set_height(tube_height)
                    tube_rects[i].set_visible(True)
                else:
                    tube_rects[i].set_visible(False)
            else:
                tube_rects[i].set_visible(False)
    else:
        reference_horizon.set_data([], [])
        for rect in tube_rects:
            rect.set_visible(False)
    
    # 6. Update quadrotor
    if not (np.isnan(y_curr) or np.isnan(z_curr)):
        quad_y, quad_z = quadplot(y_curr, z_curr, yaw_curr, QUADSIZE)
        quad_line.set_data(quad_y, quad_z)
        quad_line.set_visible(True)
        current_pos.set_data([y_curr], [z_curr])
        current_pos.set_visible(True)
    else:
        quad_line.set_visible(False)
        current_pos.set_visible(False)
    
    # 7. Update text displays
    title_text.set_text(f't = {t_curr - t_init:.2f}')
    
    # 8. Save PDF if enabled
    if SAVE_FRAMES_AS_PDF:
        filename = f'{PDF_OUTPUT_DIR}/frame_{n:04d}_t_{t_curr:.2f}.pdf'
        fig.savefig(filename, format='pdf', bbox_inches='tight', dpi=300)
        if n % 10 == 0:  # Print progress every 10 frames
            print(f'Saved frame {n}: {filename}')
    
    # 9. Update GP plots for wy (ax2)
    if SHOW_WIND_Y and np.sum(valid_mask_wy) > 5:
        gp_query_wy_current = np.column_stack((np.full_like(gp_query_points_z, t_curr), gp_query_points_z))
        mean_pts_wy = jnp.array([tvgp_wy_init.mean(jnp.array(gpqp)) for gpqp in gp_query_wy_current]).reshape(-1)
        sig_pts_wy = 3 * jnp.array([jnp.sqrt(tvgp_wy_init.variance(jnp.array(gpqp))) for gpqp in gp_query_wy_current]).reshape(-1)
        
        mean_line_wy.set_data(gp_query_points_z, mean_pts_wy)
        upper_conf_wy.set_data(gp_query_points_z, mean_pts_wy + sig_pts_wy)
        lower_conf_wy.set_data(gp_query_points_z, mean_pts_wy - sig_pts_wy)
        
        # Update observation scatter with time-based sizing
        observations_wy.set_offsets(current_obs_wy[:, 1:])
        time_weights = np.exp(2 * (current_obs_wy[:, 0] - t_curr))
        observations_wy.set_sizes(OBSSIZE * time_weights)
        observations_wy.set_array(current_obs_wy[:, 0])
    
    # 10. Update GP plots for wz (ax3)
    if SHOW_WIND_Z and np.sum(valid_mask_wz) > 5:
        gp_query_wz_current = np.column_stack((np.full_like(gp_query_points_y, t_curr), gp_query_points_y))
        mean_pts_wz = jnp.array([tvgp_wz_init.mean(jnp.array(gpqp)) for gpqp in gp_query_wz_current]).reshape(-1)
        sig_pts_wz = 3 * jnp.array([jnp.sqrt(tvgp_wz_init.variance(jnp.array(gpqp))) for gpqp in gp_query_wz_current]).reshape(-1)
        
        mean_line_wz.set_data(gp_query_points_y, mean_pts_wz)
        upper_conf_wz.set_data(gp_query_points_y, mean_pts_wz + sig_pts_wz)
        lower_conf_wz.set_data(gp_query_points_y, mean_pts_wz - sig_pts_wz)
        
        # Update observation scatter with time-based sizing
        observations_wz.set_offsets(current_obs_wz[:, 1:])
        time_weights = np.exp(2 * (current_obs_wz[:, 0] - t_curr))
        observations_wz.set_sizes(OBSSIZE * time_weights)
        observations_wz.set_array(current_obs_wz[:, 0])
    
    return return_list

In [None]:
# ====== ANIMATION CONTROL PARAMETERS ======
FRAME_SKIP = 10  # Higher values = faster rendering but choppier animation

# PDF saving options
SAVE_FRAMES_AS_PDF = True  # Set to True to save each frame as PDF
PDF_OUTPUT_DIR = 'saved_frames'  # Directory to save PDFs
# ==========================================

# Calculate time steps
time_valid = time_all[:max_timestep]
time_valid = time_valid[~np.isnan(time_valid)]

if len(time_valid) > 1:
    time_diffs = np.diff(time_valid)
    valid_diffs = time_diffs[time_diffs > 0]
    if len(valid_diffs) > 0:
        avg_dt = np.median(valid_diffs)
    else:
        avg_dt = 0.1
else:
    avg_dt = 0.1

print(f"Average time between steps: {avg_dt:.4f} seconds")
print(f"Frame skip setting: {FRAME_SKIP}")

# Create frame list
frames_to_use = list(range(0, max_timestep, FRAME_SKIP))
interval_ms = 50  # 50ms between frames = 20 fps

print(f"Total timesteps available: {max_timestep}")
print(f"Total frames to animate: {len(frames_to_use)}")
print(f"Animation interval: {interval_ms} ms")
print(f"Estimated animation duration: {len(frames_to_use) * interval_ms / 1000:.1f} seconds")

if SAVE_FRAMES_AS_PDF:
    import os
    os.makedirs(PDF_OUTPUT_DIR, exist_ok=True)
    print(f"\n*** PDF SAVING ENABLED ***")
    print(f"Will save {len(frames_to_use)} PDFs to '{PDF_OUTPUT_DIR}/' directory")
    print(f"This will take significantly longer than normal animation generation.")

# Create the animation
print("\nCreating animation...")
ani = animation.FuncAnimation(fig, update, frames=frames_to_use,
                             interval=interval_ms, blit=True, repeat=True)

# Display as HTML
HTML(ani.to_jshtml())

Average time between steps: 0.0100 seconds
Frame skip setting: 10
Total timesteps available: 1192
Total frames to animate: 120
Animation interval: 50 ms
Estimated animation duration: 6.0 seconds

*** PDF SAVING ENABLED ***
Will save 120 PDFs to 'saved_frames/' directory
This will take significantly longer than normal animation generation.

Creating animation...
Saved frame 0: saved_frames/frame_0000_t_15.00.pdf
Saved frame 0: saved_frames/frame_0000_t_15.00.pdf
Saved frame 10: saved_frames/frame_0010_t_15.77.pdf
Saved frame 20: saved_frames/frame_0020_t_15.96.pdf
Saved frame 30: saved_frames/frame_0030_t_16.26.pdf
Saved frame 40: saved_frames/frame_0040_t_16.46.pdf
Saved frame 50: saved_frames/frame_0050_t_16.76.pdf
Saved frame 60: saved_frames/frame_0060_t_16.97.pdf
Saved frame 70: saved_frames/frame_0070_t_17.17.pdf
Saved frame 80: saved_frames/frame_0080_t_17.34.pdf
Saved frame 90: saved_frames/frame_0090_t_17.51.pdf
Saved frame 100: saved_frames/frame_0100_t_17.71.pdf
Saved frame 1

In [None]:
# OPTIONAL: Save the animation as a GIF
# Uncomment and run this cell only if you want to save the animation

# animation_name = 'log_plot_with_gp.gif'

# print("Saving animation as GIF (this may take a few minutes)...")
# ani.save(f'{animation_name}', writer='pillow', fps=10)
# print(f"Animation saved as '{animation_name}'!")

In [None]:
# # OPTIONAL: Save the animation as a GIF
# # Uncomment and run this cell only if you want to save the animation

# animation_name = 'log_plot_with_gp.gif'

# print("Saving animation as GIF (this may take a few minutes)...")
# ani.save(f'{animation_name}', writer='pillow', fps=10)
# print(f"Animation saved as '{animation_name}'!")